Skip to content

Commit

Permalink
Add "ReadExt" method and update KFTO tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ChughShilpa authored and openshift-merge-bot[bot] committed Nov 21, 2024
1 parent 8803090 commit 139f8d2
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 56 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ toolchain go1.21.5
require (
github.com/kubeflow/training-operator v1.7.0
github.com/onsi/gomega v1.31.1
github.com/project-codeflare/codeflare-common v0.0.0-20241108084652-0d76fd215a22
github.com/project-codeflare/codeflare-common v0.0.0-20241121090634-e99e941c6921
github.com/prometheus/client_golang v1.20.4
github.com/prometheus/common v0.57.0
github.com/ray-project/kuberay/ray-operator v1.1.0-alpha.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/project-codeflare/appwrapper v0.8.0 h1:vWHNtXUtHutN2EzYb6rryLdESnb8iDXsCokXOuNYXvg=
github.com/project-codeflare/appwrapper v0.8.0/go.mod h1:FMQ2lI3fz6LakUVXgN1FTdpsc3BBkNIZZgtMmM9J5UM=
github.com/project-codeflare/codeflare-common v0.0.0-20241108084652-0d76fd215a22 h1:wzIJHoGAmNZupO3ZI7gbONuXgIUireabHsZvMt+3fqQ=
github.com/project-codeflare/codeflare-common v0.0.0-20241108084652-0d76fd215a22/go.mod h1:v7XKwaDoCspsHQlWJNarO7gOpR+iumSS+c1bWs3kJOI=
github.com/project-codeflare/codeflare-common v0.0.0-20241121090634-e99e941c6921 h1:OI9jKDW4yxbXDTpf4Y+8H4uVfdCH+jIqN0JTQfdUMYw=
github.com/project-codeflare/codeflare-common v0.0.0-20241121090634-e99e941c6921/go.mod h1:v7XKwaDoCspsHQlWJNarO7gOpR+iumSS+c1bWs3kJOI=
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
Expand Down
28 changes: 0 additions & 28 deletions tests/kfto/core/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,12 @@ const (
minioCliImageEnvVar = "MINIO_CLI_IMAGE"
// The environment variable for HuggingFace token to download models which require authentication
huggingfaceTokenEnvVar = "HF_TOKEN"
// The environment variable specifying existing namespace name to be used for tests
testNamespaceEnvVar = "TEST_NAMESPACE_NAME"
// The environment variable specifying name of PersistenceVolumeClaim containing GPTQ models
gptqModelPvcNameEnvVar = "GPTQ_MODEL_PVC_NAME"
// The environment variable referring to image simulating sleep condition in container
sleepImageEnvVar = "SLEEP_IMAGE"
// The environment variable specifying s3 bucket folder path used to store model
storageBucketModelPath = "AWS_STORAGE_BUCKET_MODEL_PATH"
// The environment variable for the CUDA training image
cudaTrainingImageEnvVar = "CUDA_TRAINING_IMAGE"
// The environment variable for the ROCm training image
rocmTrainingImageEnvVar = "ROCM_TRAINING_IMAGE"
)

func GetFmsHfTuningImage(t Test) string {
Expand All @@ -57,24 +51,6 @@ func GetFmsHfTuningImage(t Test) string {
return image
}

func GetCudaTrainingImage(t Test) string {
t.T().Helper()
image, ok := os.LookupEnv(cudaTrainingImageEnvVar)
if !ok {
t.T().Fatalf("Expected environment variable %s not found, please use this environment variable to specify the cuda training image to be tested.", cudaTrainingImageEnvVar)
}
return image
}

func GetROCmTrainingImage(t Test) string {
t.T().Helper()
image, ok := os.LookupEnv(rocmTrainingImageEnvVar)
if !ok {
t.T().Fatalf("Expected environment variable %s not found, please use this environment variable to specify the cuda training image to be tested.", rocmTrainingImageEnvVar)
}
return image
}

func GetBloomModelImage() string {
return lookupEnvOrDefault(bloomModelImageEnvVar, "quay.io/ksuta/bloom-560m@sha256:f6db02bb7b5d09a8d698c04994d747bfb9e581bbb4c07d00290244d207623733")
}
Expand All @@ -96,10 +72,6 @@ func GetHuggingFaceToken(t Test) string {
return image
}

func GetTestNamespaceName() (namespaceName string, exists bool) {
return os.LookupEnv(testNamespaceEnvVar)
}

func GetGptqModelPvcName() (string, error) {
image, ok := os.LookupEnv(gptqModelPvcNameEnvVar)
if !ok {
Expand Down
13 changes: 5 additions & 8 deletions tests/kfto/core/kfto_pytorchjob_failed_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package core

import (
"testing"

. "github.com/onsi/gomega"
. "github.com/project-codeflare/codeflare-common/support"
"testing"

corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
Expand All @@ -13,15 +14,11 @@ import (
)

func TestPyTorchJobFailureWithCuda(t *testing.T) {
test := With(t)
cudaBaseImage := GetCudaTrainingImage(test)
runFailedPyTorchJobTest(t, cudaBaseImage)
runFailedPyTorchJobTest(t, GetCudaTrainingImage())
}

func TestPyTorchJobFailureWithROCm(t *testing.T) {
test := With(t)
rocmBaseImage := GetROCmTrainingImage(test)
runFailedPyTorchJobTest(t, rocmBaseImage)
runFailedPyTorchJobTest(t, GetROCmTrainingImage())
}

func runFailedPyTorchJobTest(t *testing.T, image string) {
Expand Down Expand Up @@ -65,7 +62,7 @@ func createFailedPyTorchJob(test Test, namespace string, config corev1.ConfigMap
{
Name: "pytorch",
Image: baseImage,
Command: []string{"python", "-c", "raise Exception('Test failure')"},
Command: []string{"python", "-c", "raise Exception('Test failure')"},
ImagePullPolicy: corev1.PullIfNotPresent,
Resources: corev1.ResourceRequirements{
Requests: corev1.ResourceList{
Expand Down
20 changes: 3 additions & 17 deletions tests/kfto/core/kfto_training_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package core

import (
"fmt"
"os"
"testing"

. "github.com/onsi/gomega"
Expand All @@ -32,17 +31,11 @@ import (
)

func TestPyTorchJobWithCuda(t *testing.T) {
test := With(t)
cudaBaseImage := GetCudaTrainingImage(test)
gpuLabel := "nvidia.com/gpu"
runKFTOPyTorchJob(t, cudaBaseImage, gpuLabel, 1)
runKFTOPyTorchJob(t, GetCudaTrainingImage(), "nvidia.com/gpu", 1)
}

func TestPyTorchJobWithROCm(t *testing.T) {
test := With(t)
rocmBaseImage := GetROCmTrainingImage(test)
gpuLabel := "amd.com/gpu"
runKFTOPyTorchJob(t, rocmBaseImage, gpuLabel, 1)
runKFTOPyTorchJob(t, GetROCmTrainingImage(), "amd.com/gpu", 1)
}

func runKFTOPyTorchJob(t *testing.T, image string, gpuLabel string, numGpus int) {
Expand All @@ -51,16 +44,9 @@ func runKFTOPyTorchJob(t *testing.T, image string, gpuLabel string, numGpus int)
// Create a namespace
namespace := GetOrCreateTestNamespace(test)

// Parse training script
trainingScriptPath := "hf_llm_training.py"
trainingScript, err := os.ReadFile(trainingScriptPath)
if err != nil {
test.T().Fatalf("Error reading training script file: %v", err)
}

// Create a ConfigMap with training script
configData := map[string][]byte{
"hf_llm_training.py": trainingScript,
"hf_llm_training.py": ReadFileExt(test, "hf_llm_training.py"),
}
config := CreateConfigMap(test, namespace, configData)

Expand Down
9 changes: 9 additions & 0 deletions tests/kfto/core/support.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ package core
import (
"embed"
"fmt"
"os"
"time"

"github.com/onsi/gomega"
. "github.com/onsi/gomega"
. "github.com/project-codeflare/codeflare-common/support"

Expand All @@ -41,6 +43,13 @@ func ReadFile(t Test, fileName string) []byte {
return file
}

func ReadFileExt(t Test, fileName string) []byte {
t.T().Helper()
file, err := os.ReadFile(fileName)
t.Expect(err).NotTo(gomega.HaveOccurred())
return file
}

func PyTorchJob(t Test, namespace, name string) func(g Gomega) *kftov1.PyTorchJob {
return func(g Gomega) *kftov1.PyTorchJob {
job, err := t.Client().Kubeflow().KubeflowV1().PyTorchJobs(namespace).Get(t.Ctx(), name, metav1.GetOptions{})
Expand Down

0 comments on commit 139f8d2

Please sign in to comment.