diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 0b1b4e6..2e88081 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -44,6 +44,22 @@ Steps for downloading and setting up AI Warp for local development. node ../dist/cli/start.js ``` +### Testing a model with OpenAI + +To test a remote model with with OpenAI, you can use the following to +download the model we used for testing: + +```json + "aiProvider": { + "openai": { + "model": "gpt-3.5-turbo", + "apiKey": "{PLT_OPENAI_API_KEY}" + } + } +``` + +Make sure to add your OpenAI api key as `PLT_OPENAI_API_KEY` in your `.env` file. + ### Testing a local model with llama2 To test a local model with with llama2, you can use the following to diff --git a/ai-providers/llama2.ts b/ai-providers/llama2.ts index 907b178..891638f 100644 --- a/ai-providers/llama2.ts +++ b/ai-providers/llama2.ts @@ -1,4 +1,5 @@ import { ReadableByteStreamController, ReadableStream, UnderlyingByteSource } from 'stream/web' +import { FastifyLoggerInstance } from 'fastify' import { LLamaChatPromptOptions, LlamaChatSession, @@ -60,13 +61,16 @@ class Llama2ByteSource implements UnderlyingByteSource { backloggedChunks: ChunkQueue = new ChunkQueue() finished: boolean = false controller?: ReadableByteStreamController + abortController: AbortController - constructor (session: LlamaChatSession, prompt: string, chunkCallback?: StreamChunkCallback) { + constructor (session: LlamaChatSession, prompt: string, logger: FastifyLoggerInstance, chunkCallback?: StreamChunkCallback) { this.session = session this.chunkCallback = chunkCallback + this.abortController = new AbortController() session.prompt(prompt, { - onToken: this.onToken + onToken: this.onToken, + signal: this.abortController.signal }).then(() => { this.finished = true // Don't close the stream if we still have chunks to send @@ -75,13 +79,21 @@ class Llama2ByteSource implements UnderlyingByteSource { } }).catch((err: any) => { this.finished = true - if (this.controller !== undefined) { - this.controller.close() + logger.info({ err }) + if (!this.abortController.signal.aborted && this.controller !== undefined) { + try { + this.controller.close() + } catch (err) { + logger.info({ err }) + } } - throw err }) } + cancel (): void { + this.abortController.abort() + } + onToken: LLamaChatPromptOptions['onToken'] = async (chunk) => { if (this.controller === undefined) { // Stream hasn't started yet, added it to the backlog queue @@ -89,8 +101,14 @@ class Llama2ByteSource implements UnderlyingByteSource { return } - await this.clearBacklog() - await this.enqueueChunk(chunk) + try { + await this.clearBacklog() + await this.enqueueChunk(chunk) + // Ignore all errors, we can't do anything about them + // TODO: Log these errors + } catch (err) { + console.error(err) + } } private async enqueueChunk (chunk: number[]): Promise { @@ -103,6 +121,10 @@ class Llama2ByteSource implements UnderlyingByteSource { response = await this.chunkCallback(response) } + if (response === '') { + response = '\n' // It seems empty chunks are newlines + } + const eventData: AiStreamEvent = { event: 'content', data: { @@ -139,14 +161,17 @@ class Llama2ByteSource implements UnderlyingByteSource { interface Llama2ProviderCtorOptions { modelPath: string + logger: FastifyLoggerInstance } export class Llama2Provider implements AiProvider { context: LlamaContext + logger: FastifyLoggerInstance - constructor ({ modelPath }: Llama2ProviderCtorOptions) { + constructor ({ modelPath, logger }: Llama2ProviderCtorOptions) { const model = new LlamaModel({ modelPath }) this.context = new LlamaContext({ model }) + this.logger = logger } async ask (prompt: string): Promise { @@ -159,6 +184,6 @@ export class Llama2Provider implements AiProvider { async askStream (prompt: string, chunkCallback?: StreamChunkCallback): Promise { const session = new LlamaChatSession({ context: this.context }) - return new ReadableStream(new Llama2ByteSource(session, prompt, chunkCallback)) + return new ReadableStream(new Llama2ByteSource(session, prompt, this.logger, chunkCallback)) } } diff --git a/plugins/warp.ts b/plugins/warp.ts index 6877e84..040b995 100644 --- a/plugins/warp.ts +++ b/plugins/warp.ts @@ -1,5 +1,6 @@ // eslint-disable-next-line /// +import { FastifyLoggerInstance } from 'fastify' import fastifyPlugin from 'fastify-plugin' import { OpenAiProvider } from '../ai-providers/open-ai.js' import { MistralProvider } from '../ai-providers/mistral.js' @@ -12,7 +13,7 @@ import { Llama2Provider } from '../ai-providers/llama2.js' const UnknownAiProviderError = createError('UNKNOWN_AI_PROVIDER', 'Unknown AI Provider') -function build (aiProvider: AiWarpConfig['aiProvider']): AiProvider { +function build (aiProvider: AiWarpConfig['aiProvider'], logger: FastifyLoggerInstance): AiProvider { if ('openai' in aiProvider) { return new OpenAiProvider(aiProvider.openai) } else if ('mistral' in aiProvider) { @@ -22,7 +23,10 @@ function build (aiProvider: AiWarpConfig['aiProvider']): AiProvider { } else if ('azure' in aiProvider) { return new AzureProvider(aiProvider.azure) } else if ('llama2' in aiProvider) { - return new Llama2Provider(aiProvider.llama2) + return new Llama2Provider({ + ...aiProvider.llama2, + logger + }) } else { throw new UnknownAiProviderError() } @@ -30,7 +34,7 @@ function build (aiProvider: AiWarpConfig['aiProvider']): AiProvider { export default fastifyPlugin(async (fastify) => { const { config } = fastify.platformatic - const provider = build(config.aiProvider) + const provider = build(config.aiProvider, fastify.log) fastify.decorate('ai', { warp: async (request, prompt) => { diff --git a/static/chat.html b/static/chat.html new file mode 100644 index 0000000..b61b4ef --- /dev/null +++ b/static/chat.html @@ -0,0 +1,34 @@ + + + + + + Chat - AI Warp + + + + + + + +
+ +
+ + +
+ + + + + + diff --git a/static/images/avatars/platformatic.svg b/static/images/avatars/platformatic.svg new file mode 100644 index 0000000..7d2aa8f --- /dev/null +++ b/static/images/avatars/platformatic.svg @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/static/images/avatars/you.svg b/static/images/avatars/you.svg new file mode 100644 index 0000000..a43cf22 --- /dev/null +++ b/static/images/avatars/you.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/static/images/icons/arrow-long-right.svg b/static/images/icons/arrow-long-right.svg new file mode 100644 index 0000000..8edddae --- /dev/null +++ b/static/images/icons/arrow-long-right.svg @@ -0,0 +1,3 @@ + + + diff --git a/static/images/icons/checkmark.svg b/static/images/icons/checkmark.svg new file mode 100644 index 0000000..a823d62 --- /dev/null +++ b/static/images/icons/checkmark.svg @@ -0,0 +1,4 @@ + + + + diff --git a/static/images/icons/copy.svg b/static/images/icons/copy.svg new file mode 100644 index 0000000..255a481 --- /dev/null +++ b/static/images/icons/copy.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/static/images/icons/edit.svg b/static/images/icons/edit.svg new file mode 100644 index 0000000..74bd085 --- /dev/null +++ b/static/images/icons/edit.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/static/images/icons/error.svg b/static/images/icons/error.svg new file mode 100755 index 0000000..00c2daf --- /dev/null +++ b/static/images/icons/error.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/static/images/icons/regenerate.svg b/static/images/icons/regenerate.svg new file mode 100644 index 0000000..b6a5104 --- /dev/null +++ b/static/images/icons/regenerate.svg @@ -0,0 +1,4 @@ + + + + diff --git a/static/images/main-illustration.svg b/static/images/main-illustration.svg new file mode 100644 index 0000000..0f2c8b8 --- /dev/null +++ b/static/images/main-illustration.svg @@ -0,0 +1,246 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/static/images/platformatic-logo.svg b/static/images/platformatic-logo.svg new file mode 100644 index 0000000..645b334 --- /dev/null +++ b/static/images/platformatic-logo.svg @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/static/index.html b/static/index.html index f967d67..8140579 100644 --- a/static/index.html +++ b/static/index.html @@ -4,83 +4,102 @@ AI Warp + + + -
- - - + -

- - diff --git a/static/scripts/chat.js b/static/scripts/chat.js new file mode 100644 index 0000000..c0bb4cb --- /dev/null +++ b/static/scripts/chat.js @@ -0,0 +1,296 @@ +// For temporarily disabling further prompts if we're already responding to one +let isAlreadyResponding = false +const promptInput = document.getElementById('prompt-input') +const promptButton = document.getElementById('prompt-button') +const messagesElement = document.getElementById('messages') + +/** + * List of completed messages to easily keep track of them instead of making + * calls to the DOM + * + * { type: 'prompt' | 'response' | 'error', message?: string } + */ +const messages = [] + +promptButton.onclick = () => { + const prompt = promptInput.value + if (prompt === '' || isAlreadyResponding) { + return + } + + createMessageElement('prompt', prompt) + promptAiWarp(prompt).catch(err => { + throw err + }) + + promptInput.value = '' +} + +/** + * @param {KeyboardEvent} event + */ +promptInput.onkeydown = (event) => { + if (event.key === 'Enter') { + promptButton.onclick() + } +} + +const searchParams = new URL(document.location.toString()).searchParams +if (searchParams.has('prompt')) { + const prompt = searchParams.get('prompt') + createMessageElement('prompt', prompt) + promptAiWarp(prompt).catch(err => { + throw err + }) +} + +/** + * @param {string} prompt + */ +async function promptAiWarp (prompt) { + isAlreadyResponding = true + + try { + const res = await fetch('/api/v1/stream', { + method: 'POST', + headers: { + 'content-type': 'application/json' + }, + body: JSON.stringify({ prompt }) + }) + if (res.status !== 200) { + const { message, code } = await res.json() + throw new Error(`AI Warp error: ${message} (${code})`) + } + + createMessageElement('response', res.body) + } catch (err) { + createMessageElement('error') + console.error(err) + isAlreadyResponding = false + } +} + +/** + * @param {'prompt' | 'response' | 'error'} type + * @param {string | ReadableStream | undefined} message String if it's a prompt, ReadableStream if it's a response + */ +function createMessageElement (type, message) { + const messageElement = document.createElement('div') + messageElement.classList.add('message') + messagesElement.appendChild(messageElement) + + const avatarElement = document.createElement('div') + avatarElement.classList.add('message-avatar') + messageElement.appendChild(avatarElement) + + const avatarImg = document.createElement('img') + avatarElement.appendChild(avatarImg) + + const contentsElement = document.createElement('div') + contentsElement.classList.add('message-contents') + messageElement.appendChild(contentsElement) + + const authorElement = document.createElement('p') + authorElement.classList.add('message-author') + contentsElement.appendChild(authorElement) + + if (type === 'prompt') { + avatarImg.setAttribute('src', '/images/avatars/you.svg') + authorElement.innerHTML = 'You' + } else { + avatarImg.setAttribute('src', '/images/avatars/platformatic.svg') + authorElement.innerHTML = 'Platformatic Ai-Warp' + } + + if (type === 'error' && message === undefined) { + // Display error message + const textElement = document.createElement('p') + textElement.classList.add('message-error') + textElement.innerHTML = 'Error Something went wrong. If this issue persists please contact us at support@platformatic.dev' + contentsElement.appendChild(textElement) + textElement.scrollIntoView() + + messages.push({ type: 'error' }) + } else if (typeof message === 'string') { + // Echo prompt back to user + const textElement = document.createElement('p') + textElement.innerHTML = message + contentsElement.appendChild(textElement) + contentsElement.appendChild(createMessageOptionsElement('prompt', message)) + textElement.scrollIntoView() + + messages.push({ type: 'prompt', message }) + } else { + // Parse response from api + parseResponse(contentsElement, message) + .then(() => { + isAlreadyResponding = false + }) + .catch(err => { + createMessageElement('error') + console.error(err) + }) + } +} + +/** + * @param {string} response + * @returns {HTMLButtonElement} + */ +function createCopyResponseButton (response) { + const element = document.createElement('button') + element.onclick = () => { + navigator.clipboard.writeText(response) + } + + const icon = document.createElement('img') + icon.setAttribute('src', '/images/icons/copy.svg') + icon.setAttribute('alt', 'Copy') + element.appendChild(icon) + + return element +} + +/** + * @param {string} response + * @returns {HTMLButtonElement} + */ +function createRegenerateResponseButton (response) { + const element = document.createElement('button') + element.onclick = () => { + // TODO + } + + const icon = document.createElement('img') + icon.setAttribute('src', '/images/icons/regenerate.svg') + icon.setAttribute('alt', 'Regenerate') + element.appendChild(icon) + + return element +} + +/** + * @param {'prompt' | 'response'} type + * @param {string} message + * @param {number} messageIndex Index of the message in {@link messages} + * @returns {HTMLDivElement} + */ +function createMessageOptionsElement (type, message, messageIndex) { + const messageOptions = document.createElement('div') + messageOptions.classList.add('message-options') + + if (type === 'prompt') { + // TODO + } else if (type === 'response') { + messageOptions.appendChild(createRegenerateResponseButton(message)) + messageOptions.appendChild(createCopyResponseButton(message)) + } + + return messageOptions +} + +/** + * @param {HTMLDivElement} parentElement Parent + * @param {ReadableStream} stream To read from + */ +async function parseResponse (parentElement, stream) { + let isFirstPass = true + + let currentElement = document.createElement('p') + currentElement.innerHTML = 'Platformatic Ai-Warp is typing...' + parentElement.appendChild(currentElement) + + let fullResponse = '' + + const parser = new SSEParser(stream) + while (true) { + if (isFirstPass) { + currentElement.innerHTML = '' + isFirstPass = false + } + + const tokens = await parser.pull() + if (tokens === undefined) { + break + } + + const tokenString = tokens.join('') + fullResponse += tokenString + + const lines = tokenString.split('\n') + for (let i = 0; i < lines.length; i++) { + currentElement.innerHTML += lines[i] + + if (i + 1 < lines.length) { + currentElement = document.createElement('p') + parentElement.appendChild(currentElement) + currentElement.scrollIntoView() + } + } + } + + parentElement.appendChild(createMessageOptionsElement('response', fullResponse)) + + messages.push({ type: 'response', message: fullResponse }) +} + +/** + * Parser for server sent events returned by the streaming endpoint + */ +class SSEParser { + /** + * @param {ReadableStream} stream + */ + constructor (stream) { + this.reader = stream.getReader() + this.decoder = new TextDecoder() + } + + /** + * @returns {string[] | undefined} Undefined at the end of the stream + */ + async pull () { + const { done, value } = await this.reader.read() + if (done) { + return undefined + } + + const decodedValue = this.decoder.decode(value) + const lines = decodedValue.split('\n') + + const tokens = [] + let i = 0 + while (i < lines.length) { + const line = lines[i] + if (line.length === 0) { + i++ + continue + } + + if (!line.startsWith('event: ')) { + throw new Error(`Unexpected event type line: ${line}`) + } + + const dataLine = lines[i + 1] + if (!dataLine.startsWith('data: ')) { + throw new Error(`Unexpected data line: ${dataLine}`) + } + + const eventType = line.substring('event: '.length) + const data = dataLine.substring('data: '.length) + const json = JSON.parse(data) + if (eventType === 'content') { + const { response } = json + tokens.push(response) + } else if (eventType === 'error') { + const { message, code } = data + throw new Error(`AI Warp Error: ${message} (${code})`) + } + + i += 2 + } + + return tokens + } +} diff --git a/static/styles/chat.css b/static/styles/chat.css new file mode 100644 index 0000000..1a7cc0b --- /dev/null +++ b/static/styles/chat.css @@ -0,0 +1,79 @@ +#messages { + width: 50%; + height: 550px; + max-height: 550px; + margin-left: 25%; + margin-top: 50px; + overflow: scroll; + overflow-anchor: auto; +} + +.message { + display: flex; +} + +.message-avatar { + width: 5%; + padding-right: 15px; +} + +.message-contents { + width: 100%; +} + +.message-author { + margin-top: 0; + font-weight: 600; +} + +.message-error { + background-color: rgba(250, 33, 33, 0.3); + border: 1px solid #FA2121; + border-radius: 4px; + color: #FA21214D; + padding: 4px 8px 4px 8px; +} + +.message-options { + width: 100%; + background-color: rgb(234, 231, 231, 0.5); +} + +.message-options button { + border: 0; + background-color: rgba(0, 0, 0, 0); + cursor: pointer; + float: right; +} + +#prompt { + width: 50%; + margin-left: 25%; + margin-top: 50px; + display: flex; + justify-content: center; +} + +#prompt-input { + width: 95%; + background-color: rgba(0, 0, 0, 0); + color: #FFFFFF; + padding: 10px; + border: 1px solid #FFFFFFB2; + border-right: 0; + border-top-left-radius: 4px; + border-bottom-left-radius: 4px; +} + +#prompt-button { + width: 5%; + background-color: rgba(0, 0, 0, 0); + border: 1px solid #FFFFFFB2; + border-left: 0; + border-top-right-radius: 4px; + border-bottom-right-radius: 4px; +} + +#prompt-button:hover { + cursor: pointer; +} diff --git a/static/styles/common.css b/static/styles/common.css new file mode 100644 index 0000000..0016267 --- /dev/null +++ b/static/styles/common.css @@ -0,0 +1,43 @@ +@import url('https://fonts.googleapis.com/css2?family=Inter:wght@100..900&display=swap'); + +body { + background-color: #00050B; + padding: 0; + margin: 0; + color: white; + font-family: "Inter", sans-serif; + font-optical-sizing: auto; + font-weight: 100; + font-style: normal; + font-variation-settings: "slnt" 0; +} + +#navbar { + padding-top: 5px; + padding-bottom: 5px; + width: 100%; + border-bottom: 2px solid #FFFFFF26; +} + +#navbar-logo { + padding-left: 60px; +} + +#bottom-links { + margin-top: 30px; + /* 30px on the bottom so there's a little space between it and the end of the page */ + margin-bottom: 30px; + text-align: center; +} + +#bottom-links a { + padding-top: 12px; + padding-bottom: 12px; + padding-left: 8px; + padding-right: 8px; + margin-right: 8px; + border: 1px solid #FFFFFFB2; + border-radius: 4px; + text-decoration: none; + color: #FFFFFF; +} diff --git a/static/styles/index.css b/static/styles/index.css new file mode 100644 index 0000000..2c1a02a --- /dev/null +++ b/static/styles/index.css @@ -0,0 +1,84 @@ +#main-illustration { + margin-top: 60px; + width: 100%; + text-align: center; + position: absolute; +} + +#greeting { + padding-top: 450px; + text-align: center; +} + +#prompt-suggestions { + width: 50%; + margin-left: 25%; + padding-top: 100px; + display: flex; +} + +.prompt-suggestion-column { + width: 100%; + margin-right: 12px; + display: flex; + flex-direction: column; + align-content: right; +} + +.prompt-suggestion { + width: 100%; + margin-top: 8px; + border: 1px solid #FFFFFF4D; + border-radius: 4px; +} + +.prompt-suggestion p { + margin-left: 8px; + margin-right: 8px; +} + +.prompt-suggestion-title { + margin-top: 8px; + font-weight: 600; +} + +.prompt-suggestion-subtitle { + color: #FFFFFFB2; + margin-bottom: 8px; +} + +.prompt-suggestion-subtitle a { + float: right; +} + +#custom-prompt { + width: 50%; + margin-left: 25%; + margin-top: 50px; + display: flex; + justify-content: center; +} + +#custom-prompt-input { + width: 95%; + background-color: rgba(0, 0, 0, 0); + color: #FFFFFF; + padding: 10px; + border: 1px solid #FFFFFFB2; + border-right: 0; + border-top-left-radius: 4px; + border-bottom-left-radius: 4px; +} + +#custom-prompt-button { + width: 5%; + background-color: rgba(0, 0, 0, 0); + border: 1px solid #FFFFFFB2; + border-left: 0; + border-top-right-radius: 4px; + border-bottom-right-radius: 4px; +} + +#custom-prompt-button:hover { + cursor: pointer; +}