Skip to content

Commit

Permalink
Streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
flakey5 committed Mar 26, 2024
1 parent a7c1f36 commit a617e64
Show file tree
Hide file tree
Showing 9 changed files with 263 additions and 9 deletions.
37 changes: 37 additions & 0 deletions ai-providers/event.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import { FastifyError } from 'fastify'
import fastJson from 'fast-json-stringify'

const stringifyEventData = fastJson({
title: 'Stream Event Data',
type: 'object',
properties: {
// Success
response: { type: 'string' },
// Error
code: { type: 'string' },
message: { type: 'string' }
}
})

export interface AiStreamEventContent {
response: string
}

export type AiStreamEvent = {
event: 'content'
data: AiStreamEventContent
} | {
event: 'error'
data: FastifyError
}

/**
* Encode an event to the Event Stream format
* @see https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format
*/
export function encodeEvent ({ event, data }: AiStreamEvent): Uint8Array {
const jsonString = stringifyEventData(data)
const eventString = `event: ${event}\ndata: ${jsonString}\n\n`

return new TextEncoder().encode(eventString)
}
69 changes: 68 additions & 1 deletion ai-providers/mistral.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,56 @@
import { AiProvider, NoContentError } from './provider'
import { ReadableStream, UnderlyingByteSource, ReadableByteStreamController } from 'node:stream/web'
import { ChatCompletionResponseChunk } from '@mistralai/mistralai'
import { AiProvider, NoContentError, StreamChunkCallback } from './provider'
import { AiStreamEvent, encodeEvent } from './event'

type MistralStreamResponse = AsyncGenerator<ChatCompletionResponseChunk, void, unknown>

class MistralByteSource implements UnderlyingByteSource {
type: 'bytes' = 'bytes'
response: MistralStreamResponse
chunkCallback?: StreamChunkCallback

constructor (response: MistralStreamResponse, chunkCallback?: StreamChunkCallback) {
this.response = response
this.chunkCallback = chunkCallback
}

async pull (controller: ReadableByteStreamController): Promise<void> {
const { done, value } = await this.response.next()
if (done !== undefined && done) {
controller.close()
return
}

if (value.choices.length === 0) {
const error = new NoContentError('Mistral (Stream)')

const eventData: AiStreamEvent = {
event: 'error',
data: error
}
controller.enqueue(encodeEvent(eventData))
controller.close()

return
}

const { content } = value.choices[0].delta

let response = content ?? ''
if (this.chunkCallback !== undefined) {
response = await this.chunkCallback(response)
}

const eventData: AiStreamEvent = {
event: 'content',
data: {
response
}
}
controller.enqueue(encodeEvent(eventData))
}
}

export class MistralProvider implements AiProvider {
model: string
Expand Down Expand Up @@ -29,4 +81,19 @@ export class MistralProvider implements AiProvider {

return response.choices[0].message.content
}

async askStream (prompt: string, chunkCallback?: StreamChunkCallback): Promise<ReadableStream> {
if (this.client === undefined) {
const { default: MistralClient } = await import('@mistralai/mistralai')
this.client = new MistralClient(this.apiKey)
}

const response = this.client.chatStream({
model: this.model,
messages: [
{ role: 'user', content: prompt }
]
})
return new ReadableStream(new MistralByteSource(response, chunkCallback))
}
}
96 changes: 93 additions & 3 deletions ai-providers/open-ai.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,84 @@
import { ReadableStream, UnderlyingByteSource, ReadableByteStreamController } from 'node:stream/web'
import OpenAI from 'openai'
import { AiProvider, NoContentError } from './provider'
import { AiProvider, NoContentError, StreamChunkCallback } from './provider'
import { ReadableStream as ReadableStreamPolyfill } from 'web-streams-polyfill'
import { ChatCompletionChunk } from 'openai/resources/index.mjs'
import { AiStreamEvent, encodeEvent } from './event'
import createError from '@fastify/error'

const InvalidTypeError = createError<string>('DESERIALIZING_ERROR', 'Deserializing error: %s', 500)

class OpenAiByteSource implements UnderlyingByteSource {
type: 'bytes' = 'bytes'
polyfillStream: ReadableStreamPolyfill
reader?: ReadableStreamDefaultReader
chunkCallback?: StreamChunkCallback

constructor (polyfillStream: ReadableStreamPolyfill, chunkCallback?: StreamChunkCallback) {
this.polyfillStream = polyfillStream
this.chunkCallback = chunkCallback
}

start (): void {
this.reader = this.polyfillStream.getReader()
}

async pull (controller: ReadableByteStreamController): Promise<void> {
// start() defines this.reader and is called before this
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const { done, value } = await this.reader!.read()

if (done !== undefined && done) {
controller.close()
return
}

if (!(value instanceof Uint8Array)) {
// This really shouldn't happen but just in case + typescript likes
const error = new InvalidTypeError('OpenAI stream value not a Uint8Array')

const eventData: AiStreamEvent = {
event: 'error',
data: error
}
controller.enqueue(encodeEvent(eventData))
controller.close()

return
}

const jsonString = Buffer.from(value).toString('utf8')
const chunk: ChatCompletionChunk = JSON.parse(jsonString)

if (chunk.choices.length === 0) {
const error = new NoContentError('OpenAI stream')

const eventData: AiStreamEvent = {
event: 'error',
data: error
}
controller.enqueue(encodeEvent(eventData))
controller.close()

return
}

const { content } = chunk.choices[0].delta

let response = content ?? ''
if (this.chunkCallback !== undefined) {
response = await this.chunkCallback(response)
}

const eventData: AiStreamEvent = {
event: 'content',
data: {
response
}
}
controller.enqueue(encodeEvent(eventData))
}
}

export class OpenAiProvider implements AiProvider {
model: string
Expand All @@ -15,11 +94,11 @@ export class OpenAiProvider implements AiProvider {
model: this.model,
messages: [
{ role: 'user', content: prompt }
]
],
stream: false
})

if (response.choices.length === 0) {
// TODO: figure out error handling strategy
throw new NoContentError('OpenAI')
}

Expand All @@ -30,4 +109,15 @@ export class OpenAiProvider implements AiProvider {

return content
}

async askStream (prompt: string, chunkCallback?: StreamChunkCallback): Promise<ReadableStream> {
const response = await this.client.chat.completions.create({
model: this.model,
messages: [
{ role: 'user', content: prompt }
],
stream: true
})
return new ReadableStream(new OpenAiByteSource(response.toReadableStream()))
}
}
4 changes: 4 additions & 0 deletions ai-providers/provider.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import { ReadableStream } from 'node:stream/web'
import createError from '@fastify/error'

export type StreamChunkCallback = (response: string) => Promise<string>

export interface AiProvider {
ask: (prompt: string) => Promise<string>
askStream: (prompt: string, chunkCallback?: StreamChunkCallback) => Promise<ReadableStream>
}

export const NoContentError = createError<[string]>('NO_CONTENT', '%s didn\'t return any content')
3 changes: 3 additions & 0 deletions index.d.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { ReadableStream } from 'node:stream/web'
import { PlatformaticApp } from '@platformatic/service'
import { AiWarpConfig } from './config'

Expand All @@ -6,7 +7,9 @@ declare module 'fastify' {
platformatic: PlatformaticApp<AiWarpConfig>
ai: {
warp: (request: FastifyRequest, prompt: string) => Promise<string>
warpStream: (request: FastifyRequest, prompt: string) => Promise<ReadableStream>
preResponseCallback?: ((request: FastifyRequest, response: string) => string) | ((request: FastifyRequest, response: string) => Promise<string>)
preResponseChunkCallback?: ((request: FastifyRequest, response: string) => string) | ((request: FastifyRequest, response: string) => Promise<string>)
rateLimiting: {
max?: ((req: FastifyRequest, key: string) => number) | ((req: FastifyRequest, key: string) => Promise<number>)
allowList?: (req: FastifyRequest, key: string) => boolean | Promise<boolean>
Expand Down
1 change: 1 addition & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"@platformatic/config": "^1.24.0",
"@platformatic/generators": "^1.24.0",
"@platformatic/service": "^1.24.0",
"fast-json-stringify": "^5.13.0",
"fastify-user": "^0.3.3",
"json-schema-to-typescript": "^13.0.0",
"openai": "^4.28.4",
Expand Down
39 changes: 35 additions & 4 deletions plugins/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,50 @@ const plugin: FastifyPluginAsyncTypebox = async (fastify) => {
}
},
handler: async (request) => {
let response: string
try {
const { prompt } = request.body
response = await fastify.ai.warp(request, prompt)
const response = await fastify.ai.warp(request, prompt)

return { response }
} catch (exception) {
if (exception instanceof Object && isAFastifyError(exception)) {
return exception
} else {
return new InternalServerError()
const err = new InternalServerError()
err.cause = exception
throw err
}
}
}
})

return { response }
fastify.route({
url: '/api/v1/stream',
method: 'POST',
schema: {
produces: ['text/event-stream; charset=utf-16'],
body: Type.Object({
prompt: Type.String()
})
},
handler: async (request, reply) => {
try {
const { prompt } = request.body

const response = await fastify.ai.warpStream(request, prompt)
// eslint-disable-next-line @typescript-eslint/no-floating-promises
reply.header('content-type', 'text/event-stream')

return response
} catch (exception) {
if (exception instanceof Object && isAFastifyError(exception)) {
return exception
} else {
const err = new InternalServerError()
err.cause = exception
throw err
}
}
}
})
}
Expand Down
22 changes: 21 additions & 1 deletion plugins/warp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import fastifyPlugin from 'fastify-plugin'
import { OpenAiProvider } from '../ai-providers/open-ai'
import { MistralProvider } from '../ai-providers/mistral.js'
import { AiProvider } from '../ai-providers/provider'
import { AiProvider, StreamChunkCallback } from '../ai-providers/provider'
import { AiWarpConfig } from '../config'
import createError from '@fastify/error'

Expand Down Expand Up @@ -40,6 +40,26 @@ export default fastifyPlugin(async (fastify) => {

return response
},
warpStream: async (request, prompt) => {
let decoratedPrompt = prompt
if (config.promptDecorators !== undefined) {
const { prefix, suffix } = config.promptDecorators
decoratedPrompt = (prefix ?? '') + decoratedPrompt + (suffix ?? '')
}

let chunkCallback: StreamChunkCallback | undefined
if (fastify.ai.preResponseChunkCallback !== undefined) {
chunkCallback = async (response) => {
if (fastify.ai.preResponseChunkCallback === undefined) {
return response
}
return await fastify.ai.preResponseChunkCallback(request, response)
}
}

const response = await provider.askStream(decoratedPrompt, chunkCallback)
return response
},
rateLimiting: {}
})
})

0 comments on commit a617e64

Please sign in to comment.