From 4dc083d52b988943dbc0f3681891cbd757431378 Mon Sep 17 00:00:00 2001 From: dapeng Date: Thu, 2 Jan 2025 20:26:13 +0800 Subject: [PATCH] feat: update code of process chan --- goner/gin/responer.go | 46 ++++++++++++++++++++++++------------------- goner/gin/sse.go | 6 +----- goner/gin/sse_test.go | 1 - 3 files changed, 27 insertions(+), 26 deletions(-) diff --git a/goner/gin/responer.go b/goner/gin/responer.go index c5f4a55..94669c3 100644 --- a/goner/gin/responer.go +++ b/goner/gin/responer.go @@ -139,6 +139,13 @@ func (r *responser) ProcessResults(context XContext, writer gin.ResponseWriter, return } + of := reflect.TypeOf(result) + if of.Kind() == reflect.Chan { + isNotEnd = true + r.dealChan(result, writer) + return + } + switch result.(type) { case error: r.Failed(context, result.(error)) @@ -148,9 +155,9 @@ func (r *responser) ProcessResults(context XContext, writer gin.ResponseWriter, if err != nil { r.Warnf("copy data to writer failed, err: %v", err) } - case chan any: - isNotEnd = true - r.dealChan(result.(chan any), writer) + //case chan any: + // isNotEnd = true + // r.dealChan(result.(chan any), writer) default: r.Success(context, result) } @@ -161,31 +168,30 @@ func (r *responser) ProcessResults(context XContext, writer gin.ResponseWriter, } } -func (r *responser) dealChan(ch <-chan any, writer gin.ResponseWriter) { +func (r *responser) dealChan(ch any, writer gin.ResponseWriter) { sse := NewSSE(writer) sse.Start() - for { - data, ok := <-ch + of := reflect.ValueOf(ch) - if !ok { + for { + if data, ok := of.Recv(); !ok { err := sse.End() if err != nil { r.Errorf("write 'end' error: %v", err) } - return - } - var err error - switch data.(type) { - case error: - err = sse.WriteError(ToError(data.(error))) - default: - err = sse.Write(data) - } - - if err != nil { - r.Errorf("write data error: %v", err) - return + break + } else { + var err error + i := data.Interface() + if e, y := i.(error); y { + err = sse.WriteError(ToError(e)) + } else { + err = sse.Write(i) + } + if err != nil { + r.Errorf("write data error: %v", err) + } } } } diff --git a/goner/gin/sse.go b/goner/gin/sse.go index fdfc042..bfd7372 100644 --- a/goner/gin/sse.go +++ b/goner/gin/sse.go @@ -37,10 +37,6 @@ func (s *Sse) Write(delta any) error { return err } - _, err = io.WriteString(s.Writer, "event: data\n") - if err != nil { - return err - } _, err = io.WriteString(s.Writer, fmt.Sprintf("data: %s\n\n", jsonStr)) if err != nil { return err @@ -50,7 +46,7 @@ func (s *Sse) Write(delta any) error { } func (s *Sse) End() error { - _, err := io.WriteString(s.Writer, "event: done\n") + _, err := io.WriteString(s.Writer, "event: done\ndata: \n\ndata: [DONE]\n") if err != nil { return err } diff --git a/goner/gin/sse_test.go b/goner/gin/sse_test.go index beebd5b..270b9a3 100644 --- a/goner/gin/sse_test.go +++ b/goner/gin/sse_test.go @@ -81,7 +81,6 @@ func TestSSE(t *testing.T) { writer := NewMockResponseWriter(controller) writer.EXPECT().Header().Return(http.Header{}).AnyTimes() writer.EXPECT().Flush().AnyTimes() - writer.EXPECT().WriteString(gomock.Any()).Return(100, nil) writer.EXPECT().WriteString(gomock.Any()).Return(0, errors.New("error")) sse := NewSSE(writer)