Skip to content

Commit

Permalink
Mistral embedding engine support (#2667)
Browse files Browse the repository at this point in the history
* add mistral embedding engine support

* remove console log + fix data handling onboarding

* update data handling description

---------

Co-authored-by: Timothy Carambat <[email protected]>
  • Loading branch information
shatfield4 and timothycarambat authored Nov 21, 2024
1 parent 246152c commit 9f38b93
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
export default function MistralAiOptions({ settings }) {
return (
<div className="w-full flex flex-col gap-y-4">
<div className="w-full flex items-center gap-[36px] mt-1.5">
<div className="flex flex-col w-60">
<label className="text-white text-sm font-semibold block mb-3">
API Key
</label>
<input
type="password"
name="MistralAiApiKey"
className="bg-theme-settings-input-bg text-white placeholder:text-theme-settings-input-placeholder text-sm rounded-lg focus:outline-primary-button active:outline-primary-button outline-none block w-full p-2.5"
placeholder="Mistral AI API Key"
defaultValue={settings?.MistralApiKey ? "*".repeat(20) : ""}
required={true}
autoComplete="off"
spellCheck={false}
/>
</div>
<div className="flex flex-col w-60">
<label className="text-white text-sm font-semibold block mb-3">
Model Preference
</label>
<select
name="EmbeddingModelPref"
required={true}
defaultValue={settings?.EmbeddingModelPref}
className="bg-theme-settings-input-bg border-gray-500 text-white text-sm rounded-lg block w-full p-2.5"
>
<optgroup label="Available embedding models">
{[
"mistral-embed",
].map((model) => {
return (
<option key={model} value={model}>
{model}
</option>
);
})}
</optgroup>
</select>
</div>
</div>
</div>
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import CohereLogo from "@/media/llmprovider/cohere.png";
import VoyageAiLogo from "@/media/embeddingprovider/voyageai.png";
import LiteLLMLogo from "@/media/llmprovider/litellm.png";
import GenericOpenAiLogo from "@/media/llmprovider/generic-openai.png";
import MistralAiLogo from "@/media/llmprovider/mistral.jpeg";

import PreLoader from "@/components/Preloader";
import ChangeWarningModal from "@/components/ChangeWarning";
Expand All @@ -33,6 +34,7 @@ import { useModal } from "@/hooks/useModal";
import ModalWrapper from "@/components/ModalWrapper";
import CTAButton from "@/components/lib/CTAButton";
import { useTranslation } from "react-i18next";
import MistralAiOptions from "@/components/EmbeddingSelection/MistralAiOptions";

const EMBEDDERS = [
{
Expand Down Expand Up @@ -100,6 +102,13 @@ const EMBEDDERS = [
options: (settings) => <LiteLLMOptions settings={settings} />,
description: "Run powerful embedding models from LiteLLM.",
},
{
name: "Mistral AI",
value: "mistral",
logo: MistralAiLogo,
options: (settings) => <MistralAiOptions settings={settings} />,
description: "Run powerful embedding models from Mistral AI.",
},
{
name: "Generic OpenAI",
value: "generic-openai",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,13 @@ export const EMBEDDING_ENGINE_PRIVACY = {
],
logo: VoyageAiLogo,
},
mistral: {
name: "Mistral AI",
description: [
"Data sent to Mistral AI's servers is shared according to the terms of service of https://mistral.ai.",
],
logo: MistralLogo,
},
litellm: {
name: "LiteLLM",
description: [
Expand Down
43 changes: 43 additions & 0 deletions server/utils/EmbeddingEngines/mistral/index.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
class MistralEmbedder {
constructor() {
if (!process.env.MISTRAL_API_KEY)
throw new Error("No Mistral API key was set.");

const { OpenAI: OpenAIApi } = require("openai");
this.openai = new OpenAIApi({
baseURL: "https://api.mistral.ai/v1",
apiKey: process.env.MISTRAL_API_KEY ?? null,
});
this.model = process.env.EMBEDDING_MODEL_PREF || "mistral-embed";
}

async embedTextInput(textInput) {
try {
const response = await this.openai.embeddings.create({
model: this.model,
input: textInput,
});
return response?.data[0]?.embedding || [];
} catch (error) {
console.error("Failed to get embedding from Mistral.", error.message);
return [];
}
}

async embedChunks(textChunks = []) {
try {
const response = await this.openai.embeddings.create({
model: this.model,
input: textChunks,
});
return response?.data?.map((emb) => emb.embedding) || [];
} catch (error) {
console.error("Failed to get embeddings from Mistral.", error.message);
return new Array(textChunks.length).fill([]);
}
}
}

module.exports = {
MistralEmbedder,
};
3 changes: 3 additions & 0 deletions server/utils/helpers/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ function getEmbeddingEngineSelection() {
case "litellm":
const { LiteLLMEmbedder } = require("../EmbeddingEngines/liteLLM");
return new LiteLLMEmbedder();
case "mistral":
const { MistralEmbedder } = require("../EmbeddingEngines/mistral");
return new MistralEmbedder();
case "generic-openai":
const {
GenericOpenAiEmbedder,
Expand Down
1 change: 1 addition & 0 deletions server/utils/helpers/updateENV.js
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,7 @@ function supportedEmbeddingModel(input = "") {
"voyageai",
"litellm",
"generic-openai",
"mistral",
];
return supported.includes(input)
? null
Expand Down

0 comments on commit 9f38b93

Please sign in to comment.