Skip to content

Commit

Permalink
Merge pull request #138 from supabase-community/feat/byo-llm
Browse files Browse the repository at this point in the history
feat: bring your own llm
  • Loading branch information
gregnr authored Nov 28, 2024
2 parents 3e31c4e + 66dac4f commit 06b4210
Show file tree
Hide file tree
Showing 27 changed files with 1,837 additions and 825 deletions.
1 change: 1 addition & 0 deletions apps/postgres-new/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
public/sw.mjs
51 changes: 5 additions & 46 deletions apps/postgres-new/app/api/chat/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ import { createOpenAI } from '@ai-sdk/openai'
import { Ratelimit } from '@upstash/ratelimit'
import { kv } from '@vercel/kv'
import { convertToCoreMessages, streamText, ToolInvocation, ToolResultPart } from 'ai'
import { codeBlock } from 'common-tags'
import { convertToCoreTools, maxMessageContext, maxRowLimit, tools } from '~/lib/tools'
import { getSystemPrompt } from '~/lib/system-prompt'
import { convertToCoreTools, maxMessageContext, tools } from '~/lib/tools'
import { createClient } from '~/utils/supabase/server'
import { ChatInferenceEventToolResult, logEvent } from '~/utils/telemetry'

Expand Down Expand Up @@ -72,49 +72,8 @@ export async function POST(req: Request) {
const coreMessages = convertToCoreMessages(trimmedMessageContext)
const coreTools = convertToCoreTools(tools)

const result = await streamText({
system: codeBlock`
You are a helpful database assistant. Under the hood you have access to an in-browser Postgres database called PGlite (https://github.com/electric-sql/pglite).
Some special notes about this database:
- foreign data wrappers are not supported
- the following extensions are available:
- plpgsql [pre-enabled]
- vector (https://github.com/pgvector/pgvector) [pre-enabled]
- use <=> for cosine distance (default to this)
- use <#> for negative inner product
- use <-> for L2 distance
- use <+> for L1 distance
- note queried vectors will be truncated/redacted due to their size - export as CSV if the full vector is required
When generating tables, do the following:
- For primary keys, always use "id bigint primary key generated always as identity" (not serial)
- Prefer 'text' over 'varchar'
- Keep explanations brief but helpful
- Don't repeat yourself after creating the table
When creating sample data:
- Make the data realistic, including joined data
- Check for existing records/conflicts in the table
When querying data, limit to 5 by default. The maximum number of rows you're allowed to fetch is ${maxRowLimit} (to protect AI from token abuse).
If the user needs to fetch more than ${maxRowLimit} rows at once, they can export the query as a CSV.
When performing FTS, always use 'simple' (languages aren't available).
When importing CSVs try to solve the problem yourself (eg. use a generic text column, then refine)
vs. asking the user to change the CSV. No need to select rows after importing.
You also know math. All math equations and expressions must be written in KaTex and must be wrapped in double dollar \`$$\`:
- Inline: $$\\sqrt{26}$$
- Multiline:
$$
\\sqrt{26}
$$
No images are allowed. Do not try to generate or link images, including base64 data URLs.
Feel free to suggest corrections for suspected typos.
`,
const result = streamText({
system: getSystemPrompt(),
model: openai(chatModel),
messages: coreMessages,
tools: coreTools,
Expand Down Expand Up @@ -158,7 +117,7 @@ export async function POST(req: Request) {
},
})

return result.toAIStreamResponse()
return result.toDataStreamResponse()
}

function getEventToolResult(toolResult: ToolResultPart): ChatInferenceEventToolResult | undefined {
Expand Down
16 changes: 16 additions & 0 deletions apps/postgres-new/components/app-provider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import {
import { legacyDomainHostname } from '~/lib/util'
import { parse, serialize } from '~/lib/websocket-protocol'
import { createClient } from '~/utils/supabase/client'
import { useModelProvider } from './model-provider/use-model-provider'

export type AppProps = PropsWithChildren

Expand Down Expand Up @@ -252,6 +253,9 @@ export default function AppProvider({ children }: AppProps) {
const [isLegacyDomain, setIsLegacyDomain] = useState(false)
const [isLegacyDomainRedirect, setIsLegacyDomainRedirect] = useState(false)

const [modelProviderError, setModelProviderError] = useState<string>()
const [isModelProviderDialogOpen, setIsModelProviderDialogOpen] = useState(false)

useEffect(() => {
const isLegacyDomain = window.location.hostname === legacyDomainHostname
const urlParams = new URLSearchParams(window.location.search)
Expand All @@ -263,12 +267,17 @@ export default function AppProvider({ children }: AppProps) {
setIsRenameDialogOpen(isLegacyDomain || isLegacyDomainRedirect)
}, [])

const modelProvider = useModelProvider()

return (
<AppContext.Provider
value={{
user,
isLoadingUser,
liveShare,
modelProvider,
modelProviderError,
setModelProviderError,
signIn,
signOut,
isSignInDialogOpen,
Expand All @@ -277,6 +286,8 @@ export default function AppProvider({ children }: AppProps) {
setIsRenameDialogOpen,
isRateLimited,
setIsRateLimited,
isModelProviderDialogOpen,
setIsModelProviderDialogOpen,
focusRef,
dbManager,
pgliteVersion,
Expand Down Expand Up @@ -305,6 +316,8 @@ export type AppContextValues = {
setIsRenameDialogOpen: (open: boolean) => void
isRateLimited: boolean
setIsRateLimited: (limited: boolean) => void
isModelProviderDialogOpen: boolean
setIsModelProviderDialogOpen: (open: boolean) => void
focusRef: RefObject<FocusHandle>
dbManager?: DbManager
pgliteVersion?: string
Expand All @@ -316,6 +329,9 @@ export type AppContextValues = {
clientIp: string | null
isLiveSharing: boolean
}
modelProvider: ReturnType<typeof useModelProvider>
modelProviderError?: string
setModelProviderError: (error: string | undefined) => void
isLegacyDomain: boolean
isLegacyDomainRedirect: boolean
}
Expand Down
24 changes: 24 additions & 0 deletions apps/postgres-new/components/byo-llm-button.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import { Brain } from 'lucide-react'
import { useApp } from '~/components/app-provider'
import { Button } from '~/components/ui/button'

export type ByoLlmButtonProps = {
onClick?: () => void
}

export default function ByoLlmButton({ onClick }: ByoLlmButtonProps) {
const { setIsModelProviderDialogOpen } = useApp()

return (
<Button
className="gap-2 text-base"
onClick={() => {
onClick?.()
setIsModelProviderDialogOpen(true)
}}
>
<Brain size={18} strokeWidth={2} />
Bring your own LLM
</Button>
)
}
76 changes: 61 additions & 15 deletions apps/postgres-new/components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import { Message, generateId } from 'ai'
import { useChat } from 'ai/react'
import { AnimatePresence, m } from 'framer-motion'
import { ArrowDown, ArrowUp, Flame, Paperclip, PlugIcon, Square } from 'lucide-react'
import { AlertCircle, ArrowDown, ArrowUp, Flame, Paperclip, PlugIcon, Square } from 'lucide-react'
import {
FormEventHandler,
useCallback,
Expand All @@ -22,6 +22,7 @@ import { requestFileUpload } from '~/lib/util'
import { cn } from '~/lib/utils'
import { AiIconAnimation } from './ai-icon-animation'
import { useApp } from './app-provider'
import ByoLlmButton from './byo-llm-button'
import ChatMessage from './chat-message'
import { CopyableField } from './copyable-field'
import SignInButton from './sign-in-button'
Expand Down Expand Up @@ -51,8 +52,17 @@ export function getInitialMessages(tables: TablesData): Message[] {
}

export default function Chat() {
const { user, isLoadingUser, focusRef, setIsSignInDialogOpen, isRateLimited, liveShare } =
useApp()
const {
user,
isLoadingUser,
focusRef,
setIsSignInDialogOpen,
isRateLimited,
liveShare,
modelProvider,
modelProviderError,
setIsModelProviderDialogOpen,
} = useApp()
const [inputFocusState, setInputFocusState] = useState(false)

const {
Expand Down Expand Up @@ -155,7 +165,7 @@ export default function Chat() {
cursor: dropZoneCursor,
} = useDropZone({
async onDrop(files) {
if (!user) {
if (isAuthRequired) {
return
}

Expand Down Expand Up @@ -223,8 +233,10 @@ export default function Chat() {

const [isMessageAnimationComplete, setIsMessageAnimationComplete] = useState(false)

const isAuthRequired = user === undefined && modelProvider.state?.enabled !== true

const isChatEnabled =
!isLoadingMessages && !isLoadingSchema && user !== undefined && !liveShare.isLiveSharing
!isLoadingMessages && !isLoadingSchema && !isAuthRequired && !liveShare.isLiveSharing

const isSubmitEnabled = isChatEnabled && Boolean(input.trim())

Expand Down Expand Up @@ -293,6 +305,42 @@ export default function Chat() {
isLast={i === messages.length - 1}
/>
))}
<AnimatePresence initial={false}>
{modelProviderError && !isLoading && (
<m.div
layout="position"
className="flex flex-col gap-4 justify-start items-center max-w-96 p-4 bg-destructive rounded-md text-sm"
variants={{
hidden: { scale: 0 },
show: { scale: 1, transition: { delay: 0.5 } },
}}
initial="hidden"
animate="show"
exit="hidden"
>
<AlertCircle size={64} strokeWidth={1} />
<div className="flex flex-col items-center text-start gap-4">
<h3 className="font-bold">Whoops!</h3>
<p className="text-center">
There was an error connecting to your custom model provider:{' '}
{modelProviderError}
</p>
<p>
Double check your{' '}
<a
className="underline cursor-pointer"
onClick={() => {
setIsModelProviderDialogOpen(true)
}}
>
API info
</a>
.
</p>
</div>
</m.div>
)}
</AnimatePresence>
<AnimatePresence initial={false}>
{isRateLimited && !isLoading && (
<m.div
Expand Down Expand Up @@ -357,7 +405,7 @@ export default function Chat() {
</div>
) : (
<div className="h-full w-full max-w-4xl flex flex-col gap-10 justify-center items-center">
{user ? (
{!isAuthRequired ? (
<>
<LiveShareOverlay databaseId={databaseId} />
<m.h3
Expand All @@ -384,11 +432,10 @@ export default function Chat() {
animate="show"
>
<SignInButton />
<p className="font-lighter text-center">
To prevent abuse we ask you to sign in before chatting with AI.
</p>
or
<ByoLlmButton />
<p
className="underline cursor-pointer text-primary/50"
className="underline cursor-pointer text-sm text-primary/50"
onClick={() => {
setIsSignInDialogOpen(true)
}}
Expand Down Expand Up @@ -427,7 +474,7 @@ export default function Chat() {
</div>
<div className="flex flex-col items-center gap-3 pb-1 relative">
<AnimatePresence>
{!user && !isLoadingUser && isConversationStarted && (
{isAuthRequired && !isLoadingUser && isConversationStarted && (
<m.div
className="flex flex-col items-center gap-4 max-w-lg my-4"
variants={{
Expand All @@ -438,9 +485,8 @@ export default function Chat() {
exit="hidden"
>
<SignInButton />
<p className="font-lighter text-center text-sm">
To prevent abuse we ask you to sign in before chatting with AI.
</p>
or
<ByoLlmButton />
<p
className="underline cursor-pointer text-sm text-primary/50"
onClick={() => {
Expand Down Expand Up @@ -487,7 +533,7 @@ export default function Chat() {
onClick={async (e) => {
e.preventDefault()

if (!user) {
if (isAuthRequired) {
return
}

Expand Down
Loading

0 comments on commit 06b4210

Please sign in to comment.