diff --git a/ai-providers/azure.ts b/ai-providers/azure.ts new file mode 100644 index 0000000..5cb2051 --- /dev/null +++ b/ai-providers/azure.ts @@ -0,0 +1,141 @@ +import { ReadableStream, ReadableByteStreamController, UnderlyingByteSource } from 'stream/web' +import { AiProvider, NoContentError, StreamChunkCallback } from './provider' +import { AiStreamEvent, encodeEvent } from './event' + +// @ts-expect-error +type AzureEventStream = import('@azure/openai').EventStream +// @ts-expect-error +type AzureChatCompletions = import('@azure/openai').ChatCompletions + +type AzureStreamResponse = AzureEventStream + +class AzureByteSource implements UnderlyingByteSource { + type: 'bytes' = 'bytes' + response: AzureStreamResponse + reader?: ReadableStreamDefaultReader + chunkCallback?: StreamChunkCallback + + constructor (response: AzureStreamResponse, chunkCallback?: StreamChunkCallback) { + this.response = response + this.chunkCallback = chunkCallback + } + + start (): void { + this.reader = this.response.getReader() + } + + async pull (controller: ReadableByteStreamController): Promise { + // start() defines this.reader and is called before this + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + const { done, value } = await this.reader!.read() + + if (done !== undefined && done) { + controller.close() + return + } + + if (value.choices.length === 0) { + const error = new NoContentError('Azure OpenAI') + + const eventData: AiStreamEvent = { + event: 'error', + data: error + } + controller.enqueue(encodeEvent(eventData)) + controller.close() + + return + } + + const { delta } = value.choices[0] + if (delta === undefined || delta.content === null) { + const error = new NoContentError('Azure OpenAI') + + const eventData: AiStreamEvent = { + event: 'error', + data: error + } + controller.enqueue(encodeEvent(eventData)) + controller.close() + + return + } + + let response = delta.content + if (this.chunkCallback !== undefined) { + response = await this.chunkCallback(response) + } + + const eventData: AiStreamEvent = { + event: 'content', + data: { + response + } + } + controller.enqueue(encodeEvent(eventData)) + } +} + +export class AzureProvider implements AiProvider { + endpoint: string + deploymentName: string + apiKey: string + // @ts-expect-error typescript doesn't like this type import even though + // it's fine in the Mistral client? + client?: import('@azure/openai').OpenAIClient = undefined + allowInsecureConnections: boolean + + constructor (endpoint: string, apiKey: string, deploymentName: string, allowInsecureConnections: boolean = false) { + this.endpoint = endpoint + this.apiKey = apiKey + this.deploymentName = deploymentName + this.allowInsecureConnections = allowInsecureConnections + } + + async ask (prompt: string): Promise { + if (this.client === undefined) { + const { OpenAIClient, AzureKeyCredential } = await import('@azure/openai') + this.client = new OpenAIClient( + this.endpoint, + new AzureKeyCredential(this.apiKey), + { + allowInsecureConnection: this.allowInsecureConnections + } + ) + } + + const { choices } = await this.client.getChatCompletions(this.deploymentName, [ + { role: 'user', content: prompt } + ]) + + if (choices.length === 0) { + throw new NoContentError('Azure OpenAI') + } + + const { message } = choices[0] + if (message === undefined || message.content === null) { + throw new NoContentError('Azure OpenAI') + } + + return message.content + } + + async askStream (prompt: string, chunkCallback?: StreamChunkCallback | undefined): Promise { + if (this.client === undefined) { + const { OpenAIClient, AzureKeyCredential } = await import('@azure/openai') + this.client = new OpenAIClient( + this.endpoint, + new AzureKeyCredential(this.apiKey), + { + allowInsecureConnection: this.allowInsecureConnections + } + ) + } + + const response = await this.client.streamChatCompletions(this.deploymentName, [ + { role: 'user', content: prompt } + ]) + + return new ReadableStream(new AzureByteSource(response, chunkCallback)) + } +} diff --git a/config.d.ts b/config.d.ts index 460162c..af3aa18 100644 --- a/config.d.ts +++ b/config.d.ts @@ -262,6 +262,14 @@ export interface AiWarpConfig { host: string; model: string; }; + } + | { + azure: { + endpoint: string; + apiKey: string; + deploymentName: string; + allowInsecureConnections?: boolean; + }; }; promptDecorators?: { prefix?: string; diff --git a/lib/generator.ts b/lib/generator.ts index 91afab8..1082543 100644 --- a/lib/generator.ts +++ b/lib/generator.ts @@ -86,6 +86,15 @@ class AiWarpGenerator extends ServiceGenerator { } } break + case 'azure': + config.aiProvider = { + azure: { + endpoint: 'https://myaccount.openai.azure.com/', + apiKey: `{${this.getEnvVarName('PLT_AZURE_API_KEY')}}`, + deploymentName: this.config.aiModel + } + } + break default: config.aiProvider = { openai: { diff --git a/lib/schema.ts b/lib/schema.ts index 8bf3c7b..e0e2824 100644 --- a/lib/schema.ts +++ b/lib/schema.ts @@ -83,6 +83,26 @@ const aiWarpSchema = { }, required: ['ollama'], additionalProperties: false + }, + { + properties: { + azure: { + type: 'object', + properties: { + endpoint: { type: 'string' }, + apiKey: { type: 'string' }, + deploymentName: { type: 'string' }, + allowInsecureConnections: { + type: 'boolean', + default: false + } + }, + required: ['endpoint', 'apiKey', 'deploymentName'], + additionalProperties: false + } + }, + required: ['azure'], + additionalProperties: false } ] }, diff --git a/package-lock.json b/package-lock.json index 56f381f..ebfc861 100644 --- a/package-lock.json +++ b/package-lock.json @@ -8,6 +8,7 @@ "name": "@platformatic/ai-warp", "version": "0.0.1", "dependencies": { + "@azure/openai": "^1.0.0-beta.12", "@fastify/error": "^3.4.1", "@fastify/rate-limit": "^9.1.0", "@fastify/type-provider-typebox": "^4.0.0", @@ -95,6 +96,126 @@ "url": "https://github.com/sponsors/philsturgeon" } }, + "node_modules/@azure-rest/core-client": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/@azure-rest/core-client/-/core-client-1.4.0.tgz", + "integrity": "sha512-ozTDPBVUDR5eOnMIwhggbnVmOrka4fXCs8n8mvUo4WLLc38kki6bAOByDoVZZPz/pZy2jMt2kwfpvy/UjALj6w==", + "dependencies": { + "@azure/abort-controller": "^2.0.0", + "@azure/core-auth": "^1.3.0", + "@azure/core-rest-pipeline": "^1.5.0", + "@azure/core-tracing": "^1.0.1", + "@azure/core-util": "^1.0.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@azure/abort-controller": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/@azure/abort-controller/-/abort-controller-2.1.2.tgz", + "integrity": "sha512-nBrLsEWm4J2u5LpAPjxADTlq3trDgVZZXHNKabeXZtpq3d3AbN/KGO82R87rdDz5/lYB024rtEf10/q0urNgsA==", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@azure/core-auth": { + "version": "1.7.2", + "resolved": "https://registry.npmjs.org/@azure/core-auth/-/core-auth-1.7.2.tgz", + "integrity": "sha512-Igm/S3fDYmnMq1uKS38Ae1/m37B3zigdlZw+kocwEhh5GjyKjPrXKO2J6rzpC1wAxrNil/jX9BJRqBshyjnF3g==", + "dependencies": { + "@azure/abort-controller": "^2.0.0", + "@azure/core-util": "^1.1.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@azure/core-rest-pipeline": { + "version": "1.15.2", + "resolved": "https://registry.npmjs.org/@azure/core-rest-pipeline/-/core-rest-pipeline-1.15.2.tgz", + "integrity": "sha512-BmWfpjc/QXc2ipHOh6LbUzp3ONCaa6xzIssTU0DwH9bbYNXJlGUL6tujx5TrbVd/QQknmS+vlQJGrCq2oL1gZA==", + "dependencies": { + "@azure/abort-controller": "^2.0.0", + "@azure/core-auth": "^1.4.0", + "@azure/core-tracing": "^1.0.1", + "@azure/core-util": "^1.3.0", + "@azure/logger": "^1.0.0", + "http-proxy-agent": "^7.0.0", + "https-proxy-agent": "^7.0.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@azure/core-sse": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/@azure/core-sse/-/core-sse-2.1.2.tgz", + "integrity": "sha512-yf+pFIu8yCzXu9RbH2+8kp9vITIKJLHgkLgFNA6hxiDHK3fxeP596cHUj4c8Cm8JlooaUnYdHmF84KCZt3jbmw==", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@azure/core-tracing": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@azure/core-tracing/-/core-tracing-1.1.2.tgz", + "integrity": "sha512-dawW9ifvWAWmUm9/h+/UQ2jrdvjCJ7VJEuCJ6XVNudzcOwm53BFZH4Q845vjfgoUAM8ZxokvVNxNxAITc502YA==", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@azure/core-util": { + "version": "1.9.0", + "resolved": "https://registry.npmjs.org/@azure/core-util/-/core-util-1.9.0.tgz", + "integrity": "sha512-AfalUQ1ZppaKuxPPMsFEUdX6GZPB3d9paR9d/TTL7Ow2De8cJaC7ibi7kWVlFAVPCYo31OcnGymc0R89DX8Oaw==", + "dependencies": { + "@azure/abort-controller": "^2.0.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@azure/logger": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@azure/logger/-/logger-1.1.2.tgz", + "integrity": "sha512-l170uE7bsKpIU6B/giRc9i4NI0Mj+tANMMMxf7Zi/5cKzEqPayP7+X1WPrG7e+91JgY8N+7K7nF2WOi7iVhXvg==", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@azure/openai": { + "version": "1.0.0-beta.12", + "resolved": "https://registry.npmjs.org/@azure/openai/-/openai-1.0.0-beta.12.tgz", + "integrity": "sha512-qKblxr6oVa8GsyNzY+/Ub9VmEsPYKhBrUrPaNEQiM+qrxnBPVm9kaeqGFFb/U78Q2zOabmhF9ctYt3xBW0nWnQ==", + "dependencies": { + "@azure-rest/core-client": "^1.1.7", + "@azure/core-auth": "^1.4.0", + "@azure/core-rest-pipeline": "^1.13.0", + "@azure/core-sse": "^2.0.0", + "@azure/core-util": "^1.4.0", + "@azure/logger": "^1.0.3", + "tslib": "^2.4.0" + }, + "engines": { + "node": ">=18.0.0" + } + }, "node_modules/@babel/code-frame": { "version": "7.24.2", "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.24.2.tgz", @@ -1816,6 +1937,17 @@ "acorn": "^6.0.0 || ^7.0.0 || ^8.0.0" } }, + "node_modules/agent-base": { + "version": "7.1.1", + "resolved": "https://registry.npmjs.org/agent-base/-/agent-base-7.1.1.tgz", + "integrity": "sha512-H0TSyFNDMomMNJQBn8wFV5YC/2eJ+VXECwOadZJT554xP6cODZHPX3H9QMQECxvrgiSOP1pHjy1sMWQVYJOUOA==", + "dependencies": { + "debug": "^4.3.4" + }, + "engines": { + "node": ">= 14" + } + }, "node_modules/agentkeepalive": { "version": "4.5.0", "resolved": "https://registry.npmjs.org/agentkeepalive/-/agentkeepalive-4.5.0.tgz", @@ -4791,6 +4923,30 @@ "node": ">= 0.8" } }, + "node_modules/http-proxy-agent": { + "version": "7.0.2", + "resolved": "https://registry.npmjs.org/http-proxy-agent/-/http-proxy-agent-7.0.2.tgz", + "integrity": "sha512-T1gkAiYYDWYx3V5Bmyu7HcfcvL7mUrTWiM6yOfa3PIphViJ/gFPbvidQ+veqSOHci/PxBcDabeUNCzpOODJZig==", + "dependencies": { + "agent-base": "^7.1.0", + "debug": "^4.3.4" + }, + "engines": { + "node": ">= 14" + } + }, + "node_modules/https-proxy-agent": { + "version": "7.0.4", + "resolved": "https://registry.npmjs.org/https-proxy-agent/-/https-proxy-agent-7.0.4.tgz", + "integrity": "sha512-wlwpilI7YdjSkWaQ/7omYBMTliDcmCN8OLihO6I9B86g06lMyAoqgoDpV0XqoaPOKj+0DIdAvnsWfyAAhmimcg==", + "dependencies": { + "agent-base": "^7.0.2", + "debug": "4" + }, + "engines": { + "node": ">= 14" + } + }, "node_modules/human-signals": { "version": "5.0.0", "resolved": "https://registry.npmjs.org/human-signals/-/human-signals-5.0.0.tgz", diff --git a/package.json b/package.json index 704b823..7c71b4c 100644 --- a/package.json +++ b/package.json @@ -30,6 +30,7 @@ "typescript": "^5.3.3" }, "dependencies": { + "@azure/openai": "^1.0.0-beta.12", "@fastify/error": "^3.4.1", "@fastify/rate-limit": "^9.1.0", "@fastify/type-provider-typebox": "^4.0.0", diff --git a/plugins/warp.ts b/plugins/warp.ts index 5331f4b..b42a464 100644 --- a/plugins/warp.ts +++ b/plugins/warp.ts @@ -7,6 +7,7 @@ import { AiProvider, StreamChunkCallback } from '../ai-providers/provider' import { AiWarpConfig } from '../config' import createError from '@fastify/error' import { OllamaProvider } from '../ai-providers/ollama' +import { AzureProvider } from '../ai-providers/azure' const UnknownAiProviderError = createError('UNKNOWN_AI_PROVIDER', 'Unknown AI Provider') @@ -20,6 +21,9 @@ function build (aiProvider: AiWarpConfig['aiProvider']): AiProvider { } else if ('ollama' in aiProvider) { const { host, model } = aiProvider.ollama return new OllamaProvider(host, model) + } else if ('azure' in aiProvider) { + const { endpoint, apiKey, deploymentName, allowInsecureConnections } = aiProvider.azure + return new AzureProvider(endpoint, apiKey, deploymentName, allowInsecureConnections) } else { throw new UnknownAiProviderError() } diff --git a/tests/e2e/api.test.ts b/tests/e2e/api.test.ts index 6961c38..cc7065e 100644 --- a/tests/e2e/api.test.ts +++ b/tests/e2e/api.test.ts @@ -3,9 +3,11 @@ import { before, after, describe, it } from 'node:test' import assert from 'node:assert' import { FastifyInstance } from 'fastify' import fastifyPlugin from 'fastify-plugin' -import { MOCK_CONTENT_RESPONSE, OLLAMA_MOCK_HOST, buildExpectedStreamBodyString } from '../utils/mocks' import { AiWarpConfig } from '../../config' import { buildAiWarpApp } from '../utils/stackable' +import { AZURE_DEPLOYMENT_NAME, AZURE_MOCK_HOST } from '../utils/mocks/azure' +import { MOCK_CONTENT_RESPONSE, buildExpectedStreamBodyString } from '../utils/mocks/base' +import { OLLAMA_MOCK_HOST } from '../utils/mocks/ollama' const expectedStreamBody = buildExpectedStreamBodyString() @@ -33,6 +35,17 @@ const providers: Provider[] = [ } } }, + { + name: 'Azure', + config: { + azure: { + endpoint: AZURE_MOCK_HOST, + apiKey: 'asd', + deploymentName: AZURE_DEPLOYMENT_NAME, + allowInsecureConnections: true + } + } + }, { name: 'Mistral', config: { diff --git a/tests/e2e/index.ts b/tests/e2e/index.ts index d1fac9a..8b9e1b7 100644 --- a/tests/e2e/index.ts +++ b/tests/e2e/index.ts @@ -1,8 +1,6 @@ import './api.test' import './rate-limiting.test' import './auth.test' -import { mockMistralApi, mockOllama, mockOpenAiApi } from '../utils/mocks' +import { mockAllProviders } from '../utils/mocks' -mockOpenAiApi() -mockMistralApi() -mockOllama() +mockAllProviders() diff --git a/tests/unit/ai-providers.test.ts b/tests/unit/ai-providers.test.ts index b2608c6..47ba4d0 100644 --- a/tests/unit/ai-providers.test.ts +++ b/tests/unit/ai-providers.test.ts @@ -4,15 +4,19 @@ import assert from 'node:assert' import { MistralProvider } from '../../ai-providers/mistral' import { OpenAiProvider } from '../../ai-providers/open-ai' import { AiProvider } from '../../ai-providers/provider' -import { MOCK_CONTENT_RESPONSE, OLLAMA_MOCK_HOST, buildExpectedStreamBodyString } from '../utils/mocks' import { OllamaProvider } from '../../ai-providers/ollama' +import { AzureProvider } from '../../ai-providers/azure' +import { MOCK_CONTENT_RESPONSE, buildExpectedStreamBodyString } from '../utils/mocks/base' +import { OLLAMA_MOCK_HOST } from '../utils/mocks/ollama' +import { AZURE_DEPLOYMENT_NAME, AZURE_MOCK_HOST } from '../utils/mocks/azure' const expectedStreamBody = buildExpectedStreamBodyString() const providers: AiProvider[] = [ new OpenAiProvider('gpt-3.5-turbo', ''), new MistralProvider('open-mistral-7b', ''), - new OllamaProvider(OLLAMA_MOCK_HOST, 'some-model') + new OllamaProvider(OLLAMA_MOCK_HOST, 'some-model'), + new AzureProvider(AZURE_MOCK_HOST, 'abc', AZURE_DEPLOYMENT_NAME, true) ] for (const provider of providers) { diff --git a/tests/unit/index.ts b/tests/unit/index.ts index a116bb3..7ecbea7 100644 --- a/tests/unit/index.ts +++ b/tests/unit/index.ts @@ -1,7 +1,5 @@ import './generator.test' import './ai-providers.test' -import { mockMistralApi, mockOllama, mockOpenAiApi } from '../utils/mocks' +import { mockAllProviders } from '../utils/mocks' -mockOpenAiApi() -mockMistralApi() -mockOllama() +mockAllProviders() diff --git a/tests/utils/mocks.ts b/tests/utils/mocks.ts deleted file mode 100644 index 624ed33..0000000 --- a/tests/utils/mocks.ts +++ /dev/null @@ -1,252 +0,0 @@ -import { MockAgent, setGlobalDispatcher } from 'undici' - -export const MOCK_CONTENT_RESPONSE = 'asd123' - -export const MOCK_STREAMING_CONTENT_CHUNKS = [ - 'chunk1', - 'chunk2', - 'chunk3' -] - -/** - * @returns The full body that should be returned from the stream endpoint - */ -export function buildExpectedStreamBodyString (): string { - let body = '' - for (const chunk of MOCK_STREAMING_CONTENT_CHUNKS) { - body += `event: content\ndata: {"response":"${chunk}"}\n\n` - } - return body -} - -const mockAgent = new MockAgent() -let isMockAgentEstablished = false -function establishMockAgent (): void { - if (isMockAgentEstablished) { - return - } - setGlobalDispatcher(mockAgent) - isMockAgentEstablished = true -} - -let isOpenAiMocked = false - -/** - * Mock OpenAI's rest api - * @see https://platform.openai.com/docs/api-reference/chat - */ -export function mockOpenAiApi (): void { - if (isOpenAiMocked) { - return - } - - isOpenAiMocked = true - - establishMockAgent() - - const pool = mockAgent.get('https://api.openai.com') - pool.intercept({ - path: '/v1/chat/completions', - method: 'POST' - }).reply(200, (opts) => { - if (typeof opts.body !== 'string') { - throw new Error(`body is not a string (${typeof opts.body})`) - } - - const body = JSON.parse(opts.body) - - let response = '' - if (body.stream === true) { - for (let i = 0; i < MOCK_STREAMING_CONTENT_CHUNKS.length; i++) { - response += 'data: ' - response += JSON.stringify({ - id: 'chatcmpl-123', - object: 'chat.completion.chunk', - created: 1694268190, - model: 'gpt-3.5-turbo-0125', - system_fingerprint: 'fp_44709d6fcb', - choices: [{ - index: 0, - delta: { - role: 'assistant', - content: MOCK_STREAMING_CONTENT_CHUNKS[i] - }, - logprobs: null, - finish_reason: i === MOCK_STREAMING_CONTENT_CHUNKS.length ? 'stop' : null - }] - }) - response += '\n\n' - } - response += 'data: [DONE]\n\n' - } else { - response += JSON.stringify({ - id: 'chatcmpl-123', - object: 'chat.completion', - created: new Date().getTime() / 1000, - model: 'gpt-3.5-turbo-0125', - system_fingerprint: 'fp_fp_44709d6fcb', - choices: [{ - index: 0, - message: { - role: 'assistant', - content: MOCK_CONTENT_RESPONSE - }, - logprobs: null, - finish_reason: 'stop' - }], - usage: { - prompt_tokens: 1, - completion_tokens: 1, - total_tokens: 2 - } - }) - } - - return response - }, { - headers: { - 'content-type': 'application/json' - } - }).persist() -} - -let isMistralMocked = false - -/** - * Mock Mistral's rest api - * @see https://docs.mistral.ai/api/#operation/createChatCompletion - */ -export function mockMistralApi (): void { - if (isMistralMocked) { - return - } - - isMistralMocked = true - - establishMockAgent() - - const pool = mockAgent.get('https://api.mistral.ai') - pool.intercept({ - path: '/v1/chat/completions', - method: 'POST' - }).reply(200, (opts) => { - if (typeof opts.body !== 'string') { - throw new Error(`body is not a string (${typeof opts.body})`) - } - - const body = JSON.parse(opts.body) - - let response = '' - if (body.stream === true) { - for (let i = 0; i < MOCK_STREAMING_CONTENT_CHUNKS.length; i++) { - response += 'data: ' - response += JSON.stringify({ - id: 'cmpl-e5cc70bb28c444948073e77776eb30ef', - object: 'chat.completion.chunk', - created: 1694268190, - model: 'mistral-small-latest', - choices: [{ - index: 0, - delta: { - role: 'assistant', - content: MOCK_STREAMING_CONTENT_CHUNKS[i] - }, - logprobs: null, - finish_reason: i === MOCK_STREAMING_CONTENT_CHUNKS.length ? 'stop' : null - }] - }) - response += '\n\n' - } - response += 'data: [DONE]\n\n' - } else { - response += JSON.stringify({ - id: 'cmpl-e5cc70bb28c444948073e77776eb30ef', - object: 'chat.completion', - created: new Date().getTime() / 1000, - model: 'mistral-small-latest', - choices: [{ - index: 0, - message: { - role: 'assistant', - content: MOCK_CONTENT_RESPONSE - }, - logprobs: null, - finish_reason: 'stop' - }], - usage: { - prompt_tokens: 1, - completion_tokens: 1, - total_tokens: 2 - } - }) - } - - return response - }, { - headers: { - 'content-type': 'application/json' - } - }).persist() -} - -export const OLLAMA_MOCK_HOST = 'http://127.0.0.1:41434' -let isOllamaMocked = false - -/** - * @see https://github.com/ollama/ollama/blob/9446b795b58e32c8b248a76707780f4f96b6434f/docs/api.md - */ -export function mockOllama (): void { - if (isOllamaMocked) { - return - } - - isOllamaMocked = true - - establishMockAgent() - - const pool = mockAgent.get(OLLAMA_MOCK_HOST) - pool.intercept({ - path: '/api/chat', - method: 'POST' - }).reply(200, (opts) => { - if (typeof opts.body !== 'string') { - throw new Error(`body is not a string (${typeof opts.body})`) - } - - const body = JSON.parse(opts.body) - - let response = '' - if (body.stream === true) { - for (let i = 0; i < MOCK_STREAMING_CONTENT_CHUNKS.length; i++) { - response += JSON.stringify({ - model: 'llama2', - created_at: '2023-08-04T08:52:19.385406455-07:00', - message: { - role: 'assistant', - content: MOCK_STREAMING_CONTENT_CHUNKS[i], - images: null - }, - done: i === MOCK_STREAMING_CONTENT_CHUNKS.length - 1 - }) - response += '\n' - } - } else { - response += JSON.stringify({ - model: 'llama2', - created_at: '2023-08-04T19:22:45.499127Z', - message: { - role: 'assistant', - content: MOCK_CONTENT_RESPONSE, - images: null - }, - done: true - }) - } - - return response - }, { - headers: { - 'content-type': 'application/json' - } - }).persist() -} diff --git a/tests/utils/mocks/azure.ts b/tests/utils/mocks/azure.ts new file mode 100644 index 0000000..230d7f8 --- /dev/null +++ b/tests/utils/mocks/azure.ts @@ -0,0 +1,88 @@ +import { Server, createServer } from 'node:http' +import { MOCK_CONTENT_RESPONSE, MOCK_STREAMING_CONTENT_CHUNKS } from './base' + +export const AZURE_MOCK_HOST = 'http://127.0.0.1:41435' + +export const AZURE_DEPLOYMENT_NAME = 'some-deployment' + +/** + * @see https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions + */ +export function mockAzure (): Server { + // The Azure client doesn't use undici's fetch and there's no option to pass + // it in like the other providers' clients unfortunately, so let's create an + // actual server + const server = createServer((req, res) => { + if (req.url !== '/openai/deployments/some-deployment/chat/completions?api-version=2024-03-01-preview') { + res.end() + throw new Error(`unsupported url or api version: ${req.url ?? ''}`) + } + + let bodyString = '' + req.on('data', (chunk: string) => { + bodyString += chunk + }) + req.on('end', () => { + const body: { stream: boolean } = JSON.parse(bodyString) + + if (body.stream) { + res.setHeader('content-type', 'text/event-stream') + + for (let i = 0; i < MOCK_STREAMING_CONTENT_CHUNKS.length; i++) { + res.write('data: ') + res.write(JSON.stringify({ + id: 'chatcmpl-6v7mkQj980V1yBec6ETrKPRqFjNw9', + object: 'chat.completion', + created: 1679072642, + model: 'gpt-35-turbo', + usage: { + prompt_tokens: 58, + completion_tokens: 68, + total_tokens: 126 + }, + choices: [ + { + delta: { + role: 'assistant', + content: MOCK_STREAMING_CONTENT_CHUNKS[i] + }, + finish_reason: i === MOCK_STREAMING_CONTENT_CHUNKS.length ? 'stop' : null, + index: 0 + } + ] + })) + res.write('\n\n') + } + res.write('data: [DONE]\n\n') + } else { + res.setHeader('content-type', 'application/json') + res.write(JSON.stringify({ + id: 'chatcmpl-6v7mkQj980V1yBec6ETrKPRqFjNw9', + object: 'chat.completion', + created: 1679072642, + model: 'gpt-35-turbo', + usage: { + prompt_tokens: 58, + completion_tokens: 68, + total_tokens: 126 + }, + choices: [ + { + message: { + role: 'assistant', + content: MOCK_CONTENT_RESPONSE + }, + finish_reason: 'stop', + index: 0 + } + ] + })) + } + + res.end() + }) + }) + server.listen(41435) + + return server +} diff --git a/tests/utils/mocks/base.ts b/tests/utils/mocks/base.ts new file mode 100644 index 0000000..400d036 --- /dev/null +++ b/tests/utils/mocks/base.ts @@ -0,0 +1,31 @@ +import { MockAgent, setGlobalDispatcher } from 'undici' + +export const MOCK_CONTENT_RESPONSE = 'asd123' + +export const MOCK_STREAMING_CONTENT_CHUNKS = [ + 'chunk1', + 'chunk2', + 'chunk3' +] + +/** + * @returns The full body that should be returned from the stream endpoint + */ +export function buildExpectedStreamBodyString (): string { + let body = '' + for (const chunk of MOCK_STREAMING_CONTENT_CHUNKS) { + body += `event: content\ndata: {"response":"${chunk}"}\n\n` + } + return body +} + +export const MOCK_AGENT = new MockAgent() + +let isMockAgentEstablished = false +export function establishMockAgent (): void { + if (isMockAgentEstablished) { + return + } + setGlobalDispatcher(MOCK_AGENT) + isMockAgentEstablished = true +} diff --git a/tests/utils/mocks/index.ts b/tests/utils/mocks/index.ts new file mode 100644 index 0000000..7b047e1 --- /dev/null +++ b/tests/utils/mocks/index.ts @@ -0,0 +1,16 @@ +import { after } from 'node:test' +import { mockAzure } from './azure' +import { mockMistralApi } from './mistral' +import { mockOllama } from './ollama' +import { mockOpenAiApi } from './open-ai' + +export function mockAllProviders (): void { + mockOpenAiApi() + mockMistralApi() + mockOllama() + + const azureMock = mockAzure() + after(() => { + azureMock.close() + }) +} diff --git a/tests/utils/mocks/mistral.ts b/tests/utils/mocks/mistral.ts new file mode 100644 index 0000000..c8001a9 --- /dev/null +++ b/tests/utils/mocks/mistral.ts @@ -0,0 +1,80 @@ +import { MOCK_AGENT, MOCK_CONTENT_RESPONSE, MOCK_STREAMING_CONTENT_CHUNKS, establishMockAgent } from './base' + +let isMistralMocked = false + +/** + * Mock Mistral's rest api + * @see https://docs.mistral.ai/api/#operation/createChatCompletion + */ +export function mockMistralApi (): void { + if (isMistralMocked) { + return + } + + isMistralMocked = true + + establishMockAgent() + + const pool = MOCK_AGENT.get('https://api.mistral.ai') + pool.intercept({ + path: '/v1/chat/completions', + method: 'POST' + }).reply(200, (opts) => { + if (typeof opts.body !== 'string') { + throw new Error(`body is not a string (${typeof opts.body})`) + } + + const body = JSON.parse(opts.body) + + let response = '' + if (body.stream === true) { + for (let i = 0; i < MOCK_STREAMING_CONTENT_CHUNKS.length; i++) { + response += 'data: ' + response += JSON.stringify({ + id: 'cmpl-e5cc70bb28c444948073e77776eb30ef', + object: 'chat.completion.chunk', + created: 1694268190, + model: 'mistral-small-latest', + choices: [{ + index: 0, + delta: { + role: 'assistant', + content: MOCK_STREAMING_CONTENT_CHUNKS[i] + }, + logprobs: null, + finish_reason: i === MOCK_STREAMING_CONTENT_CHUNKS.length ? 'stop' : null + }] + }) + response += '\n\n' + } + response += 'data: [DONE]\n\n' + } else { + response += JSON.stringify({ + id: 'cmpl-e5cc70bb28c444948073e77776eb30ef', + object: 'chat.completion', + created: new Date().getTime() / 1000, + model: 'mistral-small-latest', + choices: [{ + index: 0, + message: { + role: 'assistant', + content: MOCK_CONTENT_RESPONSE + }, + logprobs: null, + finish_reason: 'stop' + }], + usage: { + prompt_tokens: 1, + completion_tokens: 1, + total_tokens: 2 + } + }) + } + + return response + }, { + headers: { + 'content-type': 'application/json' + } + }).persist() +} diff --git a/tests/utils/mocks/ollama.ts b/tests/utils/mocks/ollama.ts new file mode 100644 index 0000000..ee54b1d --- /dev/null +++ b/tests/utils/mocks/ollama.ts @@ -0,0 +1,63 @@ +import { MOCK_AGENT, MOCK_CONTENT_RESPONSE, MOCK_STREAMING_CONTENT_CHUNKS, establishMockAgent } from './base' + +export const OLLAMA_MOCK_HOST = 'http://127.0.0.1:41434' +let isOllamaMocked = false + +/** + * @see https://github.com/ollama/ollama/blob/9446b795b58e32c8b248a76707780f4f96b6434f/docs/api.md + */ +export function mockOllama (): void { + if (isOllamaMocked) { + return + } + + isOllamaMocked = true + + establishMockAgent() + + const pool = MOCK_AGENT.get(OLLAMA_MOCK_HOST) + pool.intercept({ + path: '/api/chat', + method: 'POST' + }).reply(200, (opts) => { + if (typeof opts.body !== 'string') { + throw new Error(`body is not a string (${typeof opts.body})`) + } + + const body = JSON.parse(opts.body) + + let response = '' + if (body.stream === true) { + for (let i = 0; i < MOCK_STREAMING_CONTENT_CHUNKS.length; i++) { + response += JSON.stringify({ + model: 'llama2', + created_at: '2023-08-04T08:52:19.385406455-07:00', + message: { + role: 'assistant', + content: MOCK_STREAMING_CONTENT_CHUNKS[i], + images: null + }, + done: i === MOCK_STREAMING_CONTENT_CHUNKS.length - 1 + }) + response += '\n' + } + } else { + response += JSON.stringify({ + model: 'llama2', + created_at: '2023-08-04T19:22:45.499127Z', + message: { + role: 'assistant', + content: MOCK_CONTENT_RESPONSE, + images: null + }, + done: true + }) + } + + return response + }, { + headers: { + 'content-type': 'application/json' + } + }).persist() +} diff --git a/tests/utils/mocks/open-ai.ts b/tests/utils/mocks/open-ai.ts new file mode 100644 index 0000000..efdf087 --- /dev/null +++ b/tests/utils/mocks/open-ai.ts @@ -0,0 +1,82 @@ +import { MOCK_AGENT, MOCK_CONTENT_RESPONSE, MOCK_STREAMING_CONTENT_CHUNKS, establishMockAgent } from './base' + +let isOpenAiMocked = false + +/** + * Mock OpenAI's rest api + * @see https://platform.openai.com/docs/api-reference/chat + */ +export function mockOpenAiApi (): void { + if (isOpenAiMocked) { + return + } + + isOpenAiMocked = true + + establishMockAgent() + + const pool = MOCK_AGENT.get('https://api.openai.com') + pool.intercept({ + path: '/v1/chat/completions', + method: 'POST' + }).reply(200, (opts) => { + if (typeof opts.body !== 'string') { + throw new Error(`body is not a string (${typeof opts.body})`) + } + + const body = JSON.parse(opts.body) + + let response = '' + if (body.stream === true) { + for (let i = 0; i < MOCK_STREAMING_CONTENT_CHUNKS.length; i++) { + response += 'data: ' + response += JSON.stringify({ + id: 'chatcmpl-123', + object: 'chat.completion.chunk', + created: 1694268190, + model: 'gpt-3.5-turbo-0125', + system_fingerprint: 'fp_44709d6fcb', + choices: [{ + index: 0, + delta: { + role: 'assistant', + content: MOCK_STREAMING_CONTENT_CHUNKS[i] + }, + logprobs: null, + finish_reason: i === MOCK_STREAMING_CONTENT_CHUNKS.length ? 'stop' : null + }] + }) + response += '\n\n' + } + response += 'data: [DONE]\n\n' + } else { + response += JSON.stringify({ + id: 'chatcmpl-123', + object: 'chat.completion', + created: new Date().getTime() / 1000, + model: 'gpt-3.5-turbo-0125', + system_fingerprint: 'fp_fp_44709d6fcb', + choices: [{ + index: 0, + message: { + role: 'assistant', + content: MOCK_CONTENT_RESPONSE + }, + logprobs: null, + finish_reason: 'stop' + }], + usage: { + prompt_tokens: 1, + completion_tokens: 1, + total_tokens: 2 + } + }) + } + + return response + }, { + headers: { + 'content-type': 'application/json' + } + }).persist() +}