diff --git a/cmd/protoc-gen-go-axe/grpc.go b/cmd/protoc-gen-go-axe/grpc.go index 6926fdd..8170d8b 100644 --- a/cmd/protoc-gen-go-axe/grpc.go +++ b/cmd/protoc-gen-go-axe/grpc.go @@ -73,16 +73,28 @@ func generateFileContent(gen *protogen.Plugin, file *protogen.File, g *protogen. func genHttpService(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, service *protogen.Service) { serverType := service.GoName + "Server" - g.P("func Register", service.GoName, "HttpServer(s *", httpPackage.Ident("Server"), ",srv ", serverType, ") {") + g.P("type MiddlewareFunc func(", httpPackage.Ident("Handler"), ") ", httpPackage.Ident("Handler")) + g.P() + g.P("func Register", service.GoName, "HttpServer(s *", httpPackage.Ident("Server"), ",srv ", serverType, ", middlewares ...MiddlewareFunc) {") g.P("mux := ", httpPackage.Ident("NewServeMux"), "()") g.P() for _, method := range service.Methods { hname := genHttpServerMethod(gen, file, g, method) - g.P("mux.HandleFunc(", strconv.Quote(fmt.Sprintf("/%s/%s", service.Desc.FullName(), method.Desc.Name())), ", ", hname, ")") + + hf := fmt.Sprintf("_%s_%s_HandlerFunc", service.GoName, method.GoName) + g.P("var ", hf, " ", httpPackage.Ident("Handler")) + + g.P(hf, " = ", httpPackage.Ident("Handler"), "(", httpPackage.Ident("HandlerFunc"), "(", hname, "))") + g.P("for _, m := range middlewares {") + g.P(hf, " = m(", hf, ")") + g.P("}") + + g.P("mux.Handle(", strconv.Quote(fmt.Sprintf("/%s/%s", service.Desc.FullName(), method.Desc.Name())), ", ", hf, ")") g.P() } + g.P("s.Handler = mux") g.P("}") g.P() diff --git a/demo/main.go b/demo/main.go index 0a48096..1200ed6 100644 --- a/demo/main.go +++ b/demo/main.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "log" + "net/http" "path" "runtime" "time" @@ -18,7 +19,27 @@ type echoServer struct { } func (s *echoServer) Echo(ctx context.Context, req *pb.EchoRequest) (*pb.EchoResponse, error) { - return &pb.EchoResponse{Value: req.GetValue()}, nil + return &pb.EchoResponse{Value: "echo1_" + req.GetValue()}, nil +} + +func (s *echoServer) Echo2(ctx context.Context, req *pb.EchoRequest) (*pb.EchoResponse, error) { + return &pb.EchoResponse{Value: "echo2_" + req.GetValue()}, nil +} + +func middleware1(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Println("middleware1 - Before Handler") + next.ServeHTTP(w, r) + log.Println("middleware1 - After Handler") + }) +} + +func middleware2(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Println("middleware2 - Before Handler") + next.ServeHTTP(w, r) + log.Println("middleware2 - After Handler") + }) } ///////////////////////////// @@ -56,7 +77,7 @@ func main() { // register rpc pb.RegisterEchoServiceServer(s.GrpcServer(), &echoServer{}) // register http, pattern和handler会自动生成 - pb.RegisterEchoServiceHttpServer(s.HttpServer(), &echoServer{}) + pb.RegisterEchoServiceHttpServer(s.HttpServer(), &echoServer{}, []pb.MiddlewareFunc{middleware1, middleware2}...) // 调用client的例子 //go clientExample() diff --git a/demo/pb/echo.pb.go b/demo/pb/echo.pb.go index 597e87d..6c5ace3 100644 --- a/demo/pb/echo.pb.go +++ b/demo/pb/echo.pb.go @@ -122,11 +122,14 @@ var file_pb_echo_proto_rawDesc = []byte{ 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x22, 0x24, 0x0a, 0x0c, 0x45, 0x63, 0x68, 0x6f, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x32, 0x34, 0x0a, 0x0b, 0x45, 0x63, + 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x32, 0x5c, 0x0a, 0x0b, 0x45, 0x63, 0x68, 0x6f, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x25, 0x0a, 0x04, 0x45, 0x63, 0x68, 0x6f, 0x12, 0x0c, 0x2e, 0x45, 0x63, 0x68, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x0d, 0x2e, 0x45, 0x63, 0x68, 0x6f, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, - 0x42, 0x06, 0x5a, 0x04, 0x2e, 0x2f, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x12, 0x26, 0x0a, 0x05, 0x45, 0x63, 0x68, 0x6f, 0x32, 0x12, 0x0c, 0x2e, 0x45, 0x63, 0x68, 0x6f, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x0d, 0x2e, 0x45, 0x63, 0x68, 0x6f, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x06, 0x5a, 0x04, 0x2e, 0x2f, 0x70, 0x62, + 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -148,9 +151,11 @@ var file_pb_echo_proto_goTypes = []interface{}{ } var file_pb_echo_proto_depIdxs = []int32{ 0, // 0: EchoService.Echo:input_type -> EchoRequest - 1, // 1: EchoService.Echo:output_type -> EchoResponse - 1, // [1:2] is the sub-list for method output_type - 0, // [0:1] is the sub-list for method input_type + 0, // 1: EchoService.Echo2:input_type -> EchoRequest + 1, // 2: EchoService.Echo:output_type -> EchoResponse + 1, // 3: EchoService.Echo2:output_type -> EchoResponse + 2, // [2:4] is the sub-list for method output_type + 0, // [0:2] is the sub-list for method input_type 0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for field type_name diff --git a/demo/pb/echo.proto b/demo/pb/echo.proto index f33262d..b59c216 100644 --- a/demo/pb/echo.proto +++ b/demo/pb/echo.proto @@ -12,4 +12,5 @@ message EchoResponse{ service EchoService{ rpc Echo(EchoRequest) returns (EchoResponse){} + rpc Echo2(EchoRequest) returns (EchoResponse){} } \ No newline at end of file diff --git a/demo/pb/echo_axe.pb.go b/demo/pb/echo_axe.pb.go index 22c885e..c83c9d4 100644 --- a/demo/pb/echo_axe.pb.go +++ b/demo/pb/echo_axe.pb.go @@ -22,6 +22,7 @@ const _ = grpc.SupportPackageIsVersion7 // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. type EchoServiceClient interface { Echo(ctx context.Context, in *EchoRequest, opts ...grpc.CallOption) (*EchoResponse, error) + Echo2(ctx context.Context, in *EchoRequest, opts ...grpc.CallOption) (*EchoResponse, error) } type echoServiceClient struct { @@ -41,11 +42,21 @@ func (c *echoServiceClient) Echo(ctx context.Context, in *EchoRequest, opts ...g return out, nil } +func (c *echoServiceClient) Echo2(ctx context.Context, in *EchoRequest, opts ...grpc.CallOption) (*EchoResponse, error) { + out := new(EchoResponse) + err := c.cc.Invoke(ctx, "/EchoService/Echo2", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // EchoServiceServer is the server API for EchoService service. // All implementations must embed UnimplementedEchoServiceServer // for forward compatibility type EchoServiceServer interface { Echo(context.Context, *EchoRequest) (*EchoResponse, error) + Echo2(context.Context, *EchoRequest) (*EchoResponse, error) mustEmbedUnimplementedEchoServiceServer() } @@ -56,6 +67,9 @@ type UnimplementedEchoServiceServer struct { func (UnimplementedEchoServiceServer) Echo(context.Context, *EchoRequest) (*EchoResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method Echo not implemented") } +func (UnimplementedEchoServiceServer) Echo2(context.Context, *EchoRequest) (*EchoResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Echo2 not implemented") +} func (UnimplementedEchoServiceServer) mustEmbedUnimplementedEchoServiceServer() {} // UnsafeEchoServiceServer may be embedded to opt out of forward compatibility for this service. @@ -87,6 +101,24 @@ func _EchoService_Echo_Handler(srv interface{}, ctx context.Context, dec func(in return interceptor(ctx, in, info, handler) } +func _EchoService_Echo2_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(EchoRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(EchoServiceServer).Echo2(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/EchoService/Echo2", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(EchoServiceServer).Echo2(ctx, req.(*EchoRequest)) + } + return interceptor(ctx, in, info, handler) +} + // EchoService_ServiceDesc is the grpc.ServiceDesc for EchoService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -98,12 +130,18 @@ var EchoService_ServiceDesc = grpc.ServiceDesc{ MethodName: "Echo", Handler: _EchoService_Echo_Handler, }, + { + MethodName: "Echo2", + Handler: _EchoService_Echo2_Handler, + }, }, Streams: []grpc.StreamDesc{}, Metadata: "pb/echo.proto", } -func RegisterEchoServiceHttpServer(s *http.Server, srv EchoServiceServer) { +type MiddlewareFunc func(http.Handler) http.Handler + +func RegisterEchoServiceHttpServer(s *http.Server, srv EchoServiceServer, middlewares ...MiddlewareFunc) { mux := http.NewServeMux() _EchoService_Echo_Http_Handler := func(w http.ResponseWriter, req *http.Request) { @@ -133,7 +171,46 @@ func RegisterEchoServiceHttpServer(s *http.Server, srv EchoServiceServer) { } w.Write(b) } - mux.HandleFunc("/EchoService/Echo", _EchoService_Echo_Http_Handler) + var _EchoService_Echo_HandlerFunc http.Handler + _EchoService_Echo_HandlerFunc = http.Handler(http.HandlerFunc(_EchoService_Echo_Http_Handler)) + for _, m := range middlewares { + _EchoService_Echo_HandlerFunc = m(_EchoService_Echo_HandlerFunc) + } + mux.Handle("/EchoService/Echo", _EchoService_Echo_HandlerFunc) + + _EchoService_Echo2_Http_Handler := func(w http.ResponseWriter, req *http.Request) { + data, err := ioutil.ReadAll(req.Body) + defer req.Body.Close() + if err != nil { + w.Write([]byte(err.Error())) + return + } + var reqData EchoRequest + if len(data) != 0 { + err = json.Unmarshal(data, &reqData) + if err != nil { + w.Write([]byte(err.Error())) + return + } + } + respData, err := srv.Echo2(context.Background(), &reqData) + if err != nil { + w.Write([]byte(err.Error())) + return + } + b, err := json.Marshal(respData) + if err != nil { + w.Write([]byte(err.Error())) + return + } + w.Write(b) + } + var _EchoService_Echo2_HandlerFunc http.Handler + _EchoService_Echo2_HandlerFunc = http.Handler(http.HandlerFunc(_EchoService_Echo2_Http_Handler)) + for _, m := range middlewares { + _EchoService_Echo2_HandlerFunc = m(_EchoService_Echo2_HandlerFunc) + } + mux.Handle("/EchoService/Echo2", _EchoService_Echo2_HandlerFunc) s.Handler = mux }