From fb1cb19910cf6a5664d38f405f0b113738d9b6c8 Mon Sep 17 00:00:00 2001 From: Thorsten Klein Date: Mon, 23 Dec 2024 10:04:11 +0100 Subject: [PATCH] feat: call the validate subtool from a model-provider tool if available before saving the provider config --- pkg/api/handlers/modelprovider.go | 82 ++++++++++++++++++- pkg/api/router/router.go | 3 +- pkg/availablemodels/availablemodels.go | 8 +- .../handlers/toolreference/toolreference.go | 20 ++++- .../model-providers/ModelProviderForm.tsx | 61 ++++++++++++-- ui/admin/app/hooks/useAsync.tsx | 14 ++++ ui/admin/app/lib/routers/apiRoutes.ts | 2 + .../service/api/modelProviderApiService.ts | 18 ++++ 8 files changed, 196 insertions(+), 12 deletions(-) diff --git a/pkg/api/handlers/modelprovider.go b/pkg/api/handlers/modelprovider.go index 230223ff..a9864c88 100644 --- a/pkg/api/handlers/modelprovider.go +++ b/pkg/api/handlers/modelprovider.go @@ -1,14 +1,20 @@ package handlers import ( + "encoding/json" "fmt" "strings" "github.com/gptscript-ai/go-gptscript" "github.com/obot-platform/obot/apiclient/types" + "github.com/obot-platform/obot/logger" "github.com/obot-platform/obot/pkg/api" "github.com/obot-platform/obot/pkg/gateway/server/dispatcher" + "github.com/obot-platform/obot/pkg/invoke" v1 "github.com/obot-platform/obot/pkg/storage/apis/otto.otto8.ai/v1" + "github.com/obot-platform/obot/pkg/system" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/fields" kclient "sigs.k8s.io/controller-runtime/pkg/client" ) @@ -16,12 +22,14 @@ import ( type ModelProviderHandler struct { gptscript *gptscript.GPTScript dispatcher *dispatcher.Dispatcher + invoker *invoke.Invoker } -func NewModelProviderHandler(gClient *gptscript.GPTScript, dispatcher *dispatcher.Dispatcher) *ModelProviderHandler { +func NewModelProviderHandler(gClient *gptscript.GPTScript, dispatcher *dispatcher.Dispatcher, invoker *invoke.Invoker) *ModelProviderHandler { return &ModelProviderHandler{ gptscript: gClient, dispatcher: dispatcher, + invoker: invoker, } } @@ -89,6 +97,78 @@ func (mp *ModelProviderHandler) List(req api.Context) error { return req.Write(types.ModelProviderList{Items: resp}) } +type ValidationError struct { + Err string `json:"error"` +} + +func (ve *ValidationError) Error() string { + return fmt.Sprintf("model-provider credentials validation failed: {\"error\": \"%s\"}", ve.Err) +} + +func (mp *ModelProviderHandler) Validate(req api.Context) error { + var ref v1.ToolReference + if err := req.Get(&ref, req.PathValue("id")); err != nil { + return err + } + + if ref.Spec.Type != types.ToolReferenceTypeModelProvider { + return types.NewErrBadRequest("%q is not a model provider", ref.Name) + } + + l := logger.Package() + l.Debugf("Validating model provider %q", ref.Name) + + var envVars map[string]string + if err := req.Read(&envVars); err != nil { + return err + } + + envs := make([]string, 0, len(envVars)) + for key, val := range envVars { + envs = append(envs, key+"="+val) + } + + thread := &v1.Thread{ + ObjectMeta: metav1.ObjectMeta{ + GenerateName: system.ThreadPrefix + "-" + ref.Name + "-validate-", + Namespace: ref.Namespace, + }, + Spec: v1.ThreadSpec{ + SystemTask: true, + }, + } + + if err := req.Get(thread, thread.Name); apierrors.IsNotFound(err) { + if err = req.Create(thread); err != nil { + return fmt.Errorf("failed to create thread: %w", err) + } + } + + defer func() { _ = req.Delete(thread) }() + + task, err := mp.invoker.SystemTask(req.Context(), thread, "validate from "+ref.Spec.Reference, "", invoke.SystemTaskOptions{Env: envs}) + if err != nil { + return err + } + defer task.Close() + + res, err := task.Result(req.Context()) + if err != nil { + if strings.Contains(err.Error(), "tool not found: validate from "+ref.Spec.Reference) { // there's no simple way to do errors.As/.Is at this point unfortunately + l.Infof("Model provider %q does not provide a validate tool. Looking for 'validate from %s'", ref.Name, ref.Spec.Reference) + return nil // Do not fail if model provider doesn't provide a validate tool + } + return &ValidationError{Err: strings.Trim(err.Error(), "\"'")} + } + + var validationError ValidationError + if json.Unmarshal([]byte(res.Output), &validationError) == nil && validationError.Err != "" { + return &validationError + } + + return nil +} + func (mp *ModelProviderHandler) Configure(req api.Context) error { var ref v1.ToolReference if err := req.Get(&ref, req.PathValue("id")); err != nil { diff --git a/pkg/api/router/router.go b/pkg/api/router/router.go index f0f639fa..b9b22e20 100644 --- a/pkg/api/router/router.go +++ b/pkg/api/router/router.go @@ -24,7 +24,7 @@ func Router(services *services.Services) (http.Handler, error) { cronJobs := handlers.NewCronJobHandler() models := handlers.NewModelHandler() availableModels := handlers.NewAvailableModelsHandler(services.GPTClient, services.ModelProviderDispatcher) - modelProviders := handlers.NewModelProviderHandler(services.GPTClient, services.ModelProviderDispatcher) + modelProviders := handlers.NewModelProviderHandler(services.GPTClient, services.ModelProviderDispatcher, services.Invoker) prompt := handlers.NewPromptHandler(services.GPTClient) emailreceiver := handlers.NewEmailReceiverHandler(services.EmailServerName) defaultModelAliases := handlers.NewDefaultModelAliasHandler() @@ -276,6 +276,7 @@ func Router(services *services.Services) (http.Handler, error) { // Model providers mux.HandleFunc("GET /api/model-providers", modelProviders.List) mux.HandleFunc("GET /api/model-providers/{id}", modelProviders.ByID) + mux.HandleFunc("POST /api/model-providers/{id}/validate", modelProviders.Validate) mux.HandleFunc("POST /api/model-providers/{id}/configure", modelProviders.Configure) mux.HandleFunc("POST /api/model-providers/{id}/deconfigure", modelProviders.Deconfigure) mux.HandleFunc("POST /api/model-providers/{id}/reveal", modelProviders.Reveal) diff --git a/pkg/availablemodels/availablemodels.go b/pkg/availablemodels/availablemodels.go index 40ae8dfc..6fec43e6 100644 --- a/pkg/availablemodels/availablemodels.go +++ b/pkg/availablemodels/availablemodels.go @@ -19,7 +19,7 @@ func ForProvider(ctx context.Context, dispatcher *dispatcher.Dispatcher, modelPr r, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String()+"/v1/models", nil) if err != nil { - return nil, fmt.Errorf("failed to create request to model provider %s: %w", modelProviderName, err) + return nil, fmt.Errorf("failed to create request to model provider %q: %w", modelProviderName, err) } if token != "" { @@ -28,18 +28,18 @@ func ForProvider(ctx context.Context, dispatcher *dispatcher.Dispatcher, modelPr resp, err := http.DefaultClient.Do(r) if err != nil { - return nil, fmt.Errorf("failed to make request to model provider %s: %w", modelProviderName, err) + return nil, fmt.Errorf("failed to make request to model provider %q: %w", modelProviderName, err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { message, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("failed to get model list from model provider %s: %s", modelProviderName, message) + return nil, fmt.Errorf("failed to get model list from model provider %q: %s", modelProviderName, message) } var oModels openai.ModelsList if err = json.NewDecoder(resp.Body).Decode(&oModels); err != nil { - return nil, fmt.Errorf("failed to decode model list from model provider %s: %w", modelProviderName, err) + return nil, fmt.Errorf("failed to decode model list from model provider %q: %w", modelProviderName, err) } return &oModels, nil diff --git a/pkg/controller/handlers/toolreference/toolreference.go b/pkg/controller/handlers/toolreference/toolreference.go index b10338a3..3dd77b99 100644 --- a/pkg/controller/handlers/toolreference/toolreference.go +++ b/pkg/controller/handlers/toolreference/toolreference.go @@ -3,11 +3,13 @@ package toolreference import ( "context" "crypto/sha256" + "encoding/json" "errors" "fmt" "net/url" "os" "path" + "regexp" "slices" "strings" "time" @@ -450,7 +452,23 @@ func (h *Handler) BackPopulateModels(req router.Request, _ router.Response) erro availableModels, err := availablemodels.ForProvider(req.Ctx, h.dispatcher, req.Namespace, req.Name) if err != nil { // Don't error and retry because it will likely fail again. Log the error, and the user can re-sync manually. - log.Errorf("Failed to get available models for model provider %q: %v", toolRef.Name, err) + // Also, the toolRef.Status.Error field will bubble up to the user in the UI. + + // Check if the model provider returned a properly formatted error message and set it as status + re := regexp.MustCompile(`\{.*"error":.*}`) + match := re.FindString(err.Error()) + if match != "" { + toolRef.Status.Error = match + type errorResponse struct { + Error string `json:"error"` + } + var eR errorResponse + if err := json.Unmarshal([]byte(match), &eR); err == nil { + toolRef.Status.Error = eR.Error + } + } + + log.Errorf("%v", err) return nil } diff --git a/ui/admin/app/components/model-providers/ModelProviderForm.tsx b/ui/admin/app/components/model-providers/ModelProviderForm.tsx index 9a8ca1f5..2bd5859b 100644 --- a/ui/admin/app/components/model-providers/ModelProviderForm.tsx +++ b/ui/admin/app/components/model-providers/ModelProviderForm.tsx @@ -124,6 +124,24 @@ export function ModelProviderForm({ } ); + const validateAndConfigureModelProvider = useAsync( + ModelProviderApiService.validateModelProviderById, + { + onSuccess: async (data, params) => { + // Only configure the model provider if validation was successful + const [modelProviderId, configParams] = params; + await configureModelProvider.execute( + modelProviderId, + configParams + ); + }, + onError: (error) => { + // Handle validation errors + console.error("Validation failed:", error); + }, + } + ); + const configureModelProvider = useAsync( ModelProviderApiService.configureModelProviderById, { @@ -185,7 +203,7 @@ export function ModelProviderForm({ } ); - await configureModelProvider.execute( + await validateAndConfigureModelProvider.execute( modelProvider.id, allConfigParams ); @@ -197,24 +215,57 @@ export function ModelProviderForm({ modelProvider.id === "azure-openai-model-provider"; const loading = + validateAndConfigureModelProvider.isLoading || fetchAvailableModels.isLoading || configureModelProvider.isLoading || isLoading; + return (
- {fetchAvailableModels.error !== null && ( + {validateAndConfigureModelProvider.error !== null && (
An error occurred! - Your configuration was saved, but we were not able - to connect to the model provider. Please check your - configuration and try again. + Your configuration could not be saved, because it + failed validation:{" "} + + {(typeof validateAndConfigureModelProvider.error === + "object" && + "message" in + validateAndConfigureModelProvider.error && + (validateAndConfigureModelProvider.error + .message as string)) ?? + "Unknown error"} +
)} + {validateAndConfigureModelProvider.error === null && + fetchAvailableModels.error !== null && ( +
+ + + An error occurred! + + Your configuration was saved, but we were not + able to connect to the model provider. Please + check your configuration and try again:{" "} + + {(typeof fetchAvailableModels.error === + "object" && + "message" in + fetchAvailableModels.error && + (fetchAvailableModels.error + .message as string)) ?? + "Unknown error"} + + + +
+ )}
diff --git a/ui/admin/app/hooks/useAsync.tsx b/ui/admin/app/hooks/useAsync.tsx index d0246107..6fa7a985 100644 --- a/ui/admin/app/hooks/useAsync.tsx +++ b/ui/admin/app/hooks/useAsync.tsx @@ -43,6 +43,20 @@ export function useAsync( onSuccess?.(data, params); }) .catch((error) => { + if ( + error.response && + typeof error.response.data === "string" + ) { + const errorMessageMatch = + error.response.data.match(/{"error":\s+"(.*?)"}/); + if (errorMessageMatch) { + const errorMessage = JSON.parse( + errorMessageMatch[0] + ).error; + console.log("Error: ", errorMessage); + error.message = errorMessage; + } + } setError(error); onError?.(error, params); }) diff --git a/ui/admin/app/lib/routers/apiRoutes.ts b/ui/admin/app/lib/routers/apiRoutes.ts index 95204b69..3306a48b 100644 --- a/ui/admin/app/lib/routers/apiRoutes.ts +++ b/ui/admin/app/lib/routers/apiRoutes.ts @@ -248,6 +248,8 @@ export const ApiRoutes = { getModelProviders: () => buildUrl("/model-providers"), getModelProviderById: (modelProviderKey: string) => buildUrl(`/model-providers/${modelProviderKey}`), + validateModelProviderById: (modelProviderKey: string) => + buildUrl(`/model-providers/${modelProviderKey}/validate`), configureModelProviderById: (modelProviderKey: string) => buildUrl(`/model-providers/${modelProviderKey}/configure`), revealModelProviderById: (modelProviderKey: string) => diff --git a/ui/admin/app/lib/service/api/modelProviderApiService.ts b/ui/admin/app/lib/service/api/modelProviderApiService.ts index 7b0fab58..ad8a1ef1 100644 --- a/ui/admin/app/lib/service/api/modelProviderApiService.ts +++ b/ui/admin/app/lib/service/api/modelProviderApiService.ts @@ -34,6 +34,23 @@ getModelProviderById.key = (modelProviderId?: string) => { }; }; +const validateModelProviderById = async ( + modelProviderKey: string, + modelProviderConfig: ModelProviderConfig +) => { + const res = await request({ + url: ApiRoutes.modelProviders.validateModelProviderById( + modelProviderKey + ).url, + method: "POST", + data: modelProviderConfig, + errorMessage: + "Failed to validate configuration values on the requested modal provider.", + }); + + return res.data; +}; + const configureModelProviderById = async ( modelProviderKey: string, modelProviderConfig: ModelProviderConfig @@ -87,6 +104,7 @@ const deconfigureModelProviderById = async (modelProviderKey: string) => { export const ModelProviderApiService = { getModelProviders, getModelProviderById, + validateModelProviderById, configureModelProviderById, revealModelProviderById, deconfigureModelProviderById,