diff --git a/.gitignore b/.gitignore index b9268f8..e5f023c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ cmd -.idea \ No newline at end of file +.idea +prototest/gen_out \ No newline at end of file diff --git a/README.md b/README.md index ab6a6f7..c3fafbd 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,26 @@ It provides some convenient features including: Responses have an `Error` attribute and these are propagated across the wire without needing to tweak your request/response schemas. +## Installation + +### Runtime Library + +The runtime library package ```github.com/actatum/stormrpc``` contains common types like ```stormrpc.Error```, ```stormrpc.Client``` and ```stormrpc.Server```. If you aren't generating servers and clients from protobuf definitions you only need to import the stormrpc package. + +```bash +$ go get github.com/actatum/stormrpc +``` + +### Code Generator + +You need to install ```go``` and the ```protoc``` compiler on your system. Then, install the protoc plugins ```protoc-gen-stormrpc``` and ```protoc-gen-go``` to generate Go code. + +```bash +$ go install github.com/actatum/stormrpc/protoc-gen-stormrpc +$ go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.28 +``` +Code generation examples can be found [here](https://github.com/actatum/stormrpc/tree/main/examples/protogen) + ## Basic Usage ### Server diff --git a/examples/protogen/client/main.go b/examples/protogen/client/main.go new file mode 100644 index 0000000..46c7d49 --- /dev/null +++ b/examples/protogen/client/main.go @@ -0,0 +1,31 @@ +package main + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/actatum/stormrpc" + "github.com/actatum/stormrpc/examples/protogen/pb" +) + +func main() { + client, err := stormrpc.NewClient("nats://0.0.0.0:40897") + if err != nil { + log.Fatal(err) + } + defer client.Close() + + c := pb.NewEchoerClient(client) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + out, err := c.Echo(ctx, &pb.EchoRequest{Message: "protogen"}) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("Response: %s\n", out.GetMessage()) +} diff --git a/examples/protogen/genproto.sh b/examples/protogen/genproto.sh new file mode 100755 index 0000000..65f31cb --- /dev/null +++ b/examples/protogen/genproto.sh @@ -0,0 +1,2 @@ +go install ../../protoc-gen-stormrpc +protoc --go_out=./pb --stormrpc_out=./pb pb/echo.proto \ No newline at end of file diff --git a/examples/protogen/pb/echo.pb.go b/examples/protogen/pb/echo.pb.go new file mode 100644 index 0000000..a93c605 --- /dev/null +++ b/examples/protogen/pb/echo.pb.go @@ -0,0 +1,210 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.28.0 +// protoc v3.6.1 +// source: pb/echo.proto + +package pb + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type EchoRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"` +} + +func (x *EchoRequest) Reset() { + *x = EchoRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_pb_echo_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *EchoRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*EchoRequest) ProtoMessage() {} + +func (x *EchoRequest) ProtoReflect() protoreflect.Message { + mi := &file_pb_echo_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use EchoRequest.ProtoReflect.Descriptor instead. +func (*EchoRequest) Descriptor() ([]byte, []int) { + return file_pb_echo_proto_rawDescGZIP(), []int{0} +} + +func (x *EchoRequest) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +type EchoResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"` +} + +func (x *EchoResponse) Reset() { + *x = EchoResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_pb_echo_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *EchoResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*EchoResponse) ProtoMessage() {} + +func (x *EchoResponse) ProtoReflect() protoreflect.Message { + mi := &file_pb_echo_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use EchoResponse.ProtoReflect.Descriptor instead. +func (*EchoResponse) Descriptor() ([]byte, []int) { + return file_pb_echo_proto_rawDescGZIP(), []int{1} +} + +func (x *EchoResponse) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +var File_pb_echo_proto protoreflect.FileDescriptor + +var file_pb_echo_proto_rawDesc = []byte{ + 0x0a, 0x0d, 0x70, 0x62, 0x2f, 0x65, 0x63, 0x68, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, + 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x67, 0x65, 0x6e, 0x22, 0x27, 0x0a, 0x0b, 0x45, 0x63, 0x68, + 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, + 0x67, 0x65, 0x22, 0x28, 0x0a, 0x0c, 0x45, 0x63, 0x68, 0x6f, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x32, 0x3f, 0x0a, 0x06, + 0x45, 0x63, 0x68, 0x6f, 0x65, 0x72, 0x12, 0x35, 0x0a, 0x04, 0x45, 0x63, 0x68, 0x6f, 0x12, 0x15, + 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x67, 0x65, 0x6e, 0x2e, 0x45, 0x63, 0x68, 0x6f, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x67, 0x65, 0x6e, + 0x2e, 0x45, 0x63, 0x68, 0x6f, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x06, 0x5a, + 0x04, 0x2e, 0x3b, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_pb_echo_proto_rawDescOnce sync.Once + file_pb_echo_proto_rawDescData = file_pb_echo_proto_rawDesc +) + +func file_pb_echo_proto_rawDescGZIP() []byte { + file_pb_echo_proto_rawDescOnce.Do(func() { + file_pb_echo_proto_rawDescData = protoimpl.X.CompressGZIP(file_pb_echo_proto_rawDescData) + }) + return file_pb_echo_proto_rawDescData +} + +var file_pb_echo_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_pb_echo_proto_goTypes = []interface{}{ + (*EchoRequest)(nil), // 0: protogen.EchoRequest + (*EchoResponse)(nil), // 1: protogen.EchoResponse +} +var file_pb_echo_proto_depIdxs = []int32{ + 0, // 0: protogen.Echoer.Echo:input_type -> protogen.EchoRequest + 1, // 1: protogen.Echoer.Echo:output_type -> protogen.EchoResponse + 1, // [1:2] is the sub-list for method output_type + 0, // [0:1] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_pb_echo_proto_init() } +func file_pb_echo_proto_init() { + if File_pb_echo_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_pb_echo_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*EchoRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pb_echo_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*EchoResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_pb_echo_proto_rawDesc, + NumEnums: 0, + NumMessages: 2, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_pb_echo_proto_goTypes, + DependencyIndexes: file_pb_echo_proto_depIdxs, + MessageInfos: file_pb_echo_proto_msgTypes, + }.Build() + File_pb_echo_proto = out.File + file_pb_echo_proto_rawDesc = nil + file_pb_echo_proto_goTypes = nil + file_pb_echo_proto_depIdxs = nil +} diff --git a/examples/protogen/pb/echo.proto b/examples/protogen/pb/echo.proto new file mode 100644 index 0000000..d0f7646 --- /dev/null +++ b/examples/protogen/pb/echo.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +package protogen; + +option go_package = ".;pb"; + +service Echoer { + rpc Echo (EchoRequest) returns (EchoResponse); +} + +message EchoRequest { + string message = 1; +} + +message EchoResponse { + string message = 1; +} \ No newline at end of file diff --git a/examples/protogen/pb/echo_stormrpc.pb.go b/examples/protogen/pb/echo_stormrpc.pb.go new file mode 100644 index 0000000..368f4f8 --- /dev/null +++ b/examples/protogen/pb/echo_stormrpc.pb.go @@ -0,0 +1,95 @@ +// Code generated by protoc-gen-stormrpc. DO NOT EDIT. + +package pb + +import ( + context "context" + fmt "fmt" + stormrpc "github.com/actatum/stormrpc" +) + +// EchoerClient is the client API for Echoer service. +type EchoerClient interface { + Echo(ctx context.Context, in *EchoRequest) (*EchoResponse, error) +} + +type echoerClient struct { + c *stormrpc.Client +} + +func NewEchoerClient(c *stormrpc.Client) EchoerClient { + return &echoerClient{c} +} + +func (c *echoerClient) Echo(ctx context.Context, in *EchoRequest) (*EchoResponse, error) { + var out EchoResponse + r, err := stormrpc.NewRequest("rpc.Echoer.Echo", in, stormrpc.WithEncodeProto()) + if err != nil { + return nil, err + } + + resp := c.c.Do(ctx, r) + if resp.Err != nil { + return nil, resp.Err + } + + if err = resp.Decode(&out); err != nil { + return nil, err + } + + return &out, nil +} + +// EchoerServer is the server API for Echoer service. +type EchoerServer interface { + Echo(context.Context, *EchoRequest) (*EchoResponse, error) +} + +func RegisterEchoerServer(s *stormrpc.Server, srv EchoerServer) { + for _, handler := range echoerHandlers { + handler.SetService(srv) + s.Handle(handler.Route(), handler.HandlerFunc()) + } +} + +type _Echoer_Echo_Handler struct { + route string + svc interface{} +} + +func (h *_Echoer_Echo_Handler) HandlerFunc() stormrpc.HandlerFunc { + return func(ctx context.Context, r stormrpc.Request) stormrpc.Response { + var in EchoRequest + if err := r.Decode(&in); err != nil { + return stormrpc.NewErrorResponse(r.Reply, fmt.Errorf("error decoding request")) + } + + out, err := h.svc.(EchoerServer).Echo(ctx, &in) + if err != nil { + return stormrpc.NewErrorResponse(r.Reply, err) + } + + resp, err := stormrpc.NewResponse(r.Reply, out, stormrpc.WithEncodeProto()) + if err != nil { + return stormrpc.NewErrorResponse(r.Reply, err) + } + + return resp + } +} +func (h *_Echoer_Echo_Handler) Route() string { + return h.route +} +func (h *_Echoer_Echo_Handler) SetService(svc interface{}) { + h.svc = svc +} + +type handler interface { + Route() string + HandlerFunc() stormrpc.HandlerFunc + SetService(interface{}) +} + +var echoerHandlers = []handler{ + &_Echoer_Echo_Handler{route: "rpc.Echoer.Echo"}, +} diff --git a/examples/protogen/server/main.go b/examples/protogen/server/main.go new file mode 100644 index 0000000..6a67f78 --- /dev/null +++ b/examples/protogen/server/main.go @@ -0,0 +1,69 @@ +package main + +import ( + "context" + "log" + "os" + "os/signal" + "syscall" + "time" + + "github.com/actatum/stormrpc" + "github.com/actatum/stormrpc/examples/protogen/pb" + "github.com/nats-io/nats-server/v2/server" +) + +type echoServer struct{} + +func (s *echoServer) Echo(ctx context.Context, in *pb.EchoRequest) (*pb.EchoResponse, error) { + return &pb.EchoResponse{ + Message: in.GetMessage(), + }, nil +} + +func main() { + ns, err := server.NewServer(&server.Options{ + Port: 40897, + }) + if err != nil { + log.Fatal(err) + } + go ns.Start() + defer func() { + ns.Shutdown() + ns.WaitForShutdown() + }() + + if !ns.ReadyForConnections(1 * time.Second) { + log.Fatal("timeout waiting for nats server") + } + + srv, err := stormrpc.NewServer("echo", ns.ClientURL(), stormrpc.WithErrorHandler(logError)) + if err != nil { + log.Fatal(err) + } + + svc := &echoServer{} + + pb.RegisterEchoerServer(srv, svc) + + go func() { + _ = srv.Run() + }() + log.Printf("👋 Listening on %v", srv.Subjects()) + + done := make(chan os.Signal, 1) + signal.Notify(done, syscall.SIGINT, syscall.SIGTERM) + <-done + log.Printf("💀 Shutting down") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err = srv.Shutdown(ctx); err != nil { + log.Fatal(err) + } +} + +func logError(ctx context.Context, err error) { + log.Printf("Server Error: %v\n", err) +} diff --git a/internal/gen/gen.go b/internal/gen/gen.go new file mode 100644 index 0000000..09efd13 --- /dev/null +++ b/internal/gen/gen.go @@ -0,0 +1,259 @@ +package gen + +import ( + "fmt" + "strconv" + "strings" + + "google.golang.org/protobuf/compiler/protogen" + "google.golang.org/protobuf/types/descriptorpb" +) + +const ( + fmtPackage = protogen.GoImportPath("fmt") + contextPackage = protogen.GoImportPath("context") + stormrpcPackage = protogen.GoImportPath("github.com/actatum/stormrpc") +) + +const deprectationComment = "// Deprecated: Do not use." + +// GenerateFile generates a _stormrpc.pb.go file containing stormrpc service defintions. +func GenerateFile(gen *protogen.Plugin, file *protogen.File) *protogen.GeneratedFile { + if len(file.Services) == 0 { + return nil + } + filename := file.GeneratedFilenamePrefix + "_stormrpc.pb.go" + g := gen.NewGeneratedFile(filename, file.GoImportPath) + g.P("// Code generated by protoc-gen-stormrpc. DO NOT EDIT.") + g.P() + g.P("package ", file.GoPackageName) + g.P() + GenerateFileContent(gen, file, g) + return g +} + +// GenerateFileContent generates the stormrpc service definitions, excluding the package statement. +func GenerateFileContent(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile) { + if len(file.Services) == 0 { + return + } + + g.P() + for _, service := range file.Services { + genService(gen, file, g, service) + } +} + +func genService(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, service *protogen.Service) { + clientName := service.GoName + "Client" + + g.P("// ", clientName, " is the client API for ", service.GoName, " service.") + + // Client interface. + if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() { + g.P("//") + g.P(deprectationComment) + } + g.Annotate(clientName, service.Location) + g.P("type ", clientName, " interface {") + for _, method := range service.Methods { + g.Annotate(clientName+"."+method.GoName, method.Location) + if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() { + g.P(deprectationComment) + } + g.P(method.Comments.Leading, clientSignature(g, method)) + } + g.P("}") + g.P() + + // Client structure. + g.P("type ", unexport(clientName), " struct {") + g.P("c *", stormrpcPackage.Ident("Client")) + g.P("}") + g.P() + + // NewClient constructor. + if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() { + g.P(deprectationComment) + } + g.P("func New", clientName, " (c *", stormrpcPackage.Ident("Client"), ") ", clientName, " {") + g.P("return &", unexport(clientName), "{c}") + g.P("}") + g.P() + + var methodIndex int + for _, method := range service.Methods { + if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() { + // Unary RPC method + genClientMethod(gen, file, g, method, methodIndex) + methodIndex++ + } + } + + // Server interface. + serverType := service.GoName + "Server" + g.P("// ", serverType, " is the server API for ", service.GoName, " service.") + if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() { + g.P("//") + g.P(deprectationComment) + } + g.Annotate(serverType, service.Location) + g.P("type ", serverType, " interface {") + for _, method := range service.Methods { + g.Annotate(serverType+"."+method.GoName, method.Location) + if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() { + g.P(deprectationComment) + } + g.P(method.Comments.Leading, serverSignature(g, method)) + } + g.P("}") + g.P() + + // Server registration. + if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() { + g.P(deprectationComment) + } + g.P("func Register", service.GoName, "Server(s *", stormrpcPackage.Ident("Server"), ", srv ", serverType, ") {") + g.P("for _, handler := range ", unexport(service.GoName), "Handlers {") + g.P("handler.SetService(srv)") + g.P("s.Handle(handler.Route(), handler.HandlerFunc())") + g.P("}") + g.P("}") + g.P() + + // Server handler implementations. + var handlerNames []string + for _, method := range service.Methods { + hname := genServerHandler(gen, file, g, method) + handlerNames = append(handlerNames, hname) + } + + // Handlers + g.P("type handler interface {") + g.P("Route() string") + g.P("HandlerFunc() stormrpc.HandlerFunc") + g.P("SetService(interface{})") + g.P("}") + + // HandlerFuncs + g.P("var ", unexport(service.GoName), "Handlers = []handler{") + for i, method := range service.Methods { + g.P("&", handlerNames[i], "{route: ", strconv.Quote(routeSignature(service, method)), "},") + } + g.P("}") + g.P() +} + +func clientSignature(g *protogen.GeneratedFile, method *protogen.Method) string { + s := method.GoName + "(ctx " + g.QualifiedGoIdent(contextPackage.Ident("Context")) + if !method.Desc.IsStreamingClient() { + s += ", in *" + g.QualifiedGoIdent(method.Input.GoIdent) + } + s += ") (" + if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() { + s += "*" + g.QualifiedGoIdent(method.Output.GoIdent) + } else { + s += method.Parent.GoName + "_" + method.GoName + "Client" + } + s += ", error)" + return s +} + +func genClientMethod( + gen *protogen.Plugin, + file *protogen.File, + g *protogen.GeneratedFile, + method *protogen.Method, + index int, +) { + service := method.Parent + + if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() { + g.P(deprectationComment) + } + g.P("func (c *", unexport(service.GoName), "Client) ", clientSignature(g, method), "{") + if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() { + g.P("var out ", method.Output.GoIdent) + g.P(`r, err := stormrpc.NewRequest("` + routeSignature(service, method) + `", in, stormrpc.WithEncodeProto())`) + g.P("if err != nil { return nil, err }") + g.P() + g.P("resp := c.c.Do(ctx, r)") + g.P("if resp.Err != nil { return nil, resp.Err }") + g.P() + g.P("if err = resp.Decode(&out); err != nil { return nil, err }") + g.P() + g.P("return &out, nil") + g.P("}") + g.P() + return + } + +} + +func serverSignature(g *protogen.GeneratedFile, method *protogen.Method) string { + var reqArgs []string + ret := "error" + if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() { + reqArgs = append(reqArgs, g.QualifiedGoIdent(contextPackage.Ident("Context"))) + ret = "(*" + g.QualifiedGoIdent(method.Output.GoIdent) + ", error)" + } + if !method.Desc.IsStreamingClient() { + reqArgs = append(reqArgs, "*"+g.QualifiedGoIdent(method.Input.GoIdent)) + } + if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() { + reqArgs = append(reqArgs, method.Parent.GoName+"_"+method.GoName+"Server") + } + return method.GoName + "(" + strings.Join(reqArgs, ", ") + ") " + ret +} + +func genServerHandler( + gen *protogen.Plugin, + file *protogen.File, + g *protogen.GeneratedFile, + method *protogen.Method, +) string { + service := method.Parent + hname := unexport(fmt.Sprintf("_%s_%s_Handler", service.GoName, method.GoName)) + + if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() { + g.P("type ", hname, " struct {") + g.P("route string") + g.P("svc interface{}") + g.P("}") + g.P() + g.P("func (h *", hname, ") HandlerFunc() stormrpc.HandlerFunc {") + g.P("return func(ctx ", contextPackage.Ident("Context"), ", r stormrpc.Request) stormrpc.Response {") + g.P("var in ", method.Input.GoIdent) + g.P(`if err := r.Decode(&in); err != nil { return stormrpc.NewErrorResponse(r.Reply, `, + fmtPackage.Ident("Errorf"), `("error decoding request")) }`) + g.P() + g.P("out, err := h.svc.(", service.GoName, "Server).", method.GoName, "(ctx, &in)") + g.P("if err != nil { return stormrpc.NewErrorResponse(r.Reply, err) }") + g.P() + g.P("resp, err := stormrpc.NewResponse(r.Reply, out, stormrpc.WithEncodeProto())") + g.P("if err != nil { return stormrpc.NewErrorResponse(r.Reply, err) }") + g.P() + g.P("return resp") + g.P("}") + g.P("}") + + g.P("func (h *", hname, ") Route() string {") + g.P("return h.route") + g.P("}") + + g.P("func (h *", hname, ") SetService(svc interface{}) {") + g.P("h.svc = svc") + g.P("}") + g.P() + + return hname + } + + return hname +} + +func routeSignature(service *protogen.Service, method *protogen.Method) string { + return fmt.Sprintf("rpc.%s.%s", service.GoName, method.GoName) +} + +func unexport(s string) string { return strings.ToLower(s[:1]) + s[1:] } diff --git a/protoc-gen-stormrpc/main.go b/protoc-gen-stormrpc/main.go new file mode 100644 index 0000000..8fc819b --- /dev/null +++ b/protoc-gen-stormrpc/main.go @@ -0,0 +1,18 @@ +package main + +import ( + stormrpcgen "github.com/actatum/stormrpc/internal/gen" + "google.golang.org/protobuf/compiler/protogen" +) + +func main() { + protogen.Options{}.Run(func(gen *protogen.Plugin) error { + for _, f := range gen.Files { + if !f.Generate { + continue + } + stormrpcgen.GenerateFile(gen, f) + } + return nil + }) +} diff --git a/prototest/protoc.sh b/prototest/protoc.sh new file mode 100755 index 0000000..25eeb7c --- /dev/null +++ b/prototest/protoc.sh @@ -0,0 +1,3 @@ +go install ./protoc-gen-stormrpc +protoc --proto_path prototest -I=. prototest/test.proto \ + --stormrpc_out=./prototest/gen_out --go_out=./prototest \ No newline at end of file diff --git a/prototest/test.pb.go b/prototest/test.pb.go index 836f48f..35028a1 100644 --- a/prototest/test.pb.go +++ b/prototest/test.pb.go @@ -1,8 +1,8 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.28.0 -// protoc v3.19.4 -// source: prototest/test.proto +// protoc v3.6.1 +// source: test.proto package prototest @@ -20,7 +20,54 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) -type Greeting struct { +type HelloRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` +} + +func (x *HelloRequest) Reset() { + *x = HelloRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_test_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *HelloRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*HelloRequest) ProtoMessage() {} + +func (x *HelloRequest) ProtoReflect() protoreflect.Message { + mi := &file_test_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use HelloRequest.ProtoReflect.Descriptor instead. +func (*HelloRequest) Descriptor() ([]byte, []int) { + return file_test_proto_rawDescGZIP(), []int{0} +} + +func (x *HelloRequest) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +type HelloReply struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields @@ -28,23 +75,23 @@ type Greeting struct { Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"` } -func (x *Greeting) Reset() { - *x = Greeting{} +func (x *HelloReply) Reset() { + *x = HelloReply{} if protoimpl.UnsafeEnabled { - mi := &file_prototest_test_proto_msgTypes[0] + mi := &file_test_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } } -func (x *Greeting) String() string { +func (x *HelloReply) String() string { return protoimpl.X.MessageStringOf(x) } -func (*Greeting) ProtoMessage() {} +func (*HelloReply) ProtoMessage() {} -func (x *Greeting) ProtoReflect() protoreflect.Message { - mi := &file_prototest_test_proto_msgTypes[0] +func (x *HelloReply) ProtoReflect() protoreflect.Message { + mi := &file_test_proto_msgTypes[1] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -55,61 +102,82 @@ func (x *Greeting) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use Greeting.ProtoReflect.Descriptor instead. -func (*Greeting) Descriptor() ([]byte, []int) { - return file_prototest_test_proto_rawDescGZIP(), []int{0} +// Deprecated: Use HelloReply.ProtoReflect.Descriptor instead. +func (*HelloReply) Descriptor() ([]byte, []int) { + return file_test_proto_rawDescGZIP(), []int{1} } -func (x *Greeting) GetMessage() string { +func (x *HelloReply) GetMessage() string { if x != nil { return x.Message } return "" } -var File_prototest_test_proto protoreflect.FileDescriptor - -var file_prototest_test_proto_rawDesc = []byte{ - 0x0a, 0x14, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x74, 0x65, 0x73, 0x74, 0x2f, 0x74, 0x65, 0x73, 0x74, - 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x24, 0x0a, 0x08, 0x47, 0x72, 0x65, 0x65, 0x74, 0x69, - 0x6e, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x42, 0x0d, 0x5a, 0x0b, - 0x2e, 0x3b, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x74, 0x65, 0x73, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x33, +var File_test_proto protoreflect.FileDescriptor + +var file_test_proto_rawDesc = []byte{ + 0x0a, 0x0a, 0x74, 0x65, 0x73, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x04, 0x74, 0x65, + 0x73, 0x74, 0x22, 0x22, 0x0a, 0x0c, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x22, 0x26, 0x0a, 0x0a, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x52, + 0x65, 0x70, 0x6c, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x32, 0x3d, + 0x0a, 0x07, 0x47, 0x72, 0x65, 0x65, 0x74, 0x65, 0x72, 0x12, 0x32, 0x0a, 0x08, 0x53, 0x61, 0x79, + 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x12, 0x12, 0x2e, 0x74, 0x65, 0x73, 0x74, 0x2e, 0x48, 0x65, 0x6c, + 0x6c, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x10, 0x2e, 0x74, 0x65, 0x73, 0x74, + 0x2e, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x42, 0x0d, 0x5a, + 0x0b, 0x2e, 0x3b, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x74, 0x65, 0x73, 0x74, 0x62, 0x06, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x33, } var ( - file_prototest_test_proto_rawDescOnce sync.Once - file_prototest_test_proto_rawDescData = file_prototest_test_proto_rawDesc + file_test_proto_rawDescOnce sync.Once + file_test_proto_rawDescData = file_test_proto_rawDesc ) -func file_prototest_test_proto_rawDescGZIP() []byte { - file_prototest_test_proto_rawDescOnce.Do(func() { - file_prototest_test_proto_rawDescData = protoimpl.X.CompressGZIP(file_prototest_test_proto_rawDescData) +func file_test_proto_rawDescGZIP() []byte { + file_test_proto_rawDescOnce.Do(func() { + file_test_proto_rawDescData = protoimpl.X.CompressGZIP(file_test_proto_rawDescData) }) - return file_prototest_test_proto_rawDescData + return file_test_proto_rawDescData } -var file_prototest_test_proto_msgTypes = make([]protoimpl.MessageInfo, 1) -var file_prototest_test_proto_goTypes = []interface{}{ - (*Greeting)(nil), // 0: Greeting +var file_test_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_test_proto_goTypes = []interface{}{ + (*HelloRequest)(nil), // 0: test.HelloRequest + (*HelloReply)(nil), // 1: test.HelloReply } -var file_prototest_test_proto_depIdxs = []int32{ - 0, // [0:0] is the sub-list for method output_type - 0, // [0:0] is the sub-list for method input_type +var file_test_proto_depIdxs = []int32{ + 0, // 0: test.Greeter.SayHello:input_type -> test.HelloRequest + 1, // 1: test.Greeter.SayHello:output_type -> test.HelloReply + 1, // [1:2] is the sub-list for method output_type + 0, // [0:1] is the sub-list for method input_type 0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for field type_name } -func init() { file_prototest_test_proto_init() } -func file_prototest_test_proto_init() { - if File_prototest_test_proto != nil { +func init() { file_test_proto_init() } +func file_test_proto_init() { + if File_test_proto != nil { return } if !protoimpl.UnsafeEnabled { - file_prototest_test_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Greeting); i { + file_test_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*HelloRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_test_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*HelloReply); i { case 0: return &v.state case 1: @@ -125,18 +193,18 @@ func file_prototest_test_proto_init() { out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_prototest_test_proto_rawDesc, + RawDescriptor: file_test_proto_rawDesc, NumEnums: 0, - NumMessages: 1, + NumMessages: 2, NumExtensions: 0, - NumServices: 0, + NumServices: 1, }, - GoTypes: file_prototest_test_proto_goTypes, - DependencyIndexes: file_prototest_test_proto_depIdxs, - MessageInfos: file_prototest_test_proto_msgTypes, + GoTypes: file_test_proto_goTypes, + DependencyIndexes: file_test_proto_depIdxs, + MessageInfos: file_test_proto_msgTypes, }.Build() - File_prototest_test_proto = out.File - file_prototest_test_proto_rawDesc = nil - file_prototest_test_proto_goTypes = nil - file_prototest_test_proto_depIdxs = nil + File_test_proto = out.File + file_test_proto_rawDesc = nil + file_test_proto_goTypes = nil + file_test_proto_depIdxs = nil } diff --git a/prototest/test.proto b/prototest/test.proto index 3b8f106..35135eb 100644 --- a/prototest/test.proto +++ b/prototest/test.proto @@ -1,7 +1,17 @@ syntax = "proto3"; +package test; + option go_package = ".;prototest"; -message Greeting { - string message = 1; +service Greeter { + rpc SayHello (HelloRequest) returns (HelloReply) {} +} + +message HelloRequest { + string name = 1; +} + +message HelloReply { + string message = 1; } \ No newline at end of file diff --git a/request_test.go b/request_test.go index 3919a1e..cc695c9 100644 --- a/request_test.go +++ b/request_test.go @@ -61,7 +61,7 @@ func TestNewRequest(t *testing.T) { }) t.Run("proto option", func(t *testing.T) { - body := &prototest.Greeting{Message: "hello"} + body := &prototest.HelloRequest{Name: "aaron"} data, _ := proto.Marshal(body) want := Request{ @@ -130,13 +130,13 @@ func TestRequest_Decode(t *testing.T) { }) t.Run("decode proto", func(t *testing.T) { - body := &prototest.Greeting{Message: "hi"} + body := &prototest.HelloReply{Message: "hi"} r, err := NewRequest("test", body, WithEncodeProto()) if err != nil { t.Fatal(err) } - var got prototest.Greeting + var got prototest.HelloReply if err = r.Decode(&got); err != nil { t.Fatal(err) } @@ -147,7 +147,7 @@ func TestRequest_Decode(t *testing.T) { }) t.Run("decode proto w/non proto message", func(t *testing.T) { - body := &prototest.Greeting{Message: "hello"} + body := &prototest.HelloReply{Message: "hello"} r, err := NewRequest("test", body, WithEncodeProto()) if err != nil { t.Fatal(err) diff --git a/response_test.go b/response_test.go index f749f9c..62c58dd 100644 --- a/response_test.go +++ b/response_test.go @@ -60,7 +60,7 @@ func TestResponse_Decode(t *testing.T) { }) t.Run("decode proto", func(t *testing.T) { - body := &prototest.Greeting{Message: "hi"} + body := &prototest.HelloReply{Message: "hi"} data, _ := proto.Marshal(body) resp := &Response{ Msg: &nats.Msg{ @@ -72,7 +72,7 @@ func TestResponse_Decode(t *testing.T) { Err: nil, } - var got prototest.Greeting + var got prototest.HelloReply if err := resp.Decode(&got); err != nil { t.Fatal(err) } @@ -95,7 +95,7 @@ func TestResponse_Decode(t *testing.T) { Err: nil, } - var got prototest.Greeting + var got prototest.HelloReply err := resp.Decode(&got) if err == nil { t.Fatal("expected error got nil") @@ -201,7 +201,7 @@ func TestNewResponse(t *testing.T) { }) t.Run("proto option", func(t *testing.T) { - body := &prototest.Greeting{Message: "hello"} + body := &prototest.HelloReply{Message: "hello"} data, _ := proto.Marshal(body) want := Response{