From 0688ccda08c0f189b7d7bffb12f862e2140bf9bd Mon Sep 17 00:00:00 2001 From: flakey5 <73616808+flakey5@users.noreply.github.com> Date: Sat, 13 Apr 2024 15:10:42 -0700 Subject: [PATCH] Cleanup provider constructors Makes constructors for the different ai providers take in objects instead of a bunch of properties so it's clearer when initializing --- ai-providers/mistral.ts | 7 ++++++- ai-providers/ollama.ts | 7 ++++++- ai-providers/open-ai.ts | 7 ++++++- plugins/warp.ts | 9 +++------ tests/unit/ai-providers.test.ts | 6 +++--- 5 files changed, 24 insertions(+), 12 deletions(-) diff --git a/ai-providers/mistral.ts b/ai-providers/mistral.ts index 9ec567a..31eef71 100644 --- a/ai-providers/mistral.ts +++ b/ai-providers/mistral.ts @@ -52,12 +52,17 @@ class MistralByteSource implements UnderlyingByteSource { } } +interface MistralProviderCtorOptions { + model: string + apiKey: string +} + export class MistralProvider implements AiProvider { model: string apiKey: string client?: import('@mistralai/mistralai').default = undefined - constructor (model: string, apiKey: string) { + constructor ({ model, apiKey }: MistralProviderCtorOptions) { this.model = model this.apiKey = apiKey } diff --git a/ai-providers/ollama.ts b/ai-providers/ollama.ts index c83355c..d641c25 100644 --- a/ai-providers/ollama.ts +++ b/ai-providers/ollama.ts @@ -37,11 +37,16 @@ class OllamaByteSource implements UnderlyingByteSource { } } +interface OllamaProviderCtorOptions { + host: string + model: string +} + export class OllamaProvider implements AiProvider { model: string client: Ollama - constructor (host: string, model: string) { + constructor ({ host, model }: OllamaProviderCtorOptions) { this.model = model this.client = new Ollama({ host }) } diff --git a/ai-providers/open-ai.ts b/ai-providers/open-ai.ts index 3c16bfe..2732450 100644 --- a/ai-providers/open-ai.ts +++ b/ai-providers/open-ai.ts @@ -81,11 +81,16 @@ class OpenAiByteSource implements UnderlyingByteSource { } } +interface OpenAiProviderCtorOptions { + model: string + apiKey: string +} + export class OpenAiProvider implements AiProvider { model: string client: OpenAI - constructor (model: string, apiKey: string) { + constructor ({ model, apiKey }: OpenAiProviderCtorOptions) { this.model = model // @ts-expect-error this.client = new OpenAI({ apiKey, fetch }) diff --git a/plugins/warp.ts b/plugins/warp.ts index 5331f4b..42f4ed5 100644 --- a/plugins/warp.ts +++ b/plugins/warp.ts @@ -12,14 +12,11 @@ const UnknownAiProviderError = createError('UNKNOWN_AI_PROVIDER', 'Unknown AI Pr function build (aiProvider: AiWarpConfig['aiProvider']): AiProvider { if ('openai' in aiProvider) { - const { model, apiKey } = aiProvider.openai - return new OpenAiProvider(model, apiKey) + return new OpenAiProvider(aiProvider.openai) } else if ('mistral' in aiProvider) { - const { model, apiKey } = aiProvider.mistral - return new MistralProvider(model, apiKey) + return new MistralProvider(aiProvider.mistral) } else if ('ollama' in aiProvider) { - const { host, model } = aiProvider.ollama - return new OllamaProvider(host, model) + return new OllamaProvider(aiProvider.ollama) } else { throw new UnknownAiProviderError() } diff --git a/tests/unit/ai-providers.test.ts b/tests/unit/ai-providers.test.ts index b2608c6..bec5880 100644 --- a/tests/unit/ai-providers.test.ts +++ b/tests/unit/ai-providers.test.ts @@ -10,9 +10,9 @@ import { OllamaProvider } from '../../ai-providers/ollama' const expectedStreamBody = buildExpectedStreamBodyString() const providers: AiProvider[] = [ - new OpenAiProvider('gpt-3.5-turbo', ''), - new MistralProvider('open-mistral-7b', ''), - new OllamaProvider(OLLAMA_MOCK_HOST, 'some-model') + new OpenAiProvider({ model: 'gpt-3.5-turbo', apiKey: '' }), + new MistralProvider({ model: 'open-mistral-7b', apiKey: '' }), + new OllamaProvider({ host: OLLAMA_MOCK_HOST, model: 'some-model' }) ] for (const provider of providers) {