Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mistral[minor]: Start Mistral 1.0.0 migration #6514

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/langchain-mistralai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"license": "MIT",
"dependencies": {
"@langchain/core": ">=0.2.21 <0.3.0",
"@mistralai/mistralai": "^0.4.0",
"@mistralai/mistralai": "^1.0.2",
"uuid": "^10.0.0",
"zod": "^3.22.4",
"zod-to-json-schema": "^3.22.4"
Expand Down
153 changes: 76 additions & 77 deletions libs/langchain-mistralai/src/chat_models.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,27 @@
import { v4 as uuidv4 } from "uuid";
import {
ChatCompletionResponse,
Function as MistralAIFunction,
ToolCalls as MistralAIToolCalls,
ResponseFormat,
ChatCompletionResponseChunk,
ChatRequest,
Tool as MistralAITool,
Message as MistralAIMessage,
TokenUsage as MistralAITokenUsage,
Mistral as MistralClient,
// MistralChatCompletionResponse,
// Function as MistralAIFunction,
// ToolCalls as MistralAIToolCalls,
// ResponseFormat,
// MistralChatCompletionResponseChunk,
// ChatRequest,
// Tool as MistralAITool,
// Message as MistralAIMessage,
// TokenUsage as MistralAITokenUsage,
} from "@mistralai/mistralai";
import {
ChatCompletionRequest as MistralAIChatCompletionRequest,
ToolChoice as MistralAIToolChoice,
Messages as MistralAIMessage,
} from "@mistralai/mistralai/models/components/chatcompletionrequest.js";
import { Tool as MistralAITool } from "@mistralai/mistralai/models/components/tool.js";
import { ToolCall as MistralAIToolCall } from "@mistralai/mistralai/models/components/toolcall.js";
import { ChatCompletionStreamRequest as MistralChatCompletionStreamRequest } from "@mistralai/mistralai/models/components/chatcompletionstreamrequest.js";
import { UsageInfo as MistralAITokenUsage } from "@mistralai/mistralai/models/components/usageinfo.js";
import { CompletionEvent as MistralAIChatCompletionEvent } from "@mistralai/mistralai/models/components/completionevent.js";
import { ChatCompletionResponse as MistralChatCompletionResponse } from "@mistralai/mistralai/models/components/chatcompletionresponse.js";
import {
MessageType,
type BaseMessage,
Expand Down Expand Up @@ -73,14 +85,7 @@ interface TokenUsage {
totalTokens?: number;
}

export type MistralAIToolChoice = "auto" | "any" | "none";

type MistralAIToolInput = { type: string; function: MistralAIFunction };

type ChatMistralAIToolType =
| MistralAIToolInput
| MistralAITool
| BindToolsInput;
type ChatMistralAIToolType = MistralAIToolCall | MistralAITool | BindToolsInput;

export interface ChatMistralAICallOptions
extends Omit<BaseLanguageModelCallOptions, "stop"> {
Expand Down Expand Up @@ -200,14 +205,14 @@ function convertMessagesToMistralMessages(
);
};

const getTools = (message: BaseMessage): MistralAIToolCalls[] | undefined => {
const getTools = (message: BaseMessage): MistralAIToolCall[] | undefined => {
if (isAIMessage(message) && !!message.tool_calls?.length) {
return message.tool_calls
.map((toolCall) => ({
...toolCall,
id: _convertToolCallIdToMistralCompatible(toolCall.id ?? ""),
}))
.map(convertLangChainToolCallToOpenAI) as MistralAIToolCalls[];
.map(convertLangChainToolCallToOpenAI) as MistralAIToolCall[];
}
if (!message.additional_kwargs.tool_calls?.length) {
return undefined;
Expand Down Expand Up @@ -244,16 +249,19 @@ function convertMessagesToMistralMessages(
}

function mistralAIResponseToChatMessage(
choice: ChatCompletionResponse["choices"][0],
choice: NonNullable<MistralChatCompletionResponse["choices"]>[0],
usage?: MistralAITokenUsage
): BaseMessage {
const { message } = choice;
if (message === undefined) {
throw new Error("No message found in response");
}
// MistralAI SDK does not include tool_calls in the non
// streaming return type, so we need to extract it like this
// to satisfy typescript.
let rawToolCalls: MistralAIToolCalls[] = [];
let rawToolCalls: MistralAIToolCall[] = [];
if ("tool_calls" in message && Array.isArray(message.tool_calls)) {
rawToolCalls = message.tool_calls as MistralAIToolCalls[];
rawToolCalls = message.tool_calls as MistralAIToolCall[];
}
switch (message.role) {
case "assistant": {
Expand All @@ -275,19 +283,12 @@ function mistralAIResponseToChatMessage(
content: message.content ?? "",
tool_calls: toolCalls,
invalid_tool_calls: invalidToolCalls,
additional_kwargs: {
tool_calls: rawToolCalls.length
? rawToolCalls.map((toolCall) => ({
...toolCall,
type: "function",
}))
: undefined,
},
additional_kwargs: {},
usage_metadata: usage
? {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
input_tokens: usage.promptTokens,
output_tokens: usage.completionTokens,
total_tokens: usage.totalTokens,
}
: undefined,
});
Expand All @@ -301,7 +302,7 @@ function _convertDeltaToMessageChunk(
delta: {
role?: string | undefined;
content?: string | undefined;
tool_calls?: MistralAIToolCalls[] | undefined;
tool_calls?: MistralAIToolCall[] | undefined;
},
usage?: MistralAITokenUsage | null
) {
Expand All @@ -311,9 +312,9 @@ function _convertDeltaToMessageChunk(
content: "",
usage_metadata: usage
? {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
input_tokens: usage.promptTokens,
output_tokens: usage.completionTokens,
total_tokens: usage.totalTokens,
}
: undefined,
});
Expand All @@ -325,7 +326,7 @@ function _convertDeltaToMessageChunk(
// need to insert it here.
const rawToolCallChunksWithIndex = delta.tool_calls?.length
? delta.tool_calls?.map(
(toolCall, index): OpenAIToolCall => ({
(toolCall, index): MistralAIToolCall & { index: number } => ({
...toolCall,
index,
id: toolCall.id ?? uuidv4().replace(/-/g, ""),
Expand All @@ -342,13 +343,15 @@ function _convertDeltaToMessageChunk(
let additional_kwargs;
const toolCallChunks: ToolCallChunk[] = [];
if (rawToolCallChunksWithIndex !== undefined) {
additional_kwargs = {
tool_calls: rawToolCallChunksWithIndex,
};
for (const rawToolCallChunk of rawToolCallChunksWithIndex) {
const rawArgs = rawToolCallChunk.function?.arguments;
const args =
rawArgs === undefined || typeof rawArgs === "string"
? rawArgs
: JSON.stringify(rawArgs);
toolCallChunks.push({
name: rawToolCallChunk.function?.name,
args: rawToolCallChunk.function?.arguments,
args,
id: rawToolCallChunk.id,
index: rawToolCallChunk.index,
type: "tool_call_chunk",
Expand All @@ -367,9 +370,9 @@ function _convertDeltaToMessageChunk(
additional_kwargs,
usage_metadata: usage
? {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
input_tokens: usage.promptTokens,
output_tokens: usage.completionTokens,
total_tokens: usage.totalTokens,
}
: undefined,
});
Expand Down Expand Up @@ -791,7 +794,6 @@ export class ChatMistralAI<
this.temperature = fields?.temperature ?? this.temperature;
this.topP = fields?.topP ?? this.topP;
this.maxTokens = fields?.maxTokens ?? this.maxTokens;
this.safeMode = fields?.safeMode ?? this.safeMode;
this.safePrompt = fields?.safePrompt ?? this.safePrompt;
this.randomSeed = fields?.seed ?? fields?.randomSeed ?? this.seed;
this.seed = this.randomSeed;
Expand Down Expand Up @@ -820,22 +822,21 @@ export class ChatMistralAI<
*/
invocationParams(
options?: this["ParsedCallOptions"]
): Omit<ChatRequest, "messages"> {
): Omit<MistralAIChatCompletionRequest, "messages"> {
const { response_format, tools, tool_choice } = options ?? {};
const mistralAITools: Array<MistralAITool> | undefined = tools?.length
? _convertToolToMistralTool(tools)
: undefined;
const params: Omit<ChatRequest, "messages"> = {
const params: Omit<MistralAIChatCompletionRequest, "messages"> = {
model: this.model,
tools: mistralAITools,
temperature: this.temperature,
maxTokens: this.maxTokens,
topP: this.topP,
randomSeed: this.seed,
safeMode: this.safeMode,
safePrompt: this.safePrompt,
toolChoice: tool_choice,
responseFormat: response_format as ResponseFormat,
responseFormat: response_format,
};
return params;
}
Expand All @@ -856,38 +857,45 @@ export class ChatMistralAI<
* @returns {Promise<MistralAIChatCompletionResult | AsyncGenerator<MistralAIChatCompletionResult>>} The response from the MistralAI API.
*/
async completionWithRetry(
input: ChatRequest,
input: MistralChatCompletionStreamRequest,
streaming: true
): Promise<AsyncGenerator<ChatCompletionResponseChunk>>;
): Promise<AsyncIterable<MistralAIChatCompletionEvent>>;

async completionWithRetry(
input: ChatRequest,
input: MistralAIChatCompletionRequest,
streaming: false
): Promise<ChatCompletionResponse>;
): Promise<MistralChatCompletionResponse>;

async completionWithRetry(
input: ChatRequest,
input: MistralAIChatCompletionRequest | MistralChatCompletionStreamRequest,
streaming: boolean
): Promise<
ChatCompletionResponse | AsyncGenerator<ChatCompletionResponseChunk>
MistralChatCompletionResponse | AsyncIterable<MistralAIChatCompletionEvent>
> {
const { MistralClient } = await this.imports();
const client = new MistralClient(this.apiKey, this.endpoint);
const client = new MistralClient({
apiKey: this.apiKey,
serverURL: this.endpoint,
});

return this.caller.call(async () => {
try {
let res:
| ChatCompletionResponse
| AsyncGenerator<ChatCompletionResponseChunk>;
| MistralChatCompletionResponse
| AsyncIterable<MistralAIChatCompletionEvent>;
if (streaming) {
res = client.chatStream(input);
res = await client.chat.stream(input);
} else {
res = await client.chat(input);
res = await client.chat.complete(input);
}
return res;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
if (e.message?.includes("status: 400")) {
console.log(e, e.status, e.code, e.statusCode, e.message);
if (
e.message?.includes("status: 400") ||
e.message?.toLowerCase().includes("status 400") ||
e.message?.includes("validation failed")
) {
e.status = 400;
}
throw e;
Expand Down Expand Up @@ -936,11 +944,8 @@ export class ChatMistralAI<
// Not streaming, so we can just call the API once.
const response = await this.completionWithRetry(input, false);

const {
completion_tokens: completionTokens,
prompt_tokens: promptTokens,
total_tokens: totalTokens,
} = response?.usage ?? {};
const { completionTokens, promptTokens, totalTokens } =
response?.usage ?? {};

if (completionTokens) {
tokenUsage.completionTokens =
Expand Down Expand Up @@ -968,8 +973,8 @@ export class ChatMistralAI<
text,
message: mistralAIResponseToChatMessage(part, response?.usage),
};
if (part.finish_reason) {
generation.generationInfo = { finish_reason: part.finish_reason };
if (part.finishReason) {
generation.generationInfo = { finishReason: part.finishReason };
}
generations.push(generation);
}
Expand All @@ -992,7 +997,7 @@ export class ChatMistralAI<
};

const streamIterable = await this.completionWithRetry(input, true);
for await (const data of streamIterable) {
for await (const { data } of streamIterable) {
if (options.signal?.aborted) {
throw new Error("AbortError");
}
Expand Down Expand Up @@ -1191,12 +1196,6 @@ export class ChatMistralAI<
parsedWithFallback,
]);
}

/** @ignore */
private async imports() {
const { default: MistralClient } = await import("@mistralai/mistralai");
return { MistralClient };
}
}

function isZodSchema<
Expand Down
11 changes: 10 additions & 1 deletion yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -12078,7 +12078,7 @@ __metadata:
"@langchain/core": ">=0.2.21 <0.3.0"
"@langchain/scripts": ~0.0.20
"@langchain/standard-tests": 0.0.0
"@mistralai/mistralai": ^0.4.0
"@mistralai/mistralai": ^1.0.2
"@swc/core": ^1.3.90
"@swc/jest": ^0.2.29
"@tsconfig/recommended": ^1.0.3
Expand Down Expand Up @@ -12733,6 +12733,15 @@ __metadata:
languageName: node
linkType: hard

"@mistralai/mistralai@npm:^1.0.2":
version: 1.0.2
resolution: "@mistralai/mistralai@npm:1.0.2"
peerDependencies:
zod: ">= 3"
checksum: 9d2ed8d96d20791571cf9ad2a47c4fb2a8a991543ce95a7daeba8f146cca73cac8e05b9da1e96b16d84a08c3c9bcea90c3e4909ce55dd366ea4931b2e18a28bb
languageName: node
linkType: hard

"@mixedbread-ai/sdk@npm:^2.2.3":
version: 2.2.6
resolution: "@mixedbread-ai/sdk@npm:2.2.6"
Expand Down
Loading