diff --git a/client.go b/client.go index 63d4e39..cc3a7aa 100644 --- a/client.go +++ b/client.go @@ -38,7 +38,22 @@ func (c *Client) Close() { } // Do completes a request to a stormRPC Server. -func (c *Client) Do(ctx context.Context, r Request) Response { +func (c *Client) Do(ctx context.Context, r Request, opts ...CallOption) Response { + options := callOptions{ + headers: make(map[string]string), + } + for _, o := range opts { + err := o.before(&options) + if err != nil { + return Response{ + Msg: &nats.Msg{}, + Err: err, + } + } + } + + applyOptions(&r, &options) + msg, err := c.nc.RequestMsgWithContext(ctx, r.Msg) if errors.Is(err, nats.ErrNoResponders) { return Response{ @@ -67,3 +82,9 @@ func (c *Client) Do(ctx context.Context, r Request) Response { Err: nil, } } + +func applyOptions(r *Request, options *callOptions) { + for k, v := range options.headers { + r.Header.Set(k, v) + } +} diff --git a/context.go b/context.go new file mode 100644 index 0000000..25ad17b --- /dev/null +++ b/context.go @@ -0,0 +1,28 @@ +package stormrpc + +import ( + "context" + + "github.com/nats-io/nats.go" +) + +type ctxKey int + +const ( + headerContextKey ctxKey = iota +) + +// HeadersFromContext retrieves RPC headers from the given context. +func HeadersFromContext(ctx context.Context) nats.Header { + h, ok := ctx.Value(headerContextKey).(nats.Header) + if !ok { + return make(nats.Header) + } + + return h +} + +// newContextWithHeaders creates a new context with Header information stored in it. +func newContextWithHeaders(ctx context.Context, headers nats.Header) context.Context { + return context.WithValue(ctx, headerContextKey, headers) +} diff --git a/examples/protogen/client/main.go b/examples/protogen/client/main.go index 46c7d49..57442e6 100644 --- a/examples/protogen/client/main.go +++ b/examples/protogen/client/main.go @@ -22,7 +22,8 @@ func main() { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - out, err := c.Echo(ctx, &pb.EchoRequest{Message: "protogen"}) + headers := map[string]string{"Authorization": "Bearer xy.eay"} + out, err := c.Echo(ctx, &pb.EchoRequest{Message: "protogen"}, stormrpc.WithHeaders(headers)) if err != nil { log.Fatal(err) } diff --git a/examples/protogen/pb/echo_stormrpc.pb.go b/examples/protogen/pb/echo_stormrpc.pb.go index 368f4f8..bbcf082 100644 --- a/examples/protogen/pb/echo_stormrpc.pb.go +++ b/examples/protogen/pb/echo_stormrpc.pb.go @@ -10,7 +10,7 @@ import ( // EchoerClient is the client API for Echoer service. type EchoerClient interface { - Echo(ctx context.Context, in *EchoRequest) (*EchoResponse, error) + Echo(ctx context.Context, in *EchoRequest, opts ...stormrpc.CallOption) (*EchoResponse, error) } type echoerClient struct { @@ -21,14 +21,14 @@ func NewEchoerClient(c *stormrpc.Client) EchoerClient { return &echoerClient{c} } -func (c *echoerClient) Echo(ctx context.Context, in *EchoRequest) (*EchoResponse, error) { +func (c *echoerClient) Echo(ctx context.Context, in *EchoRequest, opts ...stormrpc.CallOption) (*EchoResponse, error) { var out EchoResponse r, err := stormrpc.NewRequest("rpc.Echoer.Echo", in, stormrpc.WithEncodeProto()) if err != nil { return nil, err } - resp := c.c.Do(ctx, r) + resp := c.c.Do(ctx, r, opts...) if resp.Err != nil { return nil, resp.Err } diff --git a/examples/protogen/server/main.go b/examples/protogen/server/main.go index 6a67f78..aa61ddb 100644 --- a/examples/protogen/server/main.go +++ b/examples/protogen/server/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "fmt" "log" "os" "os/signal" @@ -16,6 +17,8 @@ import ( type echoServer struct{} func (s *echoServer) Echo(ctx context.Context, in *pb.EchoRequest) (*pb.EchoResponse, error) { + h := stormrpc.HeadersFromContext(ctx) + fmt.Printf("headers: %v\n", h) return &pb.EchoResponse{ Message: in.GetMessage(), }, nil diff --git a/internal/gen/gen.go b/internal/gen/gen.go index 09efd13..9713733 100644 --- a/internal/gen/gen.go +++ b/internal/gen/gen.go @@ -149,7 +149,7 @@ func clientSignature(g *protogen.GeneratedFile, method *protogen.Method) string if !method.Desc.IsStreamingClient() { s += ", in *" + g.QualifiedGoIdent(method.Input.GoIdent) } - s += ") (" + s += ", opts ...stormrpc.CallOption) (" if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() { s += "*" + g.QualifiedGoIdent(method.Output.GoIdent) } else { @@ -177,7 +177,7 @@ func genClientMethod( g.P(`r, err := stormrpc.NewRequest("` + routeSignature(service, method) + `", in, stormrpc.WithEncodeProto())`) g.P("if err != nil { return nil, err }") g.P() - g.P("resp := c.c.Do(ctx, r)") + g.P("resp := c.c.Do(ctx, r, opts...)") g.P("if resp.Err != nil { return nil, resp.Err }") g.P() g.P("if err = resp.Decode(&out); err != nil { return nil, err }") diff --git a/options.go b/options.go new file mode 100644 index 0000000..b1ebeb5 --- /dev/null +++ b/options.go @@ -0,0 +1,34 @@ +package stormrpc + +// CallOption configures an RPC to perform actions before it starts or after +// the RPC has completed. +type CallOption interface { + // before is called before the RPC is sent to any server. + // If before returns a non-nil error, the RPC fails with that error. + before(*callOptions) error + + // after is called after the RPC has completed after cannot return an error. + after(*callOptions) +} + +// callOptions contains all configuration for an RPC. +type callOptions struct { + headers map[string]string +} + +// HeaderCallOption is used to configure which headers to append to the outgoing RPC. +type HeaderCallOption struct { + Headers map[string]string +} + +func (o *HeaderCallOption) before(c *callOptions) error { + c.headers = o.Headers + return nil +} + +func (o *HeaderCallOption) after(c *callOptions) {} + +// WithHeaders returns a CallOption that appends the given headers to the request. +func WithHeaders(h map[string]string) CallOption { + return &HeaderCallOption{Headers: h} +} diff --git a/options_test.go b/options_test.go new file mode 100644 index 0000000..6825e30 --- /dev/null +++ b/options_test.go @@ -0,0 +1,41 @@ +package stormrpc + +import ( + "reflect" + "testing" +) + +func TestWithHeaders(t *testing.T) { + type args struct { + h map[string]string + } + tests := []struct { + name string + args args + want CallOption + }{ + // TODO: Add test cases. + { + name: "add some headers", + args: args{ + h: map[string]string{ + "Authorization": "Bearer ey.xyz", + "X-Request-Id": "abc", + }, + }, + want: &HeaderCallOption{ + Headers: map[string]string{ + "Authorization": "Bearer ey.xyz", + "X-Request-Id": "abc", + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := WithHeaders(tt.args.h); !reflect.DeepEqual(got, tt.want) { + t.Errorf("WithHeaders() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/prototest/protoc.sh b/prototest/protoc.sh index 25eeb7c..52421d5 100755 --- a/prototest/protoc.sh +++ b/prototest/protoc.sh @@ -1,3 +1,3 @@ go install ./protoc-gen-stormrpc protoc --proto_path prototest -I=. prototest/test.proto \ - --stormrpc_out=./prototest/gen_out --go_out=./prototest \ No newline at end of file + --stormrpc_out=./prototest/gen_out --go_out=./prototest diff --git a/server.go b/server.go index 3bf7393..e27df85 100644 --- a/server.go +++ b/server.go @@ -147,6 +147,7 @@ func (s *Server) handler(msg *nats.Msg) { req := Request{ Msg: msg, } + ctx = newContextWithHeaders(ctx, req.Header) resp := fn(ctx, req)