Skip to content

Commit

Permalink
feat: add ability to add/remove user tools from threads
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanhopperlowe committed Dec 17, 2024
1 parent 485fcbe commit b7f661e
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 44 deletions.
4 changes: 3 additions & 1 deletion pkg/api/handlers/threads.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,9 @@ func (a *ThreadHandler) Update(req api.Context) error {
return err
}
for _, newTool := range newThread.Tools {
if !slices.Contains(agent.Spec.Manifest.AvailableThreadTools, newTool) {
possibleTools := slices.Concat(agent.Spec.Manifest.AvailableThreadTools, agent.Spec.Manifest.DefaultThreadTools)

if !slices.Contains(possibleTools, newTool) {
return types.NewErrBadRequest("tool %s is not available for agent %s", newTool, agent.Name)
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import { LibraryIcon, WrenchIcon } from "lucide-react";
import { useMemo } from "react";
import useSWR from "swr";

import { Agent } from "~/lib/model/agents";
import { KnowledgeFile } from "~/lib/model/knowledge";
import { AgentService } from "~/lib/service/api/agentService";
import { ThreadsService } from "~/lib/service/api/threadsService";
import { cn } from "~/lib/utils";

import { TypographyMuted, TypographySmall } from "~/components/Typography";
import { ToolEntry } from "~/components/agent/ToolEntry";
import { useChat } from "~/components/chat/ChatContext";
import {
useOptimisticThread,
useThreadAgents as useThreadAgent,
useThreadKnowledge,
} from "~/components/chat/thread-helpers";
import { Button } from "~/components/ui/button";
import {
Popover,
Expand All @@ -19,23 +21,12 @@ import {
} from "~/components/ui/popover";
import { Switch } from "~/components/ui/switch";

export function ChatHelpers({ className }: { className?: string }) {
export function ChatActions({ className }: { className?: string }) {
const { threadId } = useChat();

const { data: thread } = useSWR(
ThreadsService.getThreadById.key(threadId),
({ threadId }) => ThreadsService.getThreadById(threadId)
);

const { data: knowledge } = useSWR(
ThreadsService.getKnowledge.key(threadId),
({ threadId }) => ThreadsService.getKnowledge(threadId)
);

const { data: agent } = useSWR(
AgentService.getAgentById.key(thread?.agentID),
({ agentId }) => AgentService.getAgentById(agentId)
);
const { data: knowledge } = useThreadKnowledge(threadId);
const { data: agent } = useThreadAgent(threadId);
const { thread, updateThread } = useOptimisticThread(threadId);

const tools = thread?.tools;

Expand All @@ -44,6 +35,7 @@ export function ChatHelpers({ className }: { className?: string }) {
<div className="flex items-center gap-2">
<ToolsInfo
tools={tools ?? []}
onChange={(tools) => updateThread({ tools })}
agent={agent}
disabled={!thread}
/>
Expand All @@ -54,18 +46,26 @@ export function ChatHelpers({ className }: { className?: string }) {
);
}

type ToolItem = {
tool: string;
isToggleable: boolean;
isEnabled: boolean;
};

function ToolsInfo({
tools,
className,
agent,
disabled,
onChange,
}: {
tools: string[];
className?: string;
agent: Nullish<Agent>;
disabled?: boolean;
onChange: (tools: string[]) => void;
}) {
const toolItems = useMemo(() => {
const toolItems = useMemo<ToolItem[]>(() => {
if (!agent)
return tools.map((tool) => ({
tool,
Expand Down Expand Up @@ -93,6 +93,11 @@ function ToolsInfo({
return [...agentTools, ...toggleableTools];
}, [tools, agent]);

const handleToggleTool = (tool: string, checked: boolean) => {
console.log("toggle tool", tool, checked);
onChange(checked ? [...tools, tool] : tools.filter((t) => t !== tool));
};

return (
<Popover>
<PopoverTrigger asChild>
Expand All @@ -114,27 +119,7 @@ function ToolsInfo({
Available Tools
</TypographySmall>
<div className="space-y-1">
{toolItems.map(
({ tool, isToggleable, isEnabled }) => (
<ToolEntry
key={tool}
tool={tool}
actions={
isToggleable ? (
<Switch
checked={isEnabled}
disabled
onCheckedChange={() => {}}
/>
) : (
<TypographyMuted>
On
</TypographyMuted>
)
}
/>
)
)}
{toolItems.map(renderToolItem)}
</div>
</div>
) : (
Expand All @@ -143,6 +128,27 @@ function ToolsInfo({
</PopoverContent>
</Popover>
);

function renderToolItem({ isEnabled, isToggleable, tool }: ToolItem) {
return (
<ToolEntry
key={tool}
tool={tool}
actions={
isToggleable ? (
<Switch
checked={isEnabled}
onCheckedChange={(checked) =>
handleToggleTool(tool, checked)
}
/>
) : (
<TypographyMuted>On</TypographyMuted>
)
}
/>
);
}
}

function KnowledgeInfo({
Expand Down
4 changes: 2 additions & 2 deletions ui/admin/app/components/chat/Chatbar.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ import { useState } from "react";

import { cn } from "~/lib/utils";

import { ChatActions } from "~/components/chat/ChatActions";
import { useChat } from "~/components/chat/ChatContext";
import { ChatHelpers } from "~/components/chat/ChatHelpers";
import { LoadingSpinner } from "~/components/ui/LoadingSpinner";
import { Button } from "~/components/ui/button";
import { AutosizeTextarea } from "~/components/ui/textarea";
Expand Down Expand Up @@ -64,7 +64,7 @@ export function Chatbar({ className }: ChatbarProps) {
)}
</Button>

<ChatHelpers className="p-2" />
<ChatActions className="p-2" />
</div>
}
/>
Expand Down
58 changes: 58 additions & 0 deletions ui/admin/app/components/chat/thread-helpers.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import { toast } from "sonner";
import useSWR from "swr";

import { UpdateThread } from "~/lib/model/threads";
import { AgentService } from "~/lib/service/api/agentService";
import { ThreadsService } from "~/lib/service/api/threadsService";

import { useAsync } from "~/hooks/useAsync";

function useThread(threadId?: Nullish<string>) {
return useSWR(ThreadsService.getThreadById.key(threadId), ({ threadId }) =>
ThreadsService.getThreadById(threadId)
);
}

export function useOptimisticThread(threadId?: Nullish<string>) {
const { data: thread, mutate } = useThread(threadId);

const handleUpdateThread = useAsync(ThreadsService.updateThreadById);

const updateThread = async (updates: Partial<UpdateThread>) => {
if (!thread) return;

const updatedThread = { ...thread, ...updates };

// optimistic update
mutate((thread) => (thread ? updatedThread : thread), false);

const { error, data } = await handleUpdateThread.executeAsync(
thread.id,
updatedThread
);

if (data) return;

if (error instanceof Error) toast.error(error.message);

// revert optimistic update
mutate(thread);
};

return { thread, updateThread };
}

export function useThreadKnowledge(threadId?: Nullish<string>) {
return useSWR(ThreadsService.getKnowledge.key(threadId), ({ threadId }) =>
ThreadsService.getKnowledge(threadId)
);
}

export function useThreadAgents(threadId?: Nullish<string>) {
const { data: thread } = useThread(threadId);

return useSWR(
AgentService.getAgentById.key(thread?.agentID),
({ agentId }) => AgentService.getAgentById(agentId)
);
}
2 changes: 2 additions & 0 deletions ui/admin/app/lib/model/threads.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@ export type Thread = EntityMeta &
| { agentID: string; workflowID?: never }
| { agentID?: never; workflowID: string }
);

export type UpdateThread = ThreadBase;
1 change: 1 addition & 0 deletions ui/admin/app/lib/routers/apiRoutes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ export const ApiRoutes = {
threads: {
base: () => buildUrl("/threads"),
getById: (threadId: string) => buildUrl(`/threads/${threadId}`),
updateById: (threadId: string) => buildUrl(`/threads/${threadId}`),
getByAgent: (agentId: string) => buildUrl(`/agents/${agentId}/threads`),
events: (
threadId: string,
Expand Down
14 changes: 13 additions & 1 deletion ui/admin/app/lib/service/api/threadsService.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { ChatEvent } from "~/lib/model/chatEvents";
import { KnowledgeFile } from "~/lib/model/knowledge";
import { Thread } from "~/lib/model/threads";
import { Thread, UpdateThread } from "~/lib/model/threads";
import { WorkspaceFile } from "~/lib/model/workspace";
import { ApiRoutes, revalidateWhere } from "~/lib/routers/apiRoutes";
import { request } from "~/lib/service/api/primitives";
Expand Down Expand Up @@ -29,6 +29,17 @@ getThreadById.key = (threadId?: Nullish<string>) => {
return { url: ApiRoutes.threads.getById(threadId).path, threadId };
};

const updateThreadById = async (threadId: string, thread: UpdateThread) => {
const { data } = await request<Thread>({
url: ApiRoutes.threads.updateById(threadId).url,
method: "PUT",
data: thread,
errorMessage: "Failed to update thread",
});

return data;
};

const getThreadsByAgent = async (agentId: string) => {
const res = await request<{ items: Thread[] }>({
url: ApiRoutes.threads.getByAgent(agentId).url,
Expand Down Expand Up @@ -125,6 +136,7 @@ export const ThreadsService = {
getThreadsByAgent,
getThreadEvents,
getThreadEventSource,
updateThreadById,
deleteThread,
revalidateThreads,
getKnowledge,
Expand Down

0 comments on commit b7f661e

Please sign in to comment.