Skip to content

Commit

Permalink
Add Ollama provider
Browse files Browse the repository at this point in the history
  • Loading branch information
flakey5 committed Apr 11, 2024
1 parent b9cddee commit 79fb668
Show file tree
Hide file tree
Showing 13 changed files with 208 additions and 5 deletions.
71 changes: 71 additions & 0 deletions ai-providers/ollama.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import { ReadableStream, UnderlyingByteSource, ReadableByteStreamController } from 'stream/web'
import { Ollama, ChatResponse } from 'ollama'
import { AiProvider, StreamChunkCallback } from './provider'
import { AiStreamEvent, encodeEvent } from './event'

type OllamaStreamResponse = AsyncGenerator<ChatResponse>

class OllamaByteSource implements UnderlyingByteSource {
type: 'bytes' = 'bytes'
response: OllamaStreamResponse
chunkCallback?: StreamChunkCallback

constructor (response: OllamaStreamResponse, chunkCallback?: StreamChunkCallback) {
this.response = response
this.chunkCallback = chunkCallback
}

async pull (controller: ReadableByteStreamController): Promise<void> {
const { done, value } = await this.response.next()
if (done !== undefined && done) {
controller.close()
return
}

let response = value.message.content
if (this.chunkCallback !== undefined) {
response = await this.chunkCallback(response)
}

const eventData: AiStreamEvent = {
event: 'content',
data: {
response
}
}
controller.enqueue(encodeEvent(eventData))
}
}

export class OllamaProvider implements AiProvider {
model: string
client: Ollama

constructor (host: string, model: string) {
this.model = model
this.client = new Ollama({ host })
}

async ask (prompt: string): Promise<string> {
const response = await this.client.chat({
model: this.model,
messages: [
{ role: 'user', content: prompt }
]
})

return response.message.content
}

async askStream (prompt: string, chunkCallback?: StreamChunkCallback | undefined): Promise<ReadableStream> {
const response = await this.client.chat({
model: this.model,
messages: [
{ role: 'user', content: prompt }
],
stream: true
})

return new ReadableStream(new OllamaByteSource(response, chunkCallback))
}
}
6 changes: 6 additions & 0 deletions config.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,12 @@ export interface AiWarpConfig {
| "mistral-large-latest";
apiKey: string;
};
}
| {
ollama: {
host: string;
model: string;
};
};
promptDecorators?: {
prefix?: string;
Expand Down
8 changes: 8 additions & 0 deletions lib/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ class AiWarpGenerator extends ServiceGenerator {
}
}
break
case 'ollama':
config.aiProvider = {
ollama: {
host: 'http://127.0.0.1:11434',
model: this.config.aiModel
}
}
break
default:
config.aiProvider = {
openai: {
Expand Down
15 changes: 15 additions & 0 deletions lib/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,21 @@ const aiWarpSchema = {
},
required: ['mistral'],
additionalProperties: false
},
{
properties: {
ollama: {
type: 'object',
properties: {
host: { type: 'string' },
model: { type: 'string' }
},
required: ['host', 'model'],
additionalProperties: false
}
},
required: ['ollama'],
additionalProperties: false
}
]
},
Expand Down
14 changes: 14 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"fast-json-stringify": "^5.13.0",
"fastify-user": "^0.3.3",
"json-schema-to-typescript": "^13.0.0",
"ollama": "^0.5.0",
"openai": "^4.28.4",
"snazzy": "^9.0.0",
"ts-standard": "^12.0.2",
Expand Down
4 changes: 4 additions & 0 deletions plugins/warp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { MistralProvider } from '../ai-providers/mistral.js'
import { AiProvider, StreamChunkCallback } from '../ai-providers/provider'
import { AiWarpConfig } from '../config'
import createError from '@fastify/error'
import { OllamaProvider } from '../ai-providers/ollama'

const UnknownAiProviderError = createError('UNKNOWN_AI_PROVIDER', 'Unknown AI Provider')

Expand All @@ -16,6 +17,9 @@ function build (aiProvider: AiWarpConfig['aiProvider']): AiProvider {
} else if ('mistral' in aiProvider) {
const { model, apiKey } = aiProvider.mistral
return new MistralProvider(model, apiKey)
} else if ('ollama' in aiProvider) {
const { host, model } = aiProvider.ollama
return new OllamaProvider(host, model)
} else {
throw new UnknownAiProviderError()
}
Expand Down
11 changes: 10 additions & 1 deletion tests/e2e/api.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { before, after, describe, it } from 'node:test'
import assert from 'node:assert'
import { FastifyInstance } from 'fastify'
import fastifyPlugin from 'fastify-plugin'
import { MOCK_CONTENT_RESPONSE, buildExpectedStreamBodyString } from '../utils/mocks'
import { MOCK_CONTENT_RESPONSE, OLLAMA_MOCK_HOST, buildExpectedStreamBodyString } from '../utils/mocks'
import { AiWarpConfig } from '../../config'
import { buildAiWarpApp } from '../utils/stackable'

Expand All @@ -24,6 +24,15 @@ const providers: Provider[] = [
}
}
},
{
name: 'Ollama',
config: {
ollama: {
host: OLLAMA_MOCK_HOST,
model: 'some-model'
}
}
},
{
name: 'Mistral',
config: {
Expand Down
3 changes: 2 additions & 1 deletion tests/e2e/index.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import './api.test'
import './rate-limiting.test'
import './auth.test'
import { mockMistralApi, mockOpenAiApi } from '../utils/mocks'
import { mockMistralApi, mockOllama, mockOpenAiApi } from '../utils/mocks'

mockOpenAiApi()
mockMistralApi()
mockOllama()
9 changes: 9 additions & 0 deletions tests/types/schema.test-d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ expectAssignable<AiWarpConfig>({
}
})

expectAssignable<AiWarpConfig>({
aiProvider: {
ollama: {
host: '',
model: 'some-model'
}
}
})

expectAssignable<AiWarpConfig>({
$schema: './stackable.schema.json',
service: {
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/ai-providers.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ import assert from 'node:assert'
import { MistralProvider } from '../../ai-providers/mistral'
import { OpenAiProvider } from '../../ai-providers/open-ai'
import { AiProvider } from '../../ai-providers/provider'
import { MOCK_CONTENT_RESPONSE, buildExpectedStreamBodyString } from '../utils/mocks'
import { MOCK_CONTENT_RESPONSE, OLLAMA_MOCK_HOST, buildExpectedStreamBodyString } from '../utils/mocks'
import { OllamaProvider } from '../../ai-providers/ollama'

const expectedStreamBody = buildExpectedStreamBodyString()

const providers: AiProvider[] = [
new OpenAiProvider('gpt-3.5-turbo', ''),
new MistralProvider('open-mistral-7b', '')
new MistralProvider('open-mistral-7b', ''),
new OllamaProvider(OLLAMA_MOCK_HOST, 'some-model')
]

for (const provider of providers) {
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import './generator.test'
import './ai-providers.test'
import { mockMistralApi, mockOpenAiApi } from '../utils/mocks'
import { mockMistralApi, mockOllama, mockOpenAiApi } from '../utils/mocks'

mockOpenAiApi()
mockMistralApi()
mockOllama()
62 changes: 62 additions & 0 deletions tests/utils/mocks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,65 @@ export function mockMistralApi (): void {
}
}).persist()
}

export const OLLAMA_MOCK_HOST = 'http://127.0.0.1:41434'
let isOllamaMocked = false

/**
* @see https://github.com/ollama/ollama/blob/9446b795b58e32c8b248a76707780f4f96b6434f/docs/api.md
*/
export function mockOllama (): void {
if (isOllamaMocked) {
return
}

isOllamaMocked = true

establishMockAgent()

const pool = mockAgent.get(OLLAMA_MOCK_HOST)
pool.intercept({
path: '/api/chat',
method: 'POST'
}).reply(200, (opts) => {
if (typeof opts.body !== 'string') {
throw new Error(`body is not a string (${typeof opts.body})`)
}

const body = JSON.parse(opts.body)

let response = ''
if (body.stream === true) {
for (let i = 0; i < MOCK_STREAMING_CONTENT_CHUNKS.length; i++) {
response += JSON.stringify({
model: 'llama2',
created_at: '2023-08-04T08:52:19.385406455-07:00',
message: {
role: 'assistant',
content: MOCK_STREAMING_CONTENT_CHUNKS[i],
images: null
},
done: i === MOCK_STREAMING_CONTENT_CHUNKS.length - 1
})
response += '\n'
}
} else {
response += JSON.stringify({
model: 'llama2',
created_at: '2023-08-04T19:22:45.499127Z',
message: {
role: 'assistant',
content: MOCK_CONTENT_RESPONSE,
images: null
},
done: true
})
}

return response
}, {
headers: {
'content-type': 'application/json'
}
}).persist()
}

0 comments on commit 79fb668

Please sign in to comment.