From 7fa29f2d5592bb398ac5ac9670b07b0747704284 Mon Sep 17 00:00:00 2001 From: peterghaddad <107633597+peterghaddad@users.noreply.github.com> Date: Tue, 13 Feb 2024 14:28:08 -0700 Subject: [PATCH] Leverage KubeRay v1 instead of v1alpha1 for resources (#4818) * initial * Clean up * init * Clean up * Add TestGetEventInfo_LogTemplatesV1 * Add more tests * Fix tests * Remove dupe * Fix lint * add comment Signed-off-by: peterghaddad --------- Signed-off-by: peterghaddad Co-authored-by: Neil <150836163+neilisaur@users.noreply.github.com> --- .../go/tasks/plugins/k8s/ray/config.go | 2 + flyteplugins/go/tasks/plugins/k8s/ray/ray.go | 254 +++++++++++++++++- .../go/tasks/plugins/k8s/ray/ray_test.go | 196 ++++++++++++++ 3 files changed, 438 insertions(+), 14 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/config.go b/flyteplugins/go/tasks/plugins/k8s/ray/config.go index 67ea4d2aebb..e73fc4dc7d2 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/config.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/config.go @@ -22,6 +22,7 @@ var ( IncludeDashboard: true, DashboardHost: "0.0.0.0", EnableUsageStats: false, + KubeRayCrdVersion: "v1alpha1", Defaults: DefaultConfig{ HeadNode: NodeConfig{ StartParameters: map[string]string{ @@ -85,6 +86,7 @@ type Config struct { DashboardURLTemplate *tasklog.TemplateLogPlugin `json:"dashboardURLTemplate" pflag:"-,Template for URL of Ray dashboard running on a head node."` Defaults DefaultConfig `json:"defaults" pflag:"-,Default configuration for ray jobs"` EnableUsageStats bool `json:"enableUsageStats" pflag:",Enable usage stats for ray jobs. These stats are submitted to usage-stats.ray.io per https://docs.ray.io/en/latest/cluster/usage-stats.html"` + KubeRayCrdVersion string `json:"kubeRayCrdVersion" pflag:",Version of the Ray CRD to use when creating RayClusters or RayJobs."` } type DefaultConfig struct { diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index 6b2af127dfa..25291b20663 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -8,6 +8,7 @@ import ( "strings" "time" + rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" rayv1alpha1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1alpha1" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -28,14 +29,15 @@ import ( ) const ( - rayStateMountPath = "/tmp/ray" - defaultRayStateVolName = "system-ray-state" - rayTaskType = "ray" - KindRayJob = "RayJob" - IncludeDashboard = "include-dashboard" - NodeIPAddress = "node-ip-address" - DashboardHost = "dashboard-host" - DisableUsageStatsStartParameter = "disable-usage-stats" + rayStateMountPath = "/tmp/ray" + defaultRayStateVolName = "system-ray-state" + rayTaskType = "ray" + KindRayJob = "RayJob" + IncludeDashboard = "include-dashboard" + NodeIPAddress = "node-ip-address" + DashboardHost = "dashboard-host" + DisableUsageStatsStartParameter = "disable-usage-stats" + DisableUsageStatsStartParameterVal = "true" ) var logTemplateRegexes = struct { @@ -52,7 +54,7 @@ func (rayJobResourceHandler) GetProperties() k8s.PluginProperties { return k8s.PluginProperties{} } -// BuildResource Creates a new ray job resource. +// BuildResource Creates a new ray job resource for v1 or v1alpha1. func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) { taskTemplate, err := taskCtx.TaskReader().Read(ctx) if err != nil { @@ -109,11 +111,22 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC } if _, exists := headNodeRayStartParams[DisableUsageStatsStartParameter]; !exists && !cfg.EnableUsageStats { - headNodeRayStartParams[DisableUsageStatsStartParameter] = "true" + headNodeRayStartParams[DisableUsageStatsStartParameter] = DisableUsageStatsStartParameterVal } - enableIngress := true headPodSpec := podSpec.DeepCopy() + + if cfg.KubeRayCrdVersion == "v1" { + return constructV1Job(taskCtx, rayJob, objectMeta, *podSpec, headPodSpec, headReplicas, headNodeRayStartParams, primaryContainerIdx, *primaryContainer), nil + } + + return constructV1Alpha1Job(taskCtx, rayJob, objectMeta, *podSpec, headPodSpec, headReplicas, headNodeRayStartParams, primaryContainerIdx, *primaryContainer), nil + +} + +func constructV1Alpha1Job(taskCtx pluginsCore.TaskExecutionContext, rayJob plugins.RayJob, objectMeta *metav1.ObjectMeta, podSpec v1.PodSpec, headPodSpec *v1.PodSpec, headReplicas int32, headNodeRayStartParams map[string]string, primaryContainerIdx int, primaryContainer v1.Container) *rayv1alpha1.RayJob { + enableIngress := true + cfg := GetConfig() rayClusterSpec := rayv1alpha1.RayClusterSpec{ HeadGroupSpec: rayv1alpha1.HeadGroupSpec{ Template: buildHeadPodTemplate( @@ -152,7 +165,7 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC } if _, exists := workerNodeRayStartParams[DisableUsageStatsStartParameter]; !exists && !cfg.EnableUsageStats { - workerNodeRayStartParams[DisableUsageStatsStartParameter] = "true" + workerNodeRayStartParams[DisableUsageStatsStartParameter] = DisableUsageStatsStartParameterVal } minReplicas := spec.MinReplicas @@ -198,7 +211,7 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC RuntimeEnv: rayJob.RuntimeEnv, } - rayJobObject := rayv1alpha1.RayJob{ + return &rayv1alpha1.RayJob{ TypeMeta: metav1.TypeMeta{ Kind: KindRayJob, APIVersion: rayv1alpha1.SchemeGroupVersion.String(), @@ -206,8 +219,103 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC Spec: jobSpec, ObjectMeta: *objectMeta, } +} + +func constructV1Job(taskCtx pluginsCore.TaskExecutionContext, rayJob plugins.RayJob, objectMeta *metav1.ObjectMeta, podSpec v1.PodSpec, headPodSpec *v1.PodSpec, headReplicas int32, headNodeRayStartParams map[string]string, primaryContainerIdx int, primaryContainer v1.Container) *rayv1.RayJob { + enableIngress := true + cfg := GetConfig() + rayClusterSpec := rayv1.RayClusterSpec{ + HeadGroupSpec: rayv1.HeadGroupSpec{ + Template: buildHeadPodTemplate( + &headPodSpec.Containers[primaryContainerIdx], + headPodSpec, + objectMeta, + taskCtx, + ), + ServiceType: v1.ServiceType(cfg.ServiceType), + Replicas: &headReplicas, + EnableIngress: &enableIngress, + RayStartParams: headNodeRayStartParams, + }, + WorkerGroupSpecs: []rayv1.WorkerGroupSpec{}, + EnableInTreeAutoscaling: &rayJob.RayCluster.EnableAutoscaling, + } + + for _, spec := range rayJob.RayCluster.WorkerGroupSpec { + workerPodSpec := podSpec.DeepCopy() + workerPodTemplate := buildWorkerPodTemplate( + &workerPodSpec.Containers[primaryContainerIdx], + workerPodSpec, + objectMeta, + taskCtx, + ) + + workerNodeRayStartParams := make(map[string]string) + if spec.RayStartParams != nil { + workerNodeRayStartParams = spec.RayStartParams + } else if workerNode := cfg.Defaults.WorkerNode; len(workerNode.StartParameters) > 0 { + workerNodeRayStartParams = workerNode.StartParameters + } + + if _, exist := workerNodeRayStartParams[NodeIPAddress]; !exist { + workerNodeRayStartParams[NodeIPAddress] = cfg.Defaults.WorkerNode.IPAddress + } + + if _, exists := workerNodeRayStartParams[DisableUsageStatsStartParameter]; !exists && !cfg.EnableUsageStats { + workerNodeRayStartParams[DisableUsageStatsStartParameter] = DisableUsageStatsStartParameterVal + } - return &rayJobObject, nil + minReplicas := spec.MinReplicas + if minReplicas > spec.Replicas { + minReplicas = spec.Replicas + } + maxReplicas := spec.MaxReplicas + if maxReplicas < spec.Replicas { + maxReplicas = spec.Replicas + } + + workerNodeSpec := rayv1.WorkerGroupSpec{ + GroupName: spec.GroupName, + MinReplicas: &minReplicas, + MaxReplicas: &maxReplicas, + Replicas: &spec.Replicas, + RayStartParams: workerNodeRayStartParams, + Template: workerPodTemplate, + } + + rayClusterSpec.WorkerGroupSpecs = append(rayClusterSpec.WorkerGroupSpecs, workerNodeSpec) + } + + serviceAccountName := flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) + + rayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName = serviceAccountName + for index := range rayClusterSpec.WorkerGroupSpecs { + rayClusterSpec.WorkerGroupSpecs[index].Template.Spec.ServiceAccountName = serviceAccountName + } + + shutdownAfterJobFinishes := cfg.ShutdownAfterJobFinishes + ttlSecondsAfterFinished := &cfg.TTLSecondsAfterFinished + if rayJob.ShutdownAfterJobFinishes { + shutdownAfterJobFinishes = true + ttlSecondsAfterFinished = &rayJob.TtlSecondsAfterFinished + } + + jobSpec := rayv1.RayJobSpec{ + RayClusterSpec: &rayClusterSpec, + Entrypoint: strings.Join(primaryContainer.Args, " "), + ShutdownAfterJobFinishes: shutdownAfterJobFinishes, + TTLSecondsAfterFinished: ttlSecondsAfterFinished, + RuntimeEnv: rayJob.RuntimeEnv, + } + + return &rayv1.RayJob{ + TypeMeta: metav1.TypeMeta{ + Kind: KindRayJob, + APIVersion: rayv1alpha1.SchemeGroupVersion.String(), + }, + Spec: jobSpec, + ObjectMeta: *objectMeta, + } } func injectLogsSidecar(primaryContainer *v1.Container, podSpec *v1.PodSpec) { @@ -503,7 +611,125 @@ func getEventInfoForRayJob(logConfig logs.LogConfig, pluginContext k8s.PluginCon return &pluginsCore.TaskInfo{Logs: taskLogs}, nil } +func getEventInfoForRayJobV1(logConfig logs.LogConfig, pluginContext k8s.PluginContext, rayJob *rayv1.RayJob) (*pluginsCore.TaskInfo, error) { + logPlugin, err := logs.InitializeLogPlugins(&logConfig) + if err != nil { + return nil, fmt.Errorf("failed to initialize log plugins. Error: %w", err) + } + + var taskLogs []*core.TaskLog + + taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID() + input := tasklog.Input{ + Namespace: rayJob.Namespace, + TaskExecutionID: taskExecID, + ExtraTemplateVars: []tasklog.TemplateVar{}, + } + if rayJob.Status.JobId != "" { + input.ExtraTemplateVars = append( + input.ExtraTemplateVars, + tasklog.TemplateVar{ + Regex: logTemplateRegexes.RayJobID, + Value: rayJob.Status.JobId, + }, + ) + } + if rayJob.Status.RayClusterName != "" { + input.ExtraTemplateVars = append( + input.ExtraTemplateVars, + tasklog.TemplateVar{ + Regex: logTemplateRegexes.RayClusterName, + Value: rayJob.Status.RayClusterName, + }, + ) + } + + // TODO: Retrieve the name of head pod from rayJob.status, and add it to task logs + // RayJob CRD does not include the name of the worker or head pod for now + logOutput, err := logPlugin.GetTaskLogs(input) + if err != nil { + return nil, fmt.Errorf("failed to generate task logs. Error: %w", err) + } + taskLogs = append(taskLogs, logOutput.TaskLogs...) + + // Handling for Ray Dashboard + dashboardURLTemplate := GetConfig().DashboardURLTemplate + if dashboardURLTemplate != nil && + rayJob.Status.DashboardURL != "" && + rayJob.Status.JobStatus == rayv1.JobStatusRunning { + dashboardURLOutput, err := dashboardURLTemplate.GetTaskLogs(input) + if err != nil { + return nil, fmt.Errorf("failed to generate Ray dashboard link. Error: %w", err) + } + taskLogs = append(taskLogs, dashboardURLOutput.TaskLogs...) + } + + return &pluginsCore.TaskInfo{Logs: taskLogs}, nil +} + func (plugin rayJobResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) { + crdVersion := GetConfig().KubeRayCrdVersion + if crdVersion == "v1" { + return plugin.GetTaskPhaseV1(ctx, pluginContext, resource) + } + + return plugin.GetTaskPhaseV1Alpha1(ctx, pluginContext, resource) +} + +func (plugin rayJobResourceHandler) GetTaskPhaseV1(ctx context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) { + rayJob := resource.(*rayv1.RayJob) + info, err := getEventInfoForRayJobV1(GetConfig().Logs, pluginContext, rayJob) + if err != nil { + return pluginsCore.PhaseInfoUndefined, err + } + + if len(rayJob.Status.JobDeploymentStatus) == 0 { + return pluginsCore.PhaseInfoQueued(time.Now(), pluginsCore.DefaultPhaseVersion, "Scheduling"), nil + } + + // KubeRay creates a Ray cluster first, and then submits a Ray job to the cluster + switch rayJob.Status.JobDeploymentStatus { + case rayv1.JobDeploymentStatusInitializing: + return pluginsCore.PhaseInfoInitializing(rayJob.CreationTimestamp.Time, pluginsCore.DefaultPhaseVersion, "cluster is creating", info), nil + case rayv1.JobDeploymentStatusFailedToGetOrCreateRayCluster: + reason := fmt.Sprintf("Failed to create Ray cluster %s with error: %s", rayJob.Name, rayJob.Status.Message) + return pluginsCore.PhaseInfoFailure(flyteerr.TaskFailedWithError, reason, info), nil + case rayv1.JobDeploymentStatusFailedJobDeploy: + reason := fmt.Sprintf("Failed to submit Ray job %s with error: %s", rayJob.Name, rayJob.Status.Message) + return pluginsCore.PhaseInfoFailure(flyteerr.TaskFailedWithError, reason, info), nil + // JobDeploymentStatusSuspended is used when the suspend flag is set in rayJob. The suspend flag allows the temporary suspension of a Job's execution, which can be resumed later. + // Certain versions of KubeRay use a K8s job to submit a Ray job to the Ray cluster. JobDeploymentStatusWaitForK8sJob indicates that the K8s job is under creation. + case rayv1.JobDeploymentStatusWaitForDashboard, rayv1.JobDeploymentStatusFailedToGetJobStatus, rayv1.JobDeploymentStatusWaitForDashboardReady, rayv1.JobDeploymentStatusWaitForK8sJob, rayv1.JobDeploymentStatusSuspended: + return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info), nil + case rayv1.JobDeploymentStatusRunning, rayv1.JobDeploymentStatusComplete: + switch rayJob.Status.JobStatus { + case rayv1.JobStatusFailed: + reason := fmt.Sprintf("Failed to run Ray job %s with error: %s", rayJob.Name, rayJob.Status.Message) + return pluginsCore.PhaseInfoFailure(flyteerr.TaskFailedWithError, reason, info), nil + case rayv1.JobStatusSucceeded: + return pluginsCore.PhaseInfoSuccess(info), nil + // JobStatusStopped can occur when the suspend flag is set in rayJob. + case rayv1.JobStatusPending, rayv1.JobStatusStopped: + return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info), nil + case rayv1.JobStatusRunning: + phaseInfo := pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info) + if len(info.Logs) > 0 { + phaseInfo = phaseInfo.WithVersion(pluginsCore.DefaultPhaseVersion + 1) + } + return phaseInfo, nil + default: + // We already handle all known job status, so this should never happen unless a future version of ray + // introduced a new job status. + return pluginsCore.PhaseInfoUndefined, fmt.Errorf("unknown job status: %s", rayJob.Status.JobStatus) + } + default: + // We already handle all known deployment status, so this should never happen unless a future version of ray + // introduced a new job status. + return pluginsCore.PhaseInfoUndefined, fmt.Errorf("unknown job deployment status: %s", rayJob.Status.JobDeploymentStatus) + } +} + +func (plugin rayJobResourceHandler) GetTaskPhaseV1Alpha1(ctx context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) { rayJob := resource.(*rayv1alpha1.RayJob) info, err := getEventInfoForRayJob(GetConfig().Logs, pluginContext, rayJob) if err != nil { diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go index f4d802ff9ab..02ed83db142 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go @@ -6,6 +6,7 @@ import ( "time" structpb "github.com/golang/protobuf/ptypes/struct" + rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" rayv1alpha1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1alpha1" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -706,6 +707,50 @@ func TestGetTaskPhase(t *testing.T) { } } +func TestGetTaskPhase_V1(t *testing.T) { + ctx := context.Background() + rayJobResourceHandler := rayJobResourceHandler{} + pluginCtx := newPluginContext() + + testCases := []struct { + rayJobPhase rayv1.JobStatus + rayClusterPhase rayv1.JobDeploymentStatus + expectedCorePhase pluginsCore.Phase + expectedError bool + }{ + {"", rayv1.JobDeploymentStatusInitializing, pluginsCore.PhaseInitializing, false}, + {rayv1.JobStatusPending, rayv1.JobDeploymentStatusFailedToGetOrCreateRayCluster, pluginsCore.PhasePermanentFailure, false}, + {rayv1.JobStatusPending, rayv1.JobDeploymentStatusWaitForDashboard, pluginsCore.PhaseRunning, false}, + {rayv1.JobStatusPending, rayv1.JobDeploymentStatusWaitForDashboardReady, pluginsCore.PhaseRunning, false}, + {rayv1.JobStatusPending, rayv1.JobDeploymentStatusWaitForK8sJob, pluginsCore.PhaseRunning, false}, + {rayv1.JobStatusPending, rayv1.JobDeploymentStatusFailedJobDeploy, pluginsCore.PhasePermanentFailure, false}, + {rayv1.JobStatusPending, rayv1.JobDeploymentStatusRunning, pluginsCore.PhaseRunning, false}, + {rayv1.JobStatusPending, rayv1.JobDeploymentStatusFailedToGetJobStatus, pluginsCore.PhaseRunning, false}, + {rayv1.JobStatusRunning, rayv1.JobDeploymentStatusRunning, pluginsCore.PhaseRunning, false}, + {rayv1.JobStatusFailed, rayv1.JobDeploymentStatusRunning, pluginsCore.PhasePermanentFailure, false}, + {rayv1.JobStatusSucceeded, rayv1.JobDeploymentStatusRunning, pluginsCore.PhaseSuccess, false}, + {rayv1.JobStatusSucceeded, rayv1.JobDeploymentStatusComplete, pluginsCore.PhaseSuccess, false}, + {rayv1.JobStatusStopped, rayv1.JobDeploymentStatusSuspended, pluginsCore.PhaseRunning, false}, + } + + for _, tc := range testCases { + t.Run("TestGetTaskPhase_"+string(tc.rayJobPhase), func(t *testing.T) { + rayObject := &rayv1.RayJob{} + rayObject.Status.JobStatus = tc.rayJobPhase + rayObject.Status.JobDeploymentStatus = tc.rayClusterPhase + startTime := metav1.NewTime(time.Now()) + rayObject.Status.StartTime = &startTime + phaseInfo, err := rayJobResourceHandler.GetTaskPhaseV1(ctx, pluginCtx, rayObject) + if tc.expectedError { + assert.Error(t, err) + } else { + assert.Nil(t, err) + } + assert.Equal(t, tc.expectedCorePhase.String(), phaseInfo.Phase().String()) + }) + } +} + func TestGetEventInfo_LogTemplates(t *testing.T) { pluginCtx := newPluginContext() testCases := []struct { @@ -805,6 +850,105 @@ func TestGetEventInfo_LogTemplates(t *testing.T) { } } +func TestGetEventInfo_LogTemplates_V1(t *testing.T) { + pluginCtx := newPluginContext() + testCases := []struct { + name string + rayJob rayv1.RayJob + logPlugin tasklog.TemplateLogPlugin + expectedTaskLogs []*core.TaskLog + }{ + { + name: "namespace", + rayJob: rayv1.RayJob{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "test-namespace", + }, + }, + logPlugin: tasklog.TemplateLogPlugin{ + DisplayName: "namespace", + TemplateURIs: []tasklog.TemplateURI{"http://test/{{ .namespace }}"}, + }, + expectedTaskLogs: []*core.TaskLog{ + { + Name: "namespace", + Uri: "http://test/test-namespace", + }, + }, + }, + { + name: "task execution ID", + rayJob: rayv1.RayJob{}, + logPlugin: tasklog.TemplateLogPlugin{ + DisplayName: "taskExecID", + TemplateURIs: []tasklog.TemplateURI{ + "http://test/projects/{{ .executionProject }}/domains/{{ .executionDomain }}/executions/{{ .executionName }}/nodeId/{{ .nodeID }}/taskId/{{ .taskID }}/attempt/{{ .taskRetryAttempt }}", + }, + }, + expectedTaskLogs: []*core.TaskLog{ + { + Name: "taskExecID", + Uri: "http://test/projects/my-execution-project/domains/my-execution-domain/executions/my-execution-name/nodeId/unique-node/taskId/my-task-name/attempt/1", + }, + }, + }, + { + name: "ray cluster name", + rayJob: rayv1.RayJob{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "test-namespace", + }, + Status: rayv1.RayJobStatus{ + RayClusterName: "ray-cluster", + }, + }, + logPlugin: tasklog.TemplateLogPlugin{ + DisplayName: "ray cluster name", + TemplateURIs: []tasklog.TemplateURI{"http://test/{{ .namespace }}/{{ .rayClusterName }}"}, + }, + expectedTaskLogs: []*core.TaskLog{ + { + Name: "ray cluster name", + Uri: "http://test/test-namespace/ray-cluster", + }, + }, + }, + { + name: "ray job ID", + rayJob: rayv1.RayJob{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "test-namespace", + }, + Status: rayv1.RayJobStatus{ + JobId: "ray-job-1", + }, + }, + logPlugin: tasklog.TemplateLogPlugin{ + DisplayName: "ray job ID", + TemplateURIs: []tasklog.TemplateURI{"http://test/{{ .namespace }}/{{ .rayJobID }}"}, + }, + expectedTaskLogs: []*core.TaskLog{ + { + Name: "ray job ID", + Uri: "http://test/test-namespace/ray-job-1", + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ti, err := getEventInfoForRayJobV1( + logs.LogConfig{Templates: []tasklog.TemplateLogPlugin{tc.logPlugin}}, + pluginCtx, + &tc.rayJob, + ) + assert.NoError(t, err) + assert.Equal(t, tc.expectedTaskLogs, ti.Logs) + }) + } +} + func TestGetEventInfo_DashboardURL(t *testing.T) { pluginCtx := newPluginContext() testCases := []struct { @@ -857,6 +1001,58 @@ func TestGetEventInfo_DashboardURL(t *testing.T) { } } +func TestGetEventInfo_DashboardURL_V1(t *testing.T) { + pluginCtx := newPluginContext() + testCases := []struct { + name string + rayJob rayv1.RayJob + dashboardURLTemplate tasklog.TemplateLogPlugin + expectedTaskLogs []*core.TaskLog + }{ + { + name: "dashboard URL displayed", + rayJob: rayv1.RayJob{ + Status: rayv1.RayJobStatus{ + DashboardURL: "exists", + JobStatus: rayv1.JobStatusRunning, + }, + }, + dashboardURLTemplate: tasklog.TemplateLogPlugin{ + DisplayName: "Ray Dashboard", + TemplateURIs: []tasklog.TemplateURI{"http://test/{{.generatedName}}"}, + }, + expectedTaskLogs: []*core.TaskLog{ + { + Name: "Ray Dashboard", + Uri: "http://test/generated-name", + }, + }, + }, + { + name: "dashboard URL is not displayed", + rayJob: rayv1.RayJob{ + Status: rayv1.RayJobStatus{ + JobStatus: rayv1.JobStatusPending, + }, + }, + dashboardURLTemplate: tasklog.TemplateLogPlugin{ + DisplayName: "dummy", + TemplateURIs: []tasklog.TemplateURI{"http://dummy"}, + }, + expectedTaskLogs: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.NoError(t, SetConfig(&Config{DashboardURLTemplate: &tc.dashboardURLTemplate})) + ti, err := getEventInfoForRayJobV1(logs.LogConfig{}, pluginCtx, &tc.rayJob) + assert.NoError(t, err) + assert.Equal(t, tc.expectedTaskLogs, ti.Logs) + }) + } +} + func TestGetPropertiesRay(t *testing.T) { rayJobResourceHandler := rayJobResourceHandler{} expected := k8s.PluginProperties{}