Skip to content

Commit

Permalink
feat: call the validate subtool from a model-provider tool if available
Browse files Browse the repository at this point in the history
before saving the provider config
  • Loading branch information
iwilltry42 committed Jan 3, 2025
1 parent d3363e6 commit e4bbfe1
Show file tree
Hide file tree
Showing 8 changed files with 196 additions and 12 deletions.
82 changes: 81 additions & 1 deletion pkg/api/handlers/modelprovider.go
Original file line number Diff line number Diff line change
@@ -1,27 +1,35 @@
package handlers

import (
"encoding/json"
"fmt"
"strings"

"github.com/gptscript-ai/go-gptscript"
"github.com/obot-platform/obot/apiclient/types"
"github.com/obot-platform/obot/logger"
"github.com/obot-platform/obot/pkg/api"
"github.com/obot-platform/obot/pkg/gateway/server/dispatcher"
"github.com/obot-platform/obot/pkg/invoke"
v1 "github.com/obot-platform/obot/pkg/storage/apis/otto.otto8.ai/v1"
"github.com/obot-platform/obot/pkg/system"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/fields"
kclient "sigs.k8s.io/controller-runtime/pkg/client"
)

type ModelProviderHandler struct {
gptscript *gptscript.GPTScript
dispatcher *dispatcher.Dispatcher
invoker *invoke.Invoker
}

func NewModelProviderHandler(gClient *gptscript.GPTScript, dispatcher *dispatcher.Dispatcher) *ModelProviderHandler {
func NewModelProviderHandler(gClient *gptscript.GPTScript, dispatcher *dispatcher.Dispatcher, invoker *invoke.Invoker) *ModelProviderHandler {
return &ModelProviderHandler{
gptscript: gClient,
dispatcher: dispatcher,
invoker: invoker,
}
}

Expand Down Expand Up @@ -89,6 +97,78 @@ func (mp *ModelProviderHandler) List(req api.Context) error {
return req.Write(types.ModelProviderList{Items: resp})
}

type ValidationError struct {
Err string `json:"error"`
}

func (ve *ValidationError) Error() string {
return fmt.Sprintf("model-provider credentials validation failed: {\"error\": \"%s\"}", ve.Err)
}

func (mp *ModelProviderHandler) Validate(req api.Context) error {
var ref v1.ToolReference
if err := req.Get(&ref, req.PathValue("id")); err != nil {
return err
}

if ref.Spec.Type != types.ToolReferenceTypeModelProvider {
return types.NewErrBadRequest("%q is not a model provider", ref.Name)
}

l := logger.Package()
l.Debugf("Validating model provider %q", ref.Name)

var envVars map[string]string
if err := req.Read(&envVars); err != nil {
return err
}

envs := make([]string, 0, len(envVars))
for key, val := range envVars {
envs = append(envs, key+"="+val)
}

thread := &v1.Thread{
ObjectMeta: metav1.ObjectMeta{
GenerateName: system.ThreadPrefix + "-" + ref.Name + "-validate-",
Namespace: ref.Namespace,
},
Spec: v1.ThreadSpec{
SystemTask: true,
},
}

if err := req.Get(thread, thread.Name); apierrors.IsNotFound(err) {
if err = req.Create(thread); err != nil {
return fmt.Errorf("failed to create thread: %w", err)
}
}

defer func() { _ = req.Delete(thread) }()

task, err := mp.invoker.SystemTask(req.Context(), thread, "validate from "+ref.Spec.Reference, "", invoke.SystemTaskOptions{Env: envs})
if err != nil {
return err
}
defer task.Close()

res, err := task.Result(req.Context())
if err != nil {
if strings.Contains(err.Error(), "tool not found: validate from "+ref.Spec.Reference) { // there's no simple way to do errors.As/.Is at this point unfortunately
l.Infof("Model provider %q does not provide a validate tool. Looking for 'validate from %s'", ref.Name, ref.Spec.Reference)
return nil // Do not fail if model provider doesn't provide a validate tool
}
return &ValidationError{Err: strings.Trim(err.Error(), "\"'")}
}

var validationError ValidationError
if json.Unmarshal([]byte(res.Output), &validationError) == nil && validationError.Err != "" {
return &validationError
}

return nil
}

func (mp *ModelProviderHandler) Configure(req api.Context) error {
var ref v1.ToolReference
if err := req.Get(&ref, req.PathValue("id")); err != nil {
Expand Down
3 changes: 2 additions & 1 deletion pkg/api/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func Router(services *services.Services) (http.Handler, error) {
cronJobs := handlers.NewCronJobHandler()
models := handlers.NewModelHandler()
availableModels := handlers.NewAvailableModelsHandler(services.GPTClient, services.ModelProviderDispatcher)
modelProviders := handlers.NewModelProviderHandler(services.GPTClient, services.ModelProviderDispatcher)
modelProviders := handlers.NewModelProviderHandler(services.GPTClient, services.ModelProviderDispatcher, services.Invoker)
prompt := handlers.NewPromptHandler(services.GPTClient)
emailreceiver := handlers.NewEmailReceiverHandler(services.EmailServerName)
defaultModelAliases := handlers.NewDefaultModelAliasHandler()
Expand Down Expand Up @@ -276,6 +276,7 @@ func Router(services *services.Services) (http.Handler, error) {
// Model providers
mux.HandleFunc("GET /api/model-providers", modelProviders.List)
mux.HandleFunc("GET /api/model-providers/{id}", modelProviders.ByID)
mux.HandleFunc("POST /api/model-providers/{id}/validate", modelProviders.Validate)
mux.HandleFunc("POST /api/model-providers/{id}/configure", modelProviders.Configure)
mux.HandleFunc("POST /api/model-providers/{id}/deconfigure", modelProviders.Deconfigure)
mux.HandleFunc("POST /api/model-providers/{id}/reveal", modelProviders.Reveal)
Expand Down
8 changes: 4 additions & 4 deletions pkg/availablemodels/availablemodels.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func ForProvider(ctx context.Context, dispatcher *dispatcher.Dispatcher, modelPr

r, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String()+"/v1/models", nil)
if err != nil {
return nil, fmt.Errorf("failed to create request to model provider %s: %w", modelProviderName, err)
return nil, fmt.Errorf("failed to create request to model provider %q: %w", modelProviderName, err)
}

if token != "" {
Expand All @@ -28,18 +28,18 @@ func ForProvider(ctx context.Context, dispatcher *dispatcher.Dispatcher, modelPr

resp, err := http.DefaultClient.Do(r)
if err != nil {
return nil, fmt.Errorf("failed to make request to model provider %s: %w", modelProviderName, err)
return nil, fmt.Errorf("failed to make request to model provider %q: %w", modelProviderName, err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
message, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("failed to get model list from model provider %s: %s", modelProviderName, message)
return nil, fmt.Errorf("failed to get model list from model provider %q: %s", modelProviderName, message)
}

var oModels openai.ModelsList
if err = json.NewDecoder(resp.Body).Decode(&oModels); err != nil {
return nil, fmt.Errorf("failed to decode model list from model provider %s: %w", modelProviderName, err)
return nil, fmt.Errorf("failed to decode model list from model provider %q: %w", modelProviderName, err)
}

return &oModels, nil
Expand Down
20 changes: 19 additions & 1 deletion pkg/controller/handlers/toolreference/toolreference.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ package toolreference
import (
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"net/url"
"os"
"path"
"regexp"
"slices"
"strings"
"time"
Expand Down Expand Up @@ -450,7 +452,23 @@ func (h *Handler) BackPopulateModels(req router.Request, _ router.Response) erro
availableModels, err := availablemodels.ForProvider(req.Ctx, h.dispatcher, req.Namespace, req.Name)
if err != nil {
// Don't error and retry because it will likely fail again. Log the error, and the user can re-sync manually.
log.Errorf("Failed to get available models for model provider %q: %v", toolRef.Name, err)
// Also, the toolRef.Status.Error field will bubble up to the user in the UI.

// Check if the model provider returned a properly formatted error message and set it as status
re := regexp.MustCompile(`\{.*"error":.*}`)
match := re.FindString(err.Error())
if match != "" {
toolRef.Status.Error = match
type errorResponse struct {
Error string `json:"error"`
}
var eR errorResponse
if err := json.Unmarshal([]byte(match), &eR); err == nil {
toolRef.Status.Error = eR.Error
}
}

log.Errorf("%v", err)
return nil
}

Expand Down
61 changes: 56 additions & 5 deletions ui/admin/app/components/model-providers/ModelProviderForm.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,24 @@ export function ModelProviderForm({
}
);

const validateAndConfigureModelProvider = useAsync(
ModelProviderApiService.validateModelProviderById,
{
onSuccess: async (data, params) => {
// Only configure the model provider if validation was successful
const [modelProviderId, configParams] = params;
await configureModelProvider.execute(
modelProviderId,
configParams
);
},
onError: (error) => {
// Handle validation errors
console.error("Validation failed:", error);
},
}
);

const configureModelProvider = useAsync(
ModelProviderApiService.configureModelProviderById,
{
Expand Down Expand Up @@ -185,7 +203,7 @@ export function ModelProviderForm({
}
);

await configureModelProvider.execute(
await validateAndConfigureModelProvider.execute(
modelProvider.id,
allConfigParams
);
Expand All @@ -197,24 +215,57 @@ export function ModelProviderForm({
modelProvider.id === "azure-openai-model-provider";

const loading =
validateAndConfigureModelProvider.isLoading ||
fetchAvailableModels.isLoading ||
configureModelProvider.isLoading ||
isLoading;

return (
<div className="flex flex-col">
{fetchAvailableModels.error !== null && (
{validateAndConfigureModelProvider.error !== null && (
<div className="px-4">
<Alert variant="destructive">
<CircleAlertIcon className="w-4 h-4" />
<AlertTitle>An error occurred!</AlertTitle>
<AlertDescription>
Your configuration was saved, but we were not able
to connect to the model provider. Please check your
configuration and try again.
Your configuration could not be saved, because it
failed validation:{" "}
<strong>
{(typeof validateAndConfigureModelProvider.error ===
"object" &&
"message" in
validateAndConfigureModelProvider.error &&
(validateAndConfigureModelProvider.error
.message as string)) ??
"Unknown error"}
</strong>
</AlertDescription>
</Alert>
</div>
)}
{validateAndConfigureModelProvider.error === null &&
fetchAvailableModels.error !== null && (
<div className="px-4">
<Alert variant="destructive">
<CircleAlertIcon className="w-4 h-4" />
<AlertTitle>An error occurred!</AlertTitle>
<AlertDescription>
Your configuration was saved, but we were not
able to connect to the model provider. Please
check your configuration and try again:{" "}
<strong>
{(typeof fetchAvailableModels.error ===
"object" &&
"message" in
fetchAvailableModels.error &&
(fetchAvailableModels.error
.message as string)) ??
"Unknown error"}
</strong>
</AlertDescription>
</Alert>
</div>
)}
<ScrollArea className="max-h-[50vh]">
<div className="flex flex-col gap-4 p-4">
<TypographyH4 className="font-semibold text-md">
Expand Down
14 changes: 14 additions & 0 deletions ui/admin/app/hooks/useAsync.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,20 @@ export function useAsync<TData, TParams extends unknown[]>(
onSuccess?.(data, params);
})
.catch((error) => {
if (
error.response &&
typeof error.response.data === "string"
) {
const errorMessageMatch =
error.response.data.match(/{"error":\s+"(.*?)"}/);
if (errorMessageMatch) {
const errorMessage = JSON.parse(
errorMessageMatch[0]
).error;
console.log("Error: ", errorMessage);
error.message = errorMessage;
}
}
setError(error);
onError?.(error, params);
})
Expand Down
2 changes: 2 additions & 0 deletions ui/admin/app/lib/routers/apiRoutes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ export const ApiRoutes = {
getModelProviders: () => buildUrl("/model-providers"),
getModelProviderById: (modelProviderKey: string) =>
buildUrl(`/model-providers/${modelProviderKey}`),
validateModelProviderById: (modelProviderKey: string) =>
buildUrl(`/model-providers/${modelProviderKey}/validate`),
configureModelProviderById: (modelProviderKey: string) =>
buildUrl(`/model-providers/${modelProviderKey}/configure`),
revealModelProviderById: (modelProviderKey: string) =>
Expand Down
18 changes: 18 additions & 0 deletions ui/admin/app/lib/service/api/modelProviderApiService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,23 @@ getModelProviderById.key = (modelProviderId?: string) => {
};
};

const validateModelProviderById = async (
modelProviderKey: string,
modelProviderConfig: ModelProviderConfig
) => {
const res = await request<ModelProvider>({
url: ApiRoutes.modelProviders.validateModelProviderById(
modelProviderKey
).url,
method: "POST",
data: modelProviderConfig,
errorMessage:
"Failed to validate configuration values on the requested modal provider.",
});

return res.data;
};

const configureModelProviderById = async (
modelProviderKey: string,
modelProviderConfig: ModelProviderConfig
Expand Down Expand Up @@ -87,6 +104,7 @@ const deconfigureModelProviderById = async (modelProviderKey: string) => {
export const ModelProviderApiService = {
getModelProviders,
getModelProviderById,
validateModelProviderById,
configureModelProviderById,
revealModelProviderById,
deconfigureModelProviderById,
Expand Down

0 comments on commit e4bbfe1

Please sign in to comment.