Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Min Required Tuning Memory #440

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add oom default
ishaansehgal99 committed May 29, 2024
commit 8cc73a8b496a621c48da56455b39e806504ad02c
41 changes: 36 additions & 5 deletions api/v1alpha1/params_validation.go
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@ import (
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/klog/v2"
"knative.dev/pkg/apis"
"path/filepath"
"reflect"
@@ -108,7 +109,7 @@ func UnmarshalTrainingConfig(cm *corev1.ConfigMap) (*Config, *apis.FieldError) {
return &config, nil
}

func validateTrainingArgsViaConfigMap(cm *corev1.ConfigMap) *apis.FieldError {
func validateTrainingArgsViaConfigMap(cm *corev1.ConfigMap, modelName, methodLowerCase, sku string) *apis.FieldError {
config, err := UnmarshalTrainingConfig(cm)
if err != nil {
return err
@@ -137,13 +138,42 @@ func validateTrainingArgsViaConfigMap(cm *corev1.ConfigMap) *apis.FieldError {
}
}

// TODO: Here we perform the tuning GPU Memory Checks!
fmt.Println(trainingArgsRaw)
// Validate GPU Memory Requirements using batch size, tuning method, model etc.
errs := validateTuningParameters(modelName, methodLowerCase, sku)
if errs != nil {
return errs
}
}
}
return nil
}

func validateTuningParameters(modelName, methodLowerCase, sku string) *apis.FieldError {
skuConfig, skuExists := SupportedGPUConfigs[sku]
if !skuExists {
return apis.ErrInvalidValue(fmt.Sprintf("Unsupported SKU: '%s'", sku), "sku")
}
skuGPUMem := skuConfig.GPUMem

modelTuningConfig, modelExists := modelTuningConfigs[modelName]
if !modelExists {
//klog.Infof("Model '%s' hasn't been tested yet for fine-tuning. Proceed at your own risk.", modelName)
return nil
}

minGPURequired, methodExists := modelTuningConfig[methodLowerCase]
if !methodExists {
//klog.Infof("Tuning method '%s' for model '%s' hasn't been tested yet.", methodLowerCase, modelName)
return nil
}

if skuGPUMem < minGPURequired {
klog.Warningf("Insufficient GPU memory: For model '%s' with tuning method '%s', the SKU '%s' with %dGi GPU memory does not support even a batch size of 1 in testing. Proceed at your own risk.", modelName, methodLowerCase, sku, skuGPUMem)
return nil
}
return nil
}

func validateMethodViaConfigMap(cm *corev1.ConfigMap, methodLowerCase string) *apis.FieldError {
config, err := UnmarshalTrainingConfig(cm)
if err != nil {
@@ -250,7 +280,7 @@ func validateConfigMapSchema(cm *corev1.ConfigMap) *apis.FieldError {
return nil
}

func (r *TuningSpec) validateConfigMap(ctx context.Context, namespace string, methodLowerCase string, configMapName string) (errs *apis.FieldError) {
func (r *TuningSpec) validateConfigMap(ctx context.Context, namespace, methodLowerCase, sku string, configMapName string) (errs *apis.FieldError) {
var cm corev1.ConfigMap
if k8sclient.Client == nil {
errs = errs.Also(apis.ErrGeneric("Failed to obtain client from context.Context"))
@@ -270,7 +300,8 @@ func (r *TuningSpec) validateConfigMap(ctx context.Context, namespace string, me
if err := validateMethodViaConfigMap(&cm, methodLowerCase); err != nil {
errs = errs.Also(err)
}
if err := validateTrainingArgsViaConfigMap(&cm); err != nil {

if err := validateTrainingArgsViaConfigMap(&cm, string(r.Preset.Name), methodLowerCase, sku); err != nil {
errs = errs.Also(err)
}
}
11 changes: 11 additions & 0 deletions api/v1alpha1/tuning_config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package v1alpha1

// Map Representing Minimum Per GPU Memory required for Batch Size of 1
// ModelName, TuningMethod, MinGPUMemory
var modelTuningConfigs = map[string]map[string]int{
"falcon-7b": {
//string(TuningMethodLora): 24,
string(TuningMethodQLora): 16,
},
// Add more configurations as needed
}
8 changes: 4 additions & 4 deletions api/v1alpha1/workspace_validation.go
Original file line number Diff line number Diff line change
@@ -48,7 +48,7 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) {
}
if w.Tuning != nil {
// TODO: Add validate resource based on Tuning Spec
errs = errs.Also(w.Tuning.validateCreate(ctx, w.Namespace).ViaField("tuning"))
errs = errs.Also(w.Tuning.validateCreate(ctx, w.Namespace, w.Resource.InstanceType).ViaField("tuning"))
}
} else {
klog.InfoS("Validate update", "workspace", fmt.Sprintf("%s/%s", w.Namespace, w.Name))
@@ -89,7 +89,7 @@ func (w *Workspace) validateUpdate(old *Workspace) (errs *apis.FieldError) {
return errs
}

func (r *TuningSpec) validateCreate(ctx context.Context, workspaceNamespace string) (errs *apis.FieldError) {
func (r *TuningSpec) validateCreate(ctx context.Context, workspaceNamespace, sku string) (errs *apis.FieldError) {
methodLowerCase := strings.ToLower(string(r.Method))
if methodLowerCase != string(TuningMethodLora) && methodLowerCase != string(TuningMethodQLora) {
errs = errs.Also(apis.ErrInvalidValue(r.Method, "Method"))
@@ -106,11 +106,11 @@ func (r *TuningSpec) validateCreate(ctx context.Context, workspaceNamespace stri
} else if methodLowerCase == string(TuningMethodQLora) {
defaultConfigMapTemplateName = DefaultQloraConfigMapTemplate
}
if err := r.validateConfigMap(ctx, releaseNamespace, methodLowerCase, defaultConfigMapTemplateName); err != nil {
if err := r.validateConfigMap(ctx, releaseNamespace, methodLowerCase, sku, defaultConfigMapTemplateName); err != nil {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Failed to evaluate validateConfigMap: %v", err), "Config"))
}
} else {
if err := r.validateConfigMap(ctx, workspaceNamespace, methodLowerCase, r.ConfigTemplate); err != nil {
if err := r.validateConfigMap(ctx, workspaceNamespace, methodLowerCase, sku, r.ConfigTemplate); err != nil {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Failed to evaluate validateConfigMap: %v", err), "Config"))
}
}