Skip to content

Commit

Permalink
Temporary fix:for o1-xx model need to covert systemMessage to aiMessa…
Browse files Browse the repository at this point in the history
…ge. (#850)
  • Loading branch information
Emt-lin authored Dec 2, 2024
1 parent 961a35d commit a75ec96
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 25 deletions.
17 changes: 10 additions & 7 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
"@langchain/core": "^0.3.3",
"@langchain/google-genai": "^0.1.2",
"@langchain/groq": "^0.1.2",
"@langchain/openai": "^0.3.14",
"@orama/orama": "^3.0.0-rc-2",
"@radix-ui/react-dropdown-menu": "^2.1.2",
"@radix-ui/react-tooltip": "^1.1.3",
Expand Down
20 changes: 17 additions & 3 deletions src/LLMProviders/chainManager.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { VAULT_VECTOR_STORE_STRATEGY } from "@/constants";
import { CustomModel, SetChainOptions, setChainType } from "@/aiParams";
import ChainFactory, { ChainType, Document } from "@/chainFactory";
import { BUILTIN_CHAT_MODELS, USER_SENDER } from "@/constants";
import { AI_SENDER, BUILTIN_CHAT_MODELS, USER_SENDER } from "@/constants";
import {
ChainRunner,
CopilotPlusChainRunner,
Expand Down Expand Up @@ -261,12 +261,26 @@ export default class ChainManager {
this.validateChatModel();
this.validateChainInitialization();

const chatModel = this.chatModelManager.getChatModel();
const modelName = (chatModel as any).modelName || (chatModel as any).model || "";
const isO1Model = modelName.startsWith("o1");

// Handle ignoreSystemMessage
if (ignoreSystemMessage) {
const effectivePrompt = ChatPromptTemplate.fromMessages([
if (ignoreSystemMessage || isO1Model) {
let effectivePrompt = ChatPromptTemplate.fromMessages([
new MessagesPlaceholder("history"),
HumanMessagePromptTemplate.fromTemplate("{input}"),
]);

// TODO: hack for o1 models, to be removed when they support system prompt
if (isO1Model) {
// Temporary fix:for o1-xx model need to covert systemMessage to aiMessage
effectivePrompt = ChatPromptTemplate.fromMessages([
[AI_SENDER, getSystemPrompt() || ""],
effectivePrompt,
]);
}

this.setChain(getChainType(), {
prompt: effectivePrompt,
});
Expand Down
46 changes: 31 additions & 15 deletions src/LLMProviders/chatModelManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,12 @@ export default class ChatModelManager {

private getModelConfig(customModel: CustomModel): ModelConfig {
const settings = getSettings();

// Check if the model starts with "o1"
const modelName = customModel.name;
const isO1Model = modelName.startsWith("o1");
const baseConfig: ModelConfig = {
modelName: customModel.name,
modelName: modelName,
temperature: settings.temperature,
streaming: true,
maxRetries: 3,
Expand All @@ -82,19 +86,19 @@ export default class ChatModelManager {
[K in keyof ChatProviderConstructMap]: ConstructorParameters<ChatProviderConstructMap[K]>[0];
} = {
[ChatModelProviders.OPENAI]: {
modelName: customModel.name,
modelName: modelName,
openAIApiKey: getDecryptedKey(customModel.apiKey || settings.openAIApiKey),
// @ts-ignore
openAIOrgId: getDecryptedKey(settings.openAIOrgId),
maxTokens: settings.maxTokens,
configuration: {
baseURL: customModel.baseUrl,
fetch: customModel.enableCors ? safeFetch : undefined,
},
// @ts-ignore
openAIOrgId: getDecryptedKey(settings.openAIOrgId),
...this.handleOpenAIExtraArgs(isO1Model, settings.maxTokens, settings.temperature),
},
[ChatModelProviders.ANTHROPIC]: {
anthropicApiKey: getDecryptedKey(customModel.apiKey || settings.anthropicApiKey),
modelName: customModel.name,
modelName: modelName,
anthropicApiUrl: customModel.baseUrl,
clientOptions: {
// Required to bypass CORS restrictions
Expand All @@ -103,7 +107,6 @@ export default class ChatModelManager {
},
},
[ChatModelProviders.AZURE_OPENAI]: {
maxTokens: settings.maxTokens,
azureOpenAIApiKey: getDecryptedKey(customModel.apiKey || settings.azureOpenAIApiKey),
azureOpenAIApiInstanceName: settings.azureOpenAIApiInstanceName,
azureOpenAIApiDeploymentName: settings.azureOpenAIApiDeploymentName,
Expand All @@ -112,14 +115,15 @@ export default class ChatModelManager {
baseURL: customModel.baseUrl,
fetch: customModel.enableCors ? safeFetch : undefined,
},
...this.handleOpenAIExtraArgs(isO1Model, settings.maxTokens, settings.temperature),
},
[ChatModelProviders.COHEREAI]: {
apiKey: getDecryptedKey(customModel.apiKey || settings.cohereApiKey),
model: customModel.name,
model: modelName,
},
[ChatModelProviders.GOOGLE]: {
apiKey: getDecryptedKey(customModel.apiKey || settings.googleApiKey),
model: customModel.name,
modelName: modelName,
safetySettings: [
{
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
Expand All @@ -141,7 +145,7 @@ export default class ChatModelManager {
baseUrl: customModel.baseUrl,
},
[ChatModelProviders.OPENROUTERAI]: {
modelName: customModel.name,
modelName: modelName,
openAIApiKey: getDecryptedKey(customModel.apiKey || settings.openRouterAiApiKey),
configuration: {
baseURL: customModel.baseUrl || "https://openrouter.ai/api/v1",
Expand All @@ -150,33 +154,33 @@ export default class ChatModelManager {
},
[ChatModelProviders.GROQ]: {
apiKey: getDecryptedKey(customModel.apiKey || settings.groqApiKey),
modelName: customModel.name,
modelName: modelName,
},
[ChatModelProviders.OLLAMA]: {
// ChatOllama has `model` instead of `modelName`!!
model: customModel.name,
model: modelName,
// @ts-ignore
apiKey: customModel.apiKey || "default-key",
// MUST NOT use /v1 in the baseUrl for ollama
baseUrl: customModel.baseUrl || "http://localhost:11434",
},
[ChatModelProviders.LM_STUDIO]: {
modelName: customModel.name,
modelName: modelName,
openAIApiKey: customModel.apiKey || "default-key",
configuration: {
baseURL: customModel.baseUrl || "http://localhost:1234/v1",
fetch: customModel.enableCors ? safeFetch : undefined,
},
},
[ChatModelProviders.OPENAI_FORMAT]: {
modelName: customModel.name,
modelName: modelName,
openAIApiKey: getDecryptedKey(customModel.apiKey || settings.openAIApiKey),
maxTokens: settings.maxTokens,
configuration: {
baseURL: customModel.baseUrl,
fetch: customModel.enableCors ? safeFetch : undefined,
dangerouslyAllowBrowser: true,
},
...this.handleOpenAIExtraArgs(isO1Model, settings.maxTokens, settings.temperature),
},
};

Expand All @@ -186,6 +190,18 @@ export default class ChatModelManager {
return { ...baseConfig, ...selectedProviderConfig };
}

private handleOpenAIExtraArgs(isO1Model: boolean, maxTokens: number, temperature: number) {
return isO1Model
? {
maxCompletionTokens: maxTokens,
temperature: 1,
}
: {
maxTokens: maxTokens,
temperature: temperature,
};
}

// Build a map of modelKey to model config
public buildModelMap() {
const activeModels = getSettings().activeModels;
Expand Down

0 comments on commit a75ec96

Please sign in to comment.