diff --git a/thrift/idl.go b/thrift/idl.go index 236000b5..dcaa81e3 100644 --- a/thrift/idl.go +++ b/thrift/idl.go @@ -52,6 +52,9 @@ type Options struct { // ParseServiceMode indicates how to parse service. ParseServiceMode meta.ParseServiceMode + // ServiceName indicates which idl service to be parsed. + ServiceName string + // MapFieldWay indicates StructDescriptor.FieldByKey() uses alias to map field. // By default, we use alias to map, and alias always equals to field name if not given. MapFieldWay meta.MapFieldWay @@ -247,15 +250,26 @@ func parse(ctx context.Context, tree *parser.Thrift, mode meta.ParseServiceMode, // support one service svcs := tree.Services - switch mode { - case meta.LastServiceOnly: - svcs = svcs[len(svcs)-1:] - sDsc.name = svcs[len(svcs)-1].Name - case meta.FirstServiceOnly: - svcs = svcs[:1] - sDsc.name = svcs[0].Name - case meta.CombineServices: - sDsc.name = "CombinedServices" + + // if an idl service name is specified, it takes precedence over parse mode + if opts.ServiceName != "" { + var err error + svcs, err = getTargetService(svcs, opts.ServiceName) + if err != nil { + return nil, err + } + sDsc.name = opts.ServiceName + } else { + switch mode { + case meta.LastServiceOnly: + svcs = svcs[len(svcs)-1:] + sDsc.name = svcs[len(svcs)-1].Name + case meta.FirstServiceOnly: + svcs = svcs[:1] + sDsc.name = svcs[0].Name + case meta.CombineServices: + sDsc.name = "CombinedServices" + } } for _, svc := range svcs { @@ -289,6 +303,15 @@ func parse(ctx context.Context, tree *parser.Thrift, mode meta.ParseServiceMode, return sDsc, nil } +func getTargetService(svcs []*parser.Service, serviceName string) ([]*parser.Service, error) { + for _, svc := range svcs { + if svc.Name == serviceName { + return []*parser.Service{svc}, nil + } + } + return nil, fmt.Errorf("the idl service name %s is not in the idl. Please check your idl", serviceName) +} + type funcTreePair struct { tree *parser.Thrift fn *parser.Function diff --git a/thrift/idl_test.go b/thrift/idl_test.go index 0cc1cf76..ebfd49d8 100644 --- a/thrift/idl_test.go +++ b/thrift/idl_test.go @@ -437,3 +437,41 @@ func TestStreamingFunctionDescriptorFromContent(t *testing.T) { require.Equal(t, "Request", dsc.Functions()["EchoClient"].Request().Struct().Name()) require.Equal(t, "", dsc.Functions()["EchoUnary"].Request().Struct().Name()) } + +func TestParseWithServiceName(t *testing.T) { + path := "a/b/main.thrift" + content := ` + namespace go thrift + + struct Request { + 1: required string message, + } + + struct Response { + 1: required string message, + } + + service Service1 { + Response Test(1: Request req) + } + + service Service2 { + Response Test(1: Request req) + } + + service Service3 { + Response Test(1: Request req) + } + ` + + opts := Options{ServiceName: "Service2"} + p, err := opts.NewDescritorFromContent(context.Background(), path, content, nil, false) + require.Nil(t, err) + require.Equal(t, p.Name(), "Service2") + + opts = Options{ServiceName: "UnknownService"} + p, err = opts.NewDescritorFromContent(context.Background(), path, content, nil, false) + require.NotNil(t, err) + require.Nil(t, p) + require.Equal(t, err.Error(), "the idl service name UnknownService is not in the idl. Please check your idl") +}