diff --git a/pkg/api/handlers/agent.go b/pkg/api/handlers/agent.go index 608359fc4..9fa0fee41 100644 --- a/pkg/api/handlers/agent.go +++ b/pkg/api/handlers/agent.go @@ -13,6 +13,7 @@ import ( "github.com/obot-platform/obot/apiclient/types" "github.com/obot-platform/obot/pkg/alias" "github.com/obot-platform/obot/pkg/api" + "github.com/obot-platform/obot/pkg/controller/creds" "github.com/obot-platform/obot/pkg/invoke" "github.com/obot-platform/obot/pkg/render" v1 "github.com/obot-platform/obot/pkg/storage/apis/obot.obot.ai/v1" @@ -60,7 +61,7 @@ func (a *AgentHandler) Authenticate(req api.Context) (err error) { return err } - resp, err := runAuthForAgent(req.Context(), req.Storage, a.invoker, agent.DeepCopy(), tools) + resp, err := runAuthForAgent(req.Context(), req.Storage, a.invoker, a.gptscript, agent.DeepCopy(), tools) if err != nil { return err } @@ -94,24 +95,7 @@ func (a *AgentHandler) DeAuthenticate(req api.Context) error { return err } - var ( - errs []error - toolRef v1.ToolReference - ) - for _, tool := range tools { - if err := req.Get(&toolRef, tool); err != nil { - errs = append(errs, err) - continue - } - - if toolRef.Status.Tool != nil { - for _, cred := range toolRef.Status.Tool.CredentialNames { - if err := a.gptscript.DeleteCredential(req.Context(), id, cred); err != nil && !strings.HasSuffix(err.Error(), "credential not found") { - errs = append(errs, err) - } - } - } - } + errs := removeToolCredentials(req.Context(), req.Storage, a.gptscript, id, agent.Namespace, tools) if err := kickAgent(req.Context(), req.Storage, &agent); err != nil { errs = append(errs, fmt.Errorf("failed to update agent status: %w", err)) @@ -929,23 +913,35 @@ func MetadataFrom(obj kclient.Object, linkKV ...string) types.Metadata { return m } -func runAuthForAgent(ctx context.Context, c kclient.WithWatch, invoker *invoke.Invoker, agent *v1.Agent, tools []string) (*invoke.Response, error) { +func runAuthForAgent(ctx context.Context, c kclient.WithWatch, invoker *invoke.Invoker, gClient *gptscript.GPTScript, agent *v1.Agent, tools []string) (*invoke.Response, error) { credentials := make([]string, 0, len(tools)) var toolRef v1.ToolReference for _, tool := range tools { - if err := c.Get(ctx, kclient.ObjectKey{Namespace: agent.Namespace, Name: tool}, &toolRef); err != nil { - return nil, err - } + if strings.ContainsAny(tool, "./") { + prg, err := gClient.LoadFile(ctx, tool) + if err != nil { + return nil, err + } - if toolRef.Status.Tool == nil { - return nil, types.NewErrHttp(http.StatusTooEarly, fmt.Sprintf("tool %q is not ready", tool)) - } + credentails, _, err := creds.DetermineCredsAndCredNames(prg, prg.ToolSet[prg.EntryToolID], tool) + if err != nil { + return nil, err + } - credentials = append(credentials, toolRef.Status.Tool.Credentials...) + credentials = append(credentials, credentails...) + } else if err := c.Get(ctx, kclient.ObjectKey{Namespace: agent.Namespace, Name: tool}, &toolRef); err == nil { + if toolRef.Status.Tool == nil { + return nil, types.NewErrHttp(http.StatusTooEarly, fmt.Sprintf("tool %q is not ready", tool)) + } - // Reset the fields we care about so that we can use the same variable for the whole loop. - toolRef.Status.Tool = nil + credentials = append(credentials, toolRef.Status.Tool.Credentials...) + + // Reset the fields we care about so that we can use the same variable for the whole loop. + toolRef.Status.Tool = nil + } else { + return nil, err + } } agent.Spec.Manifest.Prompt = "#!sys.echo\nDONE" @@ -962,6 +958,50 @@ func runAuthForAgent(ctx context.Context, c kclient.WithWatch, invoker *invoke.I }) } +func removeToolCredentials(ctx context.Context, client kclient.Client, gClient *gptscript.GPTScript, credCtx, namespace string, tools []string) []error { + var ( + errs []error + toolRef v1.ToolReference + credentialNames []string + ) + for _, tool := range tools { + if strings.ContainsAny(tool, "./") { + prg, err := gClient.LoadFile(ctx, tool) + if err != nil { + errs = append(errs, err) + continue + } + + _, names, err := creds.DetermineCredsAndCredNames(prg, prg.ToolSet[prg.EntryToolID], tool) + if err != nil { + errs = append(errs, err) + continue + } + + credentialNames = append(credentialNames, names...) + } else if err := client.Get(ctx, kclient.ObjectKey{Namespace: namespace, Name: tool}, &toolRef); err == nil { + if toolRef.Status.Tool != nil { + credentialNames = append(credentialNames, toolRef.Status.Tool.CredentialNames...) + } + } else { + errs = append(errs, err) + continue + } + + // Reset the value we care about so the same variable can be used. + // This ensures that the value we read on the next iteration is pulled from the database. + toolRef.Status.Tool = nil + + for _, cred := range credentialNames { + if err := gClient.DeleteCredential(ctx, credCtx, cred); err != nil && !strings.HasSuffix(err.Error(), "credential not found") { + errs = append(errs, err) + } + } + } + + return errs +} + func kickAgent(ctx context.Context, c kclient.Client, agent *v1.Agent) error { if agent.Annotations[v1.AgentSyncAnnotation] != "" { delete(agent.Annotations, v1.AgentSyncAnnotation) diff --git a/pkg/api/handlers/workflows.go b/pkg/api/handlers/workflows.go index 87d5f3c7b..ca20d2eb3 100644 --- a/pkg/api/handlers/workflows.go +++ b/pkg/api/handlers/workflows.go @@ -58,7 +58,7 @@ func (a *WorkflowHandler) Authenticate(req api.Context) error { return err } - resp, err := runAuthForAgent(req.Context(), req.Storage, a.invoker, agent, tools) + resp, err := runAuthForAgent(req.Context(), req.Storage, a.invoker, a.gptscript, agent, tools) if err != nil { return err } @@ -92,28 +92,7 @@ func (a *WorkflowHandler) DeAuthenticate(req api.Context) error { return err } - var ( - errs []error - toolRef v1.ToolReference - ) - for _, tool := range tools { - if err := req.Get(&toolRef, tool); err != nil { - errs = append(errs, err) - continue - } - - if toolRef.Status.Tool != nil { - for _, cred := range toolRef.Status.Tool.CredentialNames { - if err := a.gptscript.DeleteCredential(req.Context(), id, cred); err != nil && !strings.HasSuffix(err.Error(), "credential not found") { - errs = append(errs, err) - } - } - - // Reset the value we care about so the same variable can be used. - // This ensures that the value we read on the next iteration is pulled from the database. - toolRef.Status.Tool = nil - } - } + errs := removeToolCredentials(req.Context(), req.Storage, a.gptscript, id, wf.Namespace, tools) if err := kickWorkflow(req.Context(), req.Storage, &wf); err != nil { errs = append(errs, fmt.Errorf("failed to update workflow status: %w", err)) diff --git a/pkg/controller/creds/creds.go b/pkg/controller/creds/creds.go new file mode 100644 index 000000000..84923405e --- /dev/null +++ b/pkg/controller/creds/creds.go @@ -0,0 +1,185 @@ +package creds + +import ( + "fmt" + "net/url" + "path" + "slices" + "strings" + + "github.com/gptscript-ai/go-gptscript" + gtypes "github.com/gptscript-ai/gptscript/pkg/types" + "github.com/obot-platform/obot/pkg/system" +) + +func DetermineCredsAndCredNames(prg *gptscript.Program, tool gptscript.Tool, name string) ([]string, []string, error) { + seen := make(map[string]struct{}) + // The available tool references from this tool are the tool itself and any tool this tool exports. + toolRefs := make([]toolRef, 0, len(tool.Export)+len(tool.Tools)+1) + toolRefs = append(toolRefs, toolRef{ + ToolReference: gptscript.ToolReference{ + Reference: name, + ToolID: prg.EntryToolID, + }, + name: name, + }) + toolRefs = append(toolRefs, toolRefsFromTools(tool, toolRefs[0], tool.Tools, seen)...) + + credentials := make([]string, 0, len(tool.Credentials)+len(tool.Export)+len(tool.Tools)) + credentialNames := make([]string, 0, len(tool.Credentials)+len(tool.Export)+len(tool.Tools)) + for len(toolRefs) > 0 { + ref := toolRefs[0] + toolRefs = toolRefs[1:] + + if _, ok := seen[ref.ToolID]; ok { + continue + } + seen[ref.ToolID] = struct{}{} + + t := prg.ToolSet[ref.ToolID] + + // Add the tools that this tool exports if we haven't already seen them. + toolRefs = append(toolRefs, toolRefsFromTools(t, ref, t.Export, seen)...) + + for _, cred := range append(t.Credentials, t.ExportCredentials...) { + if parsedCred := fullToolPathName(ref, cred); parsedCred != "" && !slices.Contains(credentials, parsedCred) { + credentials = append(credentials, parsedCred) + } + + credNames, err := determineCredentialNames(prg, prg.ToolSet[ref.ToolID], cred) + if err != nil { + return credentials, credentialNames, err + } + + for _, n := range credNames { + if !slices.Contains(credentialNames, n) { + credentialNames = append(credentialNames, n) + } + } + } + } + + return credentials, credentialNames, nil +} + +func determineCredentialNames(prg *gptscript.Program, tool gptscript.Tool, toolName string) ([]string, error) { + if toolName == system.ModelProviderCredential { + return []string{system.ModelProviderCredential}, nil + } + + var subTool string + parsedToolName, alias, args, err := gtypes.ParseCredentialArgs(toolName, "") + if err != nil { + parsedToolName, subTool = gtypes.SplitToolRef(toolName) + parsedToolName, alias, args, err = gtypes.ParseCredentialArgs(parsedToolName, "") + if err != nil { + return nil, err + } + } + + if alias != "" { + return []string{alias}, nil + } + + if args == nil { + // This is a tool and not the credential format. Parse the tool from the program to determine the alias + toolNames := make([]string, 0, len(tool.Credentials)) + if subTool == "" { + toolName = parsedToolName + } + for _, cred := range tool.Credentials { + if cred == toolName { + if len(tool.ToolMapping[cred]) == 0 { + return nil, fmt.Errorf("cannot find credential name for tool %q", toolName) + } + + for _, ref := range tool.ToolMapping[cred] { + for _, c := range prg.ToolSet[ref.ToolID].ExportCredentials { + names, err := determineCredentialNames(prg, prg.ToolSet[ref.ToolID], c) + if err != nil { + return nil, err + } + + toolNames = append(toolNames, names...) + } + } + } + } + + if len(toolNames) > 0 { + return toolNames, nil + } + + return nil, fmt.Errorf("tool %q not found in program", toolName) + } + + return []string{toolName}, nil +} + +type toolRef struct { + gptscript.ToolReference + name string +} + +func toolRefsFromTools(parentTool gptscript.Tool, parentRef toolRef, tools []string, seen map[string]struct{}) []toolRef { + var toolRefs []toolRef + for _, e := range tools { + name := e + if _, ok := parentTool.LocalTools[strings.ToLower(e)]; ok { + name, _ = gtypes.SplitToolRef(parentRef.name) + name = fmt.Sprintf("%s from %s", e, name) + } + name = fullToolPathName(parentRef, name) + if name == "" { + continue + } + + for _, r := range parentTool.ToolMapping[e] { + if _, ok := seen[r.ToolID]; !ok { + toolRefs = append(toolRefs, toolRef{ + ToolReference: r, + name: name, + }) + } + } + } + + return toolRefs +} + +func fullToolPathName(parentRef toolRef, name string) string { + toolName, subTool := gtypes.SplitToolRef(name) + if strings.HasPrefix(toolName, ".") { + parentToolName, _ := gtypes.SplitToolRef(parentRef.Reference) + if !path.IsAbs(parentToolName) { + if !strings.HasPrefix(parentToolName, ".") { + parentToolName, _ = gtypes.SplitToolRef(parentRef.name) + } else { + parentToolName = path.Join(parentRef.name, parentToolName) + } + } + + refURL, err := url.Parse(parentToolName) + if err != nil { + return "" + } + + if strings.HasSuffix(refURL.Path, ".gpt") { + refURL.Path = path.Dir(refURL.Path) + } + + refURL.Path = path.Join(refURL.Path, toolName) + name = refURL.String() + if refURL.Host == "" { + // This is only a path, so url unescape it. + // No need to check the error here, we would have errored when parsing. + name, _ = url.PathUnescape(name) + } + + if subTool != "" { + name = fmt.Sprintf("%s from %s", subTool, name) + } + } + + return name +} diff --git a/pkg/controller/handlers/toolinfo/toolinfo.go b/pkg/controller/handlers/toolinfo/toolinfo.go index b8c67ab5b..941720404 100644 --- a/pkg/controller/handlers/toolinfo/toolinfo.go +++ b/pkg/controller/handlers/toolinfo/toolinfo.go @@ -1,11 +1,14 @@ package toolinfo import ( + "context" "fmt" + "strings" "github.com/gptscript-ai/go-gptscript" "github.com/obot-platform/nah/pkg/router" "github.com/obot-platform/obot/apiclient/types" + "github.com/obot-platform/obot/pkg/controller/creds" v1 "github.com/obot-platform/obot/pkg/storage/apis/obot.obot.ai/v1" "github.com/obot-platform/obot/pkg/system" apierror "k8s.io/apimachinery/pkg/api/errors" @@ -51,28 +54,47 @@ func (h *Handler) SetToolInfoStatus(req router.Request, resp router.Response) (e tools := obj.GetTools() toolInfos := make(map[string]types.ToolInfo, len(tools)) - var toolRef v1.ToolReference + var ( + toolRef v1.ToolReference + credNames []string + ) for _, tool := range tools { - if err = req.Get(&toolRef, req.Namespace, tool); apierror.IsNotFound(err) { + if strings.ContainsAny(tool, "/.") { + credNames, err = h.credentialNamesForNonToolReferences(req.Ctx, tool) + if err != nil { + return err + } + } else if err = req.Get(&toolRef, req.Namespace, tool); apierror.IsNotFound(err) { continue } else if err != nil { return err } else if toolRef.Status.Tool == nil { return fmt.Errorf("cannot determine credential status for tool %s: no tool status found", tool) + } else if err == nil { + credNames = toolRef.Status.Tool.CredentialNames + // Clear the field we care about in this loop. + // This allows us to use the same variable for the whole loop + // while ensuring that the value we care about is loaded correctly. + toolRef.Status.Tool.CredentialNames = nil } toolInfos[tool] = types.ToolInfo{ - CredentialNames: toolRef.Status.Tool.CredentialNames, - Authorized: credsSet.HasAll(toolRef.Status.Tool.CredentialNames...), + CredentialNames: credNames, + Authorized: credsSet.HasAll(credNames...), } - - // Clear the field we care about in this loop. - // This allows us to use the same variable for the whole loop - // while ensuring that the value we care about is loaded correctly. - toolRef.Status.Tool.CredentialNames = nil } obj.SetToolInfos(toolInfos) return nil } + +func (h *Handler) credentialNamesForNonToolReferences(ctx context.Context, name string) ([]string, error) { + prg, err := h.gptscript.LoadFile(ctx, name) + if err != nil { + return nil, err + } + + _, credNames, err := creds.DetermineCredsAndCredNames(prg, prg.ToolSet[prg.EntryToolID], name) + return credNames, err +} diff --git a/pkg/controller/handlers/toolreference/toolreference.go b/pkg/controller/handlers/toolreference/toolreference.go index 25959ca89..eba55b93d 100644 --- a/pkg/controller/handlers/toolreference/toolreference.go +++ b/pkg/controller/handlers/toolreference/toolreference.go @@ -5,21 +5,18 @@ import ( "crypto/sha256" "errors" "fmt" - "net/url" "os" - "path" - "slices" "strings" "time" "github.com/gptscript-ai/go-gptscript" - gtypes "github.com/gptscript-ai/gptscript/pkg/types" "github.com/obot-platform/nah/pkg/apply" "github.com/obot-platform/nah/pkg/name" "github.com/obot-platform/nah/pkg/router" "github.com/obot-platform/obot/apiclient/types" "github.com/obot-platform/obot/logger" "github.com/obot-platform/obot/pkg/availablemodels" + "github.com/obot-platform/obot/pkg/controller/creds" "github.com/obot-platform/obot/pkg/gateway/server/dispatcher" v1 "github.com/obot-platform/obot/pkg/storage/apis/obot.obot.ai/v1" "github.com/obot-platform/obot/pkg/system" @@ -265,66 +262,9 @@ func (h *Handler) Populate(req router.Request, resp router.Response) error { } } - // The available tool references from this tool are the tool itself and any tool this tool exports. - toolRefs := make([]gptscript.ToolReference, 0, len(tool.Export)+1) - toolRefs = append(toolRefs, gptscript.ToolReference{ - Reference: toolRef.Spec.Reference, - ToolID: prg.EntryToolID, - }) - for _, exportedTool := range tool.Export { - toolRefs = append(toolRefs, tool.ToolMapping[exportedTool]...) - } - - toolRef.Status.Tool.Credentials = make([]string, 0, len(tool.Credentials)+len(tool.Export)) - toolRef.Status.Tool.CredentialNames = make([]string, 0, len(tool.Credentials)+len(tool.Export)) - for _, ref := range toolRefs { - t := prg.ToolSet[ref.ToolID] - for _, cred := range t.Credentials { - parsedCred := cred - credToolName, credSubTool := gtypes.SplitToolRef(cred) - if strings.HasPrefix(credToolName, ".") { - toolName, _ := gtypes.SplitToolRef(ref.Reference) - if !path.IsAbs(toolName) { - if !strings.HasPrefix(toolName, ".") { - toolName, _ = gtypes.SplitToolRef(toolRef.Spec.Reference) - } else { - toolName = path.Join(toolRef.Spec.Reference, toolName) - } - } - - refURL, err := url.Parse(toolName) - if err != nil { - continue - } - - refURL.Path = path.Join(refURL.Path, credToolName) - parsedCred = refURL.String() - if refURL.Host == "" { - // This is only a path, so url unescape it. - // No need to check the error here, we would have errored when parsing. - parsedCred, _ = url.PathUnescape(parsedCred) - } - - if credSubTool != "" { - parsedCred = fmt.Sprintf("%s from %s", credSubTool, parsedCred) - } - } - - if parsedCred != "" && !slices.Contains(toolRef.Status.Tool.Credentials, parsedCred) { - toolRef.Status.Tool.Credentials = append(toolRef.Status.Tool.Credentials, parsedCred) - } - - credNames, err := determineCredentialNames(prg, prg.ToolSet[ref.ToolID], cred) - if err != nil { - toolRef.Status.Error = err.Error() - } - - for _, n := range credNames { - if !slices.Contains(toolRef.Status.Tool.CredentialNames, n) { - toolRef.Status.Tool.CredentialNames = append(toolRef.Status.Tool.CredentialNames, n) - } - } - } + toolRef.Status.Tool.Credentials, toolRef.Status.Tool.CredentialNames, err = creds.DetermineCredsAndCredNames(prg, tool, toolRef.Spec.Reference) + if err != nil { + toolRef.Status.Error = err.Error() } return nil @@ -522,57 +462,3 @@ func (h *Handler) CleanupModelProvider(req router.Request, _ router.Response) er func modelName(modelProviderName, modelName string) string { return name.SafeConcatName(system.ModelPrefix, modelProviderName, fmt.Sprintf("%x", sha256.Sum256([]byte(modelName)))) } - -func determineCredentialNames(prg *gptscript.Program, tool gptscript.Tool, toolName string) ([]string, error) { - if toolName == system.ModelProviderCredential { - return []string{system.ModelProviderCredential}, nil - } - - var subTool string - parsedToolName, alias, args, err := gtypes.ParseCredentialArgs(toolName, "") - if err != nil { - parsedToolName, subTool = gtypes.SplitToolRef(toolName) - parsedToolName, alias, args, err = gtypes.ParseCredentialArgs(parsedToolName, "") - if err != nil { - return nil, err - } - } - - if alias != "" { - return []string{alias}, nil - } - - if args == nil { - // This is a tool and not the credential format. Parse the tool from the program to determine the alias - toolNames := make([]string, 0, len(tool.Credentials)) - if subTool == "" { - toolName = parsedToolName - } - for _, cred := range tool.Credentials { - if cred == toolName { - if len(tool.ToolMapping[cred]) == 0 { - return nil, fmt.Errorf("cannot find credential name for tool %q", toolName) - } - - for _, ref := range tool.ToolMapping[cred] { - for _, c := range prg.ToolSet[ref.ToolID].ExportCredentials { - names, err := determineCredentialNames(prg, prg.ToolSet[ref.ToolID], c) - if err != nil { - return nil, err - } - - toolNames = append(toolNames, names...) - } - } - } - } - - if len(toolNames) > 0 { - return toolNames, nil - } - - return nil, fmt.Errorf("tool %q not found in program", toolName) - } - - return []string{toolName}, nil -}