Skip to content

Commit

Permalink
Add Llama2 provider
Browse files Browse the repository at this point in the history
Closes #20
  • Loading branch information
flakey5 committed Apr 28, 2024
1 parent b7fb651 commit e0d2507
Show file tree
Hide file tree
Showing 10 changed files with 1,668 additions and 1 deletion.
173 changes: 173 additions & 0 deletions ai-providers/llama2.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import { ReadableByteStreamController, ReadableStream, UnderlyingByteSource } from 'stream/web'
import {
LLamaChatPromptOptions,
LlamaChatSession,
LlamaContext,
LlamaModel
} from 'node-llama-cpp'
import { AiProvider, StreamChunkCallback } from './provider.js'
import { AiStreamEvent, encodeEvent } from './event.js'

interface ChunkQueueNode {
chunk: number[]
next?: ChunkQueueNode
}

class ChunkQueue {
private size: number = 0
private head?: ChunkQueueNode
private tail?: ChunkQueueNode

getSize (): number {
return this.size
}

push (chunk: number[]): void {
this.size++

const node: ChunkQueueNode = { chunk }
if (this.head === undefined || this.tail === undefined) {
this.head = node
this.tail = node
} else {
this.tail.next = node
this.tail = node
}
}

pop (): number[] | undefined {
if (this.head === undefined) {
return undefined
}

this.size--

const chunk = this.head.chunk
this.head = this.head.next

if (this.size === 0) {
this.tail = undefined
}

return chunk
}
}

class Llama2ByteSource implements UnderlyingByteSource {
type: 'bytes' = 'bytes'
session: LlamaChatSession
chunkCallback?: StreamChunkCallback
backloggedChunks: ChunkQueue = new ChunkQueue()
finished: boolean = false
controller?: ReadableByteStreamController

constructor (session: LlamaChatSession, prompt: string, chunkCallback?: StreamChunkCallback) {
this.session = session
this.chunkCallback = chunkCallback

session.prompt(prompt, {
onToken: this.onToken
}).then(() => {
this.finished = true
// Don't close the stream if we still have chunks to send
if (this.backloggedChunks.getSize() === 0 && this.controller !== undefined) {
this.controller.close()
}
}).catch((err: any) => {
this.finished = true
if (this.controller !== undefined) {
this.controller.close()
}
throw err
})
}

onToken: LLamaChatPromptOptions['onToken'] = async (chunk) => {
if (this.controller === undefined) {
// Stream hasn't started yet, added it to the backlog queue
this.backloggedChunks.push(chunk)
return
}

await this.clearBacklog()
await this.enqueueChunk(chunk)
}

private async enqueueChunk (chunk: number[]): Promise<void> {
if (this.controller === undefined) {
throw new Error('tried enqueueing chunk before stream started')
}

let response = this.session.context.decode(chunk)
if (this.chunkCallback !== undefined) {
response = await this.chunkCallback(response)
}

const eventData: AiStreamEvent = {
event: 'content',
data: {
response
}
}
this.controller.enqueue(encodeEvent(eventData))

if (this.backloggedChunks.getSize() === 0 && this.finished) {
this.controller.close()
}
}

async clearBacklog (): Promise<void> {
if (this.backloggedChunks.getSize() === 0) {
return
}

let backloggedChunk = this.backloggedChunks.pop()
while (backloggedChunk !== undefined) {
// Each chunk needs to be sent in order, can't run all of these at once
await this.enqueueChunk(backloggedChunk)
backloggedChunk = this.backloggedChunks.pop()
}
}

start (controller: ReadableByteStreamController): void {
this.controller = controller
this.clearBacklog().catch(err => {
throw err
})
}
}

interface Llama2ProviderCtorOptions {
modelPath: string
}

export class Llama2Provider implements AiProvider {
modelPath: string
session?: LlamaChatSession

constructor ({ modelPath }: Llama2ProviderCtorOptions) {
this.modelPath = modelPath
}

async ask (prompt: string): Promise<string> {
if (this.session === undefined) {
const model = new LlamaModel({ modelPath: this.modelPath })
const context = new LlamaContext({ model })
this.session = new LlamaChatSession({ context })
}

const response = await this.session.prompt(prompt)

return response
}

async askStream (prompt: string, chunkCallback?: StreamChunkCallback): Promise<ReadableStream> {
if (this.session === undefined) {
const model = new LlamaModel({ modelPath: this.modelPath })
const context = new LlamaContext({ model })
this.session = new LlamaChatSession({ context })
}

return new ReadableStream(new Llama2ByteSource(this.session, prompt, chunkCallback))
}
}
5 changes: 5 additions & 0 deletions config.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,11 @@ export interface AiWarpConfig {
deploymentName: string;
allowInsecureConnections?: boolean;
};
}
| {
llama2: {
modelPath: string;
};
};
promptDecorators?: {
prefix?: string;
Expand Down
12 changes: 12 additions & 0 deletions lib/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { generatePlugins } from '@platformatic/generators/lib/create-plugin.js'
interface PackageJson {
name: string
version: string
devDependencies: Record<string, string>
}

class AiWarpGenerator extends ServiceGenerator {
Expand Down Expand Up @@ -95,6 +96,13 @@ class AiWarpGenerator extends ServiceGenerator {
}
}
break
case 'llama2':
config.aiProvider = {
llama2: {
modelPath: '/path/to/model'
}
}
break
default:
config.aiProvider = {
openai: {
Expand Down Expand Up @@ -130,6 +138,10 @@ class AiWarpGenerator extends ServiceGenerator {
this.config.dependencies = {
[packageJson.name]: `^${packageJson.version}`
}

if (this.config.aiProvider === 'llama2') {
this.config.dependencies['node-llama-cpp'] = packageJson.devDependencies['node-llama-cpp']
}
}

async _afterPrepare (): Promise<void> {
Expand Down
14 changes: 14 additions & 0 deletions lib/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,20 @@ const aiWarpSchema = {
},
required: ['azure'],
additionalProperties: false
},
{
properties: {
llama2: {
type: 'object',
properties: {
modelPath: { type: 'string' }
},
required: ['modelPath'],
additionalProperties: false
}
},
required: ['llama2'],
additionalProperties: false
}
]
},
Expand Down
Loading

0 comments on commit e0d2507

Please sign in to comment.