Skip to content

Commit

Permalink
feat(cli): add version header + warning on client-server mismatch. Fixes
Browse files Browse the repository at this point in the history
 #9212 (#13635)

Signed-off-by: Mason Malone <[email protected]>
  • Loading branch information
MasonM authored Sep 30, 2024
1 parent 2c3423e commit 54621cc
Show file tree
Hide file tree
Showing 8 changed files with 268 additions and 8 deletions.
4 changes: 4 additions & 0 deletions cmd/argo/commands/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/argoproj/argo-workflows/v3/cmd/argo/commands/executorplugin"
"github.com/argoproj/argo-workflows/v3/cmd/argo/commands/template"
cmdutil "github.com/argoproj/argo-workflows/v3/util/cmd"
grpcutil "github.com/argoproj/argo-workflows/v3/util/grpc"
)

const (
Expand Down Expand Up @@ -125,6 +126,9 @@ If your server is behind an ingress with a path (running "argo server --base-hre
var logLevel string
var glogLevel int
var verbose bool
command.PersistentPostRun = func(cmd *cobra.Command, args []string) {
cmdutil.PrintVersionMismatchWarning(argo.GetVersion(), grpcutil.LastSeenServerVersion)
}
command.PersistentPreRun = func(cmd *cobra.Command, args []string) {
if verbose {
logLevel = "debug"
Expand Down
7 changes: 6 additions & 1 deletion pkg/apiclient/argo-server-client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
workflowpkg "github.com/argoproj/argo-workflows/v3/pkg/apiclient/workflow"
workflowarchivepkg "github.com/argoproj/argo-workflows/v3/pkg/apiclient/workflowarchive"
workflowtemplatepkg "github.com/argoproj/argo-workflows/v3/pkg/apiclient/workflowtemplate"
grpcutil "github.com/argoproj/argo-workflows/v3/util/grpc"
)

const (
Expand Down Expand Up @@ -65,7 +66,11 @@ func newClientConn(opts ArgoServerOpts) (*grpc.ClientConn, error) {
if opts.Secure {
creds = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{InsecureSkipVerify: opts.InsecureSkipVerify}))
}
conn, err := grpc.Dial(opts.URL, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(MaxClientGRPCMessageSize)), creds)
conn, err := grpc.Dial(opts.URL,
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(MaxClientGRPCMessageSize)),
creds,
grpc.WithUnaryInterceptor(grpcutil.GetVersionHeaderClientUnaryInterceptor),
)
if err != nil {
return nil, err
}
Expand Down
2 changes: 2 additions & 0 deletions server/apiserver/argoserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ func (as *argoServer) newGRPCServer(instanceIDService instanceid.Service, workfl
grpcutil.ErrorTranslationUnaryServerInterceptor,
as.gatekeeper.UnaryServerInterceptor(),
grpcutil.RatelimitUnaryServerInterceptor(as.apiRateLimiter),
grpcutil.SetVersionHeaderUnaryServerInterceptor(argo.GetVersion()),
)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
grpc_prometheus.StreamServerInterceptor,
Expand All @@ -313,6 +314,7 @@ func (as *argoServer) newGRPCServer(instanceIDService instanceid.Service, workfl
grpcutil.ErrorTranslationStreamServerInterceptor,
as.gatekeeper.StreamServerInterceptor(),
grpcutil.RatelimitStreamServerInterceptor(as.apiRateLimiter),
grpcutil.SetVersionHeaderStreamServerInterceptor(argo.GetVersion()),
)),
}

Expand Down
19 changes: 12 additions & 7 deletions test/e2e/argo_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,11 @@ func (s *ArgoServerSuite) TestInfo() {

func (s *ArgoServerSuite) TestVersion() {
s.Run("Version", func() {
s.e().GET("/api/v1/version").
resp := s.e().GET("/api/v1/version").
Expect().
Status(200).
JSON().
Path("$.version").
NotNull()
Status(200)
resp.JSON().Path("$.version").NotNull()
resp.Header("Grpc-Metadata-Argo-Version").NotEmpty()
})
}

Expand Down Expand Up @@ -343,14 +342,20 @@ func (s *ArgoServerSuite) TestUnauthorized() {
defer func() { s.bearerToken = token }()
s.e().GET("/api/v1/workflows/argo").
Expect().
Status(401)
Status(401).
// Version header shouldn't be set on 401s for security, since that could be used by attackers to find vulnerable servers
Header("Grpc-Metadata-Argo-Version").
IsEmpty()
})
s.Run("Basic", func() {
s.username = "garbage"
defer func() { s.username = "" }()
s.e().GET("/api/v1/workflows/argo").
Expect().
Status(401)
Status(401).
// Version header shouldn't be set on 401s for security, since that could be used by attackers to find vulnerable servers
Header("Grpc-Metadata-Argo-Version").
IsEmpty()
})
}

Expand Down
7 changes: 7 additions & 0 deletions util/cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ func PrintVersion(cliName string, version wfv1.Version, short bool) {
fmt.Printf(" Platform: %s\n", version.Platform)
}

// PrintVersionMismatchWarning detects if there's a mismatch between the client and server versions and prints a warning if so
func PrintVersionMismatchWarning(clientVersion wfv1.Version, serverVersion string) {
if serverVersion != "" && clientVersion.GitTag != "" && serverVersion != clientVersion.Version {
log.Warnf("CLI version (%s) does not match server version (%s). This can lead to unexpected behavior.", clientVersion.Version, serverVersion)
}
}

// MustIsDir returns whether or not the given filePath is a directory. Exits if path does not exist
func MustIsDir(filePath string) bool {
fileInfo, err := os.Stat(filePath)
Expand Down
53 changes: 53 additions & 0 deletions util/cmd/cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ package cmd
import (
"reflect"
"testing"

log "github.com/sirupsen/logrus"
logtest "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/assert"

wfv1 "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1"
)

func TestMakeParseLabels(t *testing.T) {
Expand Down Expand Up @@ -103,3 +109,50 @@ func TestIsURL(t *testing.T) {
})
}
}

func TestPrintVersionMismatchWarning(t *testing.T) {
tests := []struct {
name string
clientVersion *wfv1.Version
serverVersion string
expectedLog string
}{
{
name: "server version not set",
clientVersion: &wfv1.Version{
Version: "v3.1.0",
GitTag: "v3.1.0",
},
serverVersion: "",
},
{
name: "client version is untagged",
clientVersion: &wfv1.Version{
Version: "v3.1.0",
},
serverVersion: "v3.1.1",
},
{
name: "version mismatch",
clientVersion: &wfv1.Version{
Version: "v3.1.0",
GitTag: "v3.1.0",
},
serverVersion: "v3.1.1",
expectedLog: "CLI version (v3.1.0) does not match server version (v3.1.1). This can lead to unexpected behavior.",
},
}
hook := &logtest.Hook{}
log.AddHook(hook)
defer log.StandardLogger().ReplaceHooks(nil)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
PrintVersionMismatchWarning(*tt.clientVersion, tt.serverVersion)
if tt.expectedLog != "" {
assert.Equal(t, tt.expectedLog, hook.LastEntry().Message)
} else {
assert.Nil(t, hook.LastEntry())
}
})
}
}
48 changes: 48 additions & 0 deletions util/grpc/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@ import (
"runtime/debug"
"strings"

wfv1 "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1"

log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"

Expand Down Expand Up @@ -40,7 +43,12 @@ func PanicLoggerStreamServerInterceptor(log *log.Entry) grpc.StreamServerInterce
}
}

const (
ArgoVersionHeader = "argo-version"
)

var (
LastSeenServerVersion string
ErrorTranslationUnaryServerInterceptor = func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
resp, err = handler(ctx, req)
return resp, TranslateError(err)
Expand All @@ -50,6 +58,46 @@ var (
}
)

// SetVersionHeaderUnaryServerInterceptor returns a new unary server interceptor that sets the argo-version header
func SetVersionHeaderUnaryServerInterceptor(version wfv1.Version) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
m, origErr := handler(ctx, req)
if origErr == nil {
// Don't set header if there was an error because attackers could use it to find vulnerable Argo servers
err := grpc.SetHeader(ctx, metadata.Pairs(ArgoVersionHeader, version.Version))
if err != nil {
log.Warnf("Failed to set header '%s': %s", ArgoVersionHeader, err)
}
}
return m, origErr
}
}

// SetVersionHeaderStreamServerInterceptor returns a new stream server interceptor that sets the argo-version header
func SetVersionHeaderStreamServerInterceptor(version wfv1.Version) grpc.StreamServerInterceptor {
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
origErr := handler(srv, ss)
if origErr == nil {
// Don't set header if there was an error because attackers could use it to find vulnerable Argo servers
err := ss.SetHeader(metadata.Pairs(ArgoVersionHeader, version.Version))
if err != nil {
log.Warnf("Failed to set header '%s': %s", ArgoVersionHeader, err)
}
}
return origErr
}
}

// GetVersionHeaderClientUnaryInterceptor returns a new unary client interceptor that extracts the argo-version from the response and sets the global variable LastSeenServerVersion
func GetVersionHeaderClientUnaryInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
var headers metadata.MD
err := invoker(ctx, method, req, reply, cc, append(opts, grpc.Header(&headers))...)
if err == nil && headers != nil && headers.Get(ArgoVersionHeader) != nil {
LastSeenServerVersion = headers.Get(ArgoVersionHeader)[0]
}
return err
}

// RatelimitUnaryServerInterceptor returns a new unary server interceptor that performs request rate limiting.
func RatelimitUnaryServerInterceptor(ratelimiter limiter.Store) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
Expand Down
136 changes: 136 additions & 0 deletions util/grpc/interceptor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
package grpc

import (
"context"
"errors"
"testing"

wfv1 "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)

type mockServerTransportStream struct {
header metadata.MD
isError bool
}

func (mockServerTransportStream) Method() string { return "" }
func (msts *mockServerTransportStream) SetHeader(md metadata.MD) error {
if msts.isError {
return errors.New("simulate error setting header")
}
msts.header = md
return nil
}
func (mockServerTransportStream) SendHeader(md metadata.MD) error { return nil }
func (mockServerTransportStream) SetTrailer(md metadata.MD) error { return nil }

var _ grpc.ServerTransportStream = &mockServerTransportStream{}

func TestSetVersionHeaderUnaryServerInterceptor(t *testing.T) {
version := &wfv1.Version{Version: "v3.1.0"}
mockReturn := "successful return"

t.Run("success", func(t *testing.T) {
handler := func(ctx context.Context, req interface{}) (interface{}, error) { return mockReturn, nil }
msts := &mockServerTransportStream{}
ctx := grpc.NewContextWithServerTransportStream(context.Background(), msts)
interceptor := SetVersionHeaderUnaryServerInterceptor(*version)

m, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{}, handler)

require.NoError(t, err)
assert.Equal(t, mockReturn, m)
assert.Equal(t, metadata.Pairs(ArgoVersionHeader, version.Version), msts.header)
})

t.Run("upstream error handling", func(t *testing.T) {
handler := func(ctx context.Context, req interface{}) (interface{}, error) { return nil, errors.New("error") }
msts := &mockServerTransportStream{}
ctx := grpc.NewContextWithServerTransportStream(context.Background(), msts)
interceptor := SetVersionHeaderUnaryServerInterceptor(*version)

_, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{}, handler)

require.Error(t, err)
assert.Empty(t, msts.header)
})

t.Run("SetHeader error handling", func(t *testing.T) {
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return mockReturn, nil
}
msts := &mockServerTransportStream{isError: true}
ctx := grpc.NewContextWithServerTransportStream(context.Background(), msts)
interceptor := SetVersionHeaderUnaryServerInterceptor(*version)

m, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{}, handler)

require.NoError(t, err)
require.Equal(t, mockReturn, m)
assert.Empty(t, msts.header)
})
}

type mockServerStream struct {
header metadata.MD
isError bool
}

func (msts mockServerStream) SetHeader(md metadata.MD) error {
if msts.isError {
return errors.New("simulate error setting header")
}
msts.header.Set(ArgoVersionHeader, md.Get(ArgoVersionHeader)...)
return nil
}
func (mockServerStream) SendHeader(md metadata.MD) error { return nil }
func (mockServerStream) SetTrailer(md metadata.MD) {}
func (mockServerStream) Context() context.Context { return context.Background() }
func (mockServerStream) SendMsg(m any) error { return nil }
func (mockServerStream) RecvMsg(m any) error { return nil }

var _ grpc.ServerStream = &mockServerStream{}

func TestSetVersionHeaderStreamServerInterceptor(t *testing.T) {
version := &wfv1.Version{Version: "v3.1.0"}

t.Run("success", func(t *testing.T) {
handler := func(srv any, stream grpc.ServerStream) error { return nil }
msts := &mockServerStream{header: metadata.New(nil)}
interceptor := SetVersionHeaderStreamServerInterceptor(*version)

err := interceptor(nil, msts, nil, handler)

require.NoError(t, err)
assert.Equal(t, metadata.Pairs(ArgoVersionHeader, version.Version), msts.header)
})

t.Run("upstream error handling", func(t *testing.T) {
handler := func(srv any, stream grpc.ServerStream) error {
return errors.New("test error")
}
msts := &mockServerStream{header: metadata.New(nil)}
interceptor := SetVersionHeaderStreamServerInterceptor(*version)

err := interceptor(nil, msts, nil, handler)

require.Error(t, err, "test error")
assert.Empty(t, msts.header)
})

t.Run("SetHeader error handling", func(t *testing.T) {
handler := func(srv any, stream grpc.ServerStream) error { return nil }
msts := &mockServerStream{isError: true}
interceptor := SetVersionHeaderStreamServerInterceptor(*version)

err := interceptor(nil, msts, nil, handler)

require.NoError(t, err)
assert.Empty(t, msts.header)
})
}

0 comments on commit 54621cc

Please sign in to comment.