Skip to content

Commit

Permalink
fix embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
rjmacarthy committed Sep 5, 2024
1 parent c845522 commit f00272c
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 8 deletions.
5 changes: 3 additions & 2 deletions src/common/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ export interface RequestOptionsOllama extends RequestBodyBase {
model: string
keep_alive?: string | number
messages?: Message[] | Message
prompt: string
prompt?: string
input?: string
options: Record<string, unknown>
}

Expand Down Expand Up @@ -267,7 +268,7 @@ export interface ChunkOptions {
}

export type Embedding = {
embedding: number[]
embeddings: number[]
}

export type EmbeddedDocument = {
Expand Down
4 changes: 2 additions & 2 deletions src/extension/embeddings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ export class EmbeddingDatabase {

const requestBody: RequestOptionsOllama = {
model: this._embeddingModel,
prompt: content,
input: content,
stream: false,
options: {}
}
Expand All @@ -82,7 +82,7 @@ export class EmbeddingDatabase {
body: requestBody,
options: requestOptions,
onData: (response) => {
resolve((response as Embedding).embedding)
resolve((response as Embedding).embeddings)
}
})
})
Expand Down
21 changes: 19 additions & 2 deletions src/extension/provider-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ export class ProviderManager {
getDefaultChatProvider() {
return {
apiHostname: '0.0.0.0',
apiPath: '/v1/chat/completions',
apiPath: '/api/chat',
apiPort: 11434,
apiProtocol: 'http',
id: uuidv4(),
Expand All @@ -100,7 +100,7 @@ export class ProviderManager {
getDefaultEmbeddingsProvider() {
return {
apiHostname: '0.0.0.0',
apiPath: '/v1/embeddings',
apiPath: '/api/embed',
apiPort: 11434,
apiProtocol: 'http',
id: uuidv4(),
Expand Down Expand Up @@ -148,13 +148,27 @@ export class ProviderManager {
return provider
}

fixLegacyDefaultEmbeddingPath() {
const provider = this._context.globalState.get<TwinnyProvider>(
ACTIVE_EMBEDDINGS_PROVIDER_STORAGE_KEY
)
if (provider && provider.apiPath === '/v1/embeddings') {
this.updateProvider({
...provider,
apiPath: '/api/embed',
})
}
}

addDefaultEmbeddingsProvider(): TwinnyProvider {
const provider = this.getDefaultEmbeddingsProvider()

if (
!this._context.globalState.get(ACTIVE_EMBEDDINGS_PROVIDER_STORAGE_KEY)
) {
this.addDefaultProvider(provider)
}
this.fixLegacyDefaultEmbeddingPath()
return provider
}

Expand Down Expand Up @@ -283,13 +297,16 @@ export class ProviderManager {
const providers = this.getProviders() || {}
const activeFimProvider = this.getActiveFimProvider()
const activeChatProvider = this.getActiveChatProvider()
const activeEmbeddingsProvider = this.getActiveEmbeddingsProvider()
if (!provider) return
providers[provider.id] = provider
this._context.globalState.update(INFERENCE_PROVIDERS_STORAGE_KEY, providers)
if (provider.id === activeFimProvider?.id)
this.setActiveFimProvider(provider)
if (provider.id === activeChatProvider?.id)
this.setActiveChatProvider(provider)
if (provider.id === activeEmbeddingsProvider?.id)
this.setActiveEmbeddingsProvider(provider)
this.getAllProviders()
}

Expand Down
5 changes: 3 additions & 2 deletions src/webview/providers.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ export const Providers = () => {
const [provider, setProvider] = React.useState<TwinnyProvider | undefined>()
const { models } = useOllamaModels()
const hasOllamaModels = !!models?.length
const { updateProvider } = useProviders()
const { providers, removeProvider, copyProvider, resetProviders } =
const { updateProvider, providers, removeProvider, copyProvider, resetProviders } =
useProviders()

console.log(providers)

const handleClose = () => {
setShowForm(false)
}
Expand Down

0 comments on commit f00272c

Please sign in to comment.