diff --git a/ui/admin/app/components/composed/DataTable.tsx b/ui/admin/app/components/composed/DataTable.tsx index 62e00ce3..f3d5c73a 100644 --- a/ui/admin/app/components/composed/DataTable.tsx +++ b/ui/admin/app/components/composed/DataTable.tsx @@ -83,28 +83,7 @@ export function DataTable({ rowClassName?.(row.original) )} > - {row.getVisibleCells().map((cell) => ( - { - if ( - !disableClickPropagation?.(cell) - ) { - onRowClick?.(row.original); - } - }} - > - {flexRender( - cell.column.columnDef.cell, - cell.getContext() - )} - - ))} + {row.getVisibleCells().map(renderCell)} )) ) : ( @@ -124,4 +103,22 @@ export function DataTable({ ); + + function renderCell(cell: Cell) { + return ( + { + if (!disableClickPropagation?.(cell)) { + onRowClick?.(cell.row.original); + } + }} + > + {flexRender(cell.column.columnDef.cell, cell.getContext())} + + ); + } } diff --git a/ui/admin/app/components/form/controlledInputs.tsx b/ui/admin/app/components/form/controlledInputs.tsx index f39e3456..8364a299 100644 --- a/ui/admin/app/components/form/controlledInputs.tsx +++ b/ui/admin/app/components/form/controlledInputs.tsx @@ -1,14 +1,16 @@ -import { ReactNode } from "react"; +import { ComponentProps, ReactNode } from "react"; import { Control, ControllerFieldState, ControllerRenderProps, FieldPath, FieldValues, + FormState, } from "react-hook-form"; import { cn } from "~/lib/utils"; +import { Checkbox } from "~/components/ui/checkbox"; import { FormControl, FormDescription, @@ -180,6 +182,107 @@ export function ControlledAutosizeTextarea< ); } +export type ControlledCheckboxProps< + TValues extends FieldValues, + TName extends FieldPath, +> = BaseProps & ComponentProps; + +export function ControlledCheckbox< + TValues extends FieldValues, + TName extends FieldPath, +>({ + control, + name, + label, + description, + onCheckedChange, + ...checkboxProps +}: ControlledCheckboxProps) { + return ( + ( + +
+ + { + field.onChange(value); + onCheckedChange?.(value); + }} + className={cn( + getFieldStateClasses(fieldState), + checkboxProps.className + )} + /> + + + {label && {label}} +
+ + + + {description && ( + {description} + )} +
+ )} + /> + ); +} + +export type ControlledCustomInputProps< + TValues extends FieldValues, + TName extends FieldPath, +> = BaseProps & { + children: (props: { + field: ControllerRenderProps; + fieldState: ControllerFieldState; + formState: FormState; + className?: string; + }) => ReactNode; +}; + +export function ControlledCustomInput< + TValues extends FieldValues, + TName extends FieldPath, +>({ + control, + name, + label, + description, + children, +}: ControlledCustomInputProps) { + return ( + ( + + {label && {label}} + + + {children({ + ...args, + className: getFieldStateClasses(args.fieldState), + })} + + + + + {description && ( + {description} + )} + + )} + /> + ); +} + function getFieldStateClasses(fieldState: ControllerFieldState) { return cn({ "focus-visible:ring-destructive border-destructive": fieldState.invalid, diff --git a/ui/admin/app/components/model/CreateModel.tsx b/ui/admin/app/components/model/CreateModel.tsx new file mode 100644 index 00000000..13980b48 --- /dev/null +++ b/ui/admin/app/components/model/CreateModel.tsx @@ -0,0 +1,33 @@ +import { PlusIcon } from "lucide-react"; +import { useState } from "react"; + +import { ModelForm } from "~/components/model/ModelForm"; +import { Button } from "~/components/ui/button"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogTitle, + DialogTrigger, +} from "~/components/ui/dialog"; + +export function CreateModel() { + const [open, setOpen] = useState(false); + + return ( + + + + + + + Create Model + + + setOpen(false)} /> + + + ); +} diff --git a/ui/admin/app/components/model/DeleteModel.tsx b/ui/admin/app/components/model/DeleteModel.tsx new file mode 100644 index 00000000..27425abc --- /dev/null +++ b/ui/admin/app/components/model/DeleteModel.tsx @@ -0,0 +1,54 @@ +import { TrashIcon } from "lucide-react"; +import { mutate } from "swr"; + +import { ModelApiService } from "~/lib/service/api/modelApiService"; + +import { ConfirmationDialog } from "~/components/composed/ConfirmationDialog"; +import { Button } from "~/components/ui/button"; +import { + Tooltip, + TooltipContent, + TooltipProvider, + TooltipTrigger, +} from "~/components/ui/tooltip"; +import { useAsync } from "~/hooks/useAsync"; + +type DeleteModelProps = { + id: string; +}; + +export function DeleteModel(props: DeleteModelProps) { + const deleteModel = useAsync(ModelApiService.deleteModel, { + onSuccess: () => mutate(ModelApiService.getModels.key()), + }); + + return ( + + + deleteModel.execute(props.id)} + confirmProps={{ + variant: "destructive", + children: "Delete", + }} + > + + + + + + Delete Model + + + ); +} diff --git a/ui/admin/app/components/model/ModelForm.tsx b/ui/admin/app/components/model/ModelForm.tsx new file mode 100644 index 00000000..1f9c135f --- /dev/null +++ b/ui/admin/app/components/model/ModelForm.tsx @@ -0,0 +1,161 @@ +import { zodResolver } from "@hookform/resolvers/zod"; +import { useMemo } from "react"; +import { useForm } from "react-hook-form"; +import { toast } from "sonner"; +import useSWR, { mutate } from "swr"; +import { z } from "zod"; + +import { Model, ModelManifest, ModelManifestSchema } from "~/lib/model/models"; +import { BadRequestError } from "~/lib/service/api/apiErrors"; +import { ModelApiService } from "~/lib/service/api/modelApiService"; + +import { + ControlledCheckbox, + ControlledCustomInput, + ControlledInput, +} from "~/components/form/controlledInputs"; +import { Button } from "~/components/ui/button"; +import { Form } from "~/components/ui/form"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "~/components/ui/select"; +import { useAsync } from "~/hooks/useAsync"; + +type ModelFormProps = { + model?: Model; + onSubmit: (model: ModelManifest) => void; +}; + +type FormValues = z.infer; + +export function ModelForm(props: ModelFormProps) { + const { model, onSubmit } = props; + + const { data: modelProviders } = useSWR( + ModelApiService.getModelProviders.key(true), + ({ onlyConfigured }) => + ModelApiService.getModelProviders(onlyConfigured) + ); + + const updateModel = useAsync(ModelApiService.updateModel, { + onSuccess: (values) => { + toast.success("Model updated"); + mutate(ModelApiService.getModels.key()); + onSubmit(values); + }, + onError, + }); + + const createModel = useAsync(ModelApiService.createModel, { + onSuccess: (values) => { + toast.success("Model created"); + mutate(ModelApiService.getModels.key()); + onSubmit(values); + }, + onError, + }); + + const defaultValues = useMemo(() => { + return { + name: model?.name ?? "", + targetModel: model?.targetModel ?? "", + modelProvider: model?.modelProvider ?? "", + active: model?.active ?? true, + default: model?.default ?? false, + }; + }, [model]); + + const form = useForm({ + resolver: zodResolver(ModelManifestSchema), + defaultValues, + }); + + const { loading, submit } = getSubmitInfo(); + + const handleSubmit = form.handleSubmit(submit); + + return ( +
+ + + + + {({ field: { ref: _, ...field }, className }) => ( + + )} + + + + + + + + + + ); + + function getSubmitInfo() { + if (model) { + return { + isEdit: true, + loading: updateModel.isLoading, + submit: (values: FormValues) => + updateModel.execute(model.id, values), + }; + } + + return { + isEdit: false, + loading: createModel.isLoading, + submit: (values: FormValues) => createModel.execute(values), + }; + } + + function onError(error: unknown) { + if (error instanceof BadRequestError) + form.setError("default", { message: error.message }); + } +} diff --git a/ui/admin/app/components/model/UpdateModel.tsx b/ui/admin/app/components/model/UpdateModel.tsx new file mode 100644 index 00000000..6f0eea12 --- /dev/null +++ b/ui/admin/app/components/model/UpdateModel.tsx @@ -0,0 +1,37 @@ +import { Model } from "~/lib/model/models"; + +import { ModelForm } from "~/components/model/ModelForm"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogTitle, + DialogTrigger, +} from "~/components/ui/dialog"; + +type UpdateModelDialogProps = { + model: Nullish; + open: boolean; + setOpen: (open: boolean) => void; + children?: React.ReactNode; +}; + +export function UpdateModelDialog(props: UpdateModelDialogProps) { + const { model, open, setOpen, children } = props; + + if (!model) return null; + + return ( + + + Update model + + + + setOpen(false)} /> + + + {children && {children}} + + ); +} diff --git a/ui/admin/app/components/sidebar/Sidebar.tsx b/ui/admin/app/components/sidebar/Sidebar.tsx index 78024a9a..131049f6 100644 --- a/ui/admin/app/components/sidebar/Sidebar.tsx +++ b/ui/admin/app/components/sidebar/Sidebar.tsx @@ -1,6 +1,7 @@ import { Link } from "@remix-run/react"; import { BotIcon, + BrainIcon, KeyIcon, MessageSquare, SettingsIcon, @@ -59,6 +60,11 @@ const items = [ url: $path("/oauth-apps"), icon: KeyIcon, }, + { + title: "Models", + url: $path("/models"), + icon: BrainIcon, + }, ]; export function AppSidebar() { diff --git a/ui/admin/app/components/ui/button.tsx b/ui/admin/app/components/ui/button.tsx index 1d6cd070..79dcf3e8 100644 --- a/ui/admin/app/components/ui/button.tsx +++ b/ui/admin/app/components/ui/button.tsx @@ -40,6 +40,7 @@ export interface ButtonProps VariantProps { asChild?: boolean; loading?: boolean; + startContent?: React.ReactNode; } const Button = React.forwardRef( @@ -50,6 +51,7 @@ const Button = React.forwardRef( size, asChild = false, loading = false, + startContent, children, ...props }, @@ -77,7 +79,10 @@ const Button = React.forwardRef( {children} ) : ( - children +
+ {startContent} + {children} +
); } } diff --git a/ui/admin/app/components/ui/dialog.tsx b/ui/admin/app/components/ui/dialog.tsx index 9f5459c3..f3dd578c 100644 --- a/ui/admin/app/components/ui/dialog.tsx +++ b/ui/admin/app/components/ui/dialog.tsx @@ -48,7 +48,10 @@ const DialogContent = React.forwardRef< return ( - + e.stopPropagation()} + /> = { onSuccess?: (data: TData, params: TParams) => void; onError?: (error: unknown, params: TParams) => void; onSettled?: ({ params }: { params: TParams }) => void; + shouldThrow?: (error: unknown) => boolean; }; -const defaultShouldThrow = (error: unknown) => !(error instanceof AxiosError); +const defaultShouldThrow = (error: unknown) => error instanceof BoundaryError; export function useAsync( callback: (...params: TParams) => Promise, config?: Config ) { - const { onSuccess, onError, onSettled } = config || {}; + const { + onSuccess, + onError, + onSettled, + shouldThrow = defaultShouldThrow, + } = config || {}; const [data, setData] = useState(null); const [error, setError] = useState(null); const [isLoading, setIsLoading] = useState(false); const [lastCallParams, setLastCallParams] = useState(null); - if (error && defaultShouldThrow(error)) throw error; + if (error && shouldThrow(error)) throw error; const executeAsync = useCallback( async (...params: TParams) => { diff --git a/ui/admin/app/lib/model/models.ts b/ui/admin/app/lib/model/models.ts new file mode 100644 index 00000000..f4b96fc4 --- /dev/null +++ b/ui/admin/app/lib/model/models.ts @@ -0,0 +1,36 @@ +import { z } from "zod"; + +import { EntityMeta } from "~/lib/model/primitives"; + +export type ModelManifest = { + name?: string; + targetModel?: string; + modelProvider: string; + active: boolean; + default: boolean; +}; + +export type ModelProviderStatus = { + configured: boolean; + missingEnvVars?: string[]; +}; + +export type Model = EntityMeta & ModelManifest & ModelProviderStatus; + +export const ModelManifestSchema = z.object({ + name: z.string(), + targetModel: z.string().min(1, "Required"), + modelProvider: z.string().min(1, "Required"), + active: z.boolean(), + default: z.boolean(), +}); + +export type ModelProvider = EntityMeta & { + description?: string; + builtin: boolean; + active: boolean; + modelProviderStatus: ModelProviderStatus; + name: string; + reference: string; + toolType: "modelProvider"; +}; diff --git a/ui/admin/app/lib/model/toolReferences.ts b/ui/admin/app/lib/model/toolReferences.ts index 5e2286b9..9bbfffa7 100644 --- a/ui/admin/app/lib/model/toolReferences.ts +++ b/ui/admin/app/lib/model/toolReferences.ts @@ -8,7 +8,7 @@ export type ToolReferenceBase = { metadata?: Record; }; -export type ToolReferenceType = "tool" | "stepTemplate"; +export type ToolReferenceType = "tool" | "stepTemplate" | "modelProvider"; export type ToolReference = { error: string; diff --git a/ui/admin/app/lib/routers/apiRoutes.ts b/ui/admin/app/lib/routers/apiRoutes.ts index 71226031..84eb740d 100644 --- a/ui/admin/app/lib/routers/apiRoutes.ts +++ b/ui/admin/app/lib/routers/apiRoutes.ts @@ -129,6 +129,14 @@ export const ApiRoutes = { supportedOauthAppTypes: () => buildUrl("/supported-oauth-app-types"), supportedAuthTypes: () => buildUrl("/supported-auth-types"), }, + models: { + base: () => buildUrl("/models"), + getModels: () => buildUrl("/models"), + getModelById: (modelId: string) => buildUrl(`/models/${modelId}`), + createModel: () => buildUrl(`/models`), + updateModel: (modelId: string) => buildUrl(`/models/${modelId}`), + deleteModel: (modelId: string) => buildUrl(`/models/${modelId}`), + }, }; /** revalidates the cache for all routes that match the filter callback diff --git a/ui/admin/app/lib/service/api/apiErrors.tsx b/ui/admin/app/lib/service/api/apiErrors.tsx index 943f5d47..761fef77 100644 --- a/ui/admin/app/lib/service/api/apiErrors.tsx +++ b/ui/admin/app/lib/service/api/apiErrors.tsx @@ -1,3 +1,8 @@ export class ConflictError extends Error {} -export class ForbiddenError extends Error {} -export class UnauthorizedError extends Error {} +export class BadRequestError extends Error {} + +// Errors that should trigger the error boundary +export class BoundaryError extends Error {} + +export class ForbiddenError extends BoundaryError {} +export class UnauthorizedError extends BoundaryError {} diff --git a/ui/admin/app/lib/service/api/modelApiService.ts b/ui/admin/app/lib/service/api/modelApiService.ts new file mode 100644 index 00000000..a6ec868e --- /dev/null +++ b/ui/admin/app/lib/service/api/modelApiService.ts @@ -0,0 +1,87 @@ +import { Model, ModelManifest, ModelProvider } from "~/lib/model/models"; +import { ApiRoutes } from "~/lib/routers/apiRoutes"; +import { request } from "~/lib/service/api/primitives"; + +async function getModels() { + const { data } = await request<{ items?: Model[] }>({ + url: ApiRoutes.models.getModels().url, + }); + + return data.items ?? []; +} +getModels.key = () => ({ url: ApiRoutes.models.getModels().path }); + +async function getModelById(modelId: string) { + const { data } = await request({ + url: ApiRoutes.models.getModelById(modelId).url, + }); + + return data; +} +getModelById.key = (modelId?: string) => { + if (!modelId) return null; + + return { + url: ApiRoutes.models.getModelById(modelId).path, + modelId, + }; +}; + +async function getModelProviders(onlyConfigured = false) { + const { data } = await request<{ items?: ModelProvider[] }>({ + url: ApiRoutes.toolReferences.base({ type: "modelProvider" }).url, + }); + + if (onlyConfigured) { + return ( + data.items?.filter( + (provider) => provider.modelProviderStatus.configured + ) ?? [] + ); + } + + return data.items ?? []; +} +getModelProviders.key = (onlyConfigured = false) => ({ + url: ApiRoutes.toolReferences.base({ type: "modelProvider" }).path, + onlyConfigured, +}); + +async function createModel(manifest: ModelManifest) { + await new Promise((resolve) => setTimeout(resolve, 1000)); + const { data } = await request({ + url: ApiRoutes.models.createModel().url, + method: "POST", + data: manifest, + }); + + return data; +} + +async function updateModel(modelId: string, manifest: ModelManifest) { + await new Promise((resolve) => setTimeout(resolve, 1000)); + + const { data } = await request({ + url: ApiRoutes.models.updateModel(modelId).url, + method: "PUT", + data: manifest, + }); + + return data; +} + +async function deleteModel(modelId: string) { + await request({ + url: ApiRoutes.models.deleteModel(modelId).url, + method: "DELETE", + }); +} + +export const ModelApiService = { + getModels, + getModelById, + getModelProviders, + createModel, + updateModel, + deleteModel, +}; diff --git a/ui/admin/app/lib/service/api/primitives.ts b/ui/admin/app/lib/service/api/primitives.ts index 70f7f89a..2e5d4f31 100644 --- a/ui/admin/app/lib/service/api/primitives.ts +++ b/ui/admin/app/lib/service/api/primitives.ts @@ -5,6 +5,7 @@ import { AuthDisabledUsername } from "~/lib/model/auth"; import { User } from "~/lib/model/users"; import { ApiRoutes } from "~/lib/routers/apiRoutes"; import { + BadRequestError, ConflictError, ForbiddenError, UnauthorizedError, @@ -35,6 +36,10 @@ export async function request, D = unknown>({ } catch (error) { console.error(errorMessage); + if (isAxiosError(error) && error.response?.status === 400) { + throw new BadRequestError(error.response.data); + } + if (isAxiosError(error) && error.response?.status === 401) { throw new UnauthorizedError(error.response.data); } diff --git a/ui/admin/app/lib/service/routeQueryParams.ts b/ui/admin/app/lib/service/routeService.ts similarity index 95% rename from ui/admin/app/lib/service/routeQueryParams.ts rename to ui/admin/app/lib/service/routeService.ts index f83d5e0f..403b54cf 100644 --- a/ui/admin/app/lib/service/routeQueryParams.ts +++ b/ui/admin/app/lib/service/routeService.ts @@ -12,6 +12,7 @@ const QueryParamSchemaMap = { }), "/debug": z.undefined(), "/home": z.undefined(), + "/models": z.undefined(), "/oauth-apps": z.undefined(), "/thread/:id": z.undefined(), "/threads": z.object({ @@ -59,17 +60,6 @@ function getUnknownQueryParams(pathname: string, search: string) { } satisfies QueryParamInfo<"/threads">; } - if ( - new RegExp($path("/workflows/:workflow", { workflow: "(.*)" })).test( - pathname - ) - ) { - return { - path: "/workflows/:workflow", - query: parseSearchParams("/workflows/:workflow", search), - } satisfies QueryParamInfo<"/workflows/:workflow">; - } - return {}; } @@ -96,6 +86,17 @@ function getUnknownPathParams( } satisfies PathParamInfo<"/thread/:id">; } + if ( + new RegExp($path("/workflows/:workflow", { workflow: "(.*)" })).test( + pathname + ) + ) { + return { + path: "/workflows/:workflow", + pathParams: $params("/workflows/:workflow", params), + } satisfies PathParamInfo<"/workflows/:workflow">; + } + return {}; } diff --git a/ui/admin/app/routes/_auth.agents.$agent.tsx b/ui/admin/app/routes/_auth.agents.$agent.tsx index 675bca48..b6480139 100644 --- a/ui/admin/app/routes/_auth.agents.$agent.tsx +++ b/ui/admin/app/routes/_auth.agents.$agent.tsx @@ -9,7 +9,7 @@ import { $path } from "remix-routes"; import { z } from "zod"; import { AgentService } from "~/lib/service/api/agentService"; -import { RouteService } from "~/lib/service/routeQueryParams"; +import { RouteService } from "~/lib/service/routeService"; import { noop } from "~/lib/utils"; import { Agent } from "~/components/agent"; diff --git a/ui/admin/app/routes/_auth.models.tsx b/ui/admin/app/routes/_auth.models.tsx new file mode 100644 index 00000000..62aaf006 --- /dev/null +++ b/ui/admin/app/routes/_auth.models.tsx @@ -0,0 +1,131 @@ +import { ColumnDef, createColumnHelper } from "@tanstack/react-table"; +import { PenSquareIcon } from "lucide-react"; +import { useMemo, useState } from "react"; +import useSWR, { preload } from "swr"; + +import { Model } from "~/lib/model/models"; +import { ModelApiService } from "~/lib/service/api/modelApiService"; + +import { TypographyH2, TypographySmall } from "~/components/Typography"; +import { DataTable } from "~/components/composed/DataTable"; +import { CreateModel } from "~/components/model/CreateModel"; +import { DeleteModel } from "~/components/model/DeleteModel"; +import { UpdateModelDialog } from "~/components/model/UpdateModel"; +import { Button } from "~/components/ui/button"; +import { + Tooltip, + TooltipContent, + TooltipTrigger, +} from "~/components/ui/tooltip"; + +export async function clientLoader() { + await Promise.all([ + preload(ModelApiService.getModels.key(), ModelApiService.getModels), + preload( + ModelApiService.getModelProviders.key(true), + ({ onlyConfigured }) => + ModelApiService.getModelProviders(onlyConfigured) + ), + ]); + return null; +} + +export default function Models() { + const [modelToEdit, setModelToEdit] = useState(null); + + const { data } = useSWR( + ModelApiService.getModels.key(), + ModelApiService.getModels + ); + + const { data: providers } = useSWR( + ModelApiService.getModelProviders.key(true), + ({ onlyConfigured }) => + ModelApiService.getModelProviders(onlyConfigured) + ); + + const providerMap = useMemo(() => { + if (!providers) return {}; + return providers?.reduce( + (acc, provider) => { + acc[provider.id] = provider.name; + return acc; + }, + {} as Record + ); + }, [providers]); + + return ( +
+
+ Models + +
+ + + + setModelToEdit(open ? modelToEdit : null)} + /> +
+ ); + + function getColumns(): ColumnDef[] { + return [ + columnHelper.accessor((model) => model.name ?? model.id, { + id: "id", + header: "Model", + }), + columnHelper.accessor( + (model) => + providerMap[model.modelProvider] ?? model.modelProvider, + { + id: "provider", + header: "Provider", + } + ), + columnHelper.display({ + id: "default", + cell: ({ row }) => { + const value = row.original.default; + + if (!value) return null; + + return ( + +
+ Default + + ); + }, + }), + columnHelper.display({ + id: "actions", + cell: ({ row }) => ( +
+ + + + + + Update Model + + + +
+ ), + }), + ]; + } +} + +const columnHelper = createColumnHelper(); diff --git a/ui/admin/app/routes/_auth.thread.$id.tsx b/ui/admin/app/routes/_auth.thread.$id.tsx index 35475fa2..3769214b 100644 --- a/ui/admin/app/routes/_auth.thread.$id.tsx +++ b/ui/admin/app/routes/_auth.thread.$id.tsx @@ -9,7 +9,7 @@ import { ArrowLeftIcon } from "lucide-react"; import { AgentService } from "~/lib/service/api/agentService"; import { ThreadsService } from "~/lib/service/api/threadsService"; import { WorkflowService } from "~/lib/service/api/workflowService"; -import { RouteService } from "~/lib/service/routeQueryParams"; +import { RouteService } from "~/lib/service/routeService"; import { noop } from "~/lib/utils"; import { Chat } from "~/components/chat"; diff --git a/ui/admin/app/routes/_auth.threads.tsx b/ui/admin/app/routes/_auth.threads.tsx index fa17720d..92b1fcc5 100644 --- a/ui/admin/app/routes/_auth.threads.tsx +++ b/ui/admin/app/routes/_auth.threads.tsx @@ -18,7 +18,7 @@ import { Workflow } from "~/lib/model/workflows"; import { AgentService } from "~/lib/service/api/agentService"; import { ThreadsService } from "~/lib/service/api/threadsService"; import { WorkflowService } from "~/lib/service/api/workflowService"; -import { RouteService } from "~/lib/service/routeQueryParams"; +import { RouteService } from "~/lib/service/routeService"; import { timeSince } from "~/lib/utils"; import { TypographyH2, TypographyP } from "~/components/Typography";