From 10ad39da3df8fdf46fbddbea9bdbb4958195d7f8 Mon Sep 17 00:00:00 2001 From: actatum Date: Sat, 18 Jun 2022 21:06:08 -0400 Subject: [PATCH] initial mvp --- .gitignore | 2 + client.go | 58 ++++++++ client_test.go | 237 +++++++++++++++++++++++++++++++++ errors.go | 68 ++++++++++ errors_test.go | 181 +++++++++++++++++++++++++ examples/simple/client/main.go | 38 ++++++ examples/simple/server/main.go | 67 ++++++++++ go.mod | 23 ++++ go.sum | 52 ++++++++ headers.go | 56 ++++++++ headers_test.go | 117 ++++++++++++++++ prototest/test.pb.go | 142 ++++++++++++++++++++ prototest/test.proto | 7 + request.go | 126 ++++++++++++++++++ request_test.go | 207 ++++++++++++++++++++++++++++ response.go | 101 ++++++++++++++ response_test.go | 235 ++++++++++++++++++++++++++++++++ server.go | 150 +++++++++++++++++++++ server_test.go | 170 +++++++++++++++++++++++ 19 files changed, 2037 insertions(+) create mode 100644 .gitignore create mode 100644 client.go create mode 100644 client_test.go create mode 100644 errors.go create mode 100644 errors_test.go create mode 100644 examples/simple/client/main.go create mode 100644 examples/simple/server/main.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 headers.go create mode 100644 headers_test.go create mode 100644 prototest/test.pb.go create mode 100644 prototest/test.proto create mode 100644 request.go create mode 100644 request_test.go create mode 100644 response.go create mode 100644 response_test.go create mode 100644 server.go create mode 100644 server_test.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b9268f8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +cmd +.idea \ No newline at end of file diff --git a/client.go b/client.go new file mode 100644 index 0000000..ee6ebd8 --- /dev/null +++ b/client.go @@ -0,0 +1,58 @@ +package stormrpc + +import ( + "errors" + + "github.com/nats-io/nats.go" +) + +type Client struct { + nc *nats.Conn +} + +func NewClient(natsURL string, opts ...ClientOption) (*Client, error) { + nc, err := nats.Connect(natsURL) + if err != nil { + return nil, err + } + + return &Client{ + nc: nc, + }, nil +} + +type clientOptions struct{} + +type ClientOption interface { + apply(*clientOptions) +} + +func (c *Client) Do(r *Request) Response { + msg, err := c.nc.RequestMsgWithContext(r.Context, r.Msg) + if errors.Is(err, nats.ErrNoResponders) { + return Response{ + Msg: msg, + Err: Errorf(ErrorCodeInternal, "no servers available for subject: %s", r.Subject()), + } + } + if err != nil { + return Response{ + Msg: msg, + Err: err, // TODO: probably use errorf and inspect different error types from nats. + } + } + + // Inspect headers and set error if appropriate + rpcErr := parseErrorHeader(msg.Header) + if rpcErr != nil { + return Response{ + Msg: msg, + Err: rpcErr, + } + } + + return Response{ + Msg: msg, + Err: nil, + } +} diff --git a/client_test.go b/client_test.go new file mode 100644 index 0000000..a59c5c9 --- /dev/null +++ b/client_test.go @@ -0,0 +1,237 @@ +package stormrpc + +import ( + "context" + "errors" + "fmt" + "math/rand" + "strconv" + "testing" + "time" + + "github.com/nats-io/nats-server/v2/server" + "github.com/nats-io/nats.go" +) + +func TestNewClient(t *testing.T) { + t.Run("no nats server running", func(t *testing.T) { + _, err := NewClient(nats.DefaultURL) + if err == nil { + t.Fatal("expected error got nil") + } + }) + + t.Run("nats server running", func(t *testing.T) { + ns, err := server.NewServer(&server.Options{ + Port: 41397, + }) + if err != nil { + t.Fatal(err) + } + go ns.Start() + t.Cleanup(func() { + ns.Shutdown() + ns.WaitForShutdown() + }) + + if !ns.ReadyForConnections(1 * time.Second) { + t.Error("timeout waiting for nats server") + return + } + + _, err = NewClient(ns.ClientURL()) + if err != nil { + t.Fatal(err) + } + }) +} + +func TestClient_Do(t *testing.T) { + t.Parallel() + + rand.Seed(time.Now().UnixNano()) + ns, err := server.NewServer(&server.Options{ + Port: 41397, + }) + if err != nil { + t.Fatal(err) + } + go ns.Start() + t.Cleanup(func() { + ns.Shutdown() + ns.WaitForShutdown() + }) + + if !ns.ReadyForConnections(1 * time.Second) { + t.Error("timeout waiting for nats server") + return + } + + t.Run("deadline exceeded", func(t *testing.T) { + t.Parallel() + + timeout := 50 * time.Millisecond + subject := strconv.Itoa(rand.Int()) + srv, err := NewServer("test", ns.ClientURL()) + if err != nil { + t.Fatal(err) + } + srv.Handle(subject, func(r Request) Response { + time.Sleep(timeout + 10*time.Millisecond) + return Response{Msg: &nats.Msg{Subject: r.Reply}} + }) + go srv.Run() + t.Cleanup(func() { + _ = srv.Shutdown(context.Background()) + }) + + client, err := NewClient(ns.ClientURL()) + if err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + r, err := NewRequest(ctx, subject, map[string]string{"howdy": "partner"}) + if err != nil { + t.Fatal(err) + } + + resp := client.Do(r) + if resp.Err == nil { + t.Fatal("expected error got nil") + } + + if !errors.Is(resp.Err, context.DeadlineExceeded) { + t.Fatalf("got = %v, want %v", resp.Err, context.DeadlineExceeded) + } + }) + + t.Run("rpc error", func(t *testing.T) { + t.Parallel() + + timeout := 50 * time.Millisecond + subject := strconv.Itoa(rand.Int()) + srv, err := NewServer("test", ns.ClientURL()) + if err != nil { + t.Fatal(err) + } + srv.Handle(subject, func(r Request) Response { + return NewErrorResponse(r.Reply, Errorf(ErrorCodeNotFound, "thingy not found")) + }) + go srv.Run() + t.Cleanup(func() { + _ = srv.Shutdown(context.Background()) + }) + + client, err := NewClient(ns.ClientURL()) + if err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + r, err := NewRequest(ctx, subject, map[string]string{"howdy": "partner"}) + if err != nil { + t.Fatal(err) + } + + resp := client.Do(r) + if resp.Err == nil { + t.Fatal("expected error got nil") + } + + code := CodeFromErr(resp.Err) + if code != ErrorCodeNotFound { + t.Fatalf("got = %v, want %v", code, ErrorCodeNotFound) + } + msg := MessageFromErr(resp.Err) + if msg != "thingy not found" { + t.Fatalf("got = %v, want %v", msg, "thingy not found") + } + }) + + t.Run("no servers", func(t *testing.T) { + t.Parallel() + + subject := strconv.Itoa(rand.Int()) + + client, err := NewClient(ns.ClientURL()) + if err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + req, err := NewRequest(ctx, subject, map[string]string{"x": "D"}) + if err != nil { + t.Fatal(err) + } + resp := client.Do(req) + if resp.Err == nil { + t.Fatal("expected error got nil") + } + + code := CodeFromErr(resp.Err) + if code != ErrorCodeInternal { + t.Fatalf("got = %v, want %v", code, ErrorCodeInternal) + } + msg := MessageFromErr(resp.Err) + if msg != fmt.Sprintf("no servers available for subject: %s", subject) { + t.Fatalf( + "got = %v, want %v", + msg, + fmt.Sprintf("no servers available for subject: %s", subject), + ) + } + }) + + t.Run("successful request", func(t *testing.T) { + t.Parallel() + + timeout := 50 * time.Millisecond + subject := strconv.Itoa(rand.Int()) + srv, err := NewServer("test", ns.ClientURL()) + if err != nil { + t.Fatal(err) + } + srv.Handle(subject, func(r Request) Response { + resp, err := NewResponse(r.Reply, map[string]string{"hello": "world"}) + if err != nil { + return NewErrorResponse(r.Reply, err) + } + return resp + }) + go srv.Run() + t.Cleanup(func() { + _ = srv.Shutdown(context.Background()) + }) + + client, err := NewClient(ns.ClientURL()) + if err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + r, err := NewRequest(ctx, subject, map[string]string{"howdy": "partner"}) + if err != nil { + t.Fatal(err) + } + + resp := client.Do(r) + if resp.Err != nil { + t.Fatal(resp.Err) + } + + var result map[string]string + if err = resp.Decode(&result); err != nil { + t.Fatal(err) + } + + if result["hello"] != "world" { + t.Fatalf("got = %v, want %v", result["hello"], "world") + } + }) +} diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..55ab892 --- /dev/null +++ b/errors.go @@ -0,0 +1,68 @@ +package stormrpc + +import ( + "errors" + "fmt" +) + +type ErrorCode int + +const ( + ErrorCodeUnknown ErrorCode = 0 + ErrorCodeInternal ErrorCode = 1 + ErrorCodeNotFound ErrorCode = 2 +) + +func (c ErrorCode) String() string { + switch c { + case ErrorCodeInternal: + return "STORMRPC_CODE_INTERNAL" + case ErrorCodeNotFound: + return "STORMRPC_CODE_NOT_FOUND" + default: + return "STORMRPC_CODE_UNKNOWN" + } +} + +type Error struct { + Code ErrorCode + Message string +} + +func (e Error) Error() string { + return fmt.Sprintf("%s: %s", e.Code.String(), e.Message) +} + +func Errorf(code ErrorCode, format string, args ...any) *Error { + return &Error{ + Code: code, + Message: fmt.Sprintf(format, args...), + } +} + +func CodeFromErr(err error) ErrorCode { + var e *Error + if errors.As(err, &e) { + return e.Code + } + return ErrorCodeUnknown +} + +func MessageFromErr(err error) string { + var e *Error + if errors.As(err, &e) { + return e.Message + } + return "unknown error" +} + +func codeFromString(s string) ErrorCode { + switch s { + case "STORMRPC_CODE_INTERNAL": + return ErrorCodeInternal + case "STORMRPC_CODE_NOT_FOUND": + return ErrorCodeNotFound + default: + return ErrorCodeUnknown + } +} diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 0000000..183d02b --- /dev/null +++ b/errors_test.go @@ -0,0 +1,181 @@ +package stormrpc + +import ( + "fmt" + "testing" +) + +func TestErrorCode_String(t *testing.T) { + tests := []struct { + name string + c ErrorCode + want string + }{ + { + name: "unknown", + c: ErrorCodeUnknown, + want: "STORMRPC_CODE_UNKNOWN", + }, + { + name: "internal", + c: ErrorCodeInternal, + want: "STORMRPC_CODE_INTERNAL", + }, + { + name: "not found", + c: ErrorCodeNotFound, + want: "STORMRPC_CODE_NOT_FOUND", + }, + { + name: "default", + c: 10000, + want: "STORMRPC_CODE_UNKNOWN", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.c.String(); got != tt.want { + t.Errorf("String() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestError_Error(t *testing.T) { + type fields struct { + Code ErrorCode + Message string + } + tests := []struct { + name string + fields fields + want string + }{ + { + name: "print error message", + fields: fields{ + Code: ErrorCodeNotFound, + Message: "thing not found", + }, + want: "STORMRPC_CODE_NOT_FOUND: thing not found", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := Error{ + Code: tt.fields.Code, + Message: tt.fields.Message, + } + if got := e.Error(); got != tt.want { + t.Errorf("Error() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCodeFromErr(t *testing.T) { + type args struct { + err error + } + tests := []struct { + name string + args args + want ErrorCode + }{ + { + name: "non stormrpc error", + args: args{ + err: fmt.Errorf("howdy"), + }, + want: ErrorCodeUnknown, + }, + { + name: "stormrpc error", + args: args{ + err: Errorf(ErrorCodeNotFound, "hi"), + }, + want: ErrorCodeNotFound, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := CodeFromErr(tt.args.err); got != tt.want { + t.Errorf("CodeFromErr() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMessageFromErr(t *testing.T) { + type args struct { + err error + } + tests := []struct { + name string + args args + want string + }{ + // TODO: Add test cases. + { + name: "non stormrpc error", + args: args{ + err: fmt.Errorf("hi"), + }, + want: "unknown error", + }, + { + name: "stormrpc error", + args: args{ + err: Errorf(ErrorCodeNotFound, "err"), + }, + want: "err", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := MessageFromErr(tt.args.err); got != tt.want { + t.Errorf("MessageFromErr() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_codeFromString(t *testing.T) { + type args struct { + s string + } + tests := []struct { + name string + args args + want ErrorCode + }{ + { + name: "default", + args: args{ + s: "asijdfoaijdsfoaijdf", + }, + want: ErrorCodeUnknown, + }, + { + name: "internal", + args: args{ + s: "STORMRPC_CODE_INTERNAL", + }, + want: ErrorCodeInternal, + }, + { + name: "not found", + args: args{ + s: "STORMRPC_CODE_NOT_FOUND", + }, + want: ErrorCodeNotFound, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := codeFromString(tt.args.s); got != tt.want { + t.Errorf("codeFromString() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/examples/simple/client/main.go b/examples/simple/client/main.go new file mode 100644 index 0000000..64d2c19 --- /dev/null +++ b/examples/simple/client/main.go @@ -0,0 +1,38 @@ +package main + +import ( + "context" + "fmt" + "log" + "stormrpc" + "time" +) + +func main() { + client, err := stormrpc.NewClient("nats://0.0.0.0:40897") + if err != nil { + log.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + r, err := stormrpc.NewRequest(ctx, "echo", map[string]string{"hello": "me"}) + if err != nil { + log.Fatal(err) + } + + resp := client.Do(r) + if resp.Err != nil { + log.Fatal(resp.Err) + } + + fmt.Println(resp.Header) + + var result map[string]string + if err = resp.Decode(&result); err != nil { + log.Fatal(err) + } + + fmt.Printf("Result: %v\n", result) +} diff --git a/examples/simple/server/main.go b/examples/simple/server/main.go new file mode 100644 index 0000000..888aebf --- /dev/null +++ b/examples/simple/server/main.go @@ -0,0 +1,67 @@ +package main + +import ( + "context" + "log" + "os" + "os/signal" + "stormrpc" + "syscall" + "time" + + "github.com/nats-io/nats-server/v2/server" +) + +func echo(req stormrpc.Request) stormrpc.Response { + var b any + if err := req.Decode(&b); err != nil { + return stormrpc.NewErrorResponse(req.Reply, err) + } + + resp, err := stormrpc.NewResponse(req.Reply, b) + if err != nil { + return stormrpc.NewErrorResponse(req.Reply, err) + } + + return resp +} + +func main() { + ns, err := server.NewServer(&server.Options{ + Port: 40897, + }) + if err != nil { + log.Fatal(err) + } + go ns.Start() + defer func() { + ns.Shutdown() + ns.WaitForShutdown() + }() + + if !ns.ReadyForConnections(1 * time.Second) { + log.Fatal("timeout waiting for nats server") + } + + srv, err := stormrpc.NewServer("echo", ns.ClientURL()) + if err != nil { + log.Fatal(err) + } + srv.Handle("echo", echo) + + go func() { + _ = srv.Run() + }() + log.Printf("👋 Listening on %v", srv.Subjects()) + + done := make(chan os.Signal, 1) + signal.Notify(done, syscall.SIGINT, syscall.SIGTERM) + <-done + log.Printf("💀 Shutting down") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err = srv.Shutdown(ctx); err != nil { + log.Fatal(err) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..8bb0dc3 --- /dev/null +++ b/go.mod @@ -0,0 +1,23 @@ +module stormrpc + +go 1.18 + +require ( + github.com/nats-io/nats-server/v2 v2.8.4 + github.com/nats-io/nats.go v1.16.0 + github.com/vmihailenco/msgpack/v5 v5.3.5 + google.golang.org/protobuf v1.28.0 +) + +require ( + github.com/golang/protobuf v1.5.2 // indirect + github.com/klauspost/compress v1.14.4 // indirect + github.com/minio/highwayhash v1.0.2 // indirect + github.com/nats-io/jwt/v2 v2.2.1-0.20220330180145-442af02fd36a // indirect + github.com/nats-io/nkeys v0.3.0 // indirect + github.com/nats-io/nuid v1.0.1 // indirect + github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect + golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd // indirect + golang.org/x/sys v0.0.0-20220111092808-5a964db01320 // indirect + golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..214fda9 --- /dev/null +++ b/go.sum @@ -0,0 +1,52 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/klauspost/compress v1.14.4 h1:eijASRJcobkVtSt81Olfh7JX43osYLwy5krOJo6YEu4= +github.com/klauspost/compress v1.14.4/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= +github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA/g= +github.com/minio/highwayhash v1.0.2/go.mod h1:BQskDq+xkJ12lmlUUi7U0M5Swg3EWR+dLTk+kldvVxY= +github.com/nats-io/jwt/v2 v2.2.1-0.20220330180145-442af02fd36a h1:lem6QCvxR0Y28gth9P+wV2K/zYUUAkJ+55U8cpS0p5I= +github.com/nats-io/jwt/v2 v2.2.1-0.20220330180145-442af02fd36a/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k= +github.com/nats-io/nats-server/v2 v2.8.4 h1:0jQzze1T9mECg8YZEl8+WYUXb9JKluJfCBriPUtluB4= +github.com/nats-io/nats-server/v2 v2.8.4/go.mod h1:8zZa+Al3WsESfmgSs98Fi06dRWLH5Bnq90m5bKD/eT4= +github.com/nats-io/nats.go v1.16.0 h1:zvLE7fGBQYW6MWaFaRdsgm9qT39PJDQoju+DS8KsO1g= +github.com/nats-io/nats.go v1.16.0/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w= +github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8= +github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4= +github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= +github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU= +github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= +github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= +github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= +golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= +golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd h1:XcWmESyNjXJMLahc3mqVQJcgSTDxFxhETVlfk9uGc38= +golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/sys v0.0.0-20190130150945-aca44879d564/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20220111092808-5a964db01320 h1:0jf+tOCoZ3LyutmCOWpVni1chK4VfFLhRsDK7MhqGRY= +golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 h1:GZokNIeuVkl3aZHJchRrr13WCsols02MLUcz1U9is6M= +golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= +google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/headers.go b/headers.go new file mode 100644 index 0000000..2f23822 --- /dev/null +++ b/headers.go @@ -0,0 +1,56 @@ +package stormrpc + +import ( + "strconv" + "strings" + "time" + + "github.com/nats-io/nats.go" +) + +const ( + errorHeader = "stormrpc-error" + deadlineHeader = "stormrpc-deadline" +) + +func parseDeadlineHeader(header nats.Header) time.Time { + dh := header.Get(deadlineHeader) + if dh == "" { + return time.Time{} + } + + i, err := strconv.ParseInt(dh, 10, 64) + if err != nil { + return time.Time{} + } + + return time.Unix(i, 0) +} + +func parseErrorHeader(header nats.Header) *Error { + eh := header.Get(errorHeader) + if eh == "" { + return nil + } + + sp := strings.Split(eh, ":") + + if len(sp) < 2 { + return &Error{ + Code: ErrorCodeUnknown, + Message: "unknown error", + } + } + + code := codeFromString(strings.TrimSpace(sp[0])) + msg := strings.TrimSpace(sp[1]) + + if code == ErrorCodeUnknown { + msg = "unknown error" + } + + return &Error{ + Code: code, + Message: msg, + } +} diff --git a/headers_test.go b/headers_test.go new file mode 100644 index 0000000..6396531 --- /dev/null +++ b/headers_test.go @@ -0,0 +1,117 @@ +package stormrpc + +import ( + "reflect" + "strconv" + "testing" + "time" + + "github.com/nats-io/nats.go" +) + +func Test_parseErrorHeader(t *testing.T) { + type args struct { + header nats.Header + } + tests := []struct { + name string + args args + want *Error + }{ + { + name: "no error header", + args: args{ + header: nats.Header{}, + }, + want: nil, + }, + { + name: "weirdly formatted error header", + args: args{ + header: nats.Header{ + errorHeader: []string{"BIG HEADER", "NICE ERROR"}, + }, + }, + want: &Error{ + Code: ErrorCodeUnknown, + Message: "unknown error", + }, + }, + { + name: "not found error", + args: args{ + header: nats.Header{ + errorHeader: []string{"STORMRPC_CODE_NOT_FOUND: new error"}, + }, + }, + want: &Error{ + Code: ErrorCodeNotFound, + Message: "new error", + }, + }, + { + name: "unknown error", + args: args{ + header: nats.Header{ + errorHeader: []string{"STORMRPC_CODE_UNKNOWN: xD"}, + }, + }, + want: &Error{ + Code: ErrorCodeUnknown, + Message: "unknown error", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := parseErrorHeader(tt.args.header); !reflect.DeepEqual(got, tt.want) { + t.Errorf("parseErrorHeader() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_parseDeadlineHeader(t *testing.T) { + type args struct { + header nats.Header + } + tests := []struct { + name string + args args + want time.Time + }{ + // TODO: Add test cases. + { + name: "no header", + args: args{ + header: nats.Header{}, + }, + want: time.Time{}, + }, + { + name: "header non int", + args: args{ + header: nats.Header{ + deadlineHeader: []string{"bob"}, + }, + }, + want: time.Time{}, + }, + { + name: "header with unix time", + args: args{ + header: nats.Header{ + deadlineHeader: []string{strconv.FormatInt(time.Now().Round(1*time.Minute).Unix(), 10)}, + }, + }, + want: time.Now().Round(1 * time.Minute), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := parseDeadlineHeader(tt.args.header); !reflect.DeepEqual(got, tt.want) { + t.Errorf("parseDeadlineHeader() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/prototest/test.pb.go b/prototest/test.pb.go new file mode 100644 index 0000000..836f48f --- /dev/null +++ b/prototest/test.pb.go @@ -0,0 +1,142 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.28.0 +// protoc v3.19.4 +// source: prototest/test.proto + +package prototest + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Greeting struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"` +} + +func (x *Greeting) Reset() { + *x = Greeting{} + if protoimpl.UnsafeEnabled { + mi := &file_prototest_test_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Greeting) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Greeting) ProtoMessage() {} + +func (x *Greeting) ProtoReflect() protoreflect.Message { + mi := &file_prototest_test_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Greeting.ProtoReflect.Descriptor instead. +func (*Greeting) Descriptor() ([]byte, []int) { + return file_prototest_test_proto_rawDescGZIP(), []int{0} +} + +func (x *Greeting) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +var File_prototest_test_proto protoreflect.FileDescriptor + +var file_prototest_test_proto_rawDesc = []byte{ + 0x0a, 0x14, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x74, 0x65, 0x73, 0x74, 0x2f, 0x74, 0x65, 0x73, 0x74, + 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x24, 0x0a, 0x08, 0x47, 0x72, 0x65, 0x65, 0x74, 0x69, + 0x6e, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x42, 0x0d, 0x5a, 0x0b, + 0x2e, 0x3b, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x74, 0x65, 0x73, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x33, +} + +var ( + file_prototest_test_proto_rawDescOnce sync.Once + file_prototest_test_proto_rawDescData = file_prototest_test_proto_rawDesc +) + +func file_prototest_test_proto_rawDescGZIP() []byte { + file_prototest_test_proto_rawDescOnce.Do(func() { + file_prototest_test_proto_rawDescData = protoimpl.X.CompressGZIP(file_prototest_test_proto_rawDescData) + }) + return file_prototest_test_proto_rawDescData +} + +var file_prototest_test_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_prototest_test_proto_goTypes = []interface{}{ + (*Greeting)(nil), // 0: Greeting +} +var file_prototest_test_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_prototest_test_proto_init() } +func file_prototest_test_proto_init() { + if File_prototest_test_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_prototest_test_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Greeting); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_prototest_test_proto_rawDesc, + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_prototest_test_proto_goTypes, + DependencyIndexes: file_prototest_test_proto_depIdxs, + MessageInfos: file_prototest_test_proto_msgTypes, + }.Build() + File_prototest_test_proto = out.File + file_prototest_test_proto_rawDesc = nil + file_prototest_test_proto_goTypes = nil + file_prototest_test_proto_depIdxs = nil +} diff --git a/prototest/test.proto b/prototest/test.proto new file mode 100644 index 0000000..3b8f106 --- /dev/null +++ b/prototest/test.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +option go_package = ".;prototest"; + +message Greeting { + string message = 1; +} \ No newline at end of file diff --git a/request.go b/request.go new file mode 100644 index 0000000..a638e05 --- /dev/null +++ b/request.go @@ -0,0 +1,126 @@ +package stormrpc + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + + "github.com/nats-io/nats.go" + "github.com/vmihailenco/msgpack/v5" + "google.golang.org/protobuf/proto" +) + +type Request struct { + *nats.Msg + context.Context +} + +func NewRequest(ctx context.Context, subject string, body any, opts ...RequestOption) (*Request, error) { + options := requestOptions{ + encodeProto: false, + encodeMsgpack: false, + } + + for _, o := range opts { + o.apply(&options) + } + + var data []byte + var err error + var contentType string + + switch { + case options.encodeProto: + switch m := body.(type) { + case proto.Message: + data, err = proto.Marshal(m) + contentType = "application/protobuf" + default: + return nil, fmt.Errorf("failed to encode proto message: invalid type: %T", m) + } + case options.encodeMsgpack: + data, err = msgpack.Marshal(body) + contentType = "application/msgpack" + default: + data, err = json.Marshal(body) + contentType = "application/json" + } + if err != nil { + return nil, err + } + + headers := nats.Header{} + dl, ok := ctx.Deadline() + if ok { + headers.Set(deadlineHeader, strconv.FormatInt(dl.UnixNano(), 10)) + } + headers.Set("Content-Type", contentType) + msg := &nats.Msg{ + Data: data, + Subject: subject, + Header: headers, + } + + return &Request{ + Msg: msg, + Context: ctx, + }, nil +} + +type requestOptions struct { + encodeProto bool + encodeMsgpack bool +} + +type RequestOption interface { + apply(options *requestOptions) +} + +type encodeProtoOption bool + +func (p encodeProtoOption) apply(opts *requestOptions) { + opts.encodeProto = bool(p) +} + +func WithEncodeProto() RequestOption { + return encodeProtoOption(true) +} + +type encodeMsgpackOption bool + +func (p encodeMsgpackOption) apply(opts *requestOptions) { + opts.encodeMsgpack = bool(p) +} + +func WithEncodeMsgpack() RequestOption { + return encodeMsgpackOption(true) +} + +func (r *Request) Decode(v any) error { + var err error + + switch r.Header.Get("Content-Type") { + case "application/msgpack": + err = msgpack.Unmarshal(r.Data, v) + case "application/protobuf": + switch m := v.(type) { + case proto.Message: + err = proto.Unmarshal(r.Data, m) + default: + return fmt.Errorf("failed to decode proto message: invalid type :%T", v) + } + default: + err = json.Unmarshal(r.Data, v) + } + + if err != nil { + return fmt.Errorf("failed to decode request: %w", err) + } + + return nil +} + +func (r *Request) Subject() string { + return r.Msg.Subject +} diff --git a/request_test.go b/request_test.go new file mode 100644 index 0000000..efec2a8 --- /dev/null +++ b/request_test.go @@ -0,0 +1,207 @@ +package stormrpc + +import ( + "context" + "encoding/json" + "reflect" + "testing" + + "stormrpc/prototest" + + "github.com/nats-io/nats.go" + "github.com/vmihailenco/msgpack/v5" + "google.golang.org/protobuf/proto" +) + +func TestNewRequest(t *testing.T) { + t.Run("defaults", func(t *testing.T) { + body := map[string]string{"hello": "world"} + data, _ := json.Marshal(body) + ctx := context.Background() + + want := &Request{ + Msg: &nats.Msg{ + Subject: "test", + Header: nats.Header{ + "Content-Type": []string{"application/json"}, + }, + Data: data, + }, + Context: ctx, + } + + got, err := NewRequest(ctx, "test", body) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(got, want) { + t.Fatalf("NewRequest() got = %v, want %v", got, want) + } + }) + + t.Run("msgpack option", func(t *testing.T) { + body := map[string]string{"hello": "world"} + data, _ := msgpack.Marshal(body) + ctx := context.Background() + + want := &Request{ + Msg: &nats.Msg{ + Subject: "test", + Header: nats.Header{ + "Content-Type": []string{"application/msgpack"}, + }, + Data: data, + }, + Context: ctx, + } + + got, err := NewRequest(ctx, "test", body, WithEncodeMsgpack()) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(got, want) { + t.Fatalf("NewRequest() got = %v, want %v", got, want) + } + }) + + t.Run("proto option", func(t *testing.T) { + body := &prototest.Greeting{Message: "hello"} + data, _ := proto.Marshal(body) + ctx := context.Background() + + want := &Request{ + Msg: &nats.Msg{ + Subject: "test", + Header: nats.Header{ + "Content-Type": []string{"application/protobuf"}, + }, + Data: data, + }, + Context: ctx, + } + + got, err := NewRequest(ctx, "test", body, WithEncodeProto()) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(got, want) { + t.Fatalf("NewRequest() got = %v, want %v", got, want) + } + }) + + t.Run("proto option w/non proto message", func(t *testing.T) { + body := map[string]string{"hello": "world"} + + _, err := NewRequest(context.Background(), "test", body, WithEncodeProto()) + if err == nil { + t.Fatal("expected error got nil") + } + }) +} + +func TestRequest_Decode(t *testing.T) { + t.Run("decode json", func(t *testing.T) { + body := map[string]string{"hello": "world"} + r, err := NewRequest(context.Background(), "test", body) + if err != nil { + t.Fatal(err) + } + + var got map[string]string + if err = r.Decode(&got); err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(got, body) { + t.Fatalf("NewRequest() got = %v, want %v", got, body) + } + }) + + t.Run("decode msgpack", func(t *testing.T) { + body := map[string]string{"hello": "world"} + r, err := NewRequest(context.Background(), "test", body, WithEncodeMsgpack()) + if err != nil { + t.Fatal(err) + } + + var got map[string]string + if err = r.Decode(&got); err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(got, body) { + t.Fatalf("NewRequest() got = %v, want %v", got, body) + } + }) + + t.Run("decode proto", func(t *testing.T) { + body := &prototest.Greeting{Message: "hi"} + r, err := NewRequest(context.Background(), "test", body, WithEncodeProto()) + if err != nil { + t.Fatal(err) + } + + var got prototest.Greeting + if err = r.Decode(&got); err != nil { + t.Fatal(err) + } + + if got.GetMessage() != body.GetMessage() { + t.Fatalf("got = %v, want %v", got.GetMessage(), body.GetMessage()) + } + + //if !reflect.DeepEqual(&got, body) { + // t.Fatalf("NewRequest() got = %v, want %v", &got, body) + //} + }) + + t.Run("decode proto w/non proto message", func(t *testing.T) { + body := &prototest.Greeting{Message: "hello"} + r, err := NewRequest(context.Background(), "test", body, WithEncodeProto()) + if err != nil { + t.Fatal(err) + } + + var got map[string]string + err = r.Decode(&got) + if err == nil { + t.Fatal(err) + } + }) +} + +func TestRequest_Subject(t *testing.T) { + type fields struct { + Msg *nats.Msg + Context context.Context + } + tests := []struct { + name string + fields fields + want string + }{ + { + name: "return msg subject", + fields: fields{ + Msg: &nats.Msg{ + Subject: "me", + }, + }, + want: "me", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &Request{ + Msg: tt.fields.Msg, + Context: tt.fields.Context, + } + if got := r.Subject(); got != tt.want { + t.Errorf("Subject() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/response.go b/response.go new file mode 100644 index 0000000..dc1d56c --- /dev/null +++ b/response.go @@ -0,0 +1,101 @@ +package stormrpc + +import ( + "encoding/json" + "fmt" + + "github.com/nats-io/nats.go" + "github.com/vmihailenco/msgpack/v5" + "google.golang.org/protobuf/proto" +) + +// TODO: Create constructor like request that handles encoding etc. + +type Response struct { + *nats.Msg + Err error +} + +func NewResponse(reply string, body any, opts ...ResponseOption) (Response, error) { + options := requestOptions{ + encodeProto: false, + encodeMsgpack: false, + } + + for _, o := range opts { + o.apply(&options) + } + + var data []byte + var err error + var contentType string + + switch { + case options.encodeProto: + switch m := body.(type) { + case proto.Message: + data, err = proto.Marshal(m) + contentType = "application/protobuf" + default: + return Response{}, fmt.Errorf("failed to encode proto message: invalid type: %T", m) + } + case options.encodeMsgpack: + data, err = msgpack.Marshal(body) + contentType = "application/msgpack" + default: + data, err = json.Marshal(body) + contentType = "application/json" + } + if err != nil { + return Response{}, err + } + + headers := nats.Header{} + headers.Set("Content-Type", contentType) + + msg := &nats.Msg{ + Subject: reply, + Header: headers, + Data: data, + } + + return Response{ + Msg: msg, + Err: nil, + }, nil +} + +func NewErrorResponse(reply string, err error) Response { + return Response{ + Msg: &nats.Msg{ + Subject: reply, + }, + Err: err, + } +} + +type ResponseOption RequestOption + +func (r *Response) Decode(v any) error { + var err error + + switch r.Header.Get("Content-Type") { + case "application/msgpack": + err = msgpack.Unmarshal(r.Data, v) + case "application/protobuf": + switch m := v.(type) { + case proto.Message: + err = proto.Unmarshal(r.Data, m) + default: + return fmt.Errorf("failed to decode proto message: invalid type: %T", m) + } + default: + err = json.Unmarshal(r.Data, v) + } + + if err != nil { + return fmt.Errorf("failed to decode response: %w", err) + } + + return nil +} diff --git a/response_test.go b/response_test.go new file mode 100644 index 0000000..15bf6fc --- /dev/null +++ b/response_test.go @@ -0,0 +1,235 @@ +package stormrpc + +import ( + "encoding/json" + "fmt" + "reflect" + "stormrpc/prototest" + "testing" + + "github.com/nats-io/nats.go" + "github.com/vmihailenco/msgpack/v5" + "google.golang.org/protobuf/proto" +) + +func TestResponse_Decode(t *testing.T) { + t.Run("decode json", func(t *testing.T) { + body := map[string]string{"hello": "world"} + data, _ := json.Marshal(body) + resp := &Response{ + Msg: &nats.Msg{ + Header: nats.Header{ + "Content-Type": []string{"application/json"}, + }, + Data: data, + }, + Err: nil, + } + + var got map[string]string + if err := resp.Decode(&got); err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(got, body) { + t.Fatalf("NewRequest() got = %v, want %v", got, body) + } + }) + + t.Run("decode msgpack", func(t *testing.T) { + body := map[string]string{"hello": "world"} + data, _ := msgpack.Marshal(body) + resp := &Response{ + Msg: &nats.Msg{ + Header: nats.Header{ + "Content-Type": []string{"application/msgpack"}, + }, + Data: data, + }, + Err: nil, + } + + var got map[string]string + if err := resp.Decode(&got); err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(got, body) { + t.Fatalf("NewRequest() got = %v, want %v", got, body) + } + }) + + t.Run("decode proto", func(t *testing.T) { + body := &prototest.Greeting{Message: "hi"} + data, _ := proto.Marshal(body) + resp := &Response{ + Msg: &nats.Msg{ + Header: nats.Header{ + "Content-Type": []string{"application/protobuf"}, + }, + Data: data, + }, + Err: nil, + } + + var got prototest.Greeting + if err := resp.Decode(&got); err != nil { + t.Fatal(err) + } + + if got.GetMessage() != body.GetMessage() { + t.Fatalf("got = %v, want %v", got.GetMessage(), body.GetMessage()) + } + }) + + t.Run("decode proto w/non proto message", func(t *testing.T) { + body := map[string]string{"hello": "world"} + data, _ := json.Marshal(body) + resp := &Response{ + Msg: &nats.Msg{ + Header: nats.Header{ + "Content-Type": []string{"application/protobuf"}, + }, + Data: data, + }, + Err: nil, + } + + var got prototest.Greeting + err := resp.Decode(&got) + if err == nil { + t.Fatal("expected error got nil") + } + }) +} + +func TestNewErrorResponse(t *testing.T) { + type args struct { + reply string + err error + } + tests := []struct { + name string + args args + want Response + }{ + { + name: "non stormrpc error", + args: args{ + reply: "test", + err: fmt.Errorf("10"), + }, + want: Response{ + Msg: &nats.Msg{ + Subject: "test", + }, + Err: fmt.Errorf("10"), + }, + }, + { + name: "stormrpc error", + args: args{ + reply: "test", + err: Errorf(ErrorCodeNotFound, "hi"), + }, + want: Response{ + Msg: &nats.Msg{ + Subject: "test", + }, + Err: Errorf(ErrorCodeNotFound, "hi"), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := NewErrorResponse(tt.args.reply, tt.args.err); !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewErrorResponse() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNewResponse(t *testing.T) { + t.Run("defaults", func(t *testing.T) { + body := map[string]string{"hello": "world"} + data, _ := json.Marshal(body) + + want := Response{ + Msg: &nats.Msg{ + Subject: "test", + Header: nats.Header{ + "Content-Type": []string{"application/json"}, + }, + Data: data, + }, + Err: nil, + } + + got, err := NewResponse("test", body) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(got, want) { + t.Fatalf("NewResponse() got = %v, want %v", got, want) + } + }) + + t.Run("with msgpack", func(t *testing.T) { + body := map[string]string{"hello": "world"} + data, _ := msgpack.Marshal(body) + + want := Response{ + Msg: &nats.Msg{ + Subject: "test", + Header: nats.Header{ + "Content-Type": []string{"application/msgpack"}, + }, + Data: data, + }, + Err: nil, + } + + got, err := NewResponse("test", body, WithEncodeMsgpack()) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(got, want) { + t.Fatalf("NewResponse() got = %v, want %v", got, want) + } + }) + + t.Run("proto option", func(t *testing.T) { + body := &prototest.Greeting{Message: "hello"} + data, _ := proto.Marshal(body) + + want := Response{ + Msg: &nats.Msg{ + Subject: "test", + Header: nats.Header{ + "Content-Type": []string{"application/protobuf"}, + }, + Data: data, + }, + Err: nil, + } + + got, err := NewResponse("test", body, WithEncodeProto()) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(got, want) { + t.Fatalf("NewResponse() got = %v, want %v", got, want) + } + }) + + t.Run("proto option w/non proto message", func(t *testing.T) { + body := map[string]string{"hello": "world"} + _, err := NewResponse("test", body, WithEncodeProto()) + if err == nil { + t.Fatal("expected error got nil") + } + }) +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..e3edb0f --- /dev/null +++ b/server.go @@ -0,0 +1,150 @@ +package stormrpc + +import ( + "context" + "fmt" + "time" + + "github.com/nats-io/nats.go" +) + +var defaultServerTimeout = 5 * time.Second + +type Server struct { + nc *nats.Conn + name string + shutdownSignal chan struct{} + handlerFuncs map[string]HandlerFunc + errorHandler ErrorHandler + timeout time.Duration +} + +func NewServer(name, natsURL string, opts ...ServerOption) (*Server, error) { + options := serverOptions{ + errorHandler: func(ctx context.Context, err error) {}, + } + + for _, o := range opts { + o.apply(&options) + } + + nc, err := nats.Connect(natsURL) + if err != nil { + return nil, err + } + + return &Server{ + nc: nc, + name: name, + shutdownSignal: make(chan struct{}), + handlerFuncs: make(map[string]HandlerFunc), + timeout: defaultServerTimeout, + errorHandler: options.errorHandler, + }, nil +} + +type serverOptions struct { + errorHandler ErrorHandler +} + +type ServerOption interface { + apply(*serverOptions) +} + +type errorHandlerOption ErrorHandler + +func (h errorHandlerOption) apply(opts *serverOptions) { + opts.errorHandler = ErrorHandler(h) +} + +func WithErrorHandler(fn ErrorHandler) ServerOption { + return errorHandlerOption(fn) +} + +type HandlerFunc func(Request) Response + +type ErrorHandler func(context.Context, error) + +func (s *Server) Handle(subject string, fn HandlerFunc) { + s.handlerFuncs[subject] = fn +} + +// Run listens on the configured subjects. +func (s *Server) Run() error { + for k := range s.handlerFuncs { + _, err := s.nc.QueueSubscribe(k, s.name, s.handler) + if err != nil { + return err + } + } + + <-s.shutdownSignal + return nil +} + +// Shutdown stops the server. +func (s *Server) Shutdown(ctx context.Context) error { + if err := s.nc.FlushWithContext(ctx); err != nil { + return err + } + + s.nc.Close() + s.shutdownSignal <- struct{}{} + return nil +} + +// Subjects returns a list of all subjects with registered handler funcs. +func (s *Server) Subjects() []string { + subs := make([]string, 0, len(s.handlerFuncs)) + for k := range s.handlerFuncs { + subs = append(subs, k) + } + + return subs +} + +// handler serves the request to the specific request handler based on subject. +// wildcard subjects are not supported as you'll need to register a handler func for each +// rpc the server supports. +func (s *Server) handler(msg *nats.Msg) { + // TODO: remove this Printf + fmt.Printf("received msg on subject: %s = %s\n", msg.Subject, string(msg.Data)) + + fn := s.handlerFuncs[msg.Subject] + + ctx, cancel := context.WithTimeout(context.Background(), s.timeout) + defer cancel() + + dl := parseDeadlineHeader(msg.Header) + if !dl.IsZero() { // if deadline is present use it + ctx, cancel = context.WithDeadline(context.Background(), dl) + defer cancel() + } + + req := Request{ + Msg: msg, + Context: ctx, + } + + resp := fn(req) + + if resp.Err != nil { + if resp.Header == nil { + resp.Header = nats.Header{} + } + resp.Header.Set(errorHeader, resp.Err.Error()) + err := msg.RespondMsg(resp.Msg) + if err != nil { + s.errorHandler(ctx, err) + // TODO: remove the Printf + fmt.Printf("msg.RespondMsg: %v\n", err) + } + } + + err := msg.RespondMsg(resp.Msg) + if err != nil { + s.errorHandler(ctx, err) + // TODO: remove the Printf + fmt.Printf("msg.RespondMsg: %v\n", err) + } +} diff --git a/server_test.go b/server_test.go new file mode 100644 index 0000000..b496038 --- /dev/null +++ b/server_test.go @@ -0,0 +1,170 @@ +package stormrpc + +import ( + "context" + "math/rand" + "reflect" + "strconv" + "testing" + "time" + + "github.com/nats-io/nats-server/v2/server" + "github.com/nats-io/nats.go" +) + +func TestServer_RunAndShutdown(t *testing.T) { + ns, err := server.NewServer(&server.Options{ + Port: 40897, + }) + if err != nil { + t.Fatal(err) + } + go ns.Start() + t.Cleanup(func() { + ns.Shutdown() + ns.WaitForShutdown() + }) + + if !ns.ReadyForConnections(1 * time.Second) { + t.Error("timeout waiting for nats server") + return + } + + srv, err := NewServer("test", ns.ClientURL()) + if err != nil { + t.Fatal(err) + } + + runCh := make(chan error) + go func(ch chan error) { + runErr := srv.Run() + runCh <- runErr + }(runCh) + time.Sleep(250 * time.Millisecond) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + if err = srv.Shutdown(ctx); err != nil { + t.Fatal(err) + } + + err = <-runCh + if err != nil { + t.Fatal(err) + } +} + +func TestServer_handler(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + ns, err := server.NewServer(&server.Options{ + Port: 40897, + }) + if err != nil { + t.Fatal(err) + } + go ns.Start() + t.Cleanup(func() { + ns.Shutdown() + ns.WaitForShutdown() + }) + + if !ns.ReadyForConnections(1 * time.Second) { + t.Error("timeout waiting for nats server") + return + } + + t.Run("successful handle", func(t *testing.T) { + t.Parallel() + + srv, err := NewServer("test", ns.ClientURL()) + if err != nil { + t.Fatal(err) + } + + subject := strconv.Itoa(rand.Int()) + srv.Handle(subject, func(r Request) Response { + return Response{ + Msg: &nats.Msg{ + Subject: r.Reply, + Data: []byte(`{"response":"1"}`), + }, + Err: nil, + } + }) + + runCh := make(chan error) + go func(ch chan error) { + runErr := srv.Run() + runCh <- runErr + }(runCh) + time.Sleep(250 * time.Millisecond) + + client, err := NewClient(ns.ClientURL()) + if err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + req, err := NewRequest(ctx, subject, map[string]string{"x": "D"}) + if err != nil { + t.Fatal(err) + } + resp := client.Do(req) + if resp.Err != nil { + t.Fatal(resp.Err) + } + + var result map[string]string + if err = resp.Decode(&result); err != nil { + t.Fatal(err) + } + + if result["response"] != "1" { + t.Fatalf("got = %v, want %v", result["response"], "1") + } + + if err = srv.Shutdown(ctx); err != nil { + t.Fatal(err) + } + + err = <-runCh + if err != nil { + t.Fatal(err) + } + }) +} + +func TestServer_Handle(t *testing.T) { + s := Server{ + handlerFuncs: make(map[string]HandlerFunc), + } + + t.Run("OK", func(t *testing.T) { + s.Handle("testing", func(r Request) Response { return Response{} }) + + if _, ok := s.handlerFuncs["testing"]; !ok { + t.Fatal("expected key testing to contain a handler func") + } + }) +} + +func TestServer_Subjects(t *testing.T) { + s := Server{ + handlerFuncs: make(map[string]HandlerFunc), + } + + s.Handle("testing", func(r Request) Response { return Response{} }) + s.Handle("testing", func(r Request) Response { return Response{} }) + s.Handle("1, 2, 3", func(r Request) Response { return Response{} }) + + expected := []string{"testing", "1, 2, 3"} + + got := s.Subjects() + + if !reflect.DeepEqual(got, expected) { + t.Fatalf("got = %v, want %v", got, expected) + } +}