From 25531535a39ca66a915e45b2d6f32251e4c547f3 Mon Sep 17 00:00:00 2001 From: Paul Dittamo <37558497+pvditt@users.noreply.github.com> Date: Wed, 6 Dec 2023 12:04:59 -0800 Subject: [PATCH 1/2] [BUG] Fix setting of service_account from PodTemplate (#4536) * don't override service account from security context if already set Signed-off-by: Paul Dittamo * update unit test Signed-off-by: Paul Dittamo * cleanup Signed-off-by: Paul Dittamo * typo Signed-off-by: Paul Dittamo * clean up sytling Signed-off-by: Paul Dittamo --------- Signed-off-by: Paul Dittamo Co-authored-by: Dan Rammer --- .../tasks/plugins/k8s/pod/container_test.go | 122 ++++++++++++++---- .../go/tasks/plugins/k8s/pod/plugin.go | 4 +- 2 files changed, 101 insertions(+), 25 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/pod/container_test.go b/flyteplugins/go/tasks/plugins/k8s/pod/container_test.go index 19000e0c72..9a70f906b9 100644 --- a/flyteplugins/go/tasks/plugins/k8s/pod/container_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/pod/container_test.go @@ -19,6 +19,7 @@ import ( flytek8sConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" pluginsIOMock "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" ) var containerResourceRequirements = &v1.ResourceRequirements{ @@ -28,6 +29,12 @@ var containerResourceRequirements = &v1.ResourceRequirements{ }, } +var ( + serviceAccount = "service-account" + podTemplateServiceAccount = "test-service-account" + securityContextServiceAccount = "security-context-service-account" +) + func dummyContainerTaskTemplate(command []string, args []string) *core.TaskTemplate { return &core.TaskTemplate{ Type: "test", @@ -40,7 +47,40 @@ func dummyContainerTaskTemplate(command []string, args []string) *core.TaskTempl } } -func dummyContainerTaskMetadata(resources *v1.ResourceRequirements, extendedResources *core.ExtendedResources) pluginsCore.TaskExecutionMetadata { +func dummyContainerTaskTemplateWithPodSpec(command []string, args []string) *core.TaskTemplate { + + podSpec := v1.PodSpec{ + Containers: []v1.Container{ + v1.Container{ + Name: "test-image", + Command: command, + Args: args, + }, + }, + ServiceAccountName: podTemplateServiceAccount, + } + + podSpecPb, err := utils.MarshalObjToStruct(podSpec) + if err != nil { + panic(err) + } + + taskTemplate := &core.TaskTemplate{ + Type: "test", + Target: &core.TaskTemplate_K8SPod{ + K8SPod: &core.K8SPod{ + PodSpec: podSpecPb, + }, + }, + Config: map[string]string{ + "primary_container_name": "test-image", + }, + } + + return taskTemplate +} + +func dummyContainerTaskMetadata(resources *v1.ResourceRequirements, extendedResources *core.ExtendedResources, returnsServiceAccount bool) pluginsCore.TaskExecutionMetadata { taskMetadata := &pluginsCoreMock.TaskExecutionMetadata{} taskMetadata.On("GetNamespace").Return("test-namespace") taskMetadata.On("GetAnnotations").Return(map[string]string{"annotation-1": "val1"}) @@ -49,9 +89,13 @@ func dummyContainerTaskMetadata(resources *v1.ResourceRequirements, extendedReso Kind: "node", Name: "blah", }) - taskMetadata.On("GetK8sServiceAccount").Return("service-account") + if returnsServiceAccount { + taskMetadata.On("GetK8sServiceAccount").Return(serviceAccount) + } else { + taskMetadata.On("GetK8sServiceAccount").Return("") + } taskMetadata.On("GetSecurityContext").Return(core.SecurityContext{ - RunAs: &core.Identity{K8SServiceAccount: "service-account"}, + RunAs: &core.Identity{K8SServiceAccount: securityContextServiceAccount}, }) taskMetadata.On("GetOwnerID").Return(types.NamespacedName{ Namespace: "test-namespace", @@ -81,8 +125,7 @@ func dummyContainerTaskMetadata(resources *v1.ResourceRequirements, extendedReso return taskMetadata } -func dummyContainerTaskContext(taskTemplate *core.TaskTemplate, resources *v1.ResourceRequirements, extendedResources *core.ExtendedResources) pluginsCore.TaskExecutionContext { - dummyTaskMetadata := dummyContainerTaskMetadata(resources, extendedResources) +func dummyContainerTaskContext(taskTemplate *core.TaskTemplate, taskMetadata pluginsCore.TaskExecutionMetadata) pluginsCore.TaskExecutionContext { taskCtx := &pluginsCoreMock.TaskExecutionContext{} inputReader := &pluginsIOMock.InputReader{} inputReader.OnGetInputPrefixPath().Return("test-data-reference") @@ -103,7 +146,7 @@ func dummyContainerTaskContext(taskTemplate *core.TaskTemplate, resources *v1.Re taskReader.OnReadMatch(mock.Anything).Return(taskTemplate, nil) taskCtx.OnTaskReader().Return(taskReader) - taskCtx.OnTaskExecutionMetadata().Return(dummyTaskMetadata) + taskCtx.OnTaskExecutionMetadata().Return(taskMetadata) pluginStateReader := &pluginsCoreMock.PluginStateReader{} pluginStateReader.OnGetMatch(mock.Anything).Return(0, nil) @@ -125,26 +168,54 @@ func TestContainerTaskExecutor_BuildIdentityResource(t *testing.T) { func TestContainerTaskExecutor_BuildResource(t *testing.T) { command := []string{"command"} args := []string{"{{.Input}}"} - taskTemplate := dummyContainerTaskTemplate(command, args) - taskCtx := dummyContainerTaskContext(taskTemplate, containerResourceRequirements, nil) + testCases := []struct { + name string + taskTemplate *core.TaskTemplate + taskMetadata pluginsCore.TaskExecutionMetadata + expectServiceAccount string + }{ + { + name: "BuildResource", + taskTemplate: dummyContainerTaskTemplate(command, args), + taskMetadata: dummyContainerTaskMetadata(containerResourceRequirements, nil, true), + expectServiceAccount: serviceAccount, + }, + { + name: "BuildResource_PodTemplate", + taskTemplate: dummyContainerTaskTemplateWithPodSpec(command, args), + taskMetadata: dummyContainerTaskMetadata(containerResourceRequirements, nil, true), + expectServiceAccount: podTemplateServiceAccount, + }, + { + name: "BuildResource_SecurityContext", + taskTemplate: dummyContainerTaskTemplate(command, args), + taskMetadata: dummyContainerTaskMetadata(containerResourceRequirements, nil, false), + expectServiceAccount: securityContextServiceAccount, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + taskCtx := dummyContainerTaskContext(tc.taskTemplate, tc.taskMetadata) - r, err := DefaultPodPlugin.BuildResource(context.TODO(), taskCtx) - assert.NoError(t, err) - assert.NotNil(t, r) - j, ok := r.(*v1.Pod) - assert.True(t, ok) + r, err := DefaultPodPlugin.BuildResource(context.TODO(), taskCtx) + assert.NoError(t, err) + assert.NotNil(t, r) + j, ok := r.(*v1.Pod) + assert.True(t, ok) - assert.NotEmpty(t, j.Spec.Containers) - assert.Equal(t, containerResourceRequirements.Limits[v1.ResourceCPU], j.Spec.Containers[0].Resources.Limits[v1.ResourceCPU]) + assert.NotEmpty(t, j.Spec.Containers) + assert.Equal(t, containerResourceRequirements.Limits[v1.ResourceCPU], j.Spec.Containers[0].Resources.Limits[v1.ResourceCPU]) - // TODO: Once configurable, test when setting storage is supported on the cluster vs not. - storageRes := j.Spec.Containers[0].Resources.Limits[v1.ResourceStorage] - assert.Equal(t, int64(0), (&storageRes).Value()) + // TODO: Once configurable, test when setting storage is supported on the cluster vs not. + storageRes := j.Spec.Containers[0].Resources.Limits[v1.ResourceStorage] + assert.Equal(t, int64(0), (&storageRes).Value()) - assert.Equal(t, command, j.Spec.Containers[0].Command) - assert.Equal(t, []string{"test-data-reference"}, j.Spec.Containers[0].Args) + assert.Equal(t, command, j.Spec.Containers[0].Command) + assert.Equal(t, []string{"test-data-reference"}, j.Spec.Containers[0].Args) - assert.Equal(t, "service-account", j.Spec.ServiceAccountName) + assert.Equal(t, tc.expectServiceAccount, j.Spec.ServiceAccountName) + }) + } } func TestContainerTaskExecutor_BuildResource_ExtendedResources(t *testing.T) { @@ -252,7 +323,8 @@ func TestContainerTaskExecutor_BuildResource_ExtendedResources(t *testing.T) { t.Run(f.name, func(t *testing.T) { taskTemplate := dummyContainerTaskTemplate([]string{"command"}, []string{"{{.Input}}"}) taskTemplate.ExtendedResources = f.extendedResourcesBase - taskContext := dummyContainerTaskContext(taskTemplate, f.resources, f.extendedResourcesOverride) + taskMetadata := dummyContainerTaskMetadata(f.resources, f.extendedResourcesOverride, true) + taskContext := dummyContainerTaskContext(taskTemplate, taskMetadata) r, err := DefaultPodPlugin.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) assert.NotNil(t, r) @@ -277,7 +349,8 @@ func TestContainerTaskExecutor_GetTaskStatus(t *testing.T) { command := []string{"command"} args := []string{"{{.Input}}"} taskTemplate := dummyContainerTaskTemplate(command, args) - taskCtx := dummyContainerTaskContext(taskTemplate, containerResourceRequirements, nil) + taskMetadata := dummyContainerTaskMetadata(containerResourceRequirements, nil, true) + taskCtx := dummyContainerTaskContext(taskTemplate, taskMetadata) j := &v1.Pod{ Status: v1.PodStatus{}, @@ -366,7 +439,8 @@ func TestContainerTaskExecutor_GetTaskStatus_InvalidImageName(t *testing.T) { command := []string{"command"} args := []string{"{{.Input}}"} taskTemplate := dummyContainerTaskTemplate(command, args) - taskCtx := dummyContainerTaskContext(taskTemplate, containerResourceRequirements, nil) + taskMetadata := dummyContainerTaskMetadata(containerResourceRequirements, nil, true) + taskCtx := dummyContainerTaskContext(taskTemplate, taskMetadata) ctx := context.TODO() reason := "InvalidImageName" diff --git a/flyteplugins/go/tasks/plugins/k8s/pod/plugin.go b/flyteplugins/go/tasks/plugins/k8s/pod/plugin.go index 11de877021..b266a6f5e8 100644 --- a/flyteplugins/go/tasks/plugins/k8s/pod/plugin.go +++ b/flyteplugins/go/tasks/plugins/k8s/pod/plugin.go @@ -126,7 +126,9 @@ func (p plugin) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecu objectMeta.Annotations[flytek8s.PrimaryContainerKey] = primaryContainerName } - podSpec.ServiceAccountName = flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) + if len(podSpec.ServiceAccountName) == 0 { + podSpec.ServiceAccountName = flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) + } pod := flytek8s.BuildIdentityPod() pod.ObjectMeta = *objectMeta From b9907f43568edcd3d248ad19d22b5643b0cf7285 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 6 Dec 2023 15:52:14 -0800 Subject: [PATCH 2/2] Fix flaky test_monitor (#4537) * Fix flaky test_monitor Signed-off-by: Kevin Su * fix test Signed-off-by: Kevin Su --------- Signed-off-by: Kevin Su Co-authored-by: Eduardo Apolinario <653394+eapolinario@users.noreply.github.com> --- .../go/tasks/pluginmachinery/internal/webapi/monitor_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flyteplugins/go/tasks/pluginmachinery/internal/webapi/monitor_test.go b/flyteplugins/go/tasks/pluginmachinery/internal/webapi/monitor_test.go index c47d77ae76..1628582156 100644 --- a/flyteplugins/go/tasks/pluginmachinery/internal/webapi/monitor_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/internal/webapi/monitor_test.go @@ -21,6 +21,7 @@ import ( func Test_monitor(t *testing.T) { ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) tCtx := &mocks.TaskExecutionContext{} ctxMeta := &mocks.TaskExecutionMetadata{} execID := &mocks.TaskExecutionID{} @@ -70,6 +71,7 @@ func Test_monitor(t *testing.T) { // Wait for sync to run to actually delete the resource wg.Wait() + cancel() cachedItem, err = cacheObj.GetOrCreate("generated_name", CacheItem{Resource: "new_resource"}) assert.NoError(t, err) assert.Equal(t, "new_resource", cachedItem.(CacheItem).Resource.(string))