diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go index 2a17751c62..69111dc030 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -759,7 +759,7 @@ func DemystifySuccess(status v1.PodStatus, info pluginsCore.TaskInfo) (pluginsCo for _, status := range append( append(status.InitContainerStatuses, status.ContainerStatuses...), status.EphemeralContainerStatuses...) { if status.State.Terminated != nil && strings.Contains(status.State.Terminated.Reason, OOMKilled) { - return pluginsCore.PhaseInfoRetryableFailure("OOMKilled", + return pluginsCore.PhaseInfoRetryableFailure(OOMKilled, "Pod reported success despite being OOMKilled", &info), nil } } @@ -777,9 +777,14 @@ func DeterminePrimaryContainerPhase(primaryContainerName string, statuses []v1.C } if s.State.Terminated != nil { - if s.State.Terminated.ExitCode != 0 { + if s.State.Terminated.ExitCode != 0 || strings.Contains(s.State.Terminated.Reason, OOMKilled) { + message := fmt.Sprintf("\r\n[%v] terminated with exit code (%v). Reason [%v]. Message: \n%v.", + s.Name, + s.State.Terminated.ExitCode, + s.State.Terminated.Reason, + s.State.Terminated.Message) return pluginsCore.PhaseInfoRetryableFailure( - s.State.Terminated.Reason, s.State.Terminated.Message, info) + s.State.Terminated.Reason, message, info) } return pluginsCore.PhaseInfoSuccess(info) } diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go index b5a51323d2..f25d499188 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go @@ -1728,7 +1728,7 @@ func TestDeterminePrimaryContainerPhase(t *testing.T) { }, info) assert.Equal(t, pluginsCore.PhaseRetryableFailure, phaseInfo.Phase()) assert.Equal(t, "foo", phaseInfo.Err().Code) - assert.Equal(t, "foo failed", phaseInfo.Err().Message) + assert.Equal(t, "\r\n[primary] terminated with exit code (1). Reason [foo]. Message: \nfoo failed.", phaseInfo.Err().Message) }) t.Run("primary container succeeded", func(t *testing.T) { phaseInfo := DeterminePrimaryContainerPhase(primaryContainerName, []v1.ContainerStatus{ @@ -1751,6 +1751,23 @@ func TestDeterminePrimaryContainerPhase(t *testing.T) { assert.Equal(t, PrimaryContainerNotFound, phaseInfo.Err().Code) assert.Equal(t, "Primary container [primary] not found in pod's container statuses", phaseInfo.Err().Message) }) + t.Run("primary container failed with OOMKilled", func(t *testing.T) { + phaseInfo := DeterminePrimaryContainerPhase(primaryContainerName, []v1.ContainerStatus{ + secondaryContainer, { + Name: primaryContainerName, + State: v1.ContainerState{ + Terminated: &v1.ContainerStateTerminated{ + ExitCode: 0, + Reason: OOMKilled, + Message: "foo failed", + }, + }, + }, + }, info) + assert.Equal(t, pluginsCore.PhaseRetryableFailure, phaseInfo.Phase()) + assert.Equal(t, OOMKilled, phaseInfo.Err().Code) + assert.Equal(t, "\r\n[primary] terminated with exit code (0). Reason [OOMKilled]. Message: \nfoo failed.", phaseInfo.Err().Message) + }) } func TestGetPodTemplate(t *testing.T) {