Skip to content

Commit

Permalink
Add tests to ensure the phase version is bumped in kubeflow plugin if…
Browse files Browse the repository at this point in the history
… reason changes within the same phase

Signed-off-by: Fabio Graetz <[email protected]>
  • Loading branch information
fg91 committed Apr 8, 2024
1 parent 24bddac commit 4d52bd8
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 47 deletions.
56 changes: 43 additions & 13 deletions flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package mpi
import (
"context"
"fmt"
"reflect"
"testing"
"time"

Expand Down Expand Up @@ -117,7 +118,7 @@ func dummyMPITaskTemplate(id string, args ...interface{}) *core.TaskTemplate {
}
}

func dummyMPITaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources) pluginsCore.TaskExecutionContext {
func dummyMPITaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources, pluginState k8s.PluginState) pluginsCore.TaskExecutionContext {
taskCtx := &mocks.TaskExecutionContext{}
inputReader := &pluginIOMocks.InputReader{}
inputReader.OnGetInputPrefixPath().Return("/input/prefix")
Expand Down Expand Up @@ -170,6 +171,18 @@ func dummyMPITaskContext(taskTemplate *core.TaskTemplate, resources *corev1.Reso
taskExecutionMetadata.OnGetPlatformResources().Return(&corev1.ResourceRequirements{})
taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil)
taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata)

pluginStateReaderMock := mocks.PluginStateReader{}
pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&pluginState).String())).Return(
func(v interface{}) uint8 {
*(v.(*k8s.PluginState)) = pluginState
return 0
},
func(v interface{}) error {
return nil
})

taskCtx.OnPluginStateReader().Return(&pluginStateReaderMock)
return taskCtx
}

Expand Down Expand Up @@ -275,7 +288,7 @@ func dummyMPIJobResource(mpiResourceHandler mpiOperatorResourceHandler,

mpiObj := dummyMPICustomObj(workers, launcher, slots)
taskTemplate := dummyMPITaskTemplate(mpiID, mpiObj)
resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil))
resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{}))
if err != nil {
panic(err)
}
Expand All @@ -302,7 +315,7 @@ func TestBuildResourceMPI(t *testing.T) {
mpiObj := dummyMPICustomObj(100, 50, 1)
taskTemplate := dummyMPITaskTemplate(mpiID2, mpiObj)

resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil))
resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{}))
assert.NoError(t, err)
assert.NotNil(t, resource)

Expand Down Expand Up @@ -338,13 +351,13 @@ func TestBuildResourceMPIForWrongInput(t *testing.T) {
mpiObj := dummyMPICustomObj(0, 0, 1)
taskTemplate := dummyMPITaskTemplate(mpiID, mpiObj)

_, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil))
_, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{}))
assert.Error(t, err)

mpiObj = dummyMPICustomObj(1, 1, 1)
taskTemplate = dummyMPITaskTemplate(mpiID2, mpiObj)

resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil))
resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{}))
app, ok := resource.(*kubeflowv1.MPIJob)
assert.Nil(t, err)
assert.Equal(t, true, ok)
Expand Down Expand Up @@ -458,7 +471,7 @@ func TestBuildResourceMPIExtendedResources(t *testing.T) {
mpiObj := dummyMPICustomObj(100, 50, 1)
taskTemplate := dummyMPITaskTemplate(mpiID2, mpiObj)
taskTemplate.ExtendedResources = f.extendedResourcesBase
taskContext := dummyMPITaskContext(taskTemplate, f.resources, f.extendedResourcesOverride)
taskContext := dummyMPITaskContext(taskTemplate, f.resources, f.extendedResourcesOverride, k8s.PluginState{})
mpiResourceHandler := mpiOperatorResourceHandler{}
r, err := mpiResourceHandler.BuildResource(context.TODO(), taskContext)
assert.Nil(t, err)
Expand Down Expand Up @@ -490,7 +503,7 @@ func TestGetTaskPhase(t *testing.T) {
return dummyMPIJobResource(mpiResourceHandler, 2, 1, 1, conditionType)
}

taskCtx := dummyMPITaskContext(dummyMPITaskTemplate("", dummyMPICustomObj(2, 1, 1)), resourceRequirements, nil)
taskCtx := dummyMPITaskContext(dummyMPITaskTemplate("", dummyMPICustomObj(2, 1, 1)), resourceRequirements, nil, k8s.PluginState{})
taskPhase, err := mpiResourceHandler.GetTaskPhase(ctx, taskCtx, dummyMPIJobResourceCreator(mpiOp.JobCreated))
assert.NoError(t, err)
assert.Equal(t, pluginsCore.PhaseQueued, taskPhase.Phase())
Expand Down Expand Up @@ -522,6 +535,23 @@ func TestGetTaskPhase(t *testing.T) {
assert.Nil(t, err)
}

func TestGetTaskPhaseIncreasePhaseVersion(t *testing.T) {
mpiResourceHandler := mpiOperatorResourceHandler{}
ctx := context.TODO()

pluginState := k8s.PluginState{
Phase: pluginsCore.PhaseQueued,
PhaseVersion: pluginsCore.DefaultPhaseVersion,
Reason: "task submitted to K8s",
}
taskCtx := dummyMPITaskContext(dummyMPITaskTemplate("", dummyMPICustomObj(2, 1, 1)), resourceRequirements, nil, pluginState)

taskPhase, err := mpiResourceHandler.GetTaskPhase(ctx, taskCtx, dummyMPIJobResource(mpiResourceHandler, 2, 1, 1, mpiOp.JobCreated))

assert.NoError(t, err)
assert.Equal(t, taskPhase.Version(), pluginsCore.DefaultPhaseVersion+1)
}

func TestGetLogs(t *testing.T) {
assert.NoError(t, logs.SetLogConfig(&logs.LogConfig{
IsKubernetesEnabled: true,
Expand All @@ -534,7 +564,7 @@ func TestGetLogs(t *testing.T) {

mpiResourceHandler := mpiOperatorResourceHandler{}
mpiJob := dummyMPIJobResource(mpiResourceHandler, workers, launcher, slots, mpiOp.JobRunning)
taskCtx := dummyMPITaskContext(dummyMPITaskTemplate("", dummyMPICustomObj(workers, launcher, slots)), resourceRequirements, nil)
taskCtx := dummyMPITaskContext(dummyMPITaskTemplate("", dummyMPICustomObj(workers, launcher, slots)), resourceRequirements, nil, k8s.PluginState{})
jobLogs, err := common.GetLogs(taskCtx, common.MPITaskType, mpiJob.ObjectMeta, false, workers, launcher, 0, 0)
assert.NoError(t, err)
assert.Equal(t, 2, len(jobLogs))
Expand Down Expand Up @@ -567,7 +597,7 @@ func TestReplicaCounts(t *testing.T) {
mpiObj := dummyMPICustomObj(test.workerReplicaCount, test.launcherReplicaCount, 1)
taskTemplate := dummyMPITaskTemplate(mpiID2, mpiObj)

resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil))
resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{}))
if test.expectError {
assert.Error(t, err)
assert.Nil(t, resource)
Expand Down Expand Up @@ -653,7 +683,7 @@ func TestBuildResourceMPIV1(t *testing.T) {
taskTemplate := dummyMPITaskTemplate(mpiID2, taskConfig)
taskTemplate.TaskTypeVersion = 1

resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil))
resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{}))
assert.NoError(t, err)
assert.NotNil(t, resource)

Expand Down Expand Up @@ -705,7 +735,7 @@ func TestBuildResourceMPIV1WithOnlyWorkerReplica(t *testing.T) {
taskTemplate := dummyMPITaskTemplate(mpiID2, taskConfig)
taskTemplate.TaskTypeVersion = 1

resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil))
resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{}))
assert.NoError(t, err)
assert.NotNil(t, resource)

Expand Down Expand Up @@ -768,7 +798,7 @@ func TestBuildResourceMPIV1ResourceTolerations(t *testing.T) {
taskTemplate := dummyMPITaskTemplate(mpiID2, taskConfig)
taskTemplate.TaskTypeVersion = 1

resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil))
resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{}))
assert.NoError(t, err)
assert.NotNil(t, resource)

Expand All @@ -783,7 +813,7 @@ func TestGetReplicaCount(t *testing.T) {
mpiResourceHandler := mpiOperatorResourceHandler{}
tfObj := dummyMPICustomObj(1, 1, 0)
taskTemplate := dummyMPITaskTemplate("the job", tfObj)
resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil))
resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{}))
assert.NoError(t, err)
assert.NotNil(t, resource)
MPIJob, ok := resource.(*kubeflowv1.MPIJob)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func dummyPytorchTaskTemplate(id string, args ...interface{}) *core.TaskTemplate
}
}

func dummyPytorchTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources, containerImage string) pluginsCore.TaskExecutionContext {
func dummyPytorchTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources, containerImage string, pluginState k8s.PluginState) pluginsCore.TaskExecutionContext {
taskCtx := &mocks.TaskExecutionContext{}
inputReader := &pluginIOMocks.InputReader{}
inputReader.OnGetInputPrefixPath().Return("/input/prefix")
Expand Down Expand Up @@ -178,11 +178,10 @@ func dummyPytorchTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.
taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil)
taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata)

inputState := k8s.PluginState{}
pluginStateReaderMock := mocks.PluginStateReader{}
pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&inputState).String())).Return(
pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&pluginState).String())).Return(
func(v interface{}) uint8 {
*(v.(*k8s.PluginState)) = inputState
*(v.(*k8s.PluginState)) = pluginState
return 0
},
func(v interface{}) error {
Expand Down Expand Up @@ -294,7 +293,7 @@ func dummyPytorchJobResource(pytorchResourceHandler pytorchOperatorResourceHandl

ptObj := dummyPytorchCustomObj(workers)
taskTemplate := dummyPytorchTaskTemplate("job1", ptObj)
resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, ""))
resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{}))
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -322,7 +321,7 @@ func TestBuildResourcePytorchElastic(t *testing.T) {
ptObj := dummyElasticPytorchCustomObj(2, plugins.ElasticConfig{MinReplicas: 1, MaxReplicas: 2, NprocPerNode: 4, RdzvBackend: "c10d"})
taskTemplate := dummyPytorchTaskTemplate("job2", ptObj)

resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, ""))
resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{}))
assert.NoError(t, err)
assert.NotNil(t, resource)

Expand Down Expand Up @@ -365,7 +364,7 @@ func TestBuildResourcePytorch(t *testing.T) {
ptObj := dummyPytorchCustomObj(100)
taskTemplate := dummyPytorchTaskTemplate("job3", ptObj)

res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, ""))
res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{}))
assert.NoError(t, err)
assert.NotNil(t, res)

Expand Down Expand Up @@ -447,7 +446,7 @@ func TestBuildResourcePytorchContainerImage(t *testing.T) {
for _, f := range fixtures {
t.Run(tCfg.name+" "+f.name, func(t *testing.T) {
taskTemplate := dummyPytorchTaskTemplate("job", tCfg.plugin)
taskContext := dummyPytorchTaskContext(taskTemplate, f.resources, nil, f.containerImageOverride)
taskContext := dummyPytorchTaskContext(taskTemplate, f.resources, nil, f.containerImageOverride, k8s.PluginState{})
pytorchResourceHandler := pytorchOperatorResourceHandler{}
r, err := pytorchResourceHandler.BuildResource(context.TODO(), taskContext)
assert.NoError(t, err)
Expand Down Expand Up @@ -589,7 +588,7 @@ func TestBuildResourcePytorchExtendedResources(t *testing.T) {
t.Run(tCfg.name+" "+f.name, func(t *testing.T) {
taskTemplate := dummyPytorchTaskTemplate("job", tCfg.plugin)
taskTemplate.ExtendedResources = f.extendedResourcesBase
taskContext := dummyPytorchTaskContext(taskTemplate, f.resources, f.extendedResourcesOverride, "")
taskContext := dummyPytorchTaskContext(taskTemplate, f.resources, f.extendedResourcesOverride, "", k8s.PluginState{})
pytorchResourceHandler := pytorchOperatorResourceHandler{}
r, err := pytorchResourceHandler.BuildResource(context.TODO(), taskContext)
assert.NoError(t, err)
Expand Down Expand Up @@ -622,7 +621,7 @@ func TestGetTaskPhase(t *testing.T) {
return dummyPytorchJobResource(pytorchResourceHandler, 2, conditionType)
}

taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(2)), resourceRequirements, nil, "")
taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(2)), resourceRequirements, nil, "", k8s.PluginState{})
taskPhase, err := pytorchResourceHandler.GetTaskPhase(ctx, taskCtx, dummyPytorchJobResourceCreator(commonOp.JobCreated))
assert.NoError(t, err)
assert.Equal(t, pluginsCore.PhaseQueued, taskPhase.Phase())
Expand Down Expand Up @@ -654,6 +653,23 @@ func TestGetTaskPhase(t *testing.T) {
assert.Nil(t, err)
}

func TestGetTaskPhaseIncreasePhaseVersion(t *testing.T) {
pytorchResourceHandler := pytorchOperatorResourceHandler{}
ctx := context.TODO()

pluginState := k8s.PluginState{
Phase: pluginsCore.PhaseQueued,
PhaseVersion: pluginsCore.DefaultPhaseVersion,
Reason: "task submitted to K8s",
}
taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(2)), resourceRequirements, nil, "", pluginState)

taskPhase, err := pytorchResourceHandler.GetTaskPhase(ctx, taskCtx, dummyPytorchJobResource(pytorchResourceHandler, 2, commonOp.JobCreated))

assert.NoError(t, err)
assert.Equal(t, taskPhase.Version(), pluginsCore.DefaultPhaseVersion+1)
}

func TestGetLogs(t *testing.T) {
assert.NoError(t, logs.SetLogConfig(&logs.LogConfig{
IsKubernetesEnabled: true,
Expand All @@ -665,7 +681,7 @@ func TestGetLogs(t *testing.T) {

pytorchResourceHandler := pytorchOperatorResourceHandler{}
pytorchJob := dummyPytorchJobResource(pytorchResourceHandler, workers, commonOp.JobRunning)
taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(workers)), resourceRequirements, nil, "")
taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(workers)), resourceRequirements, nil, "", k8s.PluginState{})
jobLogs, err := common.GetLogs(taskCtx, common.PytorchTaskType, pytorchJob.ObjectMeta, hasMaster, workers, 0, 0, 0)
assert.NoError(t, err)
assert.Equal(t, 3, len(jobLogs))
Expand All @@ -685,7 +701,7 @@ func TestGetLogsElastic(t *testing.T) {

pytorchResourceHandler := pytorchOperatorResourceHandler{}
pytorchJob := dummyPytorchJobResource(pytorchResourceHandler, workers, commonOp.JobRunning)
taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(workers)), resourceRequirements, nil, "")
taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(workers)), resourceRequirements, nil, "", k8s.PluginState{})
jobLogs, err := common.GetLogs(taskCtx, common.PytorchTaskType, pytorchJob.ObjectMeta, hasMaster, workers, 0, 0, 0)
assert.NoError(t, err)
assert.Equal(t, 2, len(jobLogs))
Expand Down Expand Up @@ -716,7 +732,7 @@ func TestReplicaCounts(t *testing.T) {
ptObj := dummyPytorchCustomObj(test.workerReplicaCount)
taskTemplate := dummyPytorchTaskTemplate("the job", ptObj)

res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, ""))
res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{}))
if test.expectError {
assert.Error(t, err)
assert.Nil(t, res)
Expand Down Expand Up @@ -798,7 +814,7 @@ func TestBuildResourcePytorchV1(t *testing.T) {
taskTemplate := dummyPytorchTaskTemplate("job4", taskConfig)
taskTemplate.TaskTypeVersion = 1

res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, ""))
res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{}))
assert.NoError(t, err)
assert.NotNil(t, res)

Expand Down Expand Up @@ -842,7 +858,7 @@ func TestBuildResourcePytorchV1WithRunPolicy(t *testing.T) {
taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig)
taskTemplate.TaskTypeVersion = 1

res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, ""))
res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{}))
assert.NoError(t, err)
assert.NotNil(t, res)

Expand Down Expand Up @@ -902,7 +918,7 @@ func TestBuildResourcePytorchV1WithOnlyWorkerSpec(t *testing.T) {
taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig)
taskTemplate.TaskTypeVersion = 1

res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, ""))
res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{}))
assert.NoError(t, err)
assert.NotNil(t, res)

Expand Down Expand Up @@ -973,7 +989,7 @@ func TestBuildResourcePytorchV1ResourceTolerations(t *testing.T) {
taskTemplate := dummyPytorchTaskTemplate("job4", taskConfig)
taskTemplate.TaskTypeVersion = 1

res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, ""))
res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{}))
assert.NoError(t, err)
assert.NotNil(t, res)

Expand All @@ -995,7 +1011,7 @@ func TestBuildResourcePytorchV1WithElastic(t *testing.T) {
taskTemplate.TaskTypeVersion = 1

pytorchResourceHandler := pytorchOperatorResourceHandler{}
resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, ""))
resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{}))
assert.NoError(t, err)
assert.NotNil(t, resource)

Expand Down Expand Up @@ -1031,7 +1047,7 @@ func TestBuildResourcePytorchV1WithZeroWorker(t *testing.T) {
pytorchResourceHandler := pytorchOperatorResourceHandler{}
taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig)
taskTemplate.TaskTypeVersion = 1
_, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, ""))
_, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{}))
assert.Error(t, err)
}

Expand All @@ -1048,7 +1064,7 @@ func TestGetReplicaCount(t *testing.T) {
pytorchResourceHandler := pytorchOperatorResourceHandler{}
tfObj := dummyPytorchCustomObj(1)
taskTemplate := dummyPytorchTaskTemplate("the job", tfObj)
resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, ""))
resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{}))
assert.NoError(t, err)
assert.NotNil(t, resource)
PytorchJob, ok := resource.(*kubeflowv1.PyTorchJob)
Expand Down
Loading

0 comments on commit 4d52bd8

Please sign in to comment.