diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index c1e5dd264d..ba92b6353f 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -27,7 +27,8 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" + pluginsUtils "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" + utils "github.com/flyteorg/flyte/flytestdlib/utils" ) const ( @@ -66,7 +67,7 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC } rayJob := plugins.RayJob{} - err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &rayJob) + err = utils.UnmarshalStructToPb(taskTemplate.GetCustom(), &rayJob) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) } @@ -379,8 +380,8 @@ func buildHeadPodTemplate(primaryContainer *v1.Container, basePodSpec *v1.PodSpe ObjectMeta: *objectMeta, } cfg := config.GetK8sPluginConfig() - podTemplateSpec.SetLabels(utils.UnionMaps(cfg.DefaultLabels, podTemplateSpec.GetLabels(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))) - podTemplateSpec.SetAnnotations(utils.UnionMaps(cfg.DefaultAnnotations, podTemplateSpec.GetAnnotations(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))) + podTemplateSpec.SetLabels(pluginsUtils.UnionMaps(cfg.DefaultLabels, podTemplateSpec.GetLabels(), pluginsUtils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))) + podTemplateSpec.SetAnnotations(pluginsUtils.UnionMaps(cfg.DefaultAnnotations, podTemplateSpec.GetAnnotations(), pluginsUtils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))) return podTemplateSpec, nil } @@ -393,8 +394,8 @@ func buildSubmitterPodTemplate(podSpec *v1.PodSpec, objectMeta *metav1.ObjectMet } cfg := config.GetK8sPluginConfig() - podTemplateSpec.SetLabels(utils.UnionMaps(cfg.DefaultLabels, podTemplateSpec.GetLabels(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))) - podTemplateSpec.SetAnnotations(utils.UnionMaps(cfg.DefaultAnnotations, podTemplateSpec.GetAnnotations(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))) + podTemplateSpec.SetLabels(pluginsUtils.UnionMaps(cfg.DefaultLabels, podTemplateSpec.GetLabels(), pluginsUtils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))) + podTemplateSpec.SetAnnotations(pluginsUtils.UnionMaps(cfg.DefaultAnnotations, podTemplateSpec.GetAnnotations(), pluginsUtils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))) return podTemplateSpec } @@ -506,8 +507,8 @@ func buildWorkerPodTemplate(primaryContainer *v1.Container, basePodSpec *v1.PodS Spec: *basePodSpec, ObjectMeta: *objectMetadata, } - podTemplateSpec.SetLabels(utils.UnionMaps(podTemplateSpec.GetLabels(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))) - podTemplateSpec.SetAnnotations(utils.UnionMaps(podTemplateSpec.GetAnnotations(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))) + podTemplateSpec.SetLabels(pluginsUtils.UnionMaps(podTemplateSpec.GetLabels(), pluginsUtils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))) + podTemplateSpec.SetAnnotations(pluginsUtils.UnionMaps(podTemplateSpec.GetAnnotations(), pluginsUtils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))) return podTemplateSpec, nil } diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go index 708939485b..4af6f759bb 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go @@ -27,7 +27,7 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" mocks2 "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s/mocks" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" + "github.com/flyteorg/flyte/flytestdlib/utils" ) const (