Skip to content

Commit

Permalink
Fix embedding api keys
Browse files Browse the repository at this point in the history
  • Loading branch information
logancyang committed Feb 22, 2024
1 parent 6085d9d commit ee33cae
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
4 changes: 3 additions & 1 deletion src/LLMProviders/chainManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ export default class ChainManager {
private vectorStore: MemoryVectorStore;
private promptManager: PromptManager;
private embeddingsManager: EmbeddingsManager;
private encryptionService: EncryptionService;
public chatModelManager: ChatModelManager;
public langChainParams: LangChainParams;
public memoryManager: MemoryManager;
Expand All @@ -55,6 +56,7 @@ export default class ChainManager {
// Instantiate singletons
this.langChainParams = langChainParams;
this.memoryManager = MemoryManager.getInstance(this.langChainParams);
this.encryptionService = encryptionService;
this.chatModelManager = ChatModelManager.getInstance(this.langChainParams, encryptionService);
this.promptManager = PromptManager.getInstance(this.langChainParams);
this.createChainWithNewModel(this.langChainParams.modelDisplayName);
Expand Down Expand Up @@ -137,7 +139,7 @@ export default class ChainManager {
this.validateChainType(chainType);
// MUST set embeddingsManager when switching to QA mode
if (chainType === ChainType.RETRIEVAL_QA_CHAIN) {
this.embeddingsManager = EmbeddingsManager.getInstance(this.langChainParams);
this.embeddingsManager = EmbeddingsManager.getInstance(this.langChainParams, this.encryptionService);
}

// Get chatModel, memory, prompt, and embeddingAPI from respective managers
Expand Down
20 changes: 12 additions & 8 deletions src/LLMProviders/embeddingManager.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { LangChainParams } from '@/aiParams';
import { ModelProviders } from '@/constants';
import EncryptionService from '@/encryptionService';
import { ProxyOpenAIEmbeddings } from '@/langchainWrappers';
import { CohereEmbeddings } from "@langchain/cohere";
import { Embeddings } from "langchain/embeddings/base";
Expand All @@ -9,19 +10,22 @@ import { OpenAIEmbeddings } from "langchain/embeddings/openai";
export default class EmbeddingManager {
private static instance: EmbeddingManager;
private constructor(
private langChainParams: LangChainParams
private langChainParams: LangChainParams,
private encryptionService: EncryptionService,
) {}

static getInstance(
langChainParams: LangChainParams
langChainParams: LangChainParams,
encryptionService: EncryptionService,
): EmbeddingManager {
if (!EmbeddingManager.instance) {
EmbeddingManager.instance = new EmbeddingManager(langChainParams);
EmbeddingManager.instance = new EmbeddingManager(langChainParams, encryptionService);
}
return EmbeddingManager.instance;
}

getEmbeddingsAPI(): Embeddings | undefined {
const decrypt = (key: string) => this.encryptionService.getDecryptedKey(key);
const {
openAIApiKey,
azureOpenAIApiKey,
Expand All @@ -36,15 +40,15 @@ export default class EmbeddingManager {
openAIEmbeddingProxyBaseUrl ?
new ProxyOpenAIEmbeddings({
modelName: openAIEmbeddingProxyModelName || this.langChainParams.embeddingModel,
openAIApiKey,
openAIApiKey: decrypt(openAIApiKey),
maxRetries: 3,
maxConcurrency: 3,
timeout: 10000,
openAIEmbeddingProxyBaseUrl,
}) :
new OpenAIEmbeddings({
modelName: openAIEmbeddingProxyModelName || this.langChainParams.embeddingModel,
openAIApiKey,
openAIApiKey: decrypt(openAIApiKey),
maxRetries: 3,
maxConcurrency: 3,
timeout: 10000,
Expand All @@ -60,20 +64,20 @@ export default class EmbeddingManager {
break;
case ModelProviders.HUGGINGFACE:
return new HuggingFaceInferenceEmbeddings({
apiKey: this.langChainParams.huggingfaceApiKey,
apiKey: decrypt(this.langChainParams.huggingfaceApiKey),
maxRetries: 3,
maxConcurrency: 3,
});
case ModelProviders.COHEREAI:
return new CohereEmbeddings({
apiKey: this.langChainParams.cohereApiKey,
apiKey: decrypt(this.langChainParams.cohereApiKey),
maxRetries: 3,
maxConcurrency: 3,
});
case ModelProviders.AZURE_OPENAI:
if (azureOpenAIApiKey) {
return new OpenAIEmbeddings({
azureOpenAIApiKey,
azureOpenAIApiKey: decrypt(azureOpenAIApiKey),
azureOpenAIApiInstanceName,
azureOpenAIApiDeploymentName: azureOpenAIApiEmbeddingDeploymentName,
azureOpenAIApiVersion,
Expand Down

0 comments on commit ee33cae

Please sign in to comment.