From 375531e663f13f07ee26db79a8d6e9ea7d047407 Mon Sep 17 00:00:00 2001 From: rjmacarthy Date: Tue, 9 Apr 2024 21:06:07 +0100 Subject: [PATCH] restore support for model select dropdown from ollama api in provider --- package.json | 24 +++++++- src/extension/ollama-service.ts | 38 ++++-------- src/webview/hooks.ts | 21 +++++++ src/webview/model-select.tsx | 35 +++++++++++ src/webview/providers.tsx | 106 ++++++++++++++++++-------------- src/webview/utils.ts | 7 +++ 6 files changed, 158 insertions(+), 73 deletions(-) create mode 100644 src/webview/model-select.tsx diff --git a/package.json b/package.json index 5ef849c0..d79832c3 100644 --- a/package.json +++ b/package.json @@ -318,14 +318,34 @@ "default": true, "description": "Enable this setting to allow twinny to keep making subsequent completion requests to the API after the last completion request was accepted." }, - "twinny.keepAlive": { + "twinny.ollamaHostname": { "order": 13, "type": "string", + "default": "0.0.0.0", + "description": "Hostname for Ollama API.", + "required": true + }, + "twinny.ollamaApiPort": { + "order": 14, + "type": "number", + "default": 11434, + "description": "The API port usually `11434`", + "required": false + }, + "twinny.keepAlive": { + "order": 15, + "type": "string", "default": "5m", "description": "Keep models in memory by making requests with keep_alive=-1. Applicable only for Ollama API." }, + "twinny.ollamaUseTls": { + "order": 25, + "type": "boolean", + "default": false, + "description": "Enables TLS encryption Ollama API connections." + }, "twinny.enableLogging": { - "order": 14, + "order": 26, "type": "boolean", "default": true, "description": "Enable twinny debug mode" diff --git a/src/extension/ollama-service.ts b/src/extension/ollama-service.ts index 107bb6a5..f009679d 100644 --- a/src/extension/ollama-service.ts +++ b/src/extension/ollama-service.ts @@ -4,37 +4,25 @@ import { Logger } from '../common/logger' export class OllamaService { private logger: Logger private _config = workspace.getConfiguration('twinny') - private _apiHostname = this._config.get('apiHostname') as string - private _chatApiPort = this._config.get('chatApiPort') as string - private _fimApiPort = this._config.get('fimApiPort') as string - private _useTls = this._config.get('useTls') as boolean - private _baseUrlChat: string - private _baseUrlFim: string + private _baseUrl: string constructor() { this.logger = new Logger() - const useTls = this._useTls - const protocol = useTls ? 'https' : 'http' - this._baseUrlChat = `${protocol}://${this._apiHostname}:${this._chatApiPort}` - this._baseUrlFim = `${protocol}://${this._apiHostname}:${this._fimApiPort}` + const protocol = (this._config.get('ollamaUseTls') as boolean) + ? 'https' + : 'http' + const hostname = this._config.get('ollamaHostname') as string + const port = this._config.get('ollamaApiPort') as string + this._baseUrl = `${protocol}://${hostname}:${port}` } public fetchModels = async (resource = '/api/tags') => { - const chatModelsRes = (await fetch(this._baseUrlChat + resource)) || [] - const fimModelsRes = await fetch(this._baseUrlFim + resource) - const { models: chatModels } = await chatModelsRes.json() - const { models: fimModels } = await fimModelsRes.json() - const models = new Set() - if (Array.isArray(chatModels)) { - for (const model of chatModels) { - models.add(model) - } + try { + const response = await fetch(`${this._baseUrl}${resource}`) + const { models } = await response.json() + return Array.isArray(models) ? [...new Set(models)] : [] + } catch (err) { + return [] } - if (Array.isArray(fimModels)) { - for (const model of fimModels) { - models.add(model) - } - } - return Array.from(models) } } diff --git a/src/webview/hooks.ts b/src/webview/hooks.ts index 9d9533e8..f30c0493 100644 --- a/src/webview/hooks.ts +++ b/src/webview/hooks.ts @@ -2,6 +2,7 @@ import { useEffect, useState } from 'react' import { MESSAGE_KEY, MESSAGE_NAME } from '../common/constants' import { + ApiModel, ClientMessage, LanguageType, ServerMessage, @@ -281,3 +282,23 @@ export const useConfigurationSetting = (key: string) => { return { configurationSetting } } + +export const useOllamaModels = () => { + const [models, setModels] = useState([]) + const handler = (event: MessageEvent) => { + const message: ServerMessage = event.data + if (message?.type === MESSAGE_NAME.twinnyFetchOllamaModels) { + setModels(message?.value.data) + } + return () => window.removeEventListener('message', handler) + } + + useEffect(() => { + global.vscode.postMessage({ + type: MESSAGE_NAME.twinnyFetchOllamaModels + }) + window.addEventListener('message', handler) + }, []) + + return { models } +} diff --git a/src/webview/model-select.tsx b/src/webview/model-select.tsx new file mode 100644 index 00000000..e4d2547e --- /dev/null +++ b/src/webview/model-select.tsx @@ -0,0 +1,35 @@ +import { VSCodeDropdown } from '@vscode/webview-ui-toolkit/react' + +import { getModelShortName } from './utils' +import { ApiModel } from '../common/types' + +interface Props { + model: string | undefined + setModel: (model: string) => void + models: ApiModel[] | undefined +} + +export const ModelSelect = ({ model, models, setModel }: Props) => { + const handleOnChange = (e: unknown): void => { + const event = e as React.ChangeEvent + const selectedValue = event?.target.value || '' + setModel(selectedValue) + } + + return ( +
+
+ +
+ + {models?.map((model, index) => { + return ( + + ) + })} + +
+ ) +} diff --git a/src/webview/providers.tsx b/src/webview/providers.tsx index 2640bcc1..7c42da90 100644 --- a/src/webview/providers.tsx +++ b/src/webview/providers.tsx @@ -1,5 +1,5 @@ import React from 'react' -import { useProviders } from './hooks' +import { useOllamaModels, useProviders } from './hooks' import { VSCodeButton, VSCodeDivider, @@ -16,6 +16,7 @@ import { DEFAULT_PROVIDER_FORM_VALUES, FIM_TEMPLATE_FORMAT, } from '../common/constants' +import { ModelSelect } from './model-select' export const Providers = () => { const [showForm, setShowForm] = React.useState(false) @@ -60,8 +61,8 @@ export const Providers = () => {
Add Provider - - + + Reset Providers
@@ -71,28 +72,28 @@ export const Providers = () => {

{provider.label}

handleEdit(provider)} > - + handleCopy(provider)} > - + handleDelete(provider)} > - +
@@ -147,6 +148,7 @@ interface ProviderFormProps { function ProviderForm({ onClose, provider }: ProviderFormProps) { const isEditing = provider !== undefined + const { models } = useOllamaModels() const { saveProvider, updateProvider } = useProviders() const [formState, setFormState] = React.useState( provider || DEFAULT_PROVIDER_FORM_VALUES @@ -185,11 +187,11 @@ function ProviderForm({ onClose, provider }: ProviderFormProps) {
- +
- +
@@ -216,10 +218,10 @@ function ProviderForm({ onClose, provider }: ProviderFormProps) { {formState.type === 'fim' && (
- +
@@ -234,10 +236,10 @@ function ProviderForm({ onClose, provider }: ProviderFormProps) {
- +
@@ -251,10 +253,10 @@ function ProviderForm({ onClose, provider }: ProviderFormProps) {
- +
@@ -266,27 +268,39 @@ function ProviderForm({ onClose, provider }: ProviderFormProps) {
-
+ {formState.provider === ApiProviders.Ollama && models?.length && ( + { + setFormState({ ...formState, modelName: model }) + }} + /> + )} + + {formState.provider !== ApiProviders.Ollama && (
- +
+ +
+
- -
+ )}
- +
@@ -294,12 +308,12 @@ function ProviderForm({ onClose, provider }: ProviderFormProps) {
- +
@@ -307,12 +321,12 @@ function ProviderForm({ onClose, provider }: ProviderFormProps) {
- +
@@ -320,21 +334,21 @@ function ProviderForm({ onClose, provider }: ProviderFormProps) {
- +
- + Save - + Cancel
diff --git a/src/webview/utils.ts b/src/webview/utils.ts index efbf04d2..df7056b9 100644 --- a/src/webview/utils.ts +++ b/src/webview/utils.ts @@ -55,3 +55,10 @@ export const kebabToSentence = (kebabStr: string) => { } export const getLineBreakCount = (str: string) => str.split('\n').length + +export const getModelShortName = (name: string) => { + if (name.length > 32) { + return `${name.substring(0, 15)}...${name.substring(name.length - 16)}` + } + return name +}