Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for credential auth for non-tool-references #1139

Merged
merged 1 commit into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 69 additions & 29 deletions pkg/api/handlers/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/obot-platform/obot/apiclient/types"
"github.com/obot-platform/obot/pkg/alias"
"github.com/obot-platform/obot/pkg/api"
"github.com/obot-platform/obot/pkg/controller/creds"
"github.com/obot-platform/obot/pkg/invoke"
"github.com/obot-platform/obot/pkg/render"
v1 "github.com/obot-platform/obot/pkg/storage/apis/obot.obot.ai/v1"
Expand Down Expand Up @@ -60,7 +61,7 @@ func (a *AgentHandler) Authenticate(req api.Context) (err error) {
return err
}

resp, err := runAuthForAgent(req.Context(), req.Storage, a.invoker, agent.DeepCopy(), tools)
resp, err := runAuthForAgent(req.Context(), req.Storage, a.invoker, a.gptscript, agent.DeepCopy(), tools)
if err != nil {
return err
}
Expand Down Expand Up @@ -94,24 +95,7 @@ func (a *AgentHandler) DeAuthenticate(req api.Context) error {
return err
}

var (
errs []error
toolRef v1.ToolReference
)
for _, tool := range tools {
if err := req.Get(&toolRef, tool); err != nil {
errs = append(errs, err)
continue
}

if toolRef.Status.Tool != nil {
for _, cred := range toolRef.Status.Tool.CredentialNames {
if err := a.gptscript.DeleteCredential(req.Context(), id, cred); err != nil && !strings.HasSuffix(err.Error(), "credential not found") {
errs = append(errs, err)
}
}
}
}
errs := removeToolCredentials(req.Context(), req.Storage, a.gptscript, id, agent.Namespace, tools)

if err := kickAgent(req.Context(), req.Storage, &agent); err != nil {
errs = append(errs, fmt.Errorf("failed to update agent status: %w", err))
Expand Down Expand Up @@ -929,23 +913,35 @@ func MetadataFrom(obj kclient.Object, linkKV ...string) types.Metadata {
return m
}

func runAuthForAgent(ctx context.Context, c kclient.WithWatch, invoker *invoke.Invoker, agent *v1.Agent, tools []string) (*invoke.Response, error) {
func runAuthForAgent(ctx context.Context, c kclient.WithWatch, invoker *invoke.Invoker, gClient *gptscript.GPTScript, agent *v1.Agent, tools []string) (*invoke.Response, error) {
credentials := make([]string, 0, len(tools))

var toolRef v1.ToolReference
for _, tool := range tools {
if err := c.Get(ctx, kclient.ObjectKey{Namespace: agent.Namespace, Name: tool}, &toolRef); err != nil {
return nil, err
}
if strings.ContainsAny(tool, "./") {
prg, err := gClient.LoadFile(ctx, tool)
if err != nil {
return nil, err
}

if toolRef.Status.Tool == nil {
return nil, types.NewErrHttp(http.StatusTooEarly, fmt.Sprintf("tool %q is not ready", tool))
}
credentails, _, err := creds.DetermineCredsAndCredNames(prg, prg.ToolSet[prg.EntryToolID], tool)
if err != nil {
return nil, err
}

credentials = append(credentials, toolRef.Status.Tool.Credentials...)
credentials = append(credentials, credentails...)
} else if err := c.Get(ctx, kclient.ObjectKey{Namespace: agent.Namespace, Name: tool}, &toolRef); err == nil {
if toolRef.Status.Tool == nil {
return nil, types.NewErrHttp(http.StatusTooEarly, fmt.Sprintf("tool %q is not ready", tool))
}

// Reset the fields we care about so that we can use the same variable for the whole loop.
toolRef.Status.Tool = nil
credentials = append(credentials, toolRef.Status.Tool.Credentials...)

// Reset the fields we care about so that we can use the same variable for the whole loop.
toolRef.Status.Tool = nil
} else {
return nil, err
}
}

agent.Spec.Manifest.Prompt = "#!sys.echo\nDONE"
Expand All @@ -962,6 +958,50 @@ func runAuthForAgent(ctx context.Context, c kclient.WithWatch, invoker *invoke.I
})
}

func removeToolCredentials(ctx context.Context, client kclient.Client, gClient *gptscript.GPTScript, credCtx, namespace string, tools []string) []error {
var (
errs []error
toolRef v1.ToolReference
credentialNames []string
)
for _, tool := range tools {
if strings.ContainsAny(tool, "./") {
prg, err := gClient.LoadFile(ctx, tool)
if err != nil {
errs = append(errs, err)
continue
}

_, names, err := creds.DetermineCredsAndCredNames(prg, prg.ToolSet[prg.EntryToolID], tool)
if err != nil {
errs = append(errs, err)
continue
}

credentialNames = append(credentialNames, names...)
} else if err := client.Get(ctx, kclient.ObjectKey{Namespace: namespace, Name: tool}, &toolRef); err == nil {
if toolRef.Status.Tool != nil {
credentialNames = append(credentialNames, toolRef.Status.Tool.CredentialNames...)
}
} else {
errs = append(errs, err)
continue
}

// Reset the value we care about so the same variable can be used.
// This ensures that the value we read on the next iteration is pulled from the database.
toolRef.Status.Tool = nil

for _, cred := range credentialNames {
if err := gClient.DeleteCredential(ctx, credCtx, cred); err != nil && !strings.HasSuffix(err.Error(), "credential not found") {
errs = append(errs, err)
}
}
}

return errs
}

func kickAgent(ctx context.Context, c kclient.Client, agent *v1.Agent) error {
if agent.Annotations[v1.AgentSyncAnnotation] != "" {
delete(agent.Annotations, v1.AgentSyncAnnotation)
Expand Down
25 changes: 2 additions & 23 deletions pkg/api/handlers/workflows.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (a *WorkflowHandler) Authenticate(req api.Context) error {
return err
}

resp, err := runAuthForAgent(req.Context(), req.Storage, a.invoker, agent, tools)
resp, err := runAuthForAgent(req.Context(), req.Storage, a.invoker, a.gptscript, agent, tools)
if err != nil {
return err
}
Expand Down Expand Up @@ -92,28 +92,7 @@ func (a *WorkflowHandler) DeAuthenticate(req api.Context) error {
return err
}

var (
errs []error
toolRef v1.ToolReference
)
for _, tool := range tools {
if err := req.Get(&toolRef, tool); err != nil {
errs = append(errs, err)
continue
}

if toolRef.Status.Tool != nil {
for _, cred := range toolRef.Status.Tool.CredentialNames {
if err := a.gptscript.DeleteCredential(req.Context(), id, cred); err != nil && !strings.HasSuffix(err.Error(), "credential not found") {
errs = append(errs, err)
}
}

// Reset the value we care about so the same variable can be used.
// This ensures that the value we read on the next iteration is pulled from the database.
toolRef.Status.Tool = nil
}
}
errs := removeToolCredentials(req.Context(), req.Storage, a.gptscript, id, wf.Namespace, tools)

if err := kickWorkflow(req.Context(), req.Storage, &wf); err != nil {
errs = append(errs, fmt.Errorf("failed to update workflow status: %w", err))
Expand Down
185 changes: 185 additions & 0 deletions pkg/controller/creds/creds.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
package creds

import (
"fmt"
"net/url"
"path"
"slices"
"strings"

"github.com/gptscript-ai/go-gptscript"
gtypes "github.com/gptscript-ai/gptscript/pkg/types"
"github.com/obot-platform/obot/pkg/system"
)

func DetermineCredsAndCredNames(prg *gptscript.Program, tool gptscript.Tool, name string) ([]string, []string, error) {
seen := make(map[string]struct{})
// The available tool references from this tool are the tool itself and any tool this tool exports.
toolRefs := make([]toolRef, 0, len(tool.Export)+len(tool.Tools)+1)
toolRefs = append(toolRefs, toolRef{
ToolReference: gptscript.ToolReference{
Reference: name,
ToolID: prg.EntryToolID,
},
name: name,
})
toolRefs = append(toolRefs, toolRefsFromTools(tool, toolRefs[0], tool.Tools, seen)...)

credentials := make([]string, 0, len(tool.Credentials)+len(tool.Export)+len(tool.Tools))
credentialNames := make([]string, 0, len(tool.Credentials)+len(tool.Export)+len(tool.Tools))
for len(toolRefs) > 0 {
ref := toolRefs[0]
toolRefs = toolRefs[1:]

if _, ok := seen[ref.ToolID]; ok {
continue
}
seen[ref.ToolID] = struct{}{}

t := prg.ToolSet[ref.ToolID]

// Add the tools that this tool exports if we haven't already seen them.
toolRefs = append(toolRefs, toolRefsFromTools(t, ref, t.Export, seen)...)

for _, cred := range append(t.Credentials, t.ExportCredentials...) {
if parsedCred := fullToolPathName(ref, cred); parsedCred != "" && !slices.Contains(credentials, parsedCred) {
credentials = append(credentials, parsedCred)
}

credNames, err := determineCredentialNames(prg, prg.ToolSet[ref.ToolID], cred)
if err != nil {
return credentials, credentialNames, err
}

for _, n := range credNames {
if !slices.Contains(credentialNames, n) {
credentialNames = append(credentialNames, n)
}
}
}
}

return credentials, credentialNames, nil
}

func determineCredentialNames(prg *gptscript.Program, tool gptscript.Tool, toolName string) ([]string, error) {
if toolName == system.ModelProviderCredential {
return []string{system.ModelProviderCredential}, nil
}

var subTool string
parsedToolName, alias, args, err := gtypes.ParseCredentialArgs(toolName, "")
if err != nil {
parsedToolName, subTool = gtypes.SplitToolRef(toolName)
parsedToolName, alias, args, err = gtypes.ParseCredentialArgs(parsedToolName, "")
if err != nil {
return nil, err
}
}

if alias != "" {
return []string{alias}, nil
}

if args == nil {
// This is a tool and not the credential format. Parse the tool from the program to determine the alias
toolNames := make([]string, 0, len(tool.Credentials))
if subTool == "" {
toolName = parsedToolName
}
for _, cred := range tool.Credentials {
if cred == toolName {
if len(tool.ToolMapping[cred]) == 0 {
return nil, fmt.Errorf("cannot find credential name for tool %q", toolName)
}

for _, ref := range tool.ToolMapping[cred] {
for _, c := range prg.ToolSet[ref.ToolID].ExportCredentials {
names, err := determineCredentialNames(prg, prg.ToolSet[ref.ToolID], c)
if err != nil {
return nil, err
}

toolNames = append(toolNames, names...)
}
}
}
}

if len(toolNames) > 0 {
return toolNames, nil
}

return nil, fmt.Errorf("tool %q not found in program", toolName)
}

return []string{toolName}, nil
}

type toolRef struct {
gptscript.ToolReference
name string
}

func toolRefsFromTools(parentTool gptscript.Tool, parentRef toolRef, tools []string, seen map[string]struct{}) []toolRef {
var toolRefs []toolRef
for _, e := range tools {
name := e
if _, ok := parentTool.LocalTools[strings.ToLower(e)]; ok {
name, _ = gtypes.SplitToolRef(parentRef.name)
name = fmt.Sprintf("%s from %s", e, name)
}
name = fullToolPathName(parentRef, name)
if name == "" {
continue
}

for _, r := range parentTool.ToolMapping[e] {
if _, ok := seen[r.ToolID]; !ok {
toolRefs = append(toolRefs, toolRef{
ToolReference: r,
name: name,
})
}
}
}

return toolRefs
}

func fullToolPathName(parentRef toolRef, name string) string {
toolName, subTool := gtypes.SplitToolRef(name)
if strings.HasPrefix(toolName, ".") {
parentToolName, _ := gtypes.SplitToolRef(parentRef.Reference)
if !path.IsAbs(parentToolName) {
if !strings.HasPrefix(parentToolName, ".") {
parentToolName, _ = gtypes.SplitToolRef(parentRef.name)
} else {
parentToolName = path.Join(parentRef.name, parentToolName)
}
}

refURL, err := url.Parse(parentToolName)
if err != nil {
return ""
}

if strings.HasSuffix(refURL.Path, ".gpt") {
refURL.Path = path.Dir(refURL.Path)
}

refURL.Path = path.Join(refURL.Path, toolName)
name = refURL.String()
if refURL.Host == "" {
// This is only a path, so url unescape it.
// No need to check the error here, we would have errored when parsing.
name, _ = url.PathUnescape(name)
}

if subTool != "" {
name = fmt.Sprintf("%s from %s", subTool, name)
}
}

return name
}
Loading
Loading