Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: model provider configuration validation #1019

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 81 additions & 1 deletion pkg/api/handlers/modelprovider.go
Original file line number Diff line number Diff line change
@@ -1,27 +1,35 @@
package handlers

import (
"encoding/json"
"fmt"
"strings"

"github.com/gptscript-ai/go-gptscript"
"github.com/obot-platform/obot/apiclient/types"
"github.com/obot-platform/obot/logger"
"github.com/obot-platform/obot/pkg/api"
"github.com/obot-platform/obot/pkg/gateway/server/dispatcher"
"github.com/obot-platform/obot/pkg/invoke"
v1 "github.com/obot-platform/obot/pkg/storage/apis/otto.otto8.ai/v1"
"github.com/obot-platform/obot/pkg/system"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/fields"
kclient "sigs.k8s.io/controller-runtime/pkg/client"
)

type ModelProviderHandler struct {
gptscript *gptscript.GPTScript
dispatcher *dispatcher.Dispatcher
invoker *invoke.Invoker
}

func NewModelProviderHandler(gClient *gptscript.GPTScript, dispatcher *dispatcher.Dispatcher) *ModelProviderHandler {
func NewModelProviderHandler(gClient *gptscript.GPTScript, dispatcher *dispatcher.Dispatcher, invoker *invoke.Invoker) *ModelProviderHandler {
return &ModelProviderHandler{
gptscript: gClient,
dispatcher: dispatcher,
invoker: invoker,
}
}

Expand Down Expand Up @@ -89,6 +97,78 @@ func (mp *ModelProviderHandler) List(req api.Context) error {
return req.Write(types.ModelProviderList{Items: resp})
}

type ValidationError struct {
Err string `json:"error"`
}

func (ve *ValidationError) Error() string {
return fmt.Sprintf("model-provider credentials validation failed: {\"error\": \"%s\"}", ve.Err)
}

func (mp *ModelProviderHandler) Validate(req api.Context) error {
var ref v1.ToolReference
if err := req.Get(&ref, req.PathValue("id")); err != nil {
return err
}

if ref.Spec.Type != types.ToolReferenceTypeModelProvider {
return types.NewErrBadRequest("%q is not a model provider", ref.Name)
}

l := logger.Package()
l.Debugf("Validating model provider %q", ref.Name)

var envVars map[string]string
if err := req.Read(&envVars); err != nil {
return err
}

envs := make([]string, 0, len(envVars))
for key, val := range envVars {
envs = append(envs, key+"="+val)
}

thread := &v1.Thread{
ObjectMeta: metav1.ObjectMeta{
GenerateName: system.ThreadPrefix + "-" + ref.Name + "-validate-",
Namespace: ref.Namespace,
},
Spec: v1.ThreadSpec{
SystemTask: true,
},
}

if err := req.Get(thread, thread.Name); apierrors.IsNotFound(err) {
if err = req.Create(thread); err != nil {
return fmt.Errorf("failed to create thread: %w", err)
}
}

defer func() { _ = req.Delete(thread) }()

task, err := mp.invoker.SystemTask(req.Context(), thread, "validate from "+ref.Spec.Reference, "", invoke.SystemTaskOptions{Env: envs})
if err != nil {
return err
}
defer task.Close()

res, err := task.Result(req.Context())
if err != nil {
if strings.Contains(err.Error(), "tool not found: validate from "+ref.Spec.Reference) { // there's no simple way to do errors.As/.Is at this point unfortunately
l.Infof("Model provider %q does not provide a validate tool. Looking for 'validate from %s'", ref.Name, ref.Spec.Reference)
return nil // Do not fail if model provider doesn't provide a validate tool
}
return &ValidationError{Err: strings.Trim(err.Error(), "\"'")}
}

var validationError ValidationError
if json.Unmarshal([]byte(res.Output), &validationError) == nil && validationError.Err != "" {
return &validationError
}

return nil
}

func (mp *ModelProviderHandler) Configure(req api.Context) error {
var ref v1.ToolReference
if err := req.Get(&ref, req.PathValue("id")); err != nil {
Expand Down
3 changes: 2 additions & 1 deletion pkg/api/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func Router(services *services.Services) (http.Handler, error) {
cronJobs := handlers.NewCronJobHandler()
models := handlers.NewModelHandler()
availableModels := handlers.NewAvailableModelsHandler(services.GPTClient, services.ModelProviderDispatcher)
modelProviders := handlers.NewModelProviderHandler(services.GPTClient, services.ModelProviderDispatcher)
modelProviders := handlers.NewModelProviderHandler(services.GPTClient, services.ModelProviderDispatcher, services.Invoker)
prompt := handlers.NewPromptHandler(services.GPTClient)
emailreceiver := handlers.NewEmailReceiverHandler(services.EmailServerName)
defaultModelAliases := handlers.NewDefaultModelAliasHandler()
Expand Down Expand Up @@ -276,6 +276,7 @@ func Router(services *services.Services) (http.Handler, error) {
// Model providers
mux.HandleFunc("GET /api/model-providers", modelProviders.List)
mux.HandleFunc("GET /api/model-providers/{id}", modelProviders.ByID)
mux.HandleFunc("POST /api/model-providers/{id}/validate", modelProviders.Validate)
mux.HandleFunc("POST /api/model-providers/{id}/configure", modelProviders.Configure)
mux.HandleFunc("POST /api/model-providers/{id}/deconfigure", modelProviders.Deconfigure)
mux.HandleFunc("POST /api/model-providers/{id}/reveal", modelProviders.Reveal)
Expand Down
8 changes: 4 additions & 4 deletions pkg/availablemodels/availablemodels.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func ForProvider(ctx context.Context, dispatcher *dispatcher.Dispatcher, modelPr

r, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String()+"/v1/models", nil)
if err != nil {
return nil, fmt.Errorf("failed to create request to model provider %s: %w", modelProviderName, err)
return nil, fmt.Errorf("failed to create request to model provider %q: %w", modelProviderName, err)
}

if token != "" {
Expand All @@ -28,18 +28,18 @@ func ForProvider(ctx context.Context, dispatcher *dispatcher.Dispatcher, modelPr

resp, err := http.DefaultClient.Do(r)
if err != nil {
return nil, fmt.Errorf("failed to make request to model provider %s: %w", modelProviderName, err)
return nil, fmt.Errorf("failed to make request to model provider %q: %w", modelProviderName, err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
message, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("failed to get model list from model provider %s: %s", modelProviderName, message)
return nil, fmt.Errorf("failed to get model list from model provider %q: %s", modelProviderName, message)
}

var oModels openai.ModelsList
if err = json.NewDecoder(resp.Body).Decode(&oModels); err != nil {
return nil, fmt.Errorf("failed to decode model list from model provider %s: %w", modelProviderName, err)
return nil, fmt.Errorf("failed to decode model list from model provider %q: %w", modelProviderName, err)
}

return &oModels, nil
Expand Down
20 changes: 19 additions & 1 deletion pkg/controller/handlers/toolreference/toolreference.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ package toolreference
import (
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"net/url"
"os"
"path"
"regexp"
"slices"
"strings"
"time"
Expand Down Expand Up @@ -450,7 +452,23 @@ func (h *Handler) BackPopulateModels(req router.Request, _ router.Response) erro
availableModels, err := availablemodels.ForProvider(req.Ctx, h.dispatcher, req.Namespace, req.Name)
if err != nil {
// Don't error and retry because it will likely fail again. Log the error, and the user can re-sync manually.
log.Errorf("Failed to get available models for model provider %q: %v", toolRef.Name, err)
// Also, the toolRef.Status.Error field will bubble up to the user in the UI.

// Check if the model provider returned a properly formatted error message and set it as status
re := regexp.MustCompile(`\{.*"error":.*}`)
match := re.FindString(err.Error())
if match != "" {
toolRef.Status.Error = match
type errorResponse struct {
Error string `json:"error"`
}
var eR errorResponse
if err := json.Unmarshal([]byte(match), &eR); err == nil {
toolRef.Status.Error = eR.Error
}
}

log.Errorf("%v", err)
return nil
}

Expand Down
61 changes: 56 additions & 5 deletions ui/admin/app/components/model-providers/ModelProviderForm.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,24 @@ export function ModelProviderForm({
}
);

const validateAndConfigureModelProvider = useAsync(
ModelProviderApiService.validateModelProviderById,
{
onSuccess: async (data, params) => {
// Only configure the model provider if validation was successful
const [modelProviderId, configParams] = params;
await configureModelProvider.execute(
modelProviderId,
configParams
);
},
onError: (error) => {
// Handle validation errors
console.error("Validation failed:", error);
},
}
);

const configureModelProvider = useAsync(
ModelProviderApiService.configureModelProviderById,
{
Expand Down Expand Up @@ -185,7 +203,7 @@ export function ModelProviderForm({
}
);

await configureModelProvider.execute(
await validateAndConfigureModelProvider.execute(
modelProvider.id,
allConfigParams
);
Expand All @@ -197,24 +215,57 @@ export function ModelProviderForm({
modelProvider.id === "azure-openai-model-provider";

const loading =
validateAndConfigureModelProvider.isLoading ||
fetchAvailableModels.isLoading ||
configureModelProvider.isLoading ||
isLoading;

return (
<div className="flex flex-col">
{fetchAvailableModels.error !== null && (
{validateAndConfigureModelProvider.error !== null && (
<div className="px-4">
<Alert variant="destructive">
<CircleAlertIcon className="w-4 h-4" />
<AlertTitle>An error occurred!</AlertTitle>
<AlertDescription>
Your configuration was saved, but we were not able
to connect to the model provider. Please check your
configuration and try again.
Your configuration could not be saved, because it
failed validation:{" "}
<strong>
{(typeof validateAndConfigureModelProvider.error ===
"object" &&
"message" in
validateAndConfigureModelProvider.error &&
(validateAndConfigureModelProvider.error
.message as string)) ??
"Unknown error"}
</strong>
</AlertDescription>
</Alert>
</div>
)}
{validateAndConfigureModelProvider.error === null &&
fetchAvailableModels.error !== null && (
<div className="px-4">
<Alert variant="destructive">
<CircleAlertIcon className="w-4 h-4" />
<AlertTitle>An error occurred!</AlertTitle>
<AlertDescription>
Your configuration was saved, but we were not
able to connect to the model provider. Please
check your configuration and try again:{" "}
<strong>
{(typeof fetchAvailableModels.error ===
"object" &&
"message" in
fetchAvailableModels.error &&
(fetchAvailableModels.error
.message as string)) ??
"Unknown error"}
</strong>
</AlertDescription>
</Alert>
</div>
)}
<ScrollArea className="max-h-[50vh]">
<div className="flex flex-col gap-4 p-4">
<TypographyH4 className="font-semibold text-md">
Expand Down
14 changes: 14 additions & 0 deletions ui/admin/app/hooks/useAsync.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,20 @@ export function useAsync<TData, TParams extends unknown[]>(
onSuccess?.(data, params);
})
.catch((error) => {
if (
error.response &&
typeof error.response.data === "string"
) {
const errorMessageMatch =
error.response.data.match(/{"error":\s+"(.*?)"}/);
if (errorMessageMatch) {
const errorMessage = JSON.parse(
errorMessageMatch[0]
).error;
console.log("Error: ", errorMessage);
error.message = errorMessage;
}
}
setError(error);
onError?.(error, params);
})
Expand Down
2 changes: 2 additions & 0 deletions ui/admin/app/lib/routers/apiRoutes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ export const ApiRoutes = {
getModelProviders: () => buildUrl("/model-providers"),
getModelProviderById: (modelProviderKey: string) =>
buildUrl(`/model-providers/${modelProviderKey}`),
validateModelProviderById: (modelProviderKey: string) =>
buildUrl(`/model-providers/${modelProviderKey}/validate`),
configureModelProviderById: (modelProviderKey: string) =>
buildUrl(`/model-providers/${modelProviderKey}/configure`),
revealModelProviderById: (modelProviderKey: string) =>
Expand Down
18 changes: 18 additions & 0 deletions ui/admin/app/lib/service/api/modelProviderApiService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,23 @@ getModelProviderById.key = (modelProviderId?: string) => {
};
};

const validateModelProviderById = async (
modelProviderKey: string,
modelProviderConfig: ModelProviderConfig
) => {
const res = await request<ModelProvider>({
url: ApiRoutes.modelProviders.validateModelProviderById(
modelProviderKey
).url,
method: "POST",
data: modelProviderConfig,
errorMessage:
"Failed to validate configuration values on the requested modal provider.",
});

return res.data;
};

const configureModelProviderById = async (
modelProviderKey: string,
modelProviderConfig: ModelProviderConfig
Expand Down Expand Up @@ -87,6 +104,7 @@ const deconfigureModelProviderById = async (modelProviderKey: string) => {
export const ModelProviderApiService = {
getModelProviders,
getModelProviderById,
validateModelProviderById,
configureModelProviderById,
revealModelProviderById,
deconfigureModelProviderById,
Expand Down