From 235c9eff091931d3fe76ed00cfcdbcb2cf3c0816 Mon Sep 17 00:00:00 2001 From: Ryan Hopper-Lowe Date: Thu, 7 Nov 2024 17:49:46 -0600 Subject: [PATCH] UI feat: Create crud interface for models --- .../app/components/composed/DataTable.tsx | 41 +++-- .../app/components/form/controlledInputs.tsx | 105 +++++++++++- ui/admin/app/components/model/CreateModel.tsx | 33 ++++ ui/admin/app/components/model/DeleteModel.tsx | 53 ++++++ ui/admin/app/components/model/ModelForm.tsx | 161 ++++++++++++++++++ ui/admin/app/components/model/UpdateModel.tsx | 60 +++++++ ui/admin/app/components/sidebar/Sidebar.tsx | 6 + ui/admin/app/components/ui/button.tsx | 7 +- ui/admin/app/hooks/useAsync.tsx | 14 +- ui/admin/app/lib/model/models.ts | 36 ++++ ui/admin/app/lib/model/toolReferences.ts | 2 +- ui/admin/app/lib/routers/apiRoutes.ts | 8 + ui/admin/app/lib/service/api/apiErrors.tsx | 9 +- .../app/lib/service/api/modelApiService.ts | 87 ++++++++++ ui/admin/app/lib/service/api/primitives.ts | 5 + ui/admin/app/routes/_auth.models.tsx | 111 ++++++++++++ 16 files changed, 707 insertions(+), 31 deletions(-) create mode 100644 ui/admin/app/components/model/CreateModel.tsx create mode 100644 ui/admin/app/components/model/DeleteModel.tsx create mode 100644 ui/admin/app/components/model/ModelForm.tsx create mode 100644 ui/admin/app/components/model/UpdateModel.tsx create mode 100644 ui/admin/app/lib/model/models.ts create mode 100644 ui/admin/app/lib/service/api/modelApiService.ts create mode 100644 ui/admin/app/routes/_auth.models.tsx diff --git a/ui/admin/app/components/composed/DataTable.tsx b/ui/admin/app/components/composed/DataTable.tsx index 62e00ce38..f3d5c73a0 100644 --- a/ui/admin/app/components/composed/DataTable.tsx +++ b/ui/admin/app/components/composed/DataTable.tsx @@ -83,28 +83,7 @@ export function DataTable({ rowClassName?.(row.original) )} > - {row.getVisibleCells().map((cell) => ( - { - if ( - !disableClickPropagation?.(cell) - ) { - onRowClick?.(row.original); - } - }} - > - {flexRender( - cell.column.columnDef.cell, - cell.getContext() - )} - - ))} + {row.getVisibleCells().map(renderCell)} )) ) : ( @@ -124,4 +103,22 @@ export function DataTable({ ); + + function renderCell(cell: Cell) { + return ( + { + if (!disableClickPropagation?.(cell)) { + onRowClick?.(cell.row.original); + } + }} + > + {flexRender(cell.column.columnDef.cell, cell.getContext())} + + ); + } } diff --git a/ui/admin/app/components/form/controlledInputs.tsx b/ui/admin/app/components/form/controlledInputs.tsx index f39e34567..8364a2999 100644 --- a/ui/admin/app/components/form/controlledInputs.tsx +++ b/ui/admin/app/components/form/controlledInputs.tsx @@ -1,14 +1,16 @@ -import { ReactNode } from "react"; +import { ComponentProps, ReactNode } from "react"; import { Control, ControllerFieldState, ControllerRenderProps, FieldPath, FieldValues, + FormState, } from "react-hook-form"; import { cn } from "~/lib/utils"; +import { Checkbox } from "~/components/ui/checkbox"; import { FormControl, FormDescription, @@ -180,6 +182,107 @@ export function ControlledAutosizeTextarea< ); } +export type ControlledCheckboxProps< + TValues extends FieldValues, + TName extends FieldPath, +> = BaseProps & ComponentProps; + +export function ControlledCheckbox< + TValues extends FieldValues, + TName extends FieldPath, +>({ + control, + name, + label, + description, + onCheckedChange, + ...checkboxProps +}: ControlledCheckboxProps) { + return ( + ( + +
+ + { + field.onChange(value); + onCheckedChange?.(value); + }} + className={cn( + getFieldStateClasses(fieldState), + checkboxProps.className + )} + /> + + + {label && {label}} +
+ + + + {description && ( + {description} + )} +
+ )} + /> + ); +} + +export type ControlledCustomInputProps< + TValues extends FieldValues, + TName extends FieldPath, +> = BaseProps & { + children: (props: { + field: ControllerRenderProps; + fieldState: ControllerFieldState; + formState: FormState; + className?: string; + }) => ReactNode; +}; + +export function ControlledCustomInput< + TValues extends FieldValues, + TName extends FieldPath, +>({ + control, + name, + label, + description, + children, +}: ControlledCustomInputProps) { + return ( + ( + + {label && {label}} + + + {children({ + ...args, + className: getFieldStateClasses(args.fieldState), + })} + + + + + {description && ( + {description} + )} + + )} + /> + ); +} + function getFieldStateClasses(fieldState: ControllerFieldState) { return cn({ "focus-visible:ring-destructive border-destructive": fieldState.invalid, diff --git a/ui/admin/app/components/model/CreateModel.tsx b/ui/admin/app/components/model/CreateModel.tsx new file mode 100644 index 000000000..13980b483 --- /dev/null +++ b/ui/admin/app/components/model/CreateModel.tsx @@ -0,0 +1,33 @@ +import { PlusIcon } from "lucide-react"; +import { useState } from "react"; + +import { ModelForm } from "~/components/model/ModelForm"; +import { Button } from "~/components/ui/button"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogTitle, + DialogTrigger, +} from "~/components/ui/dialog"; + +export function CreateModel() { + const [open, setOpen] = useState(false); + + return ( + + + + + + + Create Model + + + setOpen(false)} /> + + + ); +} diff --git a/ui/admin/app/components/model/DeleteModel.tsx b/ui/admin/app/components/model/DeleteModel.tsx new file mode 100644 index 000000000..dd330822f --- /dev/null +++ b/ui/admin/app/components/model/DeleteModel.tsx @@ -0,0 +1,53 @@ +import { TrashIcon } from "lucide-react"; +import { mutate } from "swr"; + +import { ModelApiService } from "~/lib/service/api/modelApiService"; + +import { ConfirmationDialog } from "~/components/composed/ConfirmationDialog"; +import { Button } from "~/components/ui/button"; +import { + Tooltip, + TooltipContent, + TooltipProvider, + TooltipTrigger, +} from "~/components/ui/tooltip"; +import { useAsync } from "~/hooks/useAsync"; + +type DeleteModelProps = { + id: string; +}; + +export function DeleteModel(props: DeleteModelProps) { + const deleteModel = useAsync(ModelApiService.deleteModel, { + onSuccess: () => mutate(ModelApiService.getModels.key()), + }); + + return ( + + + deleteModel.execute(props.id)} + confirmProps={{ + variant: "destructive", + children: "Delete", + }} + > + + + + + + Delete Model + + + ); +} diff --git a/ui/admin/app/components/model/ModelForm.tsx b/ui/admin/app/components/model/ModelForm.tsx new file mode 100644 index 000000000..1f9c135ff --- /dev/null +++ b/ui/admin/app/components/model/ModelForm.tsx @@ -0,0 +1,161 @@ +import { zodResolver } from "@hookform/resolvers/zod"; +import { useMemo } from "react"; +import { useForm } from "react-hook-form"; +import { toast } from "sonner"; +import useSWR, { mutate } from "swr"; +import { z } from "zod"; + +import { Model, ModelManifest, ModelManifestSchema } from "~/lib/model/models"; +import { BadRequestError } from "~/lib/service/api/apiErrors"; +import { ModelApiService } from "~/lib/service/api/modelApiService"; + +import { + ControlledCheckbox, + ControlledCustomInput, + ControlledInput, +} from "~/components/form/controlledInputs"; +import { Button } from "~/components/ui/button"; +import { Form } from "~/components/ui/form"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "~/components/ui/select"; +import { useAsync } from "~/hooks/useAsync"; + +type ModelFormProps = { + model?: Model; + onSubmit: (model: ModelManifest) => void; +}; + +type FormValues = z.infer; + +export function ModelForm(props: ModelFormProps) { + const { model, onSubmit } = props; + + const { data: modelProviders } = useSWR( + ModelApiService.getModelProviders.key(true), + ({ onlyConfigured }) => + ModelApiService.getModelProviders(onlyConfigured) + ); + + const updateModel = useAsync(ModelApiService.updateModel, { + onSuccess: (values) => { + toast.success("Model updated"); + mutate(ModelApiService.getModels.key()); + onSubmit(values); + }, + onError, + }); + + const createModel = useAsync(ModelApiService.createModel, { + onSuccess: (values) => { + toast.success("Model created"); + mutate(ModelApiService.getModels.key()); + onSubmit(values); + }, + onError, + }); + + const defaultValues = useMemo(() => { + return { + name: model?.name ?? "", + targetModel: model?.targetModel ?? "", + modelProvider: model?.modelProvider ?? "", + active: model?.active ?? true, + default: model?.default ?? false, + }; + }, [model]); + + const form = useForm({ + resolver: zodResolver(ModelManifestSchema), + defaultValues, + }); + + const { loading, submit } = getSubmitInfo(); + + const handleSubmit = form.handleSubmit(submit); + + return ( +
+ + + + + {({ field: { ref: _, ...field }, className }) => ( + + )} + + + + + + + + + + ); + + function getSubmitInfo() { + if (model) { + return { + isEdit: true, + loading: updateModel.isLoading, + submit: (values: FormValues) => + updateModel.execute(model.id, values), + }; + } + + return { + isEdit: false, + loading: createModel.isLoading, + submit: (values: FormValues) => createModel.execute(values), + }; + } + + function onError(error: unknown) { + if (error instanceof BadRequestError) + form.setError("default", { message: error.message }); + } +} diff --git a/ui/admin/app/components/model/UpdateModel.tsx b/ui/admin/app/components/model/UpdateModel.tsx new file mode 100644 index 000000000..4617093b5 --- /dev/null +++ b/ui/admin/app/components/model/UpdateModel.tsx @@ -0,0 +1,60 @@ +import { PenSquareIcon } from "lucide-react"; +import { useState } from "react"; + +import { Model } from "~/lib/model/models"; + +import { ModelForm } from "~/components/model/ModelForm"; +import { Button } from "~/components/ui/button"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogTitle, + DialogTrigger, +} from "~/components/ui/dialog"; +import { + Tooltip, + TooltipContent, + TooltipProvider, + TooltipTrigger, +} from "~/components/ui/tooltip"; + +type UpdateModelProps = { + model: Model; +}; + +export function UpdateModel(props: UpdateModelProps) { + const { model } = props; + const [open, setOpen] = useState(false); + + return ( + + + + + Update model + + + + setOpen(false)} + /> + + + + + + + + + + Update Model + + + ); +} diff --git a/ui/admin/app/components/sidebar/Sidebar.tsx b/ui/admin/app/components/sidebar/Sidebar.tsx index 78024a9ad..131049f69 100644 --- a/ui/admin/app/components/sidebar/Sidebar.tsx +++ b/ui/admin/app/components/sidebar/Sidebar.tsx @@ -1,6 +1,7 @@ import { Link } from "@remix-run/react"; import { BotIcon, + BrainIcon, KeyIcon, MessageSquare, SettingsIcon, @@ -59,6 +60,11 @@ const items = [ url: $path("/oauth-apps"), icon: KeyIcon, }, + { + title: "Models", + url: $path("/models"), + icon: BrainIcon, + }, ]; export function AppSidebar() { diff --git a/ui/admin/app/components/ui/button.tsx b/ui/admin/app/components/ui/button.tsx index 1d6cd0708..79dcf3e8e 100644 --- a/ui/admin/app/components/ui/button.tsx +++ b/ui/admin/app/components/ui/button.tsx @@ -40,6 +40,7 @@ export interface ButtonProps VariantProps { asChild?: boolean; loading?: boolean; + startContent?: React.ReactNode; } const Button = React.forwardRef( @@ -50,6 +51,7 @@ const Button = React.forwardRef( size, asChild = false, loading = false, + startContent, children, ...props }, @@ -77,7 +79,10 @@ const Button = React.forwardRef( {children} ) : ( - children +
+ {startContent} + {children} +
); } } diff --git a/ui/admin/app/hooks/useAsync.tsx b/ui/admin/app/hooks/useAsync.tsx index f7904174e..d02461073 100644 --- a/ui/admin/app/hooks/useAsync.tsx +++ b/ui/admin/app/hooks/useAsync.tsx @@ -1,28 +1,34 @@ -import { AxiosError } from "axios"; import { useCallback, useState } from "react"; +import { BoundaryError } from "~/lib/service/api/apiErrors"; import { handlePromise } from "~/lib/service/async"; type Config = { onSuccess?: (data: TData, params: TParams) => void; onError?: (error: unknown, params: TParams) => void; onSettled?: ({ params }: { params: TParams }) => void; + shouldThrow?: (error: unknown) => boolean; }; -const defaultShouldThrow = (error: unknown) => !(error instanceof AxiosError); +const defaultShouldThrow = (error: unknown) => error instanceof BoundaryError; export function useAsync( callback: (...params: TParams) => Promise, config?: Config ) { - const { onSuccess, onError, onSettled } = config || {}; + const { + onSuccess, + onError, + onSettled, + shouldThrow = defaultShouldThrow, + } = config || {}; const [data, setData] = useState(null); const [error, setError] = useState(null); const [isLoading, setIsLoading] = useState(false); const [lastCallParams, setLastCallParams] = useState(null); - if (error && defaultShouldThrow(error)) throw error; + if (error && shouldThrow(error)) throw error; const executeAsync = useCallback( async (...params: TParams) => { diff --git a/ui/admin/app/lib/model/models.ts b/ui/admin/app/lib/model/models.ts new file mode 100644 index 000000000..f4b96fc4e --- /dev/null +++ b/ui/admin/app/lib/model/models.ts @@ -0,0 +1,36 @@ +import { z } from "zod"; + +import { EntityMeta } from "~/lib/model/primitives"; + +export type ModelManifest = { + name?: string; + targetModel?: string; + modelProvider: string; + active: boolean; + default: boolean; +}; + +export type ModelProviderStatus = { + configured: boolean; + missingEnvVars?: string[]; +}; + +export type Model = EntityMeta & ModelManifest & ModelProviderStatus; + +export const ModelManifestSchema = z.object({ + name: z.string(), + targetModel: z.string().min(1, "Required"), + modelProvider: z.string().min(1, "Required"), + active: z.boolean(), + default: z.boolean(), +}); + +export type ModelProvider = EntityMeta & { + description?: string; + builtin: boolean; + active: boolean; + modelProviderStatus: ModelProviderStatus; + name: string; + reference: string; + toolType: "modelProvider"; +}; diff --git a/ui/admin/app/lib/model/toolReferences.ts b/ui/admin/app/lib/model/toolReferences.ts index 5e2286b90..9bbfffa76 100644 --- a/ui/admin/app/lib/model/toolReferences.ts +++ b/ui/admin/app/lib/model/toolReferences.ts @@ -8,7 +8,7 @@ export type ToolReferenceBase = { metadata?: Record; }; -export type ToolReferenceType = "tool" | "stepTemplate"; +export type ToolReferenceType = "tool" | "stepTemplate" | "modelProvider"; export type ToolReference = { error: string; diff --git a/ui/admin/app/lib/routers/apiRoutes.ts b/ui/admin/app/lib/routers/apiRoutes.ts index 712260318..84eb740da 100644 --- a/ui/admin/app/lib/routers/apiRoutes.ts +++ b/ui/admin/app/lib/routers/apiRoutes.ts @@ -129,6 +129,14 @@ export const ApiRoutes = { supportedOauthAppTypes: () => buildUrl("/supported-oauth-app-types"), supportedAuthTypes: () => buildUrl("/supported-auth-types"), }, + models: { + base: () => buildUrl("/models"), + getModels: () => buildUrl("/models"), + getModelById: (modelId: string) => buildUrl(`/models/${modelId}`), + createModel: () => buildUrl(`/models`), + updateModel: (modelId: string) => buildUrl(`/models/${modelId}`), + deleteModel: (modelId: string) => buildUrl(`/models/${modelId}`), + }, }; /** revalidates the cache for all routes that match the filter callback diff --git a/ui/admin/app/lib/service/api/apiErrors.tsx b/ui/admin/app/lib/service/api/apiErrors.tsx index 943f5d471..761fef773 100644 --- a/ui/admin/app/lib/service/api/apiErrors.tsx +++ b/ui/admin/app/lib/service/api/apiErrors.tsx @@ -1,3 +1,8 @@ export class ConflictError extends Error {} -export class ForbiddenError extends Error {} -export class UnauthorizedError extends Error {} +export class BadRequestError extends Error {} + +// Errors that should trigger the error boundary +export class BoundaryError extends Error {} + +export class ForbiddenError extends BoundaryError {} +export class UnauthorizedError extends BoundaryError {} diff --git a/ui/admin/app/lib/service/api/modelApiService.ts b/ui/admin/app/lib/service/api/modelApiService.ts new file mode 100644 index 000000000..a6ec868e7 --- /dev/null +++ b/ui/admin/app/lib/service/api/modelApiService.ts @@ -0,0 +1,87 @@ +import { Model, ModelManifest, ModelProvider } from "~/lib/model/models"; +import { ApiRoutes } from "~/lib/routers/apiRoutes"; +import { request } from "~/lib/service/api/primitives"; + +async function getModels() { + const { data } = await request<{ items?: Model[] }>({ + url: ApiRoutes.models.getModels().url, + }); + + return data.items ?? []; +} +getModels.key = () => ({ url: ApiRoutes.models.getModels().path }); + +async function getModelById(modelId: string) { + const { data } = await request({ + url: ApiRoutes.models.getModelById(modelId).url, + }); + + return data; +} +getModelById.key = (modelId?: string) => { + if (!modelId) return null; + + return { + url: ApiRoutes.models.getModelById(modelId).path, + modelId, + }; +}; + +async function getModelProviders(onlyConfigured = false) { + const { data } = await request<{ items?: ModelProvider[] }>({ + url: ApiRoutes.toolReferences.base({ type: "modelProvider" }).url, + }); + + if (onlyConfigured) { + return ( + data.items?.filter( + (provider) => provider.modelProviderStatus.configured + ) ?? [] + ); + } + + return data.items ?? []; +} +getModelProviders.key = (onlyConfigured = false) => ({ + url: ApiRoutes.toolReferences.base({ type: "modelProvider" }).path, + onlyConfigured, +}); + +async function createModel(manifest: ModelManifest) { + await new Promise((resolve) => setTimeout(resolve, 1000)); + const { data } = await request({ + url: ApiRoutes.models.createModel().url, + method: "POST", + data: manifest, + }); + + return data; +} + +async function updateModel(modelId: string, manifest: ModelManifest) { + await new Promise((resolve) => setTimeout(resolve, 1000)); + + const { data } = await request({ + url: ApiRoutes.models.updateModel(modelId).url, + method: "PUT", + data: manifest, + }); + + return data; +} + +async function deleteModel(modelId: string) { + await request({ + url: ApiRoutes.models.deleteModel(modelId).url, + method: "DELETE", + }); +} + +export const ModelApiService = { + getModels, + getModelById, + getModelProviders, + createModel, + updateModel, + deleteModel, +}; diff --git a/ui/admin/app/lib/service/api/primitives.ts b/ui/admin/app/lib/service/api/primitives.ts index 70f7f89a7..2e5d4f318 100644 --- a/ui/admin/app/lib/service/api/primitives.ts +++ b/ui/admin/app/lib/service/api/primitives.ts @@ -5,6 +5,7 @@ import { AuthDisabledUsername } from "~/lib/model/auth"; import { User } from "~/lib/model/users"; import { ApiRoutes } from "~/lib/routers/apiRoutes"; import { + BadRequestError, ConflictError, ForbiddenError, UnauthorizedError, @@ -35,6 +36,10 @@ export async function request, D = unknown>({ } catch (error) { console.error(errorMessage); + if (isAxiosError(error) && error.response?.status === 400) { + throw new BadRequestError(error.response.data); + } + if (isAxiosError(error) && error.response?.status === 401) { throw new UnauthorizedError(error.response.data); } diff --git a/ui/admin/app/routes/_auth.models.tsx b/ui/admin/app/routes/_auth.models.tsx new file mode 100644 index 000000000..6a0291138 --- /dev/null +++ b/ui/admin/app/routes/_auth.models.tsx @@ -0,0 +1,111 @@ +import { ColumnDef, createColumnHelper } from "@tanstack/react-table"; +import { useMemo } from "react"; +import useSWR, { preload } from "swr"; + +import { Model } from "~/lib/model/models"; +import { ModelApiService } from "~/lib/service/api/modelApiService"; + +import { TypographyH2, TypographySmall } from "~/components/Typography"; +import { DataTable } from "~/components/composed/DataTable"; +import { CreateModel } from "~/components/model/CreateModel"; +import { DeleteModel } from "~/components/model/DeleteModel"; +import { UpdateModel } from "~/components/model/UpdateModel"; + +export async function clientLoader() { + await Promise.all([ + preload(ModelApiService.getModels.key(), ModelApiService.getModels), + preload( + ModelApiService.getModelProviders.key(true), + ({ onlyConfigured }) => + ModelApiService.getModelProviders(onlyConfigured) + ), + ]); + return null; +} + +export default function Models() { + const { data } = useSWR( + ModelApiService.getModels.key(), + ModelApiService.getModels + ); + + const { data: providers } = useSWR( + ModelApiService.getModelProviders.key(true), + ({ onlyConfigured }) => + ModelApiService.getModelProviders(onlyConfigured) + ); + + const providerMap = useMemo(() => { + if (!providers) return {}; + return providers?.reduce( + (acc, provider) => { + acc[provider.id] = provider.name; + return acc; + }, + {} as Record + ); + }, [providers]); + + return ( +
+
+ Models + +
+ + cell.id.includes("actions")} + /> +
+ ); + + function getColumns(): ColumnDef[] { + return [ + columnHelper.accessor((model) => model.name ?? model.id, { + id: "id", + header: "Model", + }), + columnHelper.accessor( + (model) => + providerMap[model.modelProvider] ?? model.modelProvider, + { + id: "provider", + header: "Provider", + } + ), + columnHelper.display({ + id: "default", + cell: ({ row }) => { + const value = row.original.default; + + if (!value) return null; + + return ( + +
+ Default + + ); + }, + }), + columnHelper.display({ + id: "actions", + cell: ({ row }) => ( +
+ + +
+ ), + }), + ]; + } +} + +const columnHelper = createColumnHelper();