From 8a4b8a84d67bb7431c5ce88046d94963dceebad7 Mon Sep 17 00:00:00 2001 From: frostime Date: Sat, 3 Aug 2024 17:16:05 +0800 Subject: [PATCH 1/4] =?UTF-8?q?=E2=9C=A8=20feat:=20=E8=B0=83=E6=95=B4?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E5=88=97=E8=A1=A8=EF=BC=8C=E5=B0=86=E8=87=AA?= =?UTF-8?q?=E5=AE=9A=E4=B9=89=E6=A8=A1=E5=9E=8B=E6=94=BE=E5=9C=A8=E5=89=8D?= =?UTF-8?q?=E9=9D=A2=E6=98=BE=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/utils/model.ts | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/app/utils/model.ts b/app/utils/model.ts index 4de0eb8d96a..6b1485e32ad 100644 --- a/app/utils/model.ts +++ b/app/utils/model.ts @@ -22,15 +22,6 @@ export function collectModelTable( } > = {}; - // default models - models.forEach((m) => { - // using @ as fullName - modelTable[`${m.name}@${m?.provider?.id}`] = { - ...m, - displayName: m.name, // 'provider' is copied over if it exists - }; - }); - // server custom models customModels .split(",") @@ -89,6 +80,15 @@ export function collectModelTable( } }); + // default models + models.forEach((m) => { + // using @ as fullName + modelTable[`${m.name}@${m?.provider?.id}`] = { + ...m, + displayName: m.name, // 'provider' is copied over if it exists + }; + }); + return modelTable; } @@ -99,13 +99,16 @@ export function collectModelTableWithDefaultModel( ) { let modelTable = collectModelTable(models, customModels); if (defaultModel && defaultModel !== "") { - if (defaultModel.includes('@')) { + if (defaultModel.includes("@")) { if (defaultModel in modelTable) { modelTable[defaultModel].isDefault = true; } } else { for (const key of Object.keys(modelTable)) { - if (modelTable[key].available && key.split('@').shift() == defaultModel) { + if ( + modelTable[key].available && + key.split("@").shift() == defaultModel + ) { modelTable[key].isDefault = true; break; } From b023a00445682fcb336fe231ffe7c667632c0d15 Mon Sep 17 00:00:00 2001 From: frostime Date: Mon, 5 Aug 2024 16:37:22 +0800 Subject: [PATCH 2/4] =?UTF-8?q?=F0=9F=94=A8=20refactor(model):=20=E6=9B=B4?= =?UTF-8?q?=E6=94=B9=E5=8E=9F=E5=85=88=E7=9A=84=E5=AE=9E=E7=8E=B0=E6=96=B9?= =?UTF-8?q?=E6=B3=95=EF=BC=8C=E5=9C=A8=20collect=20table=20=E5=87=BD?= =?UTF-8?q?=E6=95=B0=E5=90=8E=E9=9D=A2=E5=A2=9E=E5=8A=A0=E9=A2=9D=E5=A4=96?= =?UTF-8?q?=E7=9A=84=20sort=20=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/utils/model.ts | 50 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 39 insertions(+), 11 deletions(-) diff --git a/app/utils/model.ts b/app/utils/model.ts index 6b1485e32ad..b117b5eb64a 100644 --- a/app/utils/model.ts +++ b/app/utils/model.ts @@ -7,6 +7,29 @@ const customProvider = (providerName: string) => ({ providerType: "custom", }); +const sortModelTable = ( + models: ReturnType, + rule: "custom-first" | "default-first", +) => + models.sort((a, b) => { + if (a.provider === undefined && b.provider === undefined) { + return 0; + } + + let aIsCustom = a.provider?.providerType === "custom"; + let bIsCustom = b.provider?.providerType === "custom"; + + if (aIsCustom === bIsCustom) { + return 0; + } + + if (aIsCustom) { + return rule === "custom-first" ? -1 : 1; + } else { + return rule === "custom-first" ? 1 : -1; + } + }); + export function collectModelTable( models: readonly LLMModel[], customModels: string, @@ -22,6 +45,15 @@ export function collectModelTable( } > = {}; + // default models + models.forEach((m) => { + // using @ as fullName + modelTable[`${m.name}@${m?.provider?.id}`] = { + ...m, + displayName: m.name, // 'provider' is copied over if it exists + }; + }); + // server custom models customModels .split(",") @@ -80,15 +112,6 @@ export function collectModelTable( } }); - // default models - models.forEach((m) => { - // using @ as fullName - modelTable[`${m.name}@${m?.provider?.id}`] = { - ...m, - displayName: m.name, // 'provider' is copied over if it exists - }; - }); - return modelTable; } @@ -126,7 +149,9 @@ export function collectModels( customModels: string, ) { const modelTable = collectModelTable(models, customModels); - const allModels = Object.values(modelTable); + let allModels = Object.values(modelTable); + + allModels = sortModelTable(allModels, "custom-first"); return allModels; } @@ -141,7 +166,10 @@ export function collectModelsWithDefaultModel( customModels, defaultModel, ); - const allModels = Object.values(modelTable); + let allModels = Object.values(modelTable); + + allModels = sortModelTable(allModels, "custom-first"); + return allModels; } From 150fc84b9b55fe07da2fefa73b2cbee255d9de14 Mon Sep 17 00:00:00 2001 From: frostime Date: Mon, 5 Aug 2024 19:43:32 +0800 Subject: [PATCH 3/4] =?UTF-8?q?=E2=9C=A8=20feat(model):=20=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=20sorted=20=E5=AD=97=E6=AE=B5=EF=BC=8C=E5=B9=B6?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=E8=AF=A5=E5=AD=97=E6=AE=B5=E5=AF=B9=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E5=88=97=E8=A1=A8=E8=BF=9B=E8=A1=8C=E6=8E=92=E5=BA=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 在 Model 和 Provider 类型中增加 sorted 字段(api.ts) 2. 默认模型在初始化的时候,自动设置默认 sorted 字段,从 1000 开始自增长(constant.ts) 3. 自定义模型更新的时候,自动分配 sorted 字段(model.ts) --- app/client/api.ts | 2 ++ app/constant.ts | 19 ++++++++++++++++++ app/utils/model.ts | 49 +++++++++++++++++++++++++++------------------- 3 files changed, 50 insertions(+), 20 deletions(-) diff --git a/app/client/api.ts b/app/client/api.ts index f10e4761887..b13e0f8a4c0 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -64,12 +64,14 @@ export interface LLMModel { displayName?: string; available: boolean; provider: LLMModelProvider; + sorted: number; } export interface LLMModelProvider { id: string; providerName: string; providerType: string; + sorted: number; } export abstract class LLMApi { diff --git a/app/constant.ts b/app/constant.ts index 5251b5b4fc9..8ca17c4b359 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -320,86 +320,105 @@ const tencentModels = [ const moonshotModes = ["moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"]; +let seq = 1000; // 内置的模型序号生成器从1000开始 export const DEFAULT_MODELS = [ ...openaiModels.map((name) => ({ name, available: true, + sorted: seq++, // Global sequence sort(index) provider: { id: "openai", providerName: "OpenAI", providerType: "openai", + sorted: 1, // 这里是固定的,确保顺序与之前内置的版本一致 }, })), ...openaiModels.map((name) => ({ name, available: true, + sorted: seq++, provider: { id: "azure", providerName: "Azure", providerType: "azure", + sorted: 2, }, })), ...googleModels.map((name) => ({ name, available: true, + sorted: seq++, provider: { id: "google", providerName: "Google", providerType: "google", + sorted: 3, }, })), ...anthropicModels.map((name) => ({ name, available: true, + sorted: seq++, provider: { id: "anthropic", providerName: "Anthropic", providerType: "anthropic", + sorted: 4, }, })), ...baiduModels.map((name) => ({ name, available: true, + sorted: seq++, provider: { id: "baidu", providerName: "Baidu", providerType: "baidu", + sorted: 5, }, })), ...bytedanceModels.map((name) => ({ name, available: true, + sorted: seq++, provider: { id: "bytedance", providerName: "ByteDance", providerType: "bytedance", + sorted: 6, }, })), ...alibabaModes.map((name) => ({ name, available: true, + sorted: seq++, provider: { id: "alibaba", providerName: "Alibaba", providerType: "alibaba", + sorted: 7, }, })), ...tencentModels.map((name) => ({ name, available: true, + sorted: seq++, provider: { id: "tencent", providerName: "Tencent", providerType: "tencent", + sorted: 8, }, })), ...moonshotModes.map((name) => ({ name, available: true, + sorted: seq++, provider: { id: "moonshot", providerName: "Moonshot", providerType: "moonshot", + sorted: 9, }, })), ] as const; diff --git a/app/utils/model.ts b/app/utils/model.ts index b117b5eb64a..0b62b53be09 100644 --- a/app/utils/model.ts +++ b/app/utils/model.ts @@ -1,32 +1,39 @@ import { DEFAULT_MODELS } from "../constant"; import { LLMModel } from "../client/api"; +const CustomSeq = { + val: -1000, //To ensure the custom model located at front, start from -1000, refer to constant.ts + cache: new Map(), + next: (id: string) => { + if (CustomSeq.cache.has(id)) { + return CustomSeq.cache.get(id) as number; + } else { + let seq = CustomSeq.val++; + CustomSeq.cache.set(id, seq); + return seq; + } + }, +}; + const customProvider = (providerName: string) => ({ id: providerName.toLowerCase(), providerName: providerName, providerType: "custom", + sorted: CustomSeq.next(providerName), }); -const sortModelTable = ( - models: ReturnType, - rule: "custom-first" | "default-first", -) => +/** + * Sorts an array of models based on specified rules. + * + * First, sorted by provider; if the same, sorted by model + */ +const sortModelTable = (models: ReturnType) => models.sort((a, b) => { - if (a.provider === undefined && b.provider === undefined) { - return 0; - } - - let aIsCustom = a.provider?.providerType === "custom"; - let bIsCustom = b.provider?.providerType === "custom"; - - if (aIsCustom === bIsCustom) { - return 0; - } - - if (aIsCustom) { - return rule === "custom-first" ? -1 : 1; + if (a.provider && b.provider) { + let cmp = a.provider.sorted - b.provider.sorted; + return cmp === 0 ? a.sorted - b.sorted : cmp; } else { - return rule === "custom-first" ? 1 : -1; + return a.sorted - b.sorted; } }); @@ -40,6 +47,7 @@ export function collectModelTable( available: boolean; name: string; displayName: string; + sorted: number; provider?: LLMModel["provider"]; // Marked as optional isDefault?: boolean; } @@ -107,6 +115,7 @@ export function collectModelTable( displayName: displayName || customModelName, available, provider, // Use optional chaining + sorted: CustomSeq.next(`${customModelName}@${provider?.id}`), }; } } @@ -151,7 +160,7 @@ export function collectModels( const modelTable = collectModelTable(models, customModels); let allModels = Object.values(modelTable); - allModels = sortModelTable(allModels, "custom-first"); + allModels = sortModelTable(allModels); return allModels; } @@ -168,7 +177,7 @@ export function collectModelsWithDefaultModel( ); let allModels = Object.values(modelTable); - allModels = sortModelTable(allModels, "custom-first"); + allModels = sortModelTable(allModels); return allModels; } From 3486954e073665b4bcaa4d41096b1341e4c497ff Mon Sep 17 00:00:00 2001 From: frostime Date: Mon, 5 Aug 2024 20:26:48 +0800 Subject: [PATCH 4/4] =?UTF-8?q?=F0=9F=90=9B=20fix(openai):=20=E4=B8=8A?= =?UTF-8?q?=E6=AC=A1=20commit=20=E5=90=8E=20openai.ts=20=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E4=B8=AD=E5=87=BA=E7=8E=B0=E7=B1=BB=E5=9E=8B=E4=B8=8D=E5=8C=B9?= =?UTF-8?q?=E9=85=8D=E7=9A=84=20bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/client/platforms/openai.ts | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index 680125fe6c4..d95aebe87b2 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -411,13 +411,17 @@ export class ChatGPTApi implements LLMApi { return []; } + //由于目前 OpenAI 的 disableListModels 默认为 true,所以当前实际不会运行到这场 + let seq = 1000; //同 Constant.ts 中的排序保持一致 return chatModels.map((m) => ({ name: m.id, available: true, + sorted: seq++, provider: { id: "openai", providerName: "OpenAI", providerType: "openai", + sorted: 1, }, })); }