diff --git a/pkg/api/handlers/modelprovider.go b/pkg/api/handlers/modelprovider.go index 230223ff..a9864c88 100644 --- a/pkg/api/handlers/modelprovider.go +++ b/pkg/api/handlers/modelprovider.go @@ -1,14 +1,20 @@ package handlers import ( + "encoding/json" "fmt" "strings" "github.com/gptscript-ai/go-gptscript" "github.com/obot-platform/obot/apiclient/types" + "github.com/obot-platform/obot/logger" "github.com/obot-platform/obot/pkg/api" "github.com/obot-platform/obot/pkg/gateway/server/dispatcher" + "github.com/obot-platform/obot/pkg/invoke" v1 "github.com/obot-platform/obot/pkg/storage/apis/otto.otto8.ai/v1" + "github.com/obot-platform/obot/pkg/system" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/fields" kclient "sigs.k8s.io/controller-runtime/pkg/client" ) @@ -16,12 +22,14 @@ import ( type ModelProviderHandler struct { gptscript *gptscript.GPTScript dispatcher *dispatcher.Dispatcher + invoker *invoke.Invoker } -func NewModelProviderHandler(gClient *gptscript.GPTScript, dispatcher *dispatcher.Dispatcher) *ModelProviderHandler { +func NewModelProviderHandler(gClient *gptscript.GPTScript, dispatcher *dispatcher.Dispatcher, invoker *invoke.Invoker) *ModelProviderHandler { return &ModelProviderHandler{ gptscript: gClient, dispatcher: dispatcher, + invoker: invoker, } } @@ -89,6 +97,78 @@ func (mp *ModelProviderHandler) List(req api.Context) error { return req.Write(types.ModelProviderList{Items: resp}) } +type ValidationError struct { + Err string `json:"error"` +} + +func (ve *ValidationError) Error() string { + return fmt.Sprintf("model-provider credentials validation failed: {\"error\": \"%s\"}", ve.Err) +} + +func (mp *ModelProviderHandler) Validate(req api.Context) error { + var ref v1.ToolReference + if err := req.Get(&ref, req.PathValue("id")); err != nil { + return err + } + + if ref.Spec.Type != types.ToolReferenceTypeModelProvider { + return types.NewErrBadRequest("%q is not a model provider", ref.Name) + } + + l := logger.Package() + l.Debugf("Validating model provider %q", ref.Name) + + var envVars map[string]string + if err := req.Read(&envVars); err != nil { + return err + } + + envs := make([]string, 0, len(envVars)) + for key, val := range envVars { + envs = append(envs, key+"="+val) + } + + thread := &v1.Thread{ + ObjectMeta: metav1.ObjectMeta{ + GenerateName: system.ThreadPrefix + "-" + ref.Name + "-validate-", + Namespace: ref.Namespace, + }, + Spec: v1.ThreadSpec{ + SystemTask: true, + }, + } + + if err := req.Get(thread, thread.Name); apierrors.IsNotFound(err) { + if err = req.Create(thread); err != nil { + return fmt.Errorf("failed to create thread: %w", err) + } + } + + defer func() { _ = req.Delete(thread) }() + + task, err := mp.invoker.SystemTask(req.Context(), thread, "validate from "+ref.Spec.Reference, "", invoke.SystemTaskOptions{Env: envs}) + if err != nil { + return err + } + defer task.Close() + + res, err := task.Result(req.Context()) + if err != nil { + if strings.Contains(err.Error(), "tool not found: validate from "+ref.Spec.Reference) { // there's no simple way to do errors.As/.Is at this point unfortunately + l.Infof("Model provider %q does not provide a validate tool. Looking for 'validate from %s'", ref.Name, ref.Spec.Reference) + return nil // Do not fail if model provider doesn't provide a validate tool + } + return &ValidationError{Err: strings.Trim(err.Error(), "\"'")} + } + + var validationError ValidationError + if json.Unmarshal([]byte(res.Output), &validationError) == nil && validationError.Err != "" { + return &validationError + } + + return nil +} + func (mp *ModelProviderHandler) Configure(req api.Context) error { var ref v1.ToolReference if err := req.Get(&ref, req.PathValue("id")); err != nil { diff --git a/pkg/api/router/router.go b/pkg/api/router/router.go index f0f639fa..b9b22e20 100644 --- a/pkg/api/router/router.go +++ b/pkg/api/router/router.go @@ -24,7 +24,7 @@ func Router(services *services.Services) (http.Handler, error) { cronJobs := handlers.NewCronJobHandler() models := handlers.NewModelHandler() availableModels := handlers.NewAvailableModelsHandler(services.GPTClient, services.ModelProviderDispatcher) - modelProviders := handlers.NewModelProviderHandler(services.GPTClient, services.ModelProviderDispatcher) + modelProviders := handlers.NewModelProviderHandler(services.GPTClient, services.ModelProviderDispatcher, services.Invoker) prompt := handlers.NewPromptHandler(services.GPTClient) emailreceiver := handlers.NewEmailReceiverHandler(services.EmailServerName) defaultModelAliases := handlers.NewDefaultModelAliasHandler() @@ -276,6 +276,7 @@ func Router(services *services.Services) (http.Handler, error) { // Model providers mux.HandleFunc("GET /api/model-providers", modelProviders.List) mux.HandleFunc("GET /api/model-providers/{id}", modelProviders.ByID) + mux.HandleFunc("POST /api/model-providers/{id}/validate", modelProviders.Validate) mux.HandleFunc("POST /api/model-providers/{id}/configure", modelProviders.Configure) mux.HandleFunc("POST /api/model-providers/{id}/deconfigure", modelProviders.Deconfigure) mux.HandleFunc("POST /api/model-providers/{id}/reveal", modelProviders.Reveal) diff --git a/pkg/availablemodels/availablemodels.go b/pkg/availablemodels/availablemodels.go index 40ae8dfc..6fec43e6 100644 --- a/pkg/availablemodels/availablemodels.go +++ b/pkg/availablemodels/availablemodels.go @@ -19,7 +19,7 @@ func ForProvider(ctx context.Context, dispatcher *dispatcher.Dispatcher, modelPr r, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String()+"/v1/models", nil) if err != nil { - return nil, fmt.Errorf("failed to create request to model provider %s: %w", modelProviderName, err) + return nil, fmt.Errorf("failed to create request to model provider %q: %w", modelProviderName, err) } if token != "" { @@ -28,18 +28,18 @@ func ForProvider(ctx context.Context, dispatcher *dispatcher.Dispatcher, modelPr resp, err := http.DefaultClient.Do(r) if err != nil { - return nil, fmt.Errorf("failed to make request to model provider %s: %w", modelProviderName, err) + return nil, fmt.Errorf("failed to make request to model provider %q: %w", modelProviderName, err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { message, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("failed to get model list from model provider %s: %s", modelProviderName, message) + return nil, fmt.Errorf("failed to get model list from model provider %q: %s", modelProviderName, message) } var oModels openai.ModelsList if err = json.NewDecoder(resp.Body).Decode(&oModels); err != nil { - return nil, fmt.Errorf("failed to decode model list from model provider %s: %w", modelProviderName, err) + return nil, fmt.Errorf("failed to decode model list from model provider %q: %w", modelProviderName, err) } return &oModels, nil diff --git a/pkg/controller/handlers/toolreference/toolreference.go b/pkg/controller/handlers/toolreference/toolreference.go index b10338a3..3dd77b99 100644 --- a/pkg/controller/handlers/toolreference/toolreference.go +++ b/pkg/controller/handlers/toolreference/toolreference.go @@ -3,11 +3,13 @@ package toolreference import ( "context" "crypto/sha256" + "encoding/json" "errors" "fmt" "net/url" "os" "path" + "regexp" "slices" "strings" "time" @@ -450,7 +452,23 @@ func (h *Handler) BackPopulateModels(req router.Request, _ router.Response) erro availableModels, err := availablemodels.ForProvider(req.Ctx, h.dispatcher, req.Namespace, req.Name) if err != nil { // Don't error and retry because it will likely fail again. Log the error, and the user can re-sync manually. - log.Errorf("Failed to get available models for model provider %q: %v", toolRef.Name, err) + // Also, the toolRef.Status.Error field will bubble up to the user in the UI. + + // Check if the model provider returned a properly formatted error message and set it as status + re := regexp.MustCompile(`\{.*"error":.*}`) + match := re.FindString(err.Error()) + if match != "" { + toolRef.Status.Error = match + type errorResponse struct { + Error string `json:"error"` + } + var eR errorResponse + if err := json.Unmarshal([]byte(match), &eR); err == nil { + toolRef.Status.Error = eR.Error + } + } + + log.Errorf("%v", err) return nil } diff --git a/ui/admin/app/components/model-providers/ModelProviderForm.tsx b/ui/admin/app/components/model-providers/ModelProviderForm.tsx index 9a8ca1f5..2bd5859b 100644 --- a/ui/admin/app/components/model-providers/ModelProviderForm.tsx +++ b/ui/admin/app/components/model-providers/ModelProviderForm.tsx @@ -124,6 +124,24 @@ export function ModelProviderForm({ } ); + const validateAndConfigureModelProvider = useAsync( + ModelProviderApiService.validateModelProviderById, + { + onSuccess: async (data, params) => { + // Only configure the model provider if validation was successful + const [modelProviderId, configParams] = params; + await configureModelProvider.execute( + modelProviderId, + configParams + ); + }, + onError: (error) => { + // Handle validation errors + console.error("Validation failed:", error); + }, + } + ); + const configureModelProvider = useAsync( ModelProviderApiService.configureModelProviderById, { @@ -185,7 +203,7 @@ export function ModelProviderForm({ } ); - await configureModelProvider.execute( + await validateAndConfigureModelProvider.execute( modelProvider.id, allConfigParams ); @@ -197,24 +215,57 @@ export function ModelProviderForm({ modelProvider.id === "azure-openai-model-provider"; const loading = + validateAndConfigureModelProvider.isLoading || fetchAvailableModels.isLoading || configureModelProvider.isLoading || isLoading; + return (