Skip to content

Commit

Permalink
feat: add support for authenticating on agents and workflows
Browse files Browse the repository at this point in the history
Agents and workflows can have tools that require credentials. To
authenticate these credentials, this change adds the ability to track
which tools need which credentials, whether those credentials exist,
and the ability to create and delete them.

Signed-off-by: Donnie Adams <[email protected]>
  • Loading branch information
thedadams committed Dec 23, 2024
1 parent f03c070 commit 5de7695
Show file tree
Hide file tree
Showing 34 changed files with 795 additions and 137 deletions.
6 changes: 6 additions & 0 deletions apiclient/types/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type Agent struct {
AgentManifest
AliasAssigned *bool `json:"aliasAssigned,omitempty"`
AuthStatus map[string]OAuthAppLoginAuthStatus `json:"authStatus,omitempty"`
ToolInfo *map[string]ToolInfo `json:"toolInfo,omitempty"`
TextEmbeddingModel string `json:"textEmbeddingModel,omitempty"`
}

Expand Down Expand Up @@ -56,3 +57,8 @@ func (m AgentManifest) GetParams() *openapi3.Schema {

return gptscript.ObjectSchema(args...)
}

type ToolInfo struct {
CredentialNames []string `json:"credentialNames,omitempty"`
Authorized bool `json:"authorized"`
}
2 changes: 1 addition & 1 deletion apiclient/types/toolreference.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type ToolReference struct {
Error string `json:"error,omitempty"`
Builtin bool `json:"builtin,omitempty"`
Description string `json:"description,omitempty"`
Credential string `json:"credential,omitempty"`
Credentials []string `json:"credential,omitempty"`
Params map[string]string `json:"params,omitempty"`
}

Expand Down
1 change: 1 addition & 0 deletions apiclient/types/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ type Workflow struct {
WorkflowManifest
AliasAssigned *bool `json:"aliasAssigned,omitempty"`
AuthStatus map[string]OAuthAppLoginAuthStatus `json:"authStatus,omitempty"`
ToolInfo *map[string]ToolInfo `json:"toolInfo,omitempty"`
TextEmbeddingModel string `json:"textEmbeddingModel,omitempty"`
}

Expand Down
47 changes: 47 additions & 0 deletions apiclient/types/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

151 changes: 136 additions & 15 deletions pkg/api/handlers/agent.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package handlers

import (
"context"
"errors"
"fmt"
"net/http"
Expand All @@ -11,6 +12,7 @@ import (
"github.com/obot-platform/obot/apiclient/types"
"github.com/obot-platform/obot/pkg/alias"
"github.com/obot-platform/obot/pkg/api"
"github.com/obot-platform/obot/pkg/invoke"
"github.com/obot-platform/obot/pkg/render"
v1 "github.com/obot-platform/obot/pkg/storage/apis/otto.otto8.ai/v1"
"github.com/obot-platform/obot/pkg/system"
Expand All @@ -21,19 +23,97 @@ import (

type AgentHandler struct {
gptscript *gptscript.GPTScript
invoker *invoke.Invoker
serverURL string
// This is currently a hack to access the workflow handler
workflowHandler *WorkflowHandler
}

func NewAgentHandler(gClient *gptscript.GPTScript, serverURL string) *AgentHandler {
func NewAgentHandler(gClient *gptscript.GPTScript, invoker *invoke.Invoker, serverURL string) *AgentHandler {
return &AgentHandler{
serverURL: serverURL,
gptscript: gClient,
workflowHandler: NewWorkflowHandler(gClient, serverURL, nil),
invoker: invoker,
workflowHandler: NewWorkflowHandler(gClient, serverURL, invoker),
}
}

func (a *AgentHandler) Authenticate(req api.Context) (err error) {
var (
id = req.PathValue("id")
agent v1.Agent
tools []string
)

if err := req.Read(&tools); err != nil {
return fmt.Errorf("failed to read tools from request body: %w", err)
}

if len(tools) == 0 {
return types.NewErrBadRequest("no tools provided for authentication")
}

if err := req.Get(&agent, id); err != nil {
return err
}

resp, err := runAuthForAgent(req.Context(), req.Storage, a.invoker, agent.DeepCopy(), tools)
defer func() {
resp.Close()
if kickErr := kickAgent(req.Context(), req.Storage, &agent); kickErr != nil && err == nil {
err = fmt.Errorf("failed to update agent status: %w", kickErr)
}
}()

req.ResponseWriter.Header().Set("X-Otto-Thread-Id", resp.Thread.Name)
return req.WriteEvents(resp.Events)
}

func (a *AgentHandler) DeAuthenticate(req api.Context) error {
var (
id = req.PathValue("id")
agent v1.Agent
tools []string
)

if err := req.Read(&tools); err != nil {
return fmt.Errorf("failed to read tools from request body: %w", err)
}

if len(tools) == 0 {
return types.NewErrBadRequest("no tools provided for de-authentication")
}

if err := req.Get(&agent, id); err != nil {
return err
}

var (
errs []error
toolRef v1.ToolReference
)
for _, tool := range tools {
if err := req.Get(&toolRef, tool); err != nil {
errs = append(errs, err)
continue
}

if toolRef.Status.Tool != nil {
for _, cred := range toolRef.Status.Tool.CredentialNames {
if err := a.gptscript.DeleteCredential(req.Context(), id, cred); err != nil && !strings.HasSuffix(err.Error(), "credential not found") {
errs = append(errs, err)
}
}
}
}

if err := kickAgent(req.Context(), req.Storage, &agent); err != nil {
errs = append(errs, fmt.Errorf("failed to update agent status: %w", err))
}

return errors.Join(errs...)
}

func (a *AgentHandler) Update(req api.Context) error {
var (
id = req.PathValue("id")
Expand Down Expand Up @@ -140,16 +220,21 @@ func convertAgent(agent v1.Agent, textEmbeddingModel, baseURL string) (*types.Ag
links = []string{"invoke", baseURL + "/invoke/" + alias}
}

var aliasAssigned *bool
if agent.Generation == agent.Status.AliasObservedGeneration {
var (
aliasAssigned *bool
toolInfos *map[string]types.ToolInfo
)
if agent.Generation == agent.Status.ObservedGeneration {
aliasAssigned = &agent.Status.AliasAssigned
toolInfos = &agent.Status.ToolInfo
}

return &types.Agent{
Metadata: MetadataFrom(&agent, links...),
AgentManifest: agent.Spec.Manifest,
AliasAssigned: aliasAssigned,
AuthStatus: agent.Status.AuthStatus,
ToolInfo: toolInfos,
TextEmbeddingModel: textEmbeddingModel,
}, nil
}
Expand Down Expand Up @@ -218,6 +303,7 @@ func (a *AgentHandler) ByID(req api.Context) error {
if err != nil {
return err
}

return req.WriteCreated(resp)
}

Expand Down Expand Up @@ -658,21 +744,12 @@ func (a *AgentHandler) EnsureCredentialForKnowledgeSource(req api.Context) error
return req.WriteCreated(resp)
}

// if auth is already authenticated, then don't continue.
if authStatus.Authenticated {
resp, err := convertAgent(agent, knowledgeSet.Status.TextEmbeddingModel, req.APIBaseURL)
if err != nil {
return err
}
return req.WriteCreated(resp)
}

credentialTool, err := v1.CredentialTool(req.Context(), req.Storage, req.Namespace(), ref)
credentialTools, err := v1.CredentialTools(req.Context(), req.Storage, req.Namespace(), ref)
if err != nil {
return err
}

if credentialTool == "" {
if len(credentialTools) == 0 {
// The only way to get here is if the controller hasn't set the field yet.
if agent.Status.AuthStatus == nil {
agent.Status.AuthStatus = make(map[string]types.OAuthAppLoginAuthStatus)
Expand Down Expand Up @@ -770,3 +847,47 @@ func MetadataFrom(obj kclient.Object, linkKV ...string) types.Metadata {
}
return m
}

func runAuthForAgent(ctx context.Context, c kclient.WithWatch, invoker *invoke.Invoker, agent *v1.Agent, tools []string) (*invoke.Response, error) {
credentials := make([]string, 0, len(tools))

var toolRef v1.ToolReference
for _, tool := range tools {
if err := c.Get(ctx, kclient.ObjectKey{Namespace: agent.Namespace, Name: tool}, &toolRef); err != nil {
return nil, err
}

if toolRef.Status.Tool == nil {
return nil, types.NewErrHttp(http.StatusTooEarly, fmt.Sprintf("tool %q is not ready", tool))
}

credentials = append(credentials, toolRef.Status.Tool.Credentials...)

// Reset the fields we care about so that we can use the same variable for the whole loop.
toolRef.Status.Tool = nil
}

agent.Spec.Manifest.Prompt = "#!sys.echo\nDONE"
agent.Spec.Manifest.Tools = tools
agent.Spec.Manifest.AvailableThreadTools = nil
agent.Spec.Manifest.DefaultThreadTools = nil
agent.Spec.Credentials = credentials

return invoker.Agent(ctx, c, agent, "", invoke.Options{
Synchronous: true,
ThreadCredentialScope: new(bool),
})
}

func kickAgent(ctx context.Context, c kclient.Client, agent *v1.Agent) error {
if agent.Annotations[v1.AgentSyncAnnotation] != "" {
delete(agent.Annotations, v1.AgentSyncAnnotation)
} else {
if agent.Annotations == nil {
agent.Annotations = make(map[string]string)
}
agent.Annotations[v1.AgentSyncAnnotation] = "true"
}

return c.Update(ctx, agent)
}
2 changes: 1 addition & 1 deletion pkg/api/handlers/emailreceiver.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func convertEmailReceiver(emailReceiver v1.EmailReceiver, hostname string) *type
manifest := emailReceiver.Spec.EmailReceiverManifest

var aliasAssigned *bool
if emailReceiver.Generation == emailReceiver.Status.AliasObservedGeneration {
if emailReceiver.Generation == emailReceiver.Status.ObservedGeneration {
aliasAssigned = &emailReceiver.Status.AliasAssigned
}
er := &types.EmailReceiver{
Expand Down
2 changes: 1 addition & 1 deletion pkg/api/handlers/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ func convertModel(ctx context.Context, c kclient.Client, model v1.Model) (types.
}

var aliasAssigned *bool
if model.Generation == model.Status.AliasObservedGeneration {
if model.Generation == model.Status.ObservedGeneration {
aliasAssigned = &model.Status.AliasAssigned
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/api/handlers/toolreferences.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func convertToolReference(toolRef v1.ToolReference) types.ToolReference {
tf.Name = toolRef.Status.Tool.Name
tf.Description = toolRef.Status.Tool.Description
tf.Metadata.Metadata = toolRef.Status.Tool.Metadata
tf.Credential = toolRef.Status.Tool.Credential
tf.Credentials = toolRef.Status.Tool.Credentials
}

return tf
Expand Down
2 changes: 1 addition & 1 deletion pkg/api/handlers/webhooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func convertWebhook(webhook v1.Webhook, urlPrefix string) *types.Webhook {
}

var aliasAssigned *bool
if webhook.Generation == webhook.Status.AliasObservedGeneration {
if webhook.Generation == webhook.Status.ObservedGeneration {
aliasAssigned = &webhook.Status.AliasAssigned
}

Expand Down
Loading

0 comments on commit 5de7695

Please sign in to comment.