From f00272c0f6b69bd3fb12cb5ec88988906641ca30 Mon Sep 17 00:00:00 2001 From: rjmacarthy Date: Thu, 5 Sep 2024 13:32:01 +0100 Subject: [PATCH] fix embeddings --- src/common/types.ts | 5 +++-- src/extension/embeddings.ts | 4 ++-- src/extension/provider-manager.ts | 21 +++++++++++++++++++-- src/webview/providers.tsx | 5 +++-- 4 files changed, 27 insertions(+), 8 deletions(-) diff --git a/src/common/types.ts b/src/common/types.ts index d3d52316..aadd473f 100644 --- a/src/common/types.ts +++ b/src/common/types.ts @@ -12,7 +12,8 @@ export interface RequestOptionsOllama extends RequestBodyBase { model: string keep_alive?: string | number messages?: Message[] | Message - prompt: string + prompt?: string + input?: string options: Record } @@ -267,7 +268,7 @@ export interface ChunkOptions { } export type Embedding = { - embedding: number[] + embeddings: number[] } export type EmbeddedDocument = { diff --git a/src/extension/embeddings.ts b/src/extension/embeddings.ts index f9a859c9..dee278e1 100644 --- a/src/extension/embeddings.ts +++ b/src/extension/embeddings.ts @@ -60,7 +60,7 @@ export class EmbeddingDatabase { const requestBody: RequestOptionsOllama = { model: this._embeddingModel, - prompt: content, + input: content, stream: false, options: {} } @@ -82,7 +82,7 @@ export class EmbeddingDatabase { body: requestBody, options: requestOptions, onData: (response) => { - resolve((response as Embedding).embedding) + resolve((response as Embedding).embeddings) } }) }) diff --git a/src/extension/provider-manager.ts b/src/extension/provider-manager.ts index 1c21d90d..f1b79bce 100644 --- a/src/extension/provider-manager.ts +++ b/src/extension/provider-manager.ts @@ -86,7 +86,7 @@ export class ProviderManager { getDefaultChatProvider() { return { apiHostname: '0.0.0.0', - apiPath: '/v1/chat/completions', + apiPath: '/api/chat', apiPort: 11434, apiProtocol: 'http', id: uuidv4(), @@ -100,7 +100,7 @@ export class ProviderManager { getDefaultEmbeddingsProvider() { return { apiHostname: '0.0.0.0', - apiPath: '/v1/embeddings', + apiPath: '/api/embed', apiPort: 11434, apiProtocol: 'http', id: uuidv4(), @@ -148,13 +148,27 @@ export class ProviderManager { return provider } + fixLegacyDefaultEmbeddingPath() { + const provider = this._context.globalState.get( + ACTIVE_EMBEDDINGS_PROVIDER_STORAGE_KEY + ) + if (provider && provider.apiPath === '/v1/embeddings') { + this.updateProvider({ + ...provider, + apiPath: '/api/embed', + }) + } + } + addDefaultEmbeddingsProvider(): TwinnyProvider { const provider = this.getDefaultEmbeddingsProvider() + if ( !this._context.globalState.get(ACTIVE_EMBEDDINGS_PROVIDER_STORAGE_KEY) ) { this.addDefaultProvider(provider) } + this.fixLegacyDefaultEmbeddingPath() return provider } @@ -283,6 +297,7 @@ export class ProviderManager { const providers = this.getProviders() || {} const activeFimProvider = this.getActiveFimProvider() const activeChatProvider = this.getActiveChatProvider() + const activeEmbeddingsProvider = this.getActiveEmbeddingsProvider() if (!provider) return providers[provider.id] = provider this._context.globalState.update(INFERENCE_PROVIDERS_STORAGE_KEY, providers) @@ -290,6 +305,8 @@ export class ProviderManager { this.setActiveFimProvider(provider) if (provider.id === activeChatProvider?.id) this.setActiveChatProvider(provider) + if (provider.id === activeEmbeddingsProvider?.id) + this.setActiveEmbeddingsProvider(provider) this.getAllProviders() } diff --git a/src/webview/providers.tsx b/src/webview/providers.tsx index dbff5bf2..46ed2c3b 100644 --- a/src/webview/providers.tsx +++ b/src/webview/providers.tsx @@ -23,10 +23,11 @@ export const Providers = () => { const [provider, setProvider] = React.useState() const { models } = useOllamaModels() const hasOllamaModels = !!models?.length - const { updateProvider } = useProviders() - const { providers, removeProvider, copyProvider, resetProviders } = + const { updateProvider, providers, removeProvider, copyProvider, resetProviders } = useProviders() + console.log(providers) + const handleClose = () => { setShowForm(false) }