From 610451d21ca2fa8f598ebc26b8ebb96e423e270a Mon Sep 17 00:00:00 2001 From: ByronHsu Date: Mon, 15 May 2023 09:35:59 -0700 Subject: [PATCH] Inject user identifier to ExecutionSpec (#549) Signed-off-by: byhsu This pr provides a middleware to inject user identifier to ExecutionSpec. By default, the value of the user identifier is userid from access/id token. Users can customize their own middleware and inject different values. --- auth/identity_context.go | 15 +++++++++++++++ auth/identity_context_test.go | 10 ++++++++++ auth/interceptor.go | 9 +++++++++ auth/interceptor_test.go | 18 ++++++++++++++++++ go.mod | 2 +- go.sum | 4 ++-- pkg/manager/impl/execution_manager.go | 15 ++++++++++++++- pkg/manager/impl/execution_manager_test.go | 7 ++++++- pkg/manager/impl/util/shared.go | 3 ++- pkg/server/service.go | 2 +- 10 files changed, 78 insertions(+), 7 deletions(-) diff --git a/auth/identity_context.go b/auth/identity_context.go index 4f36bb83e8..f1cfabe188 100644 --- a/auth/identity_context.go +++ b/auth/identity_context.go @@ -32,6 +32,11 @@ type IdentityContext struct { scopes *sets.String // Raw JWT token from the IDP. Set to a pointer to support the equal operator for this struct. claims *claimsType + // executionIdentity stores a unique string that can be used to identify the user associated with a given task. + // This identifier is passed down to the ExecutionSpec and can be used for various purposes, such as setting the user identifier on a pod label. + // By default, the execution user identifier is filled with the value of IdentityContext.userID. However, you can customize your middleware to assign other values if needed. + // Providing a user identifier can be useful for tracking tasks and associating them with specific users, especially in multi-user environments. + executionIdentity string } func (c IdentityContext) Audience() string { @@ -81,6 +86,16 @@ func (c IdentityContext) AuthenticatedAt() time.Time { return c.authenticatedAt } +func (c IdentityContext) ExecutionIdentity() string { + return c.executionIdentity +} + +// WithExecutionUserIdentifier creates a copy of the original identity context and attach ExecutionIdentity +func (c IdentityContext) WithExecutionUserIdentifier(euid string) IdentityContext { + c.executionIdentity = euid + return c +} + // NewIdentityContext creates a new IdentityContext. func NewIdentityContext(audience, userID, appID string, authenticatedAt time.Time, scopes sets.String, userInfo *service.UserInfoResponse, claims map[string]interface{}) ( IdentityContext, error) { diff --git a/auth/identity_context_test.go b/auth/identity_context_test.go index 5bee6347f3..1e72042be0 100644 --- a/auth/identity_context_test.go +++ b/auth/identity_context_test.go @@ -5,6 +5,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/util/sets" ) func TestGetClaims(t *testing.T) { @@ -23,3 +24,12 @@ func TestGetClaims(t *testing.T) { assert.NotEmpty(t, withClaimsCtx.UserInfo().AdditionalClaims) } + +func TestWithExecutionUserIdentifier(t *testing.T) { + idctx, err := NewIdentityContext("", "", "", time.Now(), sets.String{}, nil, nil) + assert.NoError(t, err) + newIDCtx := idctx.WithExecutionUserIdentifier("byhsu") + // make sure the original one is intact + assert.Equal(t, "", idctx.ExecutionIdentity()) + assert.Equal(t, "byhsu", newIDCtx.ExecutionIdentity()) +} diff --git a/auth/interceptor.go b/auth/interceptor.go index a4d78d4d6e..696cdb5bee 100644 --- a/auth/interceptor.go +++ b/auth/interceptor.go @@ -22,3 +22,12 @@ func BlanketAuthorization(ctx context.Context, req interface{}, _ *grpc.UnarySer return handler(ctx, req) } + +// ExecutionUserIdentifierInterceptor injects identityContext.UserID() to identityContext.executionIdentity +func ExecutionUserIdentifierInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) ( + resp interface{}, err error) { + identityContext := IdentityContextFromContext(ctx) + identityContext = identityContext.WithExecutionUserIdentifier(identityContext.UserID()) + ctx = identityContext.WithContext(ctx) + return handler(ctx, req) +} diff --git a/auth/interceptor_test.go b/auth/interceptor_test.go index 862f76a131..8418ae067d 100644 --- a/auth/interceptor_test.go +++ b/auth/interceptor_test.go @@ -59,3 +59,21 @@ func TestBlanketAuthorization(t *testing.T) { assert.False(t, handlerCalled) }) } + +func TestGetUserIdentityFromContext(t *testing.T) { + identityContext := IdentityContext{ + userID: "yeee", + } + + ctx := identityContext.WithContext(context.Background()) + + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + identityContext := IdentityContextFromContext(ctx) + euid := identityContext.ExecutionIdentity() + assert.Equal(t, euid, "yeee") + return nil, nil + } + + _, err := ExecutionUserIdentifierInterceptor(ctx, nil, nil, handler) + assert.NoError(t, err) +} diff --git a/go.mod b/go.mod index 7b9c13c417..ee218205c6 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ require ( github.com/cloudevents/sdk-go/v2 v2.8.0 github.com/coreos/go-oidc v2.2.1+incompatible github.com/evanphx/json-patch v4.12.0+incompatible - github.com/flyteorg/flyteidl v1.5.3 + github.com/flyteorg/flyteidl v1.5.5 github.com/flyteorg/flyteplugins v1.0.56 github.com/flyteorg/flytepropeller v1.1.87 github.com/flyteorg/flytestdlib v1.0.15 diff --git a/go.sum b/go.sum index 1ef64e3889..8ac127b642 100644 --- a/go.sum +++ b/go.sum @@ -312,8 +312,8 @@ github.com/fatih/structs v1.0.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= github.com/felixge/httpsnoop v1.0.1 h1:lvB5Jl89CsZtGIWuTcDM1E/vkVs49/Ml7JJe07l8SPQ= github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/flyteorg/flyteidl v1.5.3 h1:qHyU9kvcxGIkXoloi768ayx9FHrs961dZC3WYziGGZA= -github.com/flyteorg/flyteidl v1.5.3/go.mod h1:ckLjB51moX4L0oQml+WTCrPK50zrJf6IZJ6LPC0RB4I= +github.com/flyteorg/flyteidl v1.5.5 h1:tNNhuXPog4atAMSGE2kyAg6JzYy1TvjqrrQeh1EZVHs= +github.com/flyteorg/flyteidl v1.5.5/go.mod h1:EtE/muM2lHHgBabjYcxqe9TWeJSL0kXwbI0RgVwI4Og= github.com/flyteorg/flyteplugins v1.0.56 h1:kBTDgTpdSi7wcptk4cMwz5vfh1MU82VaUMMboe1InXw= github.com/flyteorg/flyteplugins v1.0.56/go.mod h1:aFCKSn8TPzxSAILIiogHtUnHlUCN9+y6Vf+r9R4KZDU= github.com/flyteorg/flytepropeller v1.1.87 h1:Px7ASDjrWyeVrUb15qXmhw9QK7xPcFjL5Yetr2P6iGM= diff --git a/pkg/manager/impl/execution_manager.go b/pkg/manager/impl/execution_manager.go index bf926319c0..a8e823ec94 100644 --- a/pkg/manager/impl/execution_manager.go +++ b/pkg/manager/impl/execution_manager.go @@ -403,6 +403,18 @@ func (m *ExecutionManager) getExecutionConfig(ctx context.Context, request *admi RunAs: &core.Identity{}, } } + + if workflowExecConfig.GetSecurityContext().GetRunAs() == nil { + workflowExecConfig.SecurityContext.RunAs = &core.Identity{} + } + + // In the case of reference_launch_plan subworkflow, the context comes from flytepropeller instead of the user side, so user auth is missing. + // We skip getUserIdentityFromContext but can still get ExecUserId because flytepropeller passes it in the execution request. + // https://github.com/flyteorg/flytepropeller/blob/03a6672960ed04e7687ba4f790fee9a02a4057fb/pkg/controller/nodes/subworkflow/launchplan/admin.go#L114 + if workflowExecConfig.GetSecurityContext().GetRunAs().GetExecutionIdentity() == "" { + workflowExecConfig.SecurityContext.RunAs.ExecutionIdentity = auth.IdentityContextFromContext(ctx).ExecutionIdentity() + } + logger.Infof(ctx, "getting the workflow execution config from application configuration") // Defaults to one from the application config return &workflowExecConfig, nil @@ -676,7 +688,8 @@ func resolveSecurityCtx(ctx context.Context, executionConfigSecurityCtx *core.Se // Use security context from the executionConfigSecurityCtx if its set and non empty or else resolve from authRole if executionConfigSecurityCtx != nil && executionConfigSecurityCtx.RunAs != nil && (len(executionConfigSecurityCtx.RunAs.K8SServiceAccount) > 0 || - len(executionConfigSecurityCtx.RunAs.IamRole) > 0) { + len(executionConfigSecurityCtx.RunAs.IamRole) > 0 || + len(executionConfigSecurityCtx.RunAs.ExecutionIdentity) > 0) { return executionConfigSecurityCtx } logger.Warn(ctx, "Setting security context from auth Role") diff --git a/pkg/manager/impl/execution_manager_test.go b/pkg/manager/impl/execution_manager_test.go index 7549e03f56..4e00115aff 100644 --- a/pkg/manager/impl/execution_manager_test.go +++ b/pkg/manager/impl/execution_manager_test.go @@ -4540,7 +4540,11 @@ func TestGetExecutionConfigOverrides(t *testing.T) { Envs: &admin.Envs{Values: requestEnvironmentVariables}, }, } - execConfig, err := executionManager.getExecutionConfig(context.TODO(), request, nil) + identityContext, err := auth.NewIdentityContext("", "", "", time.Now(), sets.String{}, nil, nil) + assert.NoError(t, err) + identityContext = identityContext.WithExecutionUserIdentifier("yeee") + ctx := identityContext.WithContext(context.Background()) + execConfig, err := executionManager.getExecutionConfig(ctx, request, nil) assert.NoError(t, err) assert.Equal(t, requestMaxParallelism, execConfig.MaxParallelism) assert.Equal(t, requestK8sServiceAccount, execConfig.SecurityContext.RunAs.K8SServiceAccount) @@ -4549,6 +4553,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { assert.Equal(t, requestOutputLocationPrefix, execConfig.RawOutputDataConfig.OutputLocationPrefix) assert.Equal(t, requestLabels, execConfig.GetLabels().Values) assert.Equal(t, requestAnnotations, execConfig.GetAnnotations().Values) + assert.Equal(t, "yeee", execConfig.GetSecurityContext().GetRunAs().GetExecutionIdentity()) assert.Equal(t, requestEnvironmentVariables, execConfig.GetEnvs().Values) }) t.Run("request with partial config", func(t *testing.T) { diff --git a/pkg/manager/impl/util/shared.go b/pkg/manager/impl/util/shared.go index 31cf896086..d9c61b121a 100644 --- a/pkg/manager/impl/util/shared.go +++ b/pkg/manager/impl/util/shared.go @@ -297,7 +297,8 @@ func MergeIntoExecConfig(workflowExecConfig admin.WorkflowExecutionConfig, spec if workflowExecConfig.GetSecurityContext() == nil && spec.GetSecurityContext() != nil { if spec.GetSecurityContext().GetRunAs() != nil && (len(spec.GetSecurityContext().GetRunAs().GetK8SServiceAccount()) > 0 || - len(spec.GetSecurityContext().GetRunAs().GetIamRole()) > 0) { + len(spec.GetSecurityContext().GetRunAs().GetIamRole()) > 0 || + len(spec.GetSecurityContext().GetRunAs().GetExecutionIdentity()) > 0) { workflowExecConfig.SecurityContext = spec.GetSecurityContext() } } diff --git a/pkg/server/service.go b/pkg/server/service.go index 1fe2f57c14..f3b27416f6 100644 --- a/pkg/server/service.go +++ b/pkg/server/service.go @@ -77,7 +77,7 @@ func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *c scope promutils.Scope, opts ...grpc.ServerOption) (*grpc.Server, error) { logger.Infof(ctx, "Registering default middleware with blanket auth validation") - pluginRegistry.RegisterDefault(plugins.PluginIDUnaryServiceMiddleware, grpcmiddleware.ChainUnaryServer(auth.BlanketAuthorization)) + pluginRegistry.RegisterDefault(plugins.PluginIDUnaryServiceMiddleware, grpcmiddleware.ChainUnaryServer(auth.BlanketAuthorization, auth.ExecutionUserIdentifierInterceptor)) // Not yet implemented for streaming var chainedUnaryInterceptors grpc.UnaryServerInterceptor