Skip to content

Commit

Permalink
Support providing chat history
Browse files Browse the repository at this point in the history
  • Loading branch information
flakey5 committed May 11, 2024
1 parent 0b83a71 commit e62799b
Show file tree
Hide file tree
Showing 10 changed files with 115 additions and 31 deletions.
24 changes: 20 additions & 4 deletions ai-providers/azure.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { ReadableStream, ReadableByteStreamController, UnderlyingByteSource } from 'stream/web'
import { AiProvider, NoContentError, StreamChunkCallback } from './provider.js'
import { AiProvider, ChatHistory, NoContentError, StreamChunkCallback } from './provider.js'
import { AiStreamEvent, encodeEvent } from './event.js'
import { AzureKeyCredential, ChatCompletions, EventStream, OpenAIClient } from '@azure/openai'
import { AzureKeyCredential, ChatCompletions, ChatRequestMessageUnion, EventStream, OpenAIClient } from '@azure/openai'

type AzureStreamResponse = EventStream<ChatCompletions>

Expand Down Expand Up @@ -95,8 +95,9 @@ export class AzureProvider implements AiProvider {
)
}

async ask (prompt: string): Promise<string> {
async ask (prompt: string, chatHistory?: ChatHistory): Promise<string> {
const { choices } = await this.client.getChatCompletions(this.deploymentName, [
...this.chatHistoryToMessages(chatHistory),
{ role: 'user', content: prompt }
])

Expand All @@ -112,11 +113,26 @@ export class AzureProvider implements AiProvider {
return message.content
}

async askStream (prompt: string, chunkCallback?: StreamChunkCallback | undefined): Promise<ReadableStream> {
async askStream (prompt: string, chunkCallback?: StreamChunkCallback, chatHistory?: ChatHistory): Promise<ReadableStream> {
const response = await this.client.streamChatCompletions(this.deploymentName, [
...this.chatHistoryToMessages(chatHistory),
{ role: 'user', content: prompt }
])

return new ReadableStream(new AzureByteSource(response, chunkCallback))
}

private chatHistoryToMessages(chatHistory?: ChatHistory): ChatRequestMessageUnion[] {

Check failure on line 125 in ai-providers/azure.ts

View workflow job for this annotation

GitHub Actions / Linting

Missing space before function parentheses
if (chatHistory === undefined) {
return []
}

const messages: ChatRequestMessageUnion[] = []
for (const previousInteraction of chatHistory) {
messages.push({ role: 'user', content: previousInteraction.prompt })
messages.push({ role: 'assistant', content: previousInteraction.response })
}

return messages
}
}
17 changes: 12 additions & 5 deletions ai-providers/llama2.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import { ReadableByteStreamController, ReadableStream, UnderlyingByteSource } from 'stream/web'
import { FastifyLoggerInstance } from 'fastify'
import {
ConversationInteraction,

Check failure on line 4 in ai-providers/llama2.ts

View workflow job for this annotation

GitHub Actions / Linting

'ConversationInteraction' is defined but never used
LLamaChatPromptOptions,
LlamaChatSession,
LlamaContext,
LlamaModel
} from 'node-llama-cpp'
import { AiProvider, StreamChunkCallback } from './provider.js'
import { AiProvider, ChatHistory, StreamChunkCallback } from './provider.js'
import { AiStreamEvent, encodeEvent } from './event.js'

interface ChunkQueueNode {
Expand Down Expand Up @@ -174,15 +175,21 @@ export class Llama2Provider implements AiProvider {
this.logger = logger
}

async ask (prompt: string): Promise<string> {
const session = new LlamaChatSession({ context: this.context })
async ask (prompt: string, chatHistory?: ChatHistory): Promise<string> {
const session = new LlamaChatSession({
context: this.context,
conversationHistory: chatHistory
})
const response = await session.prompt(prompt)

return response
}

async askStream (prompt: string, chunkCallback?: StreamChunkCallback): Promise<ReadableStream> {
const session = new LlamaChatSession({ context: this.context })
async askStream (prompt: string, chunkCallback?: StreamChunkCallback, chatHistory?: ChatHistory): Promise<ReadableStream> {
const session = new LlamaChatSession({
context: this.context,
conversationHistory: chatHistory
})

return new ReadableStream(new Llama2ByteSource(session, prompt, this.logger, chunkCallback))
}
Expand Down
22 changes: 19 additions & 3 deletions ai-providers/mistral.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { ReadableStream, UnderlyingByteSource, ReadableByteStreamController } from 'node:stream/web'
import MistralClient, { ChatCompletionResponseChunk } from '@platformatic/mistral-client'
import { AiProvider, NoContentError, StreamChunkCallback } from './provider.js'
import { AiProvider, ChatHistory, NoContentError, StreamChunkCallback } from './provider.js'
import { AiStreamEvent, encodeEvent } from './event.js'

type MistralStreamResponse = AsyncGenerator<ChatCompletionResponseChunk, void, unknown>
Expand Down Expand Up @@ -66,10 +66,11 @@ export class MistralProvider implements AiProvider {
this.client = new MistralClient(apiKey)
}

async ask (prompt: string): Promise<string> {
async ask (prompt: string, chatHistory?: ChatHistory): Promise<string> {
const response = await this.client.chat({
model: this.model,
messages: [
...this.chatHistoryToMessages(chatHistory),
{ role: 'user', content: prompt }
]
})
Expand All @@ -81,13 +82,28 @@ export class MistralProvider implements AiProvider {
return response.choices[0].message.content
}

async askStream (prompt: string, chunkCallback?: StreamChunkCallback): Promise<ReadableStream> {
async askStream (prompt: string, chunkCallback?: StreamChunkCallback, chatHistory?: ChatHistory): Promise<ReadableStream> {
const response = this.client.chatStream({
model: this.model,
messages: [
...this.chatHistoryToMessages(chatHistory),
{ role: 'user', content: prompt }
]
})
return new ReadableStream(new MistralByteSource(response, chunkCallback))
}

private chatHistoryToMessages(chatHistory?: ChatHistory): { role: string, content: string }[] {

Check failure on line 96 in ai-providers/mistral.ts

View workflow job for this annotation

GitHub Actions / Linting

Missing space before function parentheses

Check failure on line 96 in ai-providers/mistral.ts

View workflow job for this annotation

GitHub Actions / Linting

Array type using 'T[]' is forbidden for non-simple types. Use 'Array<T>' instead
if (chatHistory === undefined) {
return []
}

const messages: { role: string, content: string }[] = []

Check failure on line 101 in ai-providers/mistral.ts

View workflow job for this annotation

GitHub Actions / Linting

Array type using 'T[]' is forbidden for non-simple types. Use 'Array<T>' instead
for (const previousInteraction of chatHistory) {
messages.push({ role: 'user', content: previousInteraction.prompt })
messages.push({ role: 'assistant', content: previousInteraction.response })
}

return messages
}
}
24 changes: 20 additions & 4 deletions ai-providers/ollama.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { ReadableStream, UnderlyingByteSource, ReadableByteStreamController } from 'stream/web'
import { Ollama, ChatResponse } from 'ollama'
import { AiProvider, StreamChunkCallback } from './provider.js'
import { Ollama, ChatResponse, Message } from 'ollama'
import { AiProvider, ChatHistory, StreamChunkCallback } from './provider.js'
import { AiStreamEvent, encodeEvent } from './event.js'

type OllamaStreamResponse = AsyncGenerator<ChatResponse>
Expand Down Expand Up @@ -51,26 +51,42 @@ export class OllamaProvider implements AiProvider {
this.client = new Ollama({ host })
}

async ask (prompt: string): Promise<string> {
async ask (prompt: string, chatHistory?: ChatHistory): Promise<string> {
const response = await this.client.chat({
model: this.model,
messages: [
...this.chatHistoryToMessages(chatHistory),
{ role: 'user', content: prompt }
]
})

return response.message.content
}

async askStream (prompt: string, chunkCallback?: StreamChunkCallback | undefined): Promise<ReadableStream> {
async askStream (prompt: string, chunkCallback?: StreamChunkCallback, chatHistory?: ChatHistory): Promise<ReadableStream> {
const response = await this.client.chat({
model: this.model,
messages: [
...this.chatHistoryToMessages(chatHistory),
{ role: 'user', content: prompt }
],
stream: true
})

return new ReadableStream(new OllamaByteSource(response, chunkCallback))
}

private chatHistoryToMessages(chatHistory?: ChatHistory): Message[] {

Check failure on line 79 in ai-providers/ollama.ts

View workflow job for this annotation

GitHub Actions / Linting

Missing space before function parentheses
if (chatHistory === undefined) {
return []
}

const messages: Message[] = []
for (const previousInteraction of chatHistory) {
messages.push({ role: 'user', content: previousInteraction.prompt })
messages.push({ role: 'assistant', content: previousInteraction.response })
}

return messages
}
}
24 changes: 20 additions & 4 deletions ai-providers/open-ai.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import { ReadableStream, UnderlyingByteSource, ReadableByteStreamController } from 'node:stream/web'
import OpenAI from 'openai'
import { AiProvider, NoContentError, StreamChunkCallback } from './provider.js'
import { AiProvider, ChatHistory, NoContentError, StreamChunkCallback } from './provider.js'
import { ReadableStream as ReadableStreamPolyfill } from 'web-streams-polyfill'
import { fetch } from 'undici'
import { ChatCompletionChunk } from 'openai/resources/index'
import { ChatCompletionChunk, ChatCompletionMessageParam } from 'openai/resources/index'
import { AiStreamEvent, encodeEvent } from './event.js'
import createError from '@fastify/error'

Expand Down Expand Up @@ -96,10 +96,11 @@ export class OpenAiProvider implements AiProvider {
this.client = new OpenAI({ apiKey, fetch })
}

async ask (prompt: string): Promise<string> {
async ask (prompt: string, chatHistory?: ChatHistory): Promise<string> {
const response = await this.client.chat.completions.create({
model: this.model,
messages: [
...this.chatHistoryToMessages(chatHistory),
{ role: 'user', content: prompt }
],
stream: false
Expand All @@ -117,14 +118,29 @@ export class OpenAiProvider implements AiProvider {
return content
}

async askStream (prompt: string, chunkCallback?: StreamChunkCallback): Promise<ReadableStream> {
async askStream (prompt: string, chunkCallback?: StreamChunkCallback, chatHistory?: ChatHistory): Promise<ReadableStream> {
const response = await this.client.chat.completions.create({
model: this.model,
messages: [
...this.chatHistoryToMessages(chatHistory),
{ role: 'user', content: prompt }
],
stream: true
})
return new ReadableStream(new OpenAiByteSource(response.toReadableStream() as ReadableStreamPolyfill, chunkCallback))
}

private chatHistoryToMessages(chatHistory?: ChatHistory): ChatCompletionMessageParam[] {

Check failure on line 133 in ai-providers/open-ai.ts

View workflow job for this annotation

GitHub Actions / Linting

Missing space before function parentheses
if (chatHistory === undefined) {
return []
}

const messages: ChatCompletionMessageParam[] = []
for (const previousInteraction of chatHistory) {
messages.push({ role: 'user', content: previousInteraction.prompt })
messages.push({ role: 'assistant', content: previousInteraction.response })
}

return messages
}
}
6 changes: 4 additions & 2 deletions ai-providers/provider.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import { ReadableStream } from 'node:stream/web'
import createError from '@fastify/error'

export type ChatHistory = { prompt: string, response: string }[]

Check failure on line 4 in ai-providers/provider.ts

View workflow job for this annotation

GitHub Actions / Linting

Array type using 'T[]' is forbidden for non-simple types. Use 'Array<T>' instead

export type StreamChunkCallback = (response: string) => Promise<string>

export interface AiProvider {
ask: (prompt: string) => Promise<string>
askStream: (prompt: string, chunkCallback?: StreamChunkCallback) => Promise<ReadableStream>
ask: (prompt: string, chatHistory?: ChatHistory) => Promise<string>
askStream: (prompt: string, chunkCallback?: StreamChunkCallback, chatHistory?: ChatHistory) => Promise<ReadableStream>
}

export const NoContentError = createError<[string]>('NO_CONTENT', '%s didn\'t return any content')
9 changes: 7 additions & 2 deletions index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,24 @@ import { PlatformaticApp } from '@platformatic/service'
import { errorResponseBuilderContext } from '@fastify/rate-limit'
import { AiWarpConfig } from './config.js'

type ChatHistory = {

Check failure on line 6 in index.d.ts

View workflow job for this annotation

GitHub Actions / Linting

Array type using 'T[]' is forbidden for non-simple types. Use 'Array<T>' instead
prompt: string,

Check failure on line 7 in index.d.ts

View workflow job for this annotation

GitHub Actions / Linting

Unexpected separator
response: string
}[]

declare module 'fastify' {
interface FastifyInstance {
platformatic: PlatformaticApp<AiWarpConfig>
ai: {
/**
* Send a prompt to the AI provider and receive the full response.
*/
warp: (request: FastifyRequest, prompt: string) => Promise<string>
warp: (request: FastifyRequest, prompt: string, chatHistory?: ChatHistory) => Promise<string>

/**
* Send a prompt to the AI provider and receive a streamed response.
*/
warpStream: (request: FastifyRequest, prompt: string) => Promise<ReadableStream>
warpStream: (request: FastifyRequest, prompt: string, chatHistory?: ChatHistory) => Promise<ReadableStream>

/**
* A function to be called before warp() returns it's result. It can
Expand Down
6 changes: 5 additions & 1 deletion plugins/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ const plugin: FastifyPluginAsyncTypebox = async (fastify) => {
method: 'POST',
schema: {
body: Type.Object({
prompt: Type.String()
prompt: Type.String(),
chatHistory: Type.Optional(Type.Array(Type.Object({
prompt: Type.String(),
response: Type.String()
})))
}),
response: {
200: Type.Object({
Expand Down
8 changes: 4 additions & 4 deletions plugins/warp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,21 @@ export default fastifyPlugin(async (fastify) => {
const provider = await build(config.aiProvider, fastify.log)

fastify.decorate('ai', {
warp: async (request, prompt) => {
warp: async (request, prompt, chatHistory) => {
let decoratedPrompt = prompt
if (config.promptDecorators !== undefined) {
const { prefix, suffix } = config.promptDecorators
decoratedPrompt = (prefix ?? '') + decoratedPrompt + (suffix ?? '')
}

let response = await provider.ask(decoratedPrompt)
let response = await provider.ask(decoratedPrompt, chatHistory)
if (fastify.ai.preResponseCallback !== undefined) {
response = await fastify.ai.preResponseCallback(request, response) ?? response
}

return response
},
warpStream: async (request, prompt) => {
warpStream: async (request, prompt, chatHistory) => {
let decoratedPrompt = prompt
if (config.promptDecorators !== undefined) {
const { prefix, suffix } = config.promptDecorators
Expand All @@ -68,7 +68,7 @@ export default fastifyPlugin(async (fastify) => {
}
}

const response = await provider.askStream(decoratedPrompt, chunkCallback)
const response = await provider.askStream(decoratedPrompt, chunkCallback, chatHistory)
return response
},
rateLimiting: {}
Expand Down
6 changes: 4 additions & 2 deletions static/scripts/chat.js
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,10 @@ async function promptAiWarp (message) {
headers: {
'content-type': 'application/json'
},
// TODO chat history
body: JSON.stringify({ prompt: message.prompt })
body: JSON.stringify({
prompt: message.prompt,
chatHistory: messages
})
})
if (res.status !== 200) {
const { message, code } = await res.json()
Expand Down

0 comments on commit e62799b

Please sign in to comment.