From 2c789d67ab1f45a860b1efcdcc49e81e204f6adf Mon Sep 17 00:00:00 2001 From: jalr4ever Date: Wed, 31 Jan 2024 13:23:30 +0800 Subject: [PATCH] feat: Support a way to define default model from CUSTOM_MODELS env. --- app/components/chat.tsx | 15 +++++++++++---- app/store/access.ts | 9 +++++++++ app/utils/model.ts | 17 +++++++++++++++-- 3 files changed, 35 insertions(+), 6 deletions(-) diff --git a/app/components/chat.tsx b/app/components/chat.tsx index 39abdd97b24..3bac382bb35 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -432,10 +432,17 @@ export function ChatActions(props: { // switch model const currentModel = chatStore.currentSession().mask.modelConfig.model; const allModels = useAllModels(); - const models = useMemo( - () => allModels.filter((m) => m.available), - [allModels], - ); + const models = useMemo(() => { + const filteredModels = allModels.filter((m) => m.available); + const defaultModel = filteredModels.find((m) => m.isDefault); + if (defaultModel) { + const arr = [defaultModel, ...filteredModels.filter((m) => m !== defaultModel)]; + return arr; + } else { + return filteredModels; + } + }, [allModels]); + const [showModelSelector, setShowModelSelector] = useState(false); useEffect(() => { diff --git a/app/store/access.ts b/app/store/access.ts index 9e8024a6aa8..0ac89b1d664 100644 --- a/app/store/access.ts +++ b/app/store/access.ts @@ -8,6 +8,7 @@ import { getHeaders } from "../client/api"; import { getClientConfig } from "../config/client"; import { createPersistStore } from "../utils/store"; import { ensure } from "../utils/clone"; +import { DEFAULT_CONFIG } from "./config"; let fetchState = 0; // 0 not fetch, 1 fetching, 2 done @@ -88,6 +89,14 @@ export const useAccessStore = createPersistStore( }, }) .then((res) => res.json()) + .then((res) => { + // Set default model from env request + let custom_models = res.customModels ?? ""; + const models = custom_models.split(","); + const model_default = models.find((model: string) => model.startsWith("+*"))?.substring(2) || "gpt-3.5-turbo"; + DEFAULT_CONFIG.modelConfig.model = model_default; + return res + }) .then((res: DangerConfig) => { console.log("[Config] got config from server", res); set(() => ({ ...res })); diff --git a/app/utils/model.ts b/app/utils/model.ts index b2a42ef022a..4695744e1b0 100644 --- a/app/utils/model.ts +++ b/app/utils/model.ts @@ -10,6 +10,7 @@ export function collectModelTable( available: boolean; name: string; displayName: string; + isDefault?: boolean; provider?: LLMModel["provider"]; // Marked as optional } > = {}; @@ -22,6 +23,8 @@ export function collectModelTable( }; }); + + // server custom models customModels .split(",") @@ -32,8 +35,17 @@ export function collectModelTable( m.startsWith("+") || m.startsWith("-") ? m.slice(1) : m; const [name, displayName] = nameConfig.split("="); - // enable or disable all models - if (name === "all") { + if (name.startsWith("*")) { // Check if name starts with wildcard + let defaultName: string | null = null; // Add variable to store wildcard value + defaultName = displayName || name.slice(1); // Store wildcard value + modelTable[defaultName] = { + name : defaultName, + displayName: defaultName, + available, + isDefault : true, + provider: modelTable[defaultName]?.provider, // Use optional chaining + }; + } else if (name === "all") { Object.values(modelTable).forEach((model) => (model.available = available)); } else { modelTable[name] = { @@ -44,6 +56,7 @@ export function collectModelTable( }; } }); + return modelTable; }