diff --git a/ai-providers/ollama.ts b/ai-providers/ollama.ts new file mode 100644 index 0000000..c83355c --- /dev/null +++ b/ai-providers/ollama.ts @@ -0,0 +1,71 @@ +import { ReadableStream, UnderlyingByteSource, ReadableByteStreamController } from 'stream/web' +import { Ollama, ChatResponse } from 'ollama' +import { AiProvider, StreamChunkCallback } from './provider' +import { AiStreamEvent, encodeEvent } from './event' + +type OllamaStreamResponse = AsyncGenerator + +class OllamaByteSource implements UnderlyingByteSource { + type: 'bytes' = 'bytes' + response: OllamaStreamResponse + chunkCallback?: StreamChunkCallback + + constructor (response: OllamaStreamResponse, chunkCallback?: StreamChunkCallback) { + this.response = response + this.chunkCallback = chunkCallback + } + + async pull (controller: ReadableByteStreamController): Promise { + const { done, value } = await this.response.next() + if (done !== undefined && done) { + controller.close() + return + } + + let response = value.message.content + if (this.chunkCallback !== undefined) { + response = await this.chunkCallback(response) + } + + const eventData: AiStreamEvent = { + event: 'content', + data: { + response + } + } + controller.enqueue(encodeEvent(eventData)) + } +} + +export class OllamaProvider implements AiProvider { + model: string + client: Ollama + + constructor (host: string, model: string) { + this.model = model + this.client = new Ollama({ host }) + } + + async ask (prompt: string): Promise { + const response = await this.client.chat({ + model: this.model, + messages: [ + { role: 'user', content: prompt } + ] + }) + + return response.message.content + } + + async askStream (prompt: string, chunkCallback?: StreamChunkCallback | undefined): Promise { + const response = await this.client.chat({ + model: this.model, + messages: [ + { role: 'user', content: prompt } + ], + stream: true + }) + + return new ReadableStream(new OllamaByteSource(response, chunkCallback)) + } +} diff --git a/config.d.ts b/config.d.ts index 54da8d0..460162c 100644 --- a/config.d.ts +++ b/config.d.ts @@ -256,6 +256,12 @@ export interface AiWarpConfig { | "mistral-large-latest"; apiKey: string; }; + } + | { + ollama: { + host: string; + model: string; + }; }; promptDecorators?: { prefix?: string; diff --git a/lib/generator.ts b/lib/generator.ts index ff3c0fd..91afab8 100644 --- a/lib/generator.ts +++ b/lib/generator.ts @@ -78,6 +78,14 @@ class AiWarpGenerator extends ServiceGenerator { } } break + case 'ollama': + config.aiProvider = { + ollama: { + host: 'http://127.0.0.1:11434', + model: this.config.aiModel + } + } + break default: config.aiProvider = { openai: { diff --git a/lib/schema.ts b/lib/schema.ts index 0cab8a9..8bf3c7b 100644 --- a/lib/schema.ts +++ b/lib/schema.ts @@ -68,6 +68,21 @@ const aiWarpSchema = { }, required: ['mistral'], additionalProperties: false + }, + { + properties: { + ollama: { + type: 'object', + properties: { + host: { type: 'string' }, + model: { type: 'string' } + }, + required: ['host', 'model'], + additionalProperties: false + } + }, + required: ['ollama'], + additionalProperties: false } ] }, diff --git a/package-lock.json b/package-lock.json index ff0e581..e831116 100644 --- a/package-lock.json +++ b/package-lock.json @@ -19,6 +19,7 @@ "fast-json-stringify": "^5.13.0", "fastify-user": "^0.3.3", "json-schema-to-typescript": "^13.0.0", + "ollama": "^0.5.0", "openai": "^4.28.4", "snazzy": "^9.0.0", "ts-standard": "^12.0.2", @@ -5975,6 +5976,14 @@ "resolved": "https://registry.npmjs.org/obliterator/-/obliterator-2.0.4.tgz", "integrity": "sha512-lgHwxlxV1qIg1Eap7LgIeoBWIMFibOjbrYPIPJZcI1mmGAI2m3lNYpK12Y+GBdPQ0U1hRwSord7GIaawz962qQ==" }, + "node_modules/ollama": { + "version": "0.5.0", + "resolved": "https://registry.npmjs.org/ollama/-/ollama-0.5.0.tgz", + "integrity": "sha512-CRtRzsho210EGdK52GrUMohA2pU+7NbgEaBG3DcYeRmvQthDO7E2LHOkLlUUeaYUlNmEd8icbjC02ug9meSYnw==", + "dependencies": { + "whatwg-fetch": "^3.6.20" + } + }, "node_modules/on-exit-leak-free": { "version": "2.1.2", "resolved": "https://registry.npmjs.org/on-exit-leak-free/-/on-exit-leak-free-2.1.2.tgz", @@ -8216,6 +8225,11 @@ "resolved": "https://registry.npmjs.org/webidl-conversions/-/webidl-conversions-3.0.1.tgz", "integrity": "sha512-2JAn3z8AR6rjK8Sm8orRC0h/bcl/DqL7tRPdGZ4I1CjdF+EaMLmYxBHyXuKL849eucPFhvBoxMsflfOb8kxaeQ==" }, + "node_modules/whatwg-fetch": { + "version": "3.6.20", + "resolved": "https://registry.npmjs.org/whatwg-fetch/-/whatwg-fetch-3.6.20.tgz", + "integrity": "sha512-EqhiFU6daOA8kpjOWTL0olhVOF3i7OrFzSYiGsEMB8GcXS+RrzauAERX65xMeNWVqxA6HXH2m69Z9LaKKdisfg==" + }, "node_modules/whatwg-url": { "version": "5.0.0", "resolved": "https://registry.npmjs.org/whatwg-url/-/whatwg-url-5.0.0.tgz", diff --git a/package.json b/package.json index a4c5c63..1f2cfd9 100644 --- a/package.json +++ b/package.json @@ -36,6 +36,7 @@ "fast-json-stringify": "^5.13.0", "fastify-user": "^0.3.3", "json-schema-to-typescript": "^13.0.0", + "ollama": "^0.5.0", "openai": "^4.28.4", "snazzy": "^9.0.0", "ts-standard": "^12.0.2", diff --git a/plugins/warp.ts b/plugins/warp.ts index 59d72ce..5331f4b 100644 --- a/plugins/warp.ts +++ b/plugins/warp.ts @@ -6,6 +6,7 @@ import { MistralProvider } from '../ai-providers/mistral.js' import { AiProvider, StreamChunkCallback } from '../ai-providers/provider' import { AiWarpConfig } from '../config' import createError from '@fastify/error' +import { OllamaProvider } from '../ai-providers/ollama' const UnknownAiProviderError = createError('UNKNOWN_AI_PROVIDER', 'Unknown AI Provider') @@ -16,6 +17,9 @@ function build (aiProvider: AiWarpConfig['aiProvider']): AiProvider { } else if ('mistral' in aiProvider) { const { model, apiKey } = aiProvider.mistral return new MistralProvider(model, apiKey) + } else if ('ollama' in aiProvider) { + const { host, model } = aiProvider.ollama + return new OllamaProvider(host, model) } else { throw new UnknownAiProviderError() } diff --git a/tests/e2e/api.test.ts b/tests/e2e/api.test.ts index 4fe33a0..6961c38 100644 --- a/tests/e2e/api.test.ts +++ b/tests/e2e/api.test.ts @@ -3,7 +3,7 @@ import { before, after, describe, it } from 'node:test' import assert from 'node:assert' import { FastifyInstance } from 'fastify' import fastifyPlugin from 'fastify-plugin' -import { MOCK_CONTENT_RESPONSE, buildExpectedStreamBodyString } from '../utils/mocks' +import { MOCK_CONTENT_RESPONSE, OLLAMA_MOCK_HOST, buildExpectedStreamBodyString } from '../utils/mocks' import { AiWarpConfig } from '../../config' import { buildAiWarpApp } from '../utils/stackable' @@ -24,6 +24,15 @@ const providers: Provider[] = [ } } }, + { + name: 'Ollama', + config: { + ollama: { + host: OLLAMA_MOCK_HOST, + model: 'some-model' + } + } + }, { name: 'Mistral', config: { diff --git a/tests/e2e/index.ts b/tests/e2e/index.ts index f344059..d1fac9a 100644 --- a/tests/e2e/index.ts +++ b/tests/e2e/index.ts @@ -1,7 +1,8 @@ import './api.test' import './rate-limiting.test' import './auth.test' -import { mockMistralApi, mockOpenAiApi } from '../utils/mocks' +import { mockMistralApi, mockOllama, mockOpenAiApi } from '../utils/mocks' mockOpenAiApi() mockMistralApi() +mockOllama() diff --git a/tests/types/schema.test-d.ts b/tests/types/schema.test-d.ts index 12dec6a..7be8845 100644 --- a/tests/types/schema.test-d.ts +++ b/tests/types/schema.test-d.ts @@ -40,6 +40,15 @@ expectAssignable({ } }) +expectAssignable({ + aiProvider: { + ollama: { + host: '', + model: 'some-model' + } + } +}) + expectAssignable({ $schema: './stackable.schema.json', service: { diff --git a/tests/unit/ai-providers.test.ts b/tests/unit/ai-providers.test.ts index 8027e88..b2608c6 100644 --- a/tests/unit/ai-providers.test.ts +++ b/tests/unit/ai-providers.test.ts @@ -4,13 +4,15 @@ import assert from 'node:assert' import { MistralProvider } from '../../ai-providers/mistral' import { OpenAiProvider } from '../../ai-providers/open-ai' import { AiProvider } from '../../ai-providers/provider' -import { MOCK_CONTENT_RESPONSE, buildExpectedStreamBodyString } from '../utils/mocks' +import { MOCK_CONTENT_RESPONSE, OLLAMA_MOCK_HOST, buildExpectedStreamBodyString } from '../utils/mocks' +import { OllamaProvider } from '../../ai-providers/ollama' const expectedStreamBody = buildExpectedStreamBodyString() const providers: AiProvider[] = [ new OpenAiProvider('gpt-3.5-turbo', ''), - new MistralProvider('open-mistral-7b', '') + new MistralProvider('open-mistral-7b', ''), + new OllamaProvider(OLLAMA_MOCK_HOST, 'some-model') ] for (const provider of providers) { diff --git a/tests/unit/index.ts b/tests/unit/index.ts index 56ca4ab..a116bb3 100644 --- a/tests/unit/index.ts +++ b/tests/unit/index.ts @@ -1,6 +1,7 @@ import './generator.test' import './ai-providers.test' -import { mockMistralApi, mockOpenAiApi } from '../utils/mocks' +import { mockMistralApi, mockOllama, mockOpenAiApi } from '../utils/mocks' mockOpenAiApi() mockMistralApi() +mockOllama() diff --git a/tests/utils/mocks.ts b/tests/utils/mocks.ts index 02b711d..624ed33 100644 --- a/tests/utils/mocks.ts +++ b/tests/utils/mocks.ts @@ -188,3 +188,65 @@ export function mockMistralApi (): void { } }).persist() } + +export const OLLAMA_MOCK_HOST = 'http://127.0.0.1:41434' +let isOllamaMocked = false + +/** + * @see https://github.com/ollama/ollama/blob/9446b795b58e32c8b248a76707780f4f96b6434f/docs/api.md + */ +export function mockOllama (): void { + if (isOllamaMocked) { + return + } + + isOllamaMocked = true + + establishMockAgent() + + const pool = mockAgent.get(OLLAMA_MOCK_HOST) + pool.intercept({ + path: '/api/chat', + method: 'POST' + }).reply(200, (opts) => { + if (typeof opts.body !== 'string') { + throw new Error(`body is not a string (${typeof opts.body})`) + } + + const body = JSON.parse(opts.body) + + let response = '' + if (body.stream === true) { + for (let i = 0; i < MOCK_STREAMING_CONTENT_CHUNKS.length; i++) { + response += JSON.stringify({ + model: 'llama2', + created_at: '2023-08-04T08:52:19.385406455-07:00', + message: { + role: 'assistant', + content: MOCK_STREAMING_CONTENT_CHUNKS[i], + images: null + }, + done: i === MOCK_STREAMING_CONTENT_CHUNKS.length - 1 + }) + response += '\n' + } + } else { + response += JSON.stringify({ + model: 'llama2', + created_at: '2023-08-04T19:22:45.499127Z', + message: { + role: 'assistant', + content: MOCK_CONTENT_RESPONSE, + images: null + }, + done: true + }) + } + + return response + }, { + headers: { + 'content-type': 'application/json' + } + }).persist() +}