Skip to content

Commit

Permalink
fix: Temporary fix:for o1-xx model need to covert systemMessage to ai…
Browse files Browse the repository at this point in the history
…Message.
  • Loading branch information
Emt-lin committed Dec 2, 2024
1 parent 3fab1a9 commit 8324d1d
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 18 deletions.
20 changes: 17 additions & 3 deletions src/LLMProviders/chainManager.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { VAULT_VECTOR_STORE_STRATEGY } from "@/constants";
import { CustomModel, SetChainOptions, setChainType } from "@/aiParams";
import ChainFactory, { ChainType, Document } from "@/chainFactory";
import { BUILTIN_CHAT_MODELS, USER_SENDER } from "@/constants";
import { AI_SENDER, BUILTIN_CHAT_MODELS, USER_SENDER } from "@/constants";
import {
ChainRunner,
CopilotPlusChainRunner,
Expand Down Expand Up @@ -261,12 +261,26 @@ export default class ChainManager {
this.validateChatModel();
this.validateChainInitialization();

const chatModel = this.chatModelManager.getChatModel();
const modelName = (chatModel as any).modelName || (chatModel as any).model || "";
const isO1Model = modelName.startsWith("o1");

// Handle ignoreSystemMessage
if (ignoreSystemMessage) {
const effectivePrompt = ChatPromptTemplate.fromMessages([
if (ignoreSystemMessage || isO1Model) {
let effectivePrompt = ChatPromptTemplate.fromMessages([
new MessagesPlaceholder("history"),
HumanMessagePromptTemplate.fromTemplate("{input}"),
]);

// TODO: hack for o1 models, to be removed when they support system prompt
if (isO1Model) {
// Temporary fix:for o1-xx model need to covert systemMessage to aiMessage
effectivePrompt = ChatPromptTemplate.fromMessages([
[AI_SENDER, this.getLangChainParams().systemMessage || ""],
effectivePrompt,
]);
}

this.setChain(getChainType(), {
prompt: effectivePrompt,
});
Expand Down
56 changes: 41 additions & 15 deletions src/LLMProviders/chatModelManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,16 @@ export default class ChatModelManager {
}

private getModelConfig(customModel: CustomModel): ModelConfig {
const decrypt = (key: string) => this.encryptionService.getDecryptedKey(key);
const params = this.getLangChainParams();

const settings = getSettings();

// Check if the model starts with "o1"
const modelName = customModel.name;
const isO1Model = modelName.startsWith("o1");
const baseConfig: ModelConfig = {
modelName: customModel.name,
modelName: modelName,
temperature: settings.temperature,
streaming: true,
maxRetries: 3,
Expand All @@ -82,19 +89,19 @@ export default class ChatModelManager {
[K in keyof ChatProviderConstructMap]: ConstructorParameters<ChatProviderConstructMap[K]>[0];
} = {
[ChatModelProviders.OPENAI]: {
modelName: customModel.name,
modelName: modelName,
openAIApiKey: getDecryptedKey(customModel.apiKey || settings.openAIApiKey),
// @ts-ignore
openAIOrgId: getDecryptedKey(settings.openAIOrgId),
maxTokens: settings.maxTokens,
configuration: {
baseURL: customModel.baseUrl,
fetch: customModel.enableCors ? safeFetch : undefined,
},
// @ts-ignore
openAIOrgId: getDecryptedKey(settings.openAIOrgId),
...this.handleOpenAIExtraArgs(isO1Model, settings.maxTokens, settings.temperature, true),
},
[ChatModelProviders.ANTHROPIC]: {
anthropicApiKey: getDecryptedKey(customModel.apiKey || settings.anthropicApiKey),
modelName: customModel.name,
modelName: modelName,
anthropicApiUrl: customModel.baseUrl,
clientOptions: {
// Required to bypass CORS restrictions
Expand All @@ -103,7 +110,6 @@ export default class ChatModelManager {
},
},
[ChatModelProviders.AZURE_OPENAI]: {
maxTokens: settings.maxTokens,
azureOpenAIApiKey: getDecryptedKey(customModel.apiKey || settings.azureOpenAIApiKey),
azureOpenAIApiInstanceName: settings.azureOpenAIApiInstanceName,
azureOpenAIApiDeploymentName: settings.azureOpenAIApiDeploymentName,
Expand All @@ -112,14 +118,15 @@ export default class ChatModelManager {
baseURL: customModel.baseUrl,
fetch: customModel.enableCors ? safeFetch : undefined,
},
...this.handleOpenAIExtraArgs(isO1Model, settings.maxTokens, settings.temperature, true),
},
[ChatModelProviders.COHEREAI]: {
apiKey: getDecryptedKey(customModel.apiKey || settings.cohereApiKey),
model: customModel.name,
model: modelName,
},
[ChatModelProviders.GOOGLE]: {
apiKey: getDecryptedKey(customModel.apiKey || settings.googleApiKey),
model: customModel.name,
modelName: modelName,
safetySettings: [
{
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
Expand All @@ -141,7 +148,7 @@ export default class ChatModelManager {
baseUrl: customModel.baseUrl,
},
[ChatModelProviders.OPENROUTERAI]: {
modelName: customModel.name,
modelName: modelName,
openAIApiKey: getDecryptedKey(customModel.apiKey || settings.openRouterAiApiKey),
configuration: {
baseURL: customModel.baseUrl || "https://openrouter.ai/api/v1",
Expand All @@ -150,33 +157,33 @@ export default class ChatModelManager {
},
[ChatModelProviders.GROQ]: {
apiKey: getDecryptedKey(customModel.apiKey || settings.groqApiKey),
modelName: customModel.name,
modelName: modelName,
},
[ChatModelProviders.OLLAMA]: {
// ChatOllama has `model` instead of `modelName`!!
model: customModel.name,
model: modelName,
// @ts-ignore
apiKey: customModel.apiKey || "default-key",
// MUST NOT use /v1 in the baseUrl for ollama
baseUrl: customModel.baseUrl || "http://localhost:11434",
},
[ChatModelProviders.LM_STUDIO]: {
modelName: customModel.name,
modelName: modelName,
openAIApiKey: customModel.apiKey || "default-key",
configuration: {
baseURL: customModel.baseUrl || "http://localhost:1234/v1",
fetch: customModel.enableCors ? safeFetch : undefined,
},
},
[ChatModelProviders.OPENAI_FORMAT]: {
modelName: customModel.name,
modelName: modelName,
openAIApiKey: getDecryptedKey(customModel.apiKey || settings.openAIApiKey),
maxTokens: settings.maxTokens,
configuration: {
baseURL: customModel.baseUrl,
fetch: customModel.enableCors ? safeFetch : undefined,
dangerouslyAllowBrowser: true,
},
...this.handleOpenAIExtraArgs(isO1Model, settings.maxTokens, settings.temperature, true),
},
};

Expand All @@ -186,6 +193,25 @@ export default class ChatModelManager {
return { ...baseConfig, ...selectedProviderConfig };
}

private handleOpenAIExtraArgs(
isO1Model: boolean,
maxTokens: number,
temperature: number,
streaming: boolean
) {
return isO1Model
? {
maxCompletionTokens: maxTokens,
temperature: 1,
streaming: false,
}
: {
maxTokens: maxTokens,
temperature: temperature,
streaming: streaming,
};
}

// Build a map of modelKey to model config
public buildModelMap() {
const activeModels = getSettings().activeModels;
Expand Down

0 comments on commit 8324d1d

Please sign in to comment.