Skip to content

Commit

Permalink
chore: add more custom tools in user ui
Browse files Browse the repository at this point in the history
  • Loading branch information
ibuildthecloud committed Jan 7, 2025
1 parent 51275aa commit 5179ae4
Show file tree
Hide file tree
Showing 27 changed files with 857 additions and 397 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ ENV PGDATA=/data/postgresql
COPY --from=build-pgvector /usr/lib/postgresql17/vector.so /usr/lib/postgresql17/
COPY --from=build-pgvector /usr/share/postgresql17/extension/vector* /usr/share/postgresql17/extension/

RUN apk add --no-cache git python-3.13 py3.13-pip openssh-server npm bash tini procps libreoffice
RUN apk add --no-cache git python-3.13 py3.13-pip openssh-server npm bash tini procps libreoffice docker
COPY --chmod=0755 /tools/package-chrome.sh /
RUN /package-chrome.sh && rm /package-chrome.sh
RUN sed -E 's/^#(PermitRootLogin)no/\1yes/' /etc/ssh/sshd_config -i
Expand Down
1 change: 1 addition & 0 deletions apiclient/types/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ type ToolManifest struct {
Description string `json:"description,omitempty"`
Icon string `json:"icon,omitempty"`
ToolType ToolType `json:"toolType,omitempty"`
Image string `json:"image,omitempty"`
Context string `json:"context,omitempty"`
Instructions string `json:"instructions,omitempty"`
Params map[string]string `json:"params,omitempty"`
Expand Down
46 changes: 46 additions & 0 deletions pkg/api/handlers/assistants.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io"
"maps"
"net/http"
"regexp"
"slices"
Expand Down Expand Up @@ -310,6 +311,51 @@ func (a *AssistantHandler) DeleteFile(req api.Context) error {
return deleteFileFromWorkspaceID(req.Context(), req, a.gptScript, thread.Status.WorkspaceID, "files/")
}

func (a *AssistantHandler) SetEnv(req api.Context) error {
var (
id = req.PathValue("id")
)

thread, err := getUserThread(req, id)
if err != nil {
return err
}

var envs map[string]string
if err := req.Read(&envs); err != nil {
return err
}

if err := setEnvMap(req, a.gptScript, thread.Name, thread.Name, envs); err != nil {
return err
}

thread.Spec.Env = slices.Collect(maps.Keys(envs))
if err := req.Update(thread); err != nil {
return err
}

return req.Write(envs)
}

func (a *AssistantHandler) GetEnv(req api.Context) error {
var (
id = req.PathValue("id")
)

thread, err := getUserThread(req, id)
if err != nil {
return err
}

data, err := getEnvMap(req, a.gptScript, thread.Name, thread.Name)
if err != nil {
return err
}

return req.Write(data)
}

func (a *AssistantHandler) Knowledge(req api.Context) error {
var (
id = req.PathValue("id")
Expand Down
127 changes: 100 additions & 27 deletions pkg/api/handlers/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,49 @@ package handlers

import (
"errors"
"maps"
"regexp"
"slices"
"strings"

"github.com/gptscript-ai/go-gptscript"
"github.com/obot-platform/obot/apiclient/types"
"github.com/obot-platform/obot/pkg/api"
"github.com/obot-platform/obot/pkg/invoke"
"github.com/obot-platform/obot/pkg/render"
v1 "github.com/obot-platform/obot/pkg/storage/apis/otto.otto8.ai/v1"
"github.com/obot-platform/obot/pkg/system"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

type ToolHandler struct {
gptScript *gptscript.GPTScript
invoke *invoke.Invoker
}

func NewToolHandler(gptScript *gptscript.GPTScript) *ToolHandler {
return &ToolHandler{gptScript: gptScript}
func NewToolHandler(gptScript *gptscript.GPTScript, invoke *invoke.Invoker) *ToolHandler {
return &ToolHandler{
gptScript: gptScript,
invoke: invoke,
}
}

var invalidEnv = regexp.MustCompile("^(OBOT|GPTSCRIPT)")

func setEnvMap(req api.Context, gptScript *gptscript.GPTScript, threadName, toolName string, env map[string]string) error {
for k := range env {
if invalidEnv.MatchString(k) {
return types.NewErrBadRequest("invalid env key %s", k)
}
}

return gptScript.CreateCredential(req.Context(), gptscript.Credential{
Context: threadName,
ToolName: toolName,
Type: gptscript.CredentialTypeTool,
Env: env,
})
}

func (t *ToolHandler) SetEnv(req api.Context) error {
toolID := req.PathValue("tool_id")
env := map[string]string{}
Expand All @@ -46,29 +67,11 @@ func (t *ToolHandler) SetEnv(req api.Context) error {
return types.NewErrNotFound("tool %s not found", toolID)
}

for k := range env {
if invalidEnv.MatchString(k) {
return types.NewErrBadRequest("invalid env key %s", k)
}
}

err = t.gptScript.CreateCredential(req.Context(), gptscript.Credential{
Context: thread.Name,
ToolName: tool.Name,
Type: gptscript.CredentialTypeTool,
Env: env,
})
if err != nil {
if err := setEnvMap(req, t.gptScript, thread.Name, tool.Name, env); err != nil {
return err
}

var envs []string
for k, v := range env {
if strings.TrimSpace(v) != "" {
envs = append(envs, k)
}
}
tool.Spec.Envs = envs
tool.Spec.Envs = slices.Collect(maps.Keys(env))
if err := req.Update(&tool); err != nil {
return err
}
Expand All @@ -93,14 +96,23 @@ func (t *ToolHandler) GetEnv(req api.Context) error {
return types.NewErrNotFound("tool %s not found", toolID)
}

cred, err := t.gptScript.RevealCredential(req.Context(), []string{thread.Name}, tool.Name)
data, err := getEnvMap(req, t.gptScript, thread.Name, tool.Name)
if err != nil {
return err
}

return req.Write(data)
}

func getEnvMap(req api.Context, gptScript *gptscript.GPTScript, threadName, toolName string) (map[string]string, error) {
cred, err := gptScript.RevealCredential(req.Context(), []string{threadName}, toolName)
if errors.As(err, &gptscript.ErrNotFound{}) {
return req.Write(map[string]string{})
return map[string]string{}, nil
} else if err != nil {
return err
return nil, err
}

return req.Write(cred.Env)
return cred.Env, nil
}

func (t *ToolHandler) Get(req api.Context) error {
Expand All @@ -123,6 +135,67 @@ func (t *ToolHandler) Get(req api.Context) error {
return req.Write(convertTool(tool, slices.Contains(thread.Spec.Manifest.Tools, tool.Name)))
}

type TestInput struct {
Input map[string]string `json:"input"`
Tool *types.AssistantTool `json:"tool"`
Env map[string]string `json:"env,omitempty"`
}

func (t *ToolHandler) Test(req api.Context) error {
toolID := req.PathValue("tool_id")

thread, err := getThreadForScope(req)
if err != nil {
return err
}

var tool v1.Tool
if err := req.Get(&tool, toolID); err != nil {
return err
}

if tool.Spec.ThreadName != thread.Name {
return types.NewErrNotFound("tool %s not found", toolID)
}

env, err := getEnvMap(req, t.gptScript, thread.Name, tool.Name)
if err != nil {
return err
}

var envList []string
for k, v := range env {
envList = append(envList, k+"="+v)
}

var input TestInput
if err := req.Read(&input); err != nil {
return err
}

for k, v := range input.Env {
envList = append(envList, k+"="+v)
}

if input.Tool != nil {
tool.Spec.Manifest = input.Tool.ToolManifest
}

tools, err := render.CustomTool(req.Context(), req.Storage, tool)
if err != nil {
return err
}

result, err := t.invoke.EphemeralThreadTask(req.Context(), thread, tools, input.Input, invoke.SystemTaskOptions{
Env: envList,
})
if err != nil {
return err
}

return req.Write(map[string]string{"output": result})
}

func (t *ToolHandler) Create(req api.Context) error {
thread, err := getThreadForScope(req)
if err != nil {
Expand Down
6 changes: 5 additions & 1 deletion pkg/api/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func Router(services *services.Services) (http.Handler, error) {

agents := handlers.NewAgentHandler(services.GPTClient, services.Invoker, services.ServerURL)
assistants := handlers.NewAssistantHandler(services.Invoker, services.Events, services.GPTClient)
tools := handlers.NewToolHandler(services.GPTClient)
tools := handlers.NewToolHandler(services.GPTClient, services.Invoker)
tasks := handlers.NewTaskHandler(services.Invoker, services.Events)
workflows := handlers.NewWorkflowHandler(services.GPTClient, services.ServerURL, services.Invoker)
invoker := handlers.NewInvokeHandler(services.Invoker)
Expand Down Expand Up @@ -72,6 +72,9 @@ func Router(services *services.Services) (http.Handler, error) {
mux.HandleFunc("GET /api/assistants/{id}/knowledge", assistants.Knowledge)
mux.HandleFunc("POST /api/assistants/{id}/knowledge/{file}", assistants.UploadKnowledge)
mux.HandleFunc("DELETE /api/assistants/{id}/knowledge/{file...}", assistants.DeleteKnowledge)
// Env
mux.HandleFunc("GET /api/assistants/{id}/env", assistants.GetEnv)
mux.HandleFunc("PUT /api/assistants/{id}/env", assistants.SetEnv)

if services.SupportDocker {
shell, err := handlers.NewShellHandler(services.Invoker)
Expand All @@ -83,6 +86,7 @@ func Router(services *services.Services) (http.Handler, error) {
// Tools
mux.HandleFunc("POST /api/assistants/{assistant_id}/tools", tools.Create)
mux.HandleFunc("PUT /api/assistants/{assistant_id}/tools/{tool_id}/env", tools.SetEnv)
mux.HandleFunc("POST /api/assistants/{assistant_id}/tools/{tool_id}/test", tools.Test)
}
mux.HandleFunc("GET /api/assistants/{assistant_id}/tools/{tool_id}", tools.Get)
mux.HandleFunc("GET /api/assistants/{assistant_id}/tools/{tool_id}/env", tools.GetEnv)
Expand Down
13 changes: 4 additions & 9 deletions pkg/invoke/invoker.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"errors"
"fmt"
"maps"
"slices"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -251,7 +250,7 @@ func CreateThreadForAgent(ctx context.Context, c kclient.WithWatch, agent *v1.Ag
return &thread, c.Create(ctx, &thread)
}

func (i *Invoker) updateThreadFields(ctx context.Context, c kclient.WithWatch, agent *v1.Agent, thread *v1.Thread, extraEnv []string, opt Options) error {
func (i *Invoker) updateThreadFields(ctx context.Context, c kclient.WithWatch, agent *v1.Agent, thread *v1.Thread, opt Options) error {
var updated bool
if opt.AgentAlias != "" && thread.Spec.AgentAlias != opt.AgentAlias {
thread.Spec.AgentAlias = opt.AgentAlias
Expand All @@ -261,10 +260,6 @@ func (i *Invoker) updateThreadFields(ctx context.Context, c kclient.WithWatch, a
thread.Spec.AgentName = agent.Name
updated = true
}
if !slices.Equal(thread.Spec.Env, extraEnv) {
thread.Spec.Env = extraEnv
updated = true
}
if updated {
return c.Status().Update(ctx, thread)
}
Expand Down Expand Up @@ -302,7 +297,7 @@ func (i *Invoker) Agent(ctx context.Context, c kclient.WithWatch, agent *v1.Agen
return nil, err
}

if err := i.updateThreadFields(ctx, c, agent, thread, extraEnv, opt); err != nil {
if err := i.updateThreadFields(ctx, c, agent, thread, opt); err != nil {
return nil, err
}

Expand Down Expand Up @@ -599,7 +594,7 @@ func (i *Invoker) Resume(ctx context.Context, c kclient.WithWatch, thread *v1.Th
func (i *Invoker) saveState(ctx context.Context, c kclient.Client, prevThreadName string, thread *v1.Thread, run *v1.Run, runResp *gptscript.Run, retErr error) error {
if isEphemeral(run) {
// Ephemeral run, don't save state
return nil
return retErr
}

var err error
Expand Down Expand Up @@ -886,7 +881,7 @@ func (i *Invoker) stream(ctx context.Context, c kclient.WithWatch, prevThreadNam
timeout = run.Spec.Timeout.Duration
}
go timeoutAfter(runCtx, cancelRun, timeout)
if run.Name != "" {
if !isEphemeral(run) {
// Don't watch thread abort for ephemeral runs
go watchThreadAbort(runCtx, c, thread, cancelRun)
}
Expand Down
19 changes: 17 additions & 2 deletions pkg/render/render.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/obot-platform/obot/apiclient/types"
"github.com/obot-platform/obot/pkg/gz"
v1 "github.com/obot-platform/obot/pkg/storage/apis/otto.otto8.ai/v1"
"github.com/obot-platform/obot/pkg/system"
apierror "k8s.io/apimachinery/pkg/api/errors"
kclient "sigs.k8s.io/controller-runtime/pkg/client"
)
Expand Down Expand Up @@ -81,9 +82,21 @@ func Agent(ctx context.Context, db kclient.Client, agent *v1.Agent, oauthServerU
if err != nil {
return nil, nil, err
}
mainTool.Tools = append(mainTool.Tools, name)
if name != "" {
mainTool.Tools = append(mainTool.Tools, name)
}
otherTools = append(otherTools, tools...)
}

credTool, err := ResolveToolReference(ctx, db, types.ToolReferenceTypeSystem, opts.Thread.Namespace, system.ExistingCredTool)
if err != nil {
return nil, nil, err
}

mainTool.Credentials = append(mainTool.Credentials, credTool+" as "+opts.Thread.Name)
if len(opts.Thread.Spec.Env) > 0 {
extraEnv = append(extraEnv, fmt.Sprintf("OBOT_THREAD_ENVS=%s", strings.Join(opts.Thread.Spec.Env, ",")))
}
}

for _, tool := range agent.Spec.Manifest.Tools {
Expand All @@ -94,7 +107,9 @@ func Agent(ctx context.Context, db kclient.Client, agent *v1.Agent, oauthServerU
if err != nil {
return nil, nil, err
}
mainTool.Tools = append(mainTool.Tools, name)
if name != "" {
mainTool.Tools = append(mainTool.Tools, name)
}
otherTools = append(otherTools, tools...)
}

Expand Down
Loading

0 comments on commit 5179ae4

Please sign in to comment.