diff --git a/ai-providers/llama2.ts b/ai-providers/llama2.ts index 369cabd..907b178 100644 --- a/ai-providers/llama2.ts +++ b/ai-providers/llama2.ts @@ -142,32 +142,23 @@ interface Llama2ProviderCtorOptions { } export class Llama2Provider implements AiProvider { - modelPath: string - session?: LlamaChatSession + context: LlamaContext constructor ({ modelPath }: Llama2ProviderCtorOptions) { - this.modelPath = modelPath + const model = new LlamaModel({ modelPath }) + this.context = new LlamaContext({ model }) } async ask (prompt: string): Promise { - if (this.session === undefined) { - const model = new LlamaModel({ modelPath: this.modelPath }) - const context = new LlamaContext({ model }) - this.session = new LlamaChatSession({ context }) - } - - const response = await this.session.prompt(prompt) + const session = new LlamaChatSession({ context: this.context }) + const response = await session.prompt(prompt) return response } async askStream (prompt: string, chunkCallback?: StreamChunkCallback): Promise { - if (this.session === undefined) { - const model = new LlamaModel({ modelPath: this.modelPath }) - const context = new LlamaContext({ model }) - this.session = new LlamaChatSession({ context }) - } + const session = new LlamaChatSession({ context: this.context }) - return new ReadableStream(new Llama2ByteSource(this.session, prompt, chunkCallback)) + return new ReadableStream(new Llama2ByteSource(session, prompt, chunkCallback)) } }