diff --git a/cli/serve.go b/cli/serve.go index 36d988f5b..eb3559908 100644 --- a/cli/serve.go +++ b/cli/serve.go @@ -22,8 +22,10 @@ import ( "github.com/spf13/cobra" "github.com/yomorun/yomo" + "github.com/yomorun/yomo/core/auth" "github.com/yomorun/yomo/core/ylog" pkgconfig "github.com/yomorun/yomo/pkg/config" + "github.com/yomorun/yomo/pkg/listener/mem" "github.com/yomorun/yomo/pkg/log" "github.com/yomorun/yomo/pkg/trace" @@ -68,6 +70,9 @@ var serveCmd = &cobra.Command{ // listening address. listenAddr := fmt.Sprintf("%s:%d", conf.Host, conf.Port) + // memory listener + var listener *mem.Listener + options := []yomo.ZipperOption{} tokenString := "" if _, ok := conf.Auth["type"]; ok { @@ -75,8 +80,8 @@ var serveCmd = &cobra.Command{ options = append(options, yomo.WithAuth("token", tokenString)) } } - // check llm bridge server config - // parse the llm bridge config + + // check and parse the llm bridge server config bridgeConf := conf.Bridge aiConfig, err := ai.ParseConfig(bridgeConf) if err != nil { @@ -88,8 +93,10 @@ var serveCmd = &cobra.Command{ } } if aiConfig != nil { + listener = mem.Listen() // add AI connection middleware options = append(options, yomo.WithZipperConnMiddleware(ai.RegisterFunctionMW())) + options = append(options, yomo.WithFrameListener(listener)) } // new zipper zipper, err := yomo.NewZipper( @@ -108,7 +115,13 @@ var serveCmd = &cobra.Command{ registerAIProvider(aiConfig) // start the llm api server go func() { - err := ai.Serve(aiConfig, listenAddr, fmt.Sprintf("token:%s", tokenString), ylog.Default()) + conn, _ := listener.Dial() + source := ai.NewSource(conn, auth.NewCredential(fmt.Sprintf("token:%s", tokenString))) + + conn2, _ := listener.Dial() + reducer := ai.NewReducer(conn2, auth.NewCredential(fmt.Sprintf("token:%s", tokenString))) + + err := ai.Serve(aiConfig, ylog.Default(), source, reducer) if err != nil { log.FailureStatusEvent(os.Stdout, err.Error()) return diff --git a/cli/serverless/golang/serverless.go b/cli/serverless/golang/serverless.go index 5da589f26..e06d1e482 100644 --- a/cli/serverless/golang/serverless.go +++ b/cli/serverless/golang/serverless.go @@ -127,7 +127,6 @@ func (s *GolangServerless) Init(opts *serverless.Options) error { // Build compiles the serverless to executable func (s *GolangServerless) Build(clean bool) error { - log.PendingStatusEvent(os.Stdout, "Building YoMo Stream Function instance...") // check if the file exists appPath := s.source if _, err := os.Stat(appPath); os.IsNotExist(err) { @@ -203,7 +202,6 @@ func (s *GolangServerless) Build(clean bool) error { if clean { file.Remove(s.tempDir) } - log.SuccessStatusEvent(os.Stdout, "YoMo Stream Function build successful!") return nil } diff --git a/core/server.go b/core/server.go index 2942f8e25..35244bbea 100644 --- a/core/server.go +++ b/core/server.go @@ -143,20 +143,32 @@ func (s *Server) Serve(ctx context.Context, conn net.PacketConn) error { defer closeServer(s.downstreams, s.connector, s.listener, s.router) - errCount := 0 - for { - fconn, err := s.listener.Accept(s.ctx) - if err != nil { - if err == s.ctx.Err() { - return ErrServerClosed + listeners := append(s.opts.listeners, s.listener) + + var wg sync.WaitGroup + for _, l := range listeners { + wg.Add(1) + go func(l frame.Listener) { + errCount := 0 + for { + fconn, err := l.Accept(s.ctx) + if err != nil { + if err == s.ctx.Err() { + wg.Done() + return + } + errCount++ + s.logger.Error("accepted an error when accepting a connection", "err", err, "err_count", errCount) + continue + } + + go s.handleFrameConn(fconn, s.logger) } - errCount++ - s.logger.Error("accepted an error when accepting a connection", "err", err, "err_count", errCount) - continue - } - - go s.handleFrameConn(fconn, s.logger) + }(l) } + + wg.Wait() + return ErrServerClosed } func (s *Server) handleFrameConn(fconn frame.Conn, logger *slog.Logger) { @@ -380,7 +392,11 @@ func (s *Server) routingDataFrame(c *Context) error { // dispatch every DataFrames to all downstreams func (s *Server) dispatchToDownstreams(c *Context) error { - dataFrame := c.Frame + dataFrame := &frame.DataFrame{ + Tag: c.Frame.Tag, + Payload: c.Frame.Payload, + Metadata: c.Frame.Metadata, + } if c.Connection.ClientType() == ClientTypeUpstreamZipper { c.Logger.Debug("ignored client", "client_type", c.Connection.ClientType().String()) // loop protection diff --git a/core/server_options.go b/core/server_options.go index efb1da019..de5f64814 100644 --- a/core/server_options.go +++ b/core/server_options.go @@ -7,6 +7,7 @@ import ( "github.com/quic-go/quic-go" "github.com/yomorun/yomo/core/auth" + "github.com/yomorun/yomo/core/frame" "github.com/yomorun/yomo/core/router" "github.com/yomorun/yomo/core/ylog" ) @@ -38,6 +39,7 @@ type serverOptions struct { router router.Router connMiddlewares []ConnMiddleware frameMiddlewares []FrameMiddleware + listeners []frame.Listener } func defaultServerOptions() *serverOptions { @@ -120,3 +122,10 @@ func WithConnMiddleware(mws ...ConnMiddleware) ServerOption { o.connMiddlewares = append(o.connMiddlewares, mws...) } } + +// WithFrameListener adds a Listener other than a quic.Listener. +func WithFrameListener(l ...frame.Listener) ServerOption { + return func(o *serverOptions) { + o.listeners = append(o.listeners, l...) + } +} diff --git a/options.go b/options.go index c9b957b84..d06bd80ba 100644 --- a/options.go +++ b/options.go @@ -6,6 +6,7 @@ import ( "github.com/quic-go/quic-go" "github.com/yomorun/yomo/core" + "github.com/yomorun/yomo/core/frame" "github.com/yomorun/yomo/core/router" ) @@ -147,4 +148,11 @@ var ( o.serverOption = append(o.serverOption, core.WithFrameMiddleware(mw...)) } } + + // WithFrameListener adds a Listener other than a quic.Listener. + WithFrameListener = func(l ...frame.Listener) ZipperOption { + return func(o *zipperOptions) { + o.serverOption = append(o.serverOption, core.WithFrameListener(l...)) + } + } ) diff --git a/pkg/bridge/ai/api_server.go b/pkg/bridge/ai/api_server.go index 52caac829..e1554b98b 100644 --- a/pkg/bridge/ai/api_server.go +++ b/pkg/bridge/ai/api_server.go @@ -13,6 +13,7 @@ import ( "time" openai "github.com/sashabaranov/go-openai" + "github.com/yomorun/yomo" "github.com/yomorun/yomo/ai" "github.com/yomorun/yomo/pkg/bridge/ai/provider" "github.com/yomorun/yomo/pkg/bridge/ai/register" @@ -34,23 +35,20 @@ const ( // BasicAPIServer provides restful service for end user type BasicAPIServer struct { - zipperAddr string - credential string httpHandler http.Handler } // Serve starts the Basic API Server -func Serve(config *Config, zipperListenAddr string, credential string, logger *slog.Logger) error { +func Serve(config *Config, logger *slog.Logger, source yomo.Source, reducer yomo.StreamFunction) error { provider, err := provider.GetProvider(config.Server.Provider) if err != nil { return err } - srv, err := NewBasicAPIServer(config, zipperListenAddr, credential, provider, logger) + srv, err := NewBasicAPIServer(config, provider, source, reducer, logger) if err != nil { return err } - logger.Info("start AI Bridge service", "addr", config.Server.Addr, "provider", provider.Name()) return http.ListenAndServe(config.Server.Addr, srv.httpHandler) } @@ -80,24 +78,23 @@ func DecorateHandler(h http.Handler, decorates ...func(handler http.Handler) htt } // NewBasicAPIServer creates a new restful service -func NewBasicAPIServer(config *Config, zipperAddr, credential string, provider provider.LLMProvider, logger *slog.Logger) (*BasicAPIServer, error) { - zipperAddr = parseZipperAddr(zipperAddr) - +func NewBasicAPIServer(config *Config, provider provider.LLMProvider, source yomo.Source, reducer yomo.StreamFunction, logger *slog.Logger) (*BasicAPIServer, error) { logger = logger.With("service", "llm-bridge") - service := NewService(zipperAddr, provider, &ServiceOptions{ + opts := &ServiceOptions{ Logger: logger, - CredentialFunc: func(r *http.Request) (string, error) { return credential, nil }, - }) + SourceBuilder: func() yomo.Source { return source }, + ReducerBuilder: func() yomo.StreamFunction { return reducer }, + } + service := NewService(provider, opts) mux := NewServeMux(service) server := &BasicAPIServer{ - zipperAddr: zipperAddr, - credential: credential, httpHandler: DecorateHandler(mux, decorateReqContext(service, logger)), } + logger.Info("start AI Bridge service", "addr", config.Server.Addr, "provider", provider.Name()) return server, nil } diff --git a/pkg/bridge/ai/api_server_test.go b/pkg/bridge/ai/api_server_test.go index 7d3ce0bfa..181a2fb62 100644 --- a/pkg/bridge/ai/api_server_test.go +++ b/pkg/bridge/ai/api_server_test.go @@ -46,9 +46,9 @@ func TestServer(t *testing.T) { return mockCaller(nil), err } - service := newService("fake_zipper_addr", pd, newCaller, &ServiceOptions{ - SourceBuilder: func(_, _ string) yomo.Source { return flow }, - ReducerBuilder: func(_, _ string) yomo.StreamFunction { return flow }, + service := newService(pd, newCaller, &ServiceOptions{ + SourceBuilder: func() yomo.Source { return flow }, + ReducerBuilder: func() yomo.StreamFunction { return flow }, MetadataExchanger: func(_ string) (metadata.M, error) { return metadata.M{"hello": "llm bridge"}, nil }, }) diff --git a/pkg/bridge/ai/reducer.go b/pkg/bridge/ai/reducer.go new file mode 100644 index 000000000..3dc97c5bf --- /dev/null +++ b/pkg/bridge/ai/reducer.go @@ -0,0 +1,129 @@ +package ai + +import ( + "github.com/yomorun/yomo" + "github.com/yomorun/yomo/core" + "github.com/yomorun/yomo/core/auth" + "github.com/yomorun/yomo/core/frame" + "github.com/yomorun/yomo/core/metadata" + "github.com/yomorun/yomo/core/serverless" + "github.com/yomorun/yomo/pkg/id" + "github.com/yomorun/yomo/pkg/listener/mem" +) + +var _ yomo.Source = &memSource{} + +type memSource struct { + cred *auth.Credential + conn *mem.FrameConn +} + +func NewSource(conn *mem.FrameConn, cred *auth.Credential) yomo.Source { + return &memSource{ + conn: conn, + cred: cred, + } +} + +func (m *memSource) Connect() error { + hf := &frame.HandshakeFrame{ + Name: "fc-source", + ID: id.New(), + ClientType: byte(core.ClientTypeSource), + AuthName: m.cred.Name(), + AuthPayload: m.cred.Payload(), + Version: core.Version, + } + + return m.conn.Handshake(hf) +} + +func (m *memSource) Write(tag uint32, data []byte) error { + df := &frame.DataFrame{ + Tag: tag, + Payload: data, + } + return m.conn.WriteFrame(df) +} + +func (m *memSource) Close() error { return nil } +func (m *memSource) SetErrorHandler(_ func(_ error)) {} +func (m *memSource) WriteWithTarget(_ uint32, _ []byte, _ string) error { return nil } + +type memStreamFunction struct { + observedTags []uint32 + handler core.AsyncHandler + cred *auth.Credential + conn *mem.FrameConn +} + +// NewReducer creates a new instance of memory StreamFunction. +func NewReducer(conn *mem.FrameConn, cred *auth.Credential) yomo.StreamFunction { + return &memStreamFunction{ + conn: conn, + cred: cred, + } +} + +func (m *memStreamFunction) Close() error { + return nil +} + +func (m *memStreamFunction) Connect() error { + hf := &frame.HandshakeFrame{ + Name: "fc-reducer", + ID: id.New(), + ClientType: byte(core.ClientTypeStreamFunction), + AuthName: m.cred.Name(), + AuthPayload: m.cred.Payload(), + ObserveDataTags: m.observedTags, + Version: core.Version, + } + + if err := m.conn.Handshake(hf); err != nil { + return nil + } + + go func() { + for { + f, err := m.conn.ReadFrame() + if err != nil { + return + } + + switch ff := f.(type) { + case *frame.DataFrame: + go m.onDataFrame(ff) + default: + return + } + } + }() + + return nil +} + +func (m *memStreamFunction) onDataFrame(dataFrame *frame.DataFrame) { + md, err := metadata.Decode(dataFrame.Metadata) + if err != nil { + return + } + + serverlessCtx := serverless.NewContext(m.conn, dataFrame.Tag, md, dataFrame.Payload) + m.handler(serverlessCtx) +} + +func (m *memStreamFunction) SetHandler(fn core.AsyncHandler) error { + m.handler = fn + return nil +} + +func (m *memStreamFunction) Init(_ func() error) error { return nil } +func (m *memStreamFunction) SetCronHandler(_ string, _ core.CronHandler) error { return nil } +func (m *memStreamFunction) SetErrorHandler(_ func(err error)) {} +func (m *memStreamFunction) SetObserveDataTags(tags ...uint32) { m.observedTags = tags } +func (m *memStreamFunction) SetPipeHandler(fn core.PipeHandler) error { return nil } +func (m *memStreamFunction) SetWantedTarget(string) {} +func (m *memStreamFunction) Wait() {} + +var _ yomo.StreamFunction = &memStreamFunction{} diff --git a/pkg/bridge/ai/service.go b/pkg/bridge/ai/service.go index b46834473..638db5835 100644 --- a/pkg/bridge/ai/service.go +++ b/pkg/bridge/ai/service.go @@ -26,7 +26,6 @@ import ( // Service is the service layer for llm bridge server. // service is responsible for handling the logic from handler layer. type Service struct { - zipperAddr string provider provider.LLMProvider newCallerFunc newCallerFunc callers *expirable.LRU[string, *Caller] @@ -47,16 +46,16 @@ type ServiceOptions struct { // CallerCallTimeout is the timeout for awaiting the function response. CallerCallTimeout time.Duration // SourceBuilder should builds an unconnected source. - SourceBuilder func(zipperAddr, credential string) yomo.Source + SourceBuilder func() yomo.Source // ReducerBuilder should builds an unconnected reducer. - ReducerBuilder func(zipperAddr, credential string) yomo.StreamFunction + ReducerBuilder func() yomo.StreamFunction // MetadataExchanger exchanges metadata from the credential. MetadataExchanger func(credential string) (metadata.M, error) } // NewService creates a new service for handling the logic from handler layer. -func NewService(zipperAddr string, provider provider.LLMProvider, opt *ServiceOptions) *Service { - return newService(zipperAddr, provider, NewCaller, opt) +func NewService(provider provider.LLMProvider, opt *ServiceOptions) *Service { + return newService(provider, NewCaller, opt) } func initOption(opt *ServiceOptions) *ServiceOptions { @@ -67,7 +66,7 @@ func initOption(opt *ServiceOptions) *ServiceOptions { opt.Logger = ylog.Default() } if opt.CredentialFunc == nil { - opt.CredentialFunc = func(_ *http.Request) (string, error) { return "", nil } + opt.CredentialFunc = func(_ *http.Request) (string, error) { return "token", nil } } if opt.CallerCacheSize == 0 { opt.CallerCacheSize = 1 @@ -75,22 +74,6 @@ func initOption(opt *ServiceOptions) *ServiceOptions { if opt.CallerCallTimeout == 0 { opt.CallerCallTimeout = 60 * time.Second } - if opt.SourceBuilder == nil { - opt.SourceBuilder = func(zipperAddr, credential string) yomo.Source { - return yomo.NewSource( - "fc-source", - zipperAddr, - yomo.WithSourceReConnect(), yomo.WithCredential(credential)) - } - } - if opt.ReducerBuilder == nil { - opt.ReducerBuilder = func(zipperAddr, credential string) yomo.StreamFunction { - return yomo.NewStreamFunction( - "fc-reducer", - zipperAddr, - yomo.WithSfnReConnect(), yomo.WithSfnCredential(credential), yomo.DisableOtelTrace()) - } - } if opt.MetadataExchanger == nil { opt.MetadataExchanger = func(credential string) (metadata.M, error) { return metadata.New(), nil @@ -100,7 +83,7 @@ func initOption(opt *ServiceOptions) *ServiceOptions { return opt } -func newService(zipperAddr string, provider provider.LLMProvider, ncf newCallerFunc, opt *ServiceOptions) *Service { +func newService(provider provider.LLMProvider, ncf newCallerFunc, opt *ServiceOptions) *Service { var onEvict = func(_ string, caller *Caller) { caller.Close() } @@ -108,7 +91,6 @@ func newService(zipperAddr string, provider provider.LLMProvider, ncf newCallerF opt = initOption(opt) service := &Service{ - zipperAddr: zipperAddr, provider: provider, newCallerFunc: ncf, callers: expirable.NewLRU(opt.CallerCacheSize, onEvict, opt.CallerCacheTTL), @@ -464,8 +446,8 @@ func (srv *Service) loadOrCreateCaller(credential string) (*Caller, error) { return nil, err } caller, err = srv.newCallerFunc( - srv.option.SourceBuilder(srv.zipperAddr, credential), - srv.option.ReducerBuilder(srv.zipperAddr, credential), + srv.option.SourceBuilder(), + srv.option.ReducerBuilder(), md, srv.option.CallerCallTimeout, ) diff --git a/pkg/bridge/ai/service_test.go b/pkg/bridge/ai/service_test.go index e50d106d5..3d763c2b7 100644 --- a/pkg/bridge/ai/service_test.go +++ b/pkg/bridge/ai/service_test.go @@ -220,9 +220,9 @@ func TestServiceInvoke(t *testing.T) { return mockCaller(tt.args.mockCallReqResp), err } - service := newService("fake_zipper_addr", pd, newCaller, &ServiceOptions{ - SourceBuilder: func(_, _ string) yomo.Source { return flow }, - ReducerBuilder: func(_, _ string) yomo.StreamFunction { return flow }, + service := newService(pd, newCaller, &ServiceOptions{ + SourceBuilder: func() yomo.Source { return flow }, + ReducerBuilder: func() yomo.StreamFunction { return flow }, MetadataExchanger: func(_ string) (metadata.M, error) { return metadata.M{"hello": "llm bridge"}, nil }, }) @@ -389,9 +389,9 @@ func TestServiceChatCompletion(t *testing.T) { return mockCaller(tt.args.mockCallReqResp), err } - service := newService("fake_zipper_addr", pd, newCaller, &ServiceOptions{ - SourceBuilder: func(_, _ string) yomo.Source { return flow }, - ReducerBuilder: func(_, _ string) yomo.StreamFunction { return flow }, + service := newService(pd, newCaller, &ServiceOptions{ + SourceBuilder: func() yomo.Source { return flow }, + ReducerBuilder: func() yomo.StreamFunction { return flow }, MetadataExchanger: func(_ string) (metadata.M, error) { return metadata.M{"hello": "llm bridge"}, nil }, }) diff --git a/pkg/listener/mem/mem.go b/pkg/listener/mem/mem.go new file mode 100644 index 000000000..089011b5b --- /dev/null +++ b/pkg/listener/mem/mem.go @@ -0,0 +1,192 @@ +// Package mem provides a memory implementation of yomo.FrameConn. +package mem + +import ( + "context" + "errors" + "fmt" + "net" + + "github.com/yomorun/yomo/core/frame" +) + +// FrameConn is an implements of FrameConn, +// It transmits frames upon the golang channel. +type FrameConn struct { + ctx context.Context + cancel context.CancelCauseFunc + rCh chan frame.Frame + wCh chan frame.Frame +} + +var _ frame.Conn = &FrameConn{} + +// NewFrameConn creates FrameConn from read write channel. +func NewFrameConn(ctx context.Context) *FrameConn { + return newFrameConn(ctx, make(chan frame.Frame), make(chan frame.Frame)) +} + +// newFrameConn creates FrameConn from read write channel. +func newFrameConn(ctx context.Context, rCh, wCh chan frame.Frame) *FrameConn { + ctx, cancel := context.WithCancelCause(ctx) + + conn := &FrameConn{ + ctx: ctx, + cancel: cancel, + rCh: rCh, + wCh: wCh, + } + + return conn +} + +// Handshake sends a HandshakeFrame to the connection. +// This function should be called before ReadFrame or WriteFrame. +func (p *FrameConn) Handshake(hf *frame.HandshakeFrame) error { + if err := p.WriteFrame(hf); err != nil { + return err + } + + first, err := p.ReadFrame() + if err != nil { + return err + } + + switch f := first.(type) { + case *frame.HandshakeAckFrame: + return nil + case *frame.RejectedFrame: + return errors.New(f.Message) + default: + return errors.New("unexpected frame") + } +} + +// Context returns the context of the connection. +func (p *FrameConn) Context() context.Context { + return p.ctx +} + +type memAddr struct { + remote bool +} + +func (m *memAddr) Network() string { + return "mem" +} +func (m *memAddr) String() string { + rs := "local" + if m.remote { + rs = "remote" + } + return fmt.Sprintf("mem://%s", rs) +} + +// RemoteAddr returns the remote address of connection. +func (p *FrameConn) RemoteAddr() net.Addr { + addr := &memAddr{ + remote: true, + } + return addr +} + +// LocalAddr returns the local address of connection. +func (p *FrameConn) LocalAddr() net.Addr { + addr := &memAddr{ + remote: false, + } + return addr +} + +// CloseWithError closes the connection. +// After calling CloseWithError, ReadFrame and WriteFrame will return frame.ErrConnClosed error. +func (p *FrameConn) CloseWithError(errString string) error { + select { + case <-p.ctx.Done(): + return nil + default: + p.cancel(frame.NewErrConnClosed(false, errString)) + } + return nil +} + +// ReadFrame reads a frame. it usually be called in a for-loop. +func (p *FrameConn) ReadFrame() (frame.Frame, error) { + select { + case f := <-p.rCh: + return f, nil + case <-p.ctx.Done(): + return nil, context.Cause(p.ctx) + } +} + +// WriteFrame writes a frame to connection. +func (p *FrameConn) WriteFrame(f frame.Frame) error { + select { + case p.wCh <- f: + return nil + case <-p.ctx.Done(): + return context.Cause(p.ctx) + } +} + +// Listener listens a net.PacketConn and accepts connections. +type Listener struct { + ctx context.Context + cancel context.CancelFunc + conns chan *FrameConn +} + +// Listen returns a Listener that can accept connections. +func Listen() *Listener { + ctx, cancel := context.WithCancel(context.Background()) + + l := &Listener{ + ctx: ctx, + cancel: cancel, + conns: make(chan *FrameConn, 10), + } + + return l +} + +func (l *Listener) Dial() (*FrameConn, error) { + var ( + rCh = make(chan frame.Frame) + wCh = make(chan frame.Frame) + ) + conn := newFrameConn(l.ctx, rCh, wCh) + + select { + case <-l.ctx.Done(): + return nil, l.ctx.Err() + case l.conns <- conn: + return conn, nil + } +} + +// Accept accepts FrameConns. +func (l *Listener) Accept(ctx context.Context) (frame.Conn, error) { + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case c := <-l.conns: + conn := &FrameConn{ + ctx: c.ctx, + cancel: c.cancel, + // swap rCh and wCh for bidirectional + rCh: c.wCh, + wCh: c.rCh, + } + return conn, nil + } + } +} + +// Close closes listener. +// If listener be closed, all connection receive quic application error that code=0, message="". +func (l *Listener) Close() error { + l.cancel() + return nil +} diff --git a/pkg/listener/mem/mem_test.go b/pkg/listener/mem/mem_test.go new file mode 100644 index 000000000..6e8c4201f --- /dev/null +++ b/pkg/listener/mem/mem_test.go @@ -0,0 +1,92 @@ +package mem + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/yomorun/yomo/core/frame" +) + +const ( + handshakeName = "hello yomo" + streamContent = "hello stream" + CloseMessage = "bye!" +) + +func TestMemAddr(t *testing.T) { + fconn := NewFrameConn(context.TODO()) + + assert.Equal(t, "mem", fconn.LocalAddr().Network()) + assert.Equal(t, "mem://local", fconn.LocalAddr().String()) + + assert.Equal(t, "mem", fconn.RemoteAddr().Network()) + assert.Equal(t, "mem://remote", fconn.RemoteAddr().String()) +} + +func TestListener(t *testing.T) { + listener := Listen() + + go func() { + if err := runListener(t, listener); err != nil { + panic(err) + } + }() + + fconn, err := listener.Dial() + assert.NoError(t, err) + + time.AfterFunc(time.Second, func() { + err := fconn.CloseWithError(CloseMessage) + assert.NoError(t, err) + }) + + err = fconn.Handshake(&frame.HandshakeFrame{Name: handshakeName}) + assert.NoError(t, err) + + for { + f, err := fconn.ReadFrame() + if err != nil { + se := new(frame.ErrConnClosed) + assert.True(t, errors.As(err, &se)) + assert.Equal(t, frame.NewErrConnClosed(false, CloseMessage), err) + return + } + df, ok := f.(*frame.DataFrame) + if !ok { + t.Fatalf("unexpected frame: %v", f) + } + assert.Equal(t, streamContent, string(df.Payload)) + } +} + +func runListener(t *testing.T, l *Listener) error { + fconn, err := l.Accept(context.TODO()) + if err != nil { + return err + } + + f, err := fconn.ReadFrame() + assert.NoError(t, err) + assert.Equal(t, f.Type(), frame.TypeHandshakeFrame) + + if err := fconn.WriteFrame(&frame.HandshakeAckFrame{}); err != nil { + return err + } + + time.AfterFunc(time.Second, func() { + fconn.CloseWithError(CloseMessage) + l.Close() + }) + + for range 10 { + fconn.WriteFrame(&frame.DataFrame{ + Tag: 0x34, + Payload: []byte(streamContent), + }) + } + + return nil +}