From 54621cc60117cf68183be24322119d85a80bb650 Mon Sep 17 00:00:00 2001 From: Mason Malone <651224+MasonM@users.noreply.github.com> Date: Mon, 30 Sep 2024 10:20:38 -0700 Subject: [PATCH] feat(cli): add version header + warning on client-server mismatch. Fixes #9212 (#13635) Signed-off-by: Mason Malone --- cmd/argo/commands/root.go | 4 + pkg/apiclient/argo-server-client.go | 7 +- server/apiserver/argoserver.go | 2 + test/e2e/argo_server_test.go | 19 ++-- util/cmd/cmd.go | 7 ++ util/cmd/cmd_test.go | 53 +++++++++++ util/grpc/interceptor.go | 48 ++++++++++ util/grpc/interceptor_test.go | 136 ++++++++++++++++++++++++++++ 8 files changed, 268 insertions(+), 8 deletions(-) create mode 100644 util/grpc/interceptor_test.go diff --git a/cmd/argo/commands/root.go b/cmd/argo/commands/root.go index 0e652f576346..0073adcae550 100644 --- a/cmd/argo/commands/root.go +++ b/cmd/argo/commands/root.go @@ -19,6 +19,7 @@ import ( "github.com/argoproj/argo-workflows/v3/cmd/argo/commands/executorplugin" "github.com/argoproj/argo-workflows/v3/cmd/argo/commands/template" cmdutil "github.com/argoproj/argo-workflows/v3/util/cmd" + grpcutil "github.com/argoproj/argo-workflows/v3/util/grpc" ) const ( @@ -125,6 +126,9 @@ If your server is behind an ingress with a path (running "argo server --base-hre var logLevel string var glogLevel int var verbose bool + command.PersistentPostRun = func(cmd *cobra.Command, args []string) { + cmdutil.PrintVersionMismatchWarning(argo.GetVersion(), grpcutil.LastSeenServerVersion) + } command.PersistentPreRun = func(cmd *cobra.Command, args []string) { if verbose { logLevel = "debug" diff --git a/pkg/apiclient/argo-server-client.go b/pkg/apiclient/argo-server-client.go index c7d08eb272bb..0a4419b2dc09 100644 --- a/pkg/apiclient/argo-server-client.go +++ b/pkg/apiclient/argo-server-client.go @@ -15,6 +15,7 @@ import ( workflowpkg "github.com/argoproj/argo-workflows/v3/pkg/apiclient/workflow" workflowarchivepkg "github.com/argoproj/argo-workflows/v3/pkg/apiclient/workflowarchive" workflowtemplatepkg "github.com/argoproj/argo-workflows/v3/pkg/apiclient/workflowtemplate" + grpcutil "github.com/argoproj/argo-workflows/v3/util/grpc" ) const ( @@ -65,7 +66,11 @@ func newClientConn(opts ArgoServerOpts) (*grpc.ClientConn, error) { if opts.Secure { creds = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{InsecureSkipVerify: opts.InsecureSkipVerify})) } - conn, err := grpc.Dial(opts.URL, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(MaxClientGRPCMessageSize)), creds) + conn, err := grpc.Dial(opts.URL, + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(MaxClientGRPCMessageSize)), + creds, + grpc.WithUnaryInterceptor(grpcutil.GetVersionHeaderClientUnaryInterceptor), + ) if err != nil { return nil, err } diff --git a/server/apiserver/argoserver.go b/server/apiserver/argoserver.go index 1255f6cce62d..88b5ed9cc930 100644 --- a/server/apiserver/argoserver.go +++ b/server/apiserver/argoserver.go @@ -305,6 +305,7 @@ func (as *argoServer) newGRPCServer(instanceIDService instanceid.Service, workfl grpcutil.ErrorTranslationUnaryServerInterceptor, as.gatekeeper.UnaryServerInterceptor(), grpcutil.RatelimitUnaryServerInterceptor(as.apiRateLimiter), + grpcutil.SetVersionHeaderUnaryServerInterceptor(argo.GetVersion()), )), grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( grpc_prometheus.StreamServerInterceptor, @@ -313,6 +314,7 @@ func (as *argoServer) newGRPCServer(instanceIDService instanceid.Service, workfl grpcutil.ErrorTranslationStreamServerInterceptor, as.gatekeeper.StreamServerInterceptor(), grpcutil.RatelimitStreamServerInterceptor(as.apiRateLimiter), + grpcutil.SetVersionHeaderStreamServerInterceptor(argo.GetVersion()), )), } diff --git a/test/e2e/argo_server_test.go b/test/e2e/argo_server_test.go index ffb0e02c1cb5..ff1a8e685649 100644 --- a/test/e2e/argo_server_test.go +++ b/test/e2e/argo_server_test.go @@ -88,12 +88,11 @@ func (s *ArgoServerSuite) TestInfo() { func (s *ArgoServerSuite) TestVersion() { s.Run("Version", func() { - s.e().GET("/api/v1/version"). + resp := s.e().GET("/api/v1/version"). Expect(). - Status(200). - JSON(). - Path("$.version"). - NotNull() + Status(200) + resp.JSON().Path("$.version").NotNull() + resp.Header("Grpc-Metadata-Argo-Version").NotEmpty() }) } @@ -343,14 +342,20 @@ func (s *ArgoServerSuite) TestUnauthorized() { defer func() { s.bearerToken = token }() s.e().GET("/api/v1/workflows/argo"). Expect(). - Status(401) + Status(401). + // Version header shouldn't be set on 401s for security, since that could be used by attackers to find vulnerable servers + Header("Grpc-Metadata-Argo-Version"). + IsEmpty() }) s.Run("Basic", func() { s.username = "garbage" defer func() { s.username = "" }() s.e().GET("/api/v1/workflows/argo"). Expect(). - Status(401) + Status(401). + // Version header shouldn't be set on 401s for security, since that could be used by attackers to find vulnerable servers + Header("Grpc-Metadata-Argo-Version"). + IsEmpty() }) } diff --git a/util/cmd/cmd.go b/util/cmd/cmd.go index b7e0bffc6884..87c136cb4f9f 100644 --- a/util/cmd/cmd.go +++ b/util/cmd/cmd.go @@ -45,6 +45,13 @@ func PrintVersion(cliName string, version wfv1.Version, short bool) { fmt.Printf(" Platform: %s\n", version.Platform) } +// PrintVersionMismatchWarning detects if there's a mismatch between the client and server versions and prints a warning if so +func PrintVersionMismatchWarning(clientVersion wfv1.Version, serverVersion string) { + if serverVersion != "" && clientVersion.GitTag != "" && serverVersion != clientVersion.Version { + log.Warnf("CLI version (%s) does not match server version (%s). This can lead to unexpected behavior.", clientVersion.Version, serverVersion) + } +} + // MustIsDir returns whether or not the given filePath is a directory. Exits if path does not exist func MustIsDir(filePath string) bool { fileInfo, err := os.Stat(filePath) diff --git a/util/cmd/cmd_test.go b/util/cmd/cmd_test.go index c71840192e8a..9e8bb6a3e4d8 100644 --- a/util/cmd/cmd_test.go +++ b/util/cmd/cmd_test.go @@ -3,6 +3,12 @@ package cmd import ( "reflect" "testing" + + log "github.com/sirupsen/logrus" + logtest "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/assert" + + wfv1 "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1" ) func TestMakeParseLabels(t *testing.T) { @@ -103,3 +109,50 @@ func TestIsURL(t *testing.T) { }) } } + +func TestPrintVersionMismatchWarning(t *testing.T) { + tests := []struct { + name string + clientVersion *wfv1.Version + serverVersion string + expectedLog string + }{ + { + name: "server version not set", + clientVersion: &wfv1.Version{ + Version: "v3.1.0", + GitTag: "v3.1.0", + }, + serverVersion: "", + }, + { + name: "client version is untagged", + clientVersion: &wfv1.Version{ + Version: "v3.1.0", + }, + serverVersion: "v3.1.1", + }, + { + name: "version mismatch", + clientVersion: &wfv1.Version{ + Version: "v3.1.0", + GitTag: "v3.1.0", + }, + serverVersion: "v3.1.1", + expectedLog: "CLI version (v3.1.0) does not match server version (v3.1.1). This can lead to unexpected behavior.", + }, + } + hook := &logtest.Hook{} + log.AddHook(hook) + defer log.StandardLogger().ReplaceHooks(nil) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + PrintVersionMismatchWarning(*tt.clientVersion, tt.serverVersion) + if tt.expectedLog != "" { + assert.Equal(t, tt.expectedLog, hook.LastEntry().Message) + } else { + assert.Nil(t, hook.LastEntry()) + } + }) + } +} diff --git a/util/grpc/interceptor.go b/util/grpc/interceptor.go index df6b5e40d5f0..9773cb1b3f77 100644 --- a/util/grpc/interceptor.go +++ b/util/grpc/interceptor.go @@ -5,9 +5,12 @@ import ( "runtime/debug" "strings" + wfv1 "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1" + log "github.com/sirupsen/logrus" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/status" @@ -40,7 +43,12 @@ func PanicLoggerStreamServerInterceptor(log *log.Entry) grpc.StreamServerInterce } } +const ( + ArgoVersionHeader = "argo-version" +) + var ( + LastSeenServerVersion string ErrorTranslationUnaryServerInterceptor = func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { resp, err = handler(ctx, req) return resp, TranslateError(err) @@ -50,6 +58,46 @@ var ( } ) +// SetVersionHeaderUnaryServerInterceptor returns a new unary server interceptor that sets the argo-version header +func SetVersionHeaderUnaryServerInterceptor(version wfv1.Version) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + m, origErr := handler(ctx, req) + if origErr == nil { + // Don't set header if there was an error because attackers could use it to find vulnerable Argo servers + err := grpc.SetHeader(ctx, metadata.Pairs(ArgoVersionHeader, version.Version)) + if err != nil { + log.Warnf("Failed to set header '%s': %s", ArgoVersionHeader, err) + } + } + return m, origErr + } +} + +// SetVersionHeaderStreamServerInterceptor returns a new stream server interceptor that sets the argo-version header +func SetVersionHeaderStreamServerInterceptor(version wfv1.Version) grpc.StreamServerInterceptor { + return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + origErr := handler(srv, ss) + if origErr == nil { + // Don't set header if there was an error because attackers could use it to find vulnerable Argo servers + err := ss.SetHeader(metadata.Pairs(ArgoVersionHeader, version.Version)) + if err != nil { + log.Warnf("Failed to set header '%s': %s", ArgoVersionHeader, err) + } + } + return origErr + } +} + +// GetVersionHeaderClientUnaryInterceptor returns a new unary client interceptor that extracts the argo-version from the response and sets the global variable LastSeenServerVersion +func GetVersionHeaderClientUnaryInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + var headers metadata.MD + err := invoker(ctx, method, req, reply, cc, append(opts, grpc.Header(&headers))...) + if err == nil && headers != nil && headers.Get(ArgoVersionHeader) != nil { + LastSeenServerVersion = headers.Get(ArgoVersionHeader)[0] + } + return err +} + // RatelimitUnaryServerInterceptor returns a new unary server interceptor that performs request rate limiting. func RatelimitUnaryServerInterceptor(ratelimiter limiter.Store) grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { diff --git a/util/grpc/interceptor_test.go b/util/grpc/interceptor_test.go new file mode 100644 index 000000000000..5cec14305acd --- /dev/null +++ b/util/grpc/interceptor_test.go @@ -0,0 +1,136 @@ +package grpc + +import ( + "context" + "errors" + "testing" + + wfv1 "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" +) + +type mockServerTransportStream struct { + header metadata.MD + isError bool +} + +func (mockServerTransportStream) Method() string { return "" } +func (msts *mockServerTransportStream) SetHeader(md metadata.MD) error { + if msts.isError { + return errors.New("simulate error setting header") + } + msts.header = md + return nil +} +func (mockServerTransportStream) SendHeader(md metadata.MD) error { return nil } +func (mockServerTransportStream) SetTrailer(md metadata.MD) error { return nil } + +var _ grpc.ServerTransportStream = &mockServerTransportStream{} + +func TestSetVersionHeaderUnaryServerInterceptor(t *testing.T) { + version := &wfv1.Version{Version: "v3.1.0"} + mockReturn := "successful return" + + t.Run("success", func(t *testing.T) { + handler := func(ctx context.Context, req interface{}) (interface{}, error) { return mockReturn, nil } + msts := &mockServerTransportStream{} + ctx := grpc.NewContextWithServerTransportStream(context.Background(), msts) + interceptor := SetVersionHeaderUnaryServerInterceptor(*version) + + m, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{}, handler) + + require.NoError(t, err) + assert.Equal(t, mockReturn, m) + assert.Equal(t, metadata.Pairs(ArgoVersionHeader, version.Version), msts.header) + }) + + t.Run("upstream error handling", func(t *testing.T) { + handler := func(ctx context.Context, req interface{}) (interface{}, error) { return nil, errors.New("error") } + msts := &mockServerTransportStream{} + ctx := grpc.NewContextWithServerTransportStream(context.Background(), msts) + interceptor := SetVersionHeaderUnaryServerInterceptor(*version) + + _, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{}, handler) + + require.Error(t, err) + assert.Empty(t, msts.header) + }) + + t.Run("SetHeader error handling", func(t *testing.T) { + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return mockReturn, nil + } + msts := &mockServerTransportStream{isError: true} + ctx := grpc.NewContextWithServerTransportStream(context.Background(), msts) + interceptor := SetVersionHeaderUnaryServerInterceptor(*version) + + m, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{}, handler) + + require.NoError(t, err) + require.Equal(t, mockReturn, m) + assert.Empty(t, msts.header) + }) +} + +type mockServerStream struct { + header metadata.MD + isError bool +} + +func (msts mockServerStream) SetHeader(md metadata.MD) error { + if msts.isError { + return errors.New("simulate error setting header") + } + msts.header.Set(ArgoVersionHeader, md.Get(ArgoVersionHeader)...) + return nil +} +func (mockServerStream) SendHeader(md metadata.MD) error { return nil } +func (mockServerStream) SetTrailer(md metadata.MD) {} +func (mockServerStream) Context() context.Context { return context.Background() } +func (mockServerStream) SendMsg(m any) error { return nil } +func (mockServerStream) RecvMsg(m any) error { return nil } + +var _ grpc.ServerStream = &mockServerStream{} + +func TestSetVersionHeaderStreamServerInterceptor(t *testing.T) { + version := &wfv1.Version{Version: "v3.1.0"} + + t.Run("success", func(t *testing.T) { + handler := func(srv any, stream grpc.ServerStream) error { return nil } + msts := &mockServerStream{header: metadata.New(nil)} + interceptor := SetVersionHeaderStreamServerInterceptor(*version) + + err := interceptor(nil, msts, nil, handler) + + require.NoError(t, err) + assert.Equal(t, metadata.Pairs(ArgoVersionHeader, version.Version), msts.header) + }) + + t.Run("upstream error handling", func(t *testing.T) { + handler := func(srv any, stream grpc.ServerStream) error { + return errors.New("test error") + } + msts := &mockServerStream{header: metadata.New(nil)} + interceptor := SetVersionHeaderStreamServerInterceptor(*version) + + err := interceptor(nil, msts, nil, handler) + + require.Error(t, err, "test error") + assert.Empty(t, msts.header) + }) + + t.Run("SetHeader error handling", func(t *testing.T) { + handler := func(srv any, stream grpc.ServerStream) error { return nil } + msts := &mockServerStream{isError: true} + interceptor := SetVersionHeaderStreamServerInterceptor(*version) + + err := interceptor(nil, msts, nil, handler) + + require.NoError(t, err) + assert.Empty(t, msts.header) + }) +}