diff --git a/package-lock.json b/package-lock.json index 72a9c632..5e3d6d6a 100644 --- a/package-lock.json +++ b/package-lock.json @@ -27,6 +27,7 @@ "crypto-js": "^4.1.1", "esbuild-plugin-svg": "^0.1.0", "eventsource-parser": "^1.0.0", + "jotai": "^2.10.3", "koa": "^2.14.2", "koa-proxies": "^0.12.3", "langchain": "^0.3.2", @@ -11340,6 +11341,26 @@ "url": "https://github.com/chalk/supports-color?sponsor=1" } }, + "node_modules/jotai": { + "version": "2.10.3", + "resolved": "https://registry.npmjs.org/jotai/-/jotai-2.10.3.tgz", + "integrity": "sha512-Nnf4IwrLhNfuz2JOQLI0V/AgwcpxvVy8Ec8PidIIDeRi4KCFpwTFIpHAAcU+yCgnw/oASYElq9UY0YdUUegsSA==", + "engines": { + "node": ">=12.20.0" + }, + "peerDependencies": { + "@types/react": ">=17.0.0", + "react": ">=17.0.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "react": { + "optional": true + } + } + }, "node_modules/js-base64": { "version": "3.7.2", "resolved": "https://registry.npmjs.org/js-base64/-/js-base64-3.7.2.tgz", diff --git a/package.json b/package.json index 0669a092..25b5f8a9 100644 --- a/package.json +++ b/package.json @@ -79,6 +79,7 @@ "crypto-js": "^4.1.1", "esbuild-plugin-svg": "^0.1.0", "eventsource-parser": "^1.0.0", + "jotai": "^2.10.3", "koa": "^2.14.2", "koa-proxies": "^0.12.3", "langchain": "^0.3.2", diff --git a/src/LLMProviders/brevilabsClient.ts b/src/LLMProviders/brevilabsClient.ts index 081c9a48..e7baab2d 100644 --- a/src/LLMProviders/brevilabsClient.ts +++ b/src/LLMProviders/brevilabsClient.ts @@ -1,6 +1,7 @@ import { BREVILABS_API_BASE_URL } from "@/constants"; import { Notice } from "obsidian"; - +import { getSettings } from "@/settings/model"; +import { getDecryptedKey } from "@/encryptionService"; export interface BrocaResponse { response: { tool_calls: Array<{ @@ -59,23 +60,16 @@ export interface Youtube4llmResponse { export class BrevilabsClient { private static instance: BrevilabsClient; - private licenseKey: string; - private options: any; - - private constructor(licenseKey: string, options?: { debug?: boolean }) { - this.licenseKey = licenseKey; - this.options = options; - } - static getInstance(licenseKey: string, options?: { debug?: boolean }): BrevilabsClient { + static getInstance(): BrevilabsClient { if (!BrevilabsClient.instance) { - BrevilabsClient.instance = new BrevilabsClient(licenseKey, options); + BrevilabsClient.instance = new BrevilabsClient(); } return BrevilabsClient.instance; } private checkLicenseKey() { - if (!this.licenseKey) { + if (!getSettings().plusLicenseKey) { new Notice( "Copilot Plus license key not found. Please enter your license key in the settings." ); @@ -98,13 +92,13 @@ export class BrevilabsClient { method, headers: { "Content-Type": "application/json", - Authorization: `Bearer ${this.licenseKey}`, + Authorization: `Bearer ${getDecryptedKey(getSettings().plusLicenseKey)}`, }, ...(method === "POST" && { body: JSON.stringify(body) }), }); const data = await response.json(); - if (this.options?.debug) { + if (getSettings().debug) { console.log(`==== ${endpoint} request ====:`, data); } diff --git a/src/LLMProviders/chainManager.ts b/src/LLMProviders/chainManager.ts index 2f806df8..65cd8478 100644 --- a/src/LLMProviders/chainManager.ts +++ b/src/LLMProviders/chainManager.ts @@ -1,7 +1,7 @@ -import { CustomModel, LangChainParams, SetChainOptions } from "@/aiParams"; +import { VAULT_VECTOR_STORE_STRATEGY } from "@/constants"; +import { CustomModel, SetChainOptions, setChainType } from "@/aiParams"; import ChainFactory, { ChainType, Document } from "@/chainFactory"; import { BUILTIN_CHAT_MODELS, USER_SENDER } from "@/constants"; -import EncryptionService from "@/encryptionService"; import { ChainRunner, CopilotPlusChainRunner, @@ -9,7 +9,6 @@ import { VaultQAChainRunner, } from "@/LLMProviders/chainRunner"; import { HybridRetriever } from "@/search/hybridRetriever"; -import { CopilotSettings } from "@/settings/SettingsPage"; import { ChatMessage } from "@/sharedState"; import { isSupportedChain } from "@/utils"; import VectorStoreManager from "@/VectorStoreManager"; @@ -25,15 +24,18 @@ import ChatModelManager from "./chatModelManager"; import EmbeddingsManager from "./embeddingManager"; import MemoryManager from "./memoryManager"; import PromptManager from "./promptManager"; +import { + getModelKey, + getChainType, + subscribeToModelKeyChange, + subscribeToChainTypeChange, +} from "@/aiParams"; +import { getSettings, subscribeToSettingsChange } from "@/settings/model"; export default class ChainManager { private static chain: RunnableSequence; private static retrievalChain: RunnableSequence; - private settings: CopilotSettings; - private encryptionService: EncryptionService; - private langChainParams: LangChainParams; - public app: App; public vectorStoreManager: VectorStoreManager; public chatModelManager: ChatModelManager; @@ -43,41 +45,26 @@ export default class ChainManager { public brevilabsClient: BrevilabsClient; public static retrievedDocuments: Document[] = []; - constructor( - app: App, - getLangChainParams: () => LangChainParams, - encryptionService: EncryptionService, - settings: CopilotSettings, - vectorStoreManager: VectorStoreManager, - brevilabsClient: BrevilabsClient - ) { + constructor(app: App, vectorStoreManager: VectorStoreManager, brevilabsClient: BrevilabsClient) { // Instantiate singletons this.app = app; - this.langChainParams = getLangChainParams(); - this.settings = settings; this.vectorStoreManager = vectorStoreManager; - this.memoryManager = MemoryManager.getInstance(this.getLangChainParams(), settings.debug); - this.encryptionService = encryptionService; - this.chatModelManager = ChatModelManager.getInstance( - () => this.getLangChainParams(), - encryptionService, - this.settings.activeModels - ); + this.memoryManager = MemoryManager.getInstance(); + this.chatModelManager = ChatModelManager.getInstance(); this.embeddingsManager = this.vectorStoreManager.getEmbeddingsManager(); - this.promptManager = PromptManager.getInstance(this.getLangChainParams()); + this.promptManager = PromptManager.getInstance(); this.brevilabsClient = brevilabsClient; - this.createChainWithNewModel(this.getLangChainParams().modelKey); - } - - public getLangChainParams(): LangChainParams { - return this.langChainParams; - } - - public setLangChainParam( - key: K, - value: LangChainParams[K] - ): void { - this.langChainParams[key] = value; + this.createChainWithNewModel(); + subscribeToModelKeyChange(() => this.createChainWithNewModel()); + subscribeToChainTypeChange(() => + this.setChain(getChainType(), { + refreshIndex: + getSettings().indexVaultToVectorStore === VAULT_VECTOR_STORE_STRATEGY.ON_MODE_SWITCH && + (getChainType() === ChainType.VAULT_QA_CHAIN || + getChainType() === ChainType.COPILOT_PLUS_CHAIN), + }) + ); + subscribeToSettingsChange(() => this.createChainWithNewModel()); } static getChain(): RunnableSequence { @@ -103,17 +90,14 @@ export default class ChainManager { private validateChainInitialization() { if (!ChainManager.chain || !isSupportedChain(ChainManager.chain)) { - console.error( - "Chain is not initialized properly, re-initializing chain: ", - this.getLangChainParams().chainType - ); - this.setChain(this.getLangChainParams().chainType, this.getLangChainParams().options); + console.error("Chain is not initialized properly, re-initializing chain: ", getChainType()); + this.setChain(getChainType()); } } private findCustomModel(modelKey: string): CustomModel | undefined { const [name, provider] = modelKey.split("|"); - return this.settings.activeModels.find( + return getSettings().activeModels.find( (model) => model.name === name && model.provider === provider ); } @@ -123,13 +107,11 @@ export default class ChainManager { } /** - * Update the active model and create a new chain - * with the specified model name. - * - * @param {string} newModel - the name of the new model in the dropdown - * @return {void} + * Update the active model and create a new chain with the specified model + * name. */ - createChainWithNewModel(newModelKey: string): void { + createChainWithNewModel(): void { + let newModelKey = getModelKey(); try { let customModel = this.findCustomModel(newModelKey); if (!customModel) { @@ -138,30 +120,15 @@ export default class ChainManager { customModel = BUILTIN_CHAT_MODELS[0]; newModelKey = customModel.name + "|" + customModel.provider; } - this.setLangChainParam("modelKey", newModelKey); this.chatModelManager.setChatModel(customModel); // Must update the chatModel for chain because ChainFactory always // retrieves the old chain without the chatModel change if it exists! // Create a new chain with the new chatModel - this.createChain(this.getLangChainParams().chainType, { - ...this.getLangChainParams().options, - forceNewCreation: true, - }); + this.setChain(getChainType()); console.log(`Setting model to ${newModelKey}`); } catch (error) { console.error("createChainWithNewModel failed: ", error); - console.log("modelKey:", this.getLangChainParams().modelKey); - } - } - - /* Create a new chain, or update chain with new model */ - createChain(chainType: ChainType, options?: SetChainOptions): void { - this.validateChainType(chainType); - try { - this.setChain(chainType, options); - } catch (error) { - new Notice("Error creating chain:", error); - console.error("Error creating chain:", error); + console.log("modelKey:", newModelKey); } } @@ -174,10 +141,7 @@ export default class ChainManager { this.validateChainType(chainType); // Handle index refresh if needed - if ( - options.refreshIndex && - (chainType === ChainType.VAULT_QA_CHAIN || chainType === ChainType.COPILOT_PLUS_CHAIN) - ) { + if (options.refreshIndex) { await this.vectorStoreManager.indexVaultToVectorStore(); } @@ -188,25 +152,14 @@ export default class ChainManager { switch (chainType) { case ChainType.LLM_CHAIN: { - // For initial load of the plugin - if (options.forceNewCreation) { - ChainManager.chain = ChainFactory.createNewLLMChain({ - llm: chatModel, - memory: memory, - prompt: options.prompt || chatPrompt, - abortController: options.abortController, - }) as RunnableSequence; - } else { - // For navigating back to the plugin view - ChainManager.chain = ChainFactory.getLLMChainFromMap({ - llm: chatModel, - memory: memory, - prompt: options.prompt || chatPrompt, - abortController: options.abortController, - }) as RunnableSequence; - } - - this.setLangChainParam("chainType", ChainType.LLM_CHAIN); + ChainManager.chain = ChainFactory.createNewLLMChain({ + llm: chatModel, + memory: memory, + prompt: options.prompt || chatPrompt, + abortController: options.abortController, + }) as RunnableSequence; + + setChainType(ChainType.LLM_CHAIN); break; } @@ -231,10 +184,10 @@ export default class ChainManager { this.brevilabsClient, { minSimilarityScore: 0.01, - maxK: this.settings.maxSourceChunks, + maxK: getSettings().maxSourceChunks, salientTerms: [], }, - options.debug + getSettings().debug ); // Create new conversational retrieval chain @@ -242,14 +195,14 @@ export default class ChainManager { { llm: chatModel, retriever: retriever, - systemMessage: this.getLangChainParams().systemMessage, + systemMessage: getSettings().userSystemPrompt, }, ChainManager.storeRetrieverDocuments.bind(ChainManager), - options.debug + getSettings().debug ); - this.setLangChainParam("chainType", ChainType.VAULT_QA_CHAIN); - if (options.debug) { + setChainType(ChainType.VAULT_QA_CHAIN); + if (getSettings().debug) { console.log("New Vault QA chain with hybrid retriever created for entire vault"); console.log("Set chain:", ChainType.VAULT_QA_CHAIN); } @@ -259,24 +212,14 @@ export default class ChainManager { case ChainType.COPILOT_PLUS_CHAIN: { // TODO: Create new copilotPlusChain with retriever // For initial load of the plugin - if (options.forceNewCreation) { - ChainManager.chain = ChainFactory.createNewLLMChain({ - llm: chatModel, - memory: memory, - prompt: options.prompt || chatPrompt, - abortController: options.abortController, - }) as RunnableSequence; - } else { - // For navigating back to the plugin view - ChainManager.chain = ChainFactory.getLLMChainFromMap({ - llm: chatModel, - memory: memory, - prompt: options.prompt || chatPrompt, - abortController: options.abortController, - }) as RunnableSequence; - } - - this.setLangChainParam("chainType", ChainType.COPILOT_PLUS_CHAIN); + ChainManager.chain = ChainFactory.createNewLLMChain({ + llm: chatModel, + memory: memory, + prompt: options.prompt || chatPrompt, + abortController: options.abortController, + }) as RunnableSequence; + + setChainType(ChainType.COPILOT_PLUS_CHAIN); break; } @@ -287,7 +230,8 @@ export default class ChainManager { } private getChainRunner(): ChainRunner { - switch (this.getLangChainParams().chainType) { + const chainType = getChainType(); + switch (chainType) { case ChainType.LLM_CHAIN: return new LLMChainRunner(this); case ChainType.VAULT_QA_CHAIN: @@ -295,7 +239,7 @@ export default class ChainManager { case ChainType.COPILOT_PLUS_CHAIN: return new CopilotPlusChainRunner(this); default: - throw new Error(`Unsupported chain type: ${this.getLangChainParams().chainType}`); + throw new Error(`Unsupported chain type: ${chainType}`); } } @@ -323,8 +267,7 @@ export default class ChainManager { new MessagesPlaceholder("history"), HumanMessagePromptTemplate.fromTemplate("{input}"), ]); - this.setChain(this.getLangChainParams().chainType, { - ...this.getLangChainParams().options, + this.setChain(getChainType(), { prompt: effectivePrompt, }); } diff --git a/src/LLMProviders/chainRunner.ts b/src/LLMProviders/chainRunner.ts index a4b81bab..e5b049e9 100644 --- a/src/LLMProviders/chainRunner.ts +++ b/src/LLMProviders/chainRunner.ts @@ -10,6 +10,7 @@ import { import { Notice } from "obsidian"; import ChainManager from "./chainManager"; import { COPILOT_TOOL_NAMES, IntentAnalyzer } from "./intentAnalyzer"; +import { getSettings } from "@/settings/model"; export interface ChainRunner { run( @@ -224,7 +225,7 @@ class CopilotPlusChainRunner extends BaseChainRunner { const messages: any[] = []; // Add system message if available - const systemMessage = this.chainManager.getLangChainParams().systemMessage; + const systemMessage = getSettings().userSystemPrompt; let fullSystemMessage = systemMessage || ""; // Add chat history context to system message if exists @@ -394,7 +395,7 @@ class CopilotPlusChainRunner extends BaseChainRunner { const qaPrompt = await this.chainManager.promptManager.getQAPrompt({ question: standaloneQuestion, context: context, - systemMessage: this.chainManager.getLangChainParams().systemMessage, + systemMessage: getSettings().userSystemPrompt, }); fullAIResponse = await this.streamMultimodalResponse( diff --git a/src/LLMProviders/chatModelManager.ts b/src/LLMProviders/chatModelManager.ts index 991043da..a611d7dd 100644 --- a/src/LLMProviders/chatModelManager.ts +++ b/src/LLMProviders/chatModelManager.ts @@ -1,6 +1,7 @@ -import { CustomModel, LangChainParams, ModelConfig } from "@/aiParams"; +import { CustomModel, ModelConfig, setModelKey } from "@/aiParams"; import { BUILTIN_CHAT_MODELS, ChatModelProviders } from "@/constants"; -import EncryptionService from "@/encryptionService"; +import { getDecryptedKey } from "@/encryptionService"; +import { getSettings, subscribeToSettingsChange } from "@/settings/model"; import { HarmBlockThreshold, HarmCategory } from "@google/generative-ai"; import { ChatCohere } from "@langchain/cohere"; import { BaseChatModel } from "@langchain/core/language_models/chat_models"; @@ -30,10 +31,8 @@ const CHAT_PROVIDER_CONSTRUCTORS = { type ChatProviderConstructMap = typeof CHAT_PROVIDER_CONSTRUCTORS; export default class ChatModelManager { - private encryptionService: EncryptionService; private static instance: ChatModelManager; private static chatModel: BaseChatModel; - private static chatOpenAI: ChatOpenAI; private static modelMap: Record< string, { @@ -44,48 +43,35 @@ export default class ChatModelManager { >; private readonly providerApiKeyMap: Record string> = { - [ChatModelProviders.OPENAI]: () => this.getLangChainParams().openAIApiKey, - [ChatModelProviders.GOOGLE]: () => this.getLangChainParams().googleApiKey, - [ChatModelProviders.AZURE_OPENAI]: () => this.getLangChainParams().azureOpenAIApiKey, - [ChatModelProviders.ANTHROPIC]: () => this.getLangChainParams().anthropicApiKey, - [ChatModelProviders.COHEREAI]: () => this.getLangChainParams().cohereApiKey, - [ChatModelProviders.OPENROUTERAI]: () => this.getLangChainParams().openRouterAiApiKey, - [ChatModelProviders.GROQ]: () => this.getLangChainParams().groqApiKey, + [ChatModelProviders.OPENAI]: () => getSettings().openAIApiKey, + [ChatModelProviders.GOOGLE]: () => getSettings().googleApiKey, + [ChatModelProviders.AZURE_OPENAI]: () => getSettings().azureOpenAIApiKey, + [ChatModelProviders.ANTHROPIC]: () => getSettings().anthropicApiKey, + [ChatModelProviders.COHEREAI]: () => getSettings().cohereApiKey, + [ChatModelProviders.OPENROUTERAI]: () => getSettings().openRouterAiApiKey, + [ChatModelProviders.GROQ]: () => getSettings().groqApiKey, [ChatModelProviders.OLLAMA]: () => "default-key", [ChatModelProviders.LM_STUDIO]: () => "default-key", [ChatModelProviders.OPENAI_FORMAT]: () => "default-key", } as const; - private constructor( - private getLangChainParams: () => LangChainParams, - encryptionService: EncryptionService, - activeModels: CustomModel[] - ) { - this.encryptionService = encryptionService; - this.buildModelMap(activeModels); + private constructor() { + this.buildModelMap(); + subscribeToSettingsChange(() => this.buildModelMap()); } - static getInstance( - getLangChainParams: () => LangChainParams, - encryptionService: EncryptionService, - activeModels: CustomModel[] - ): ChatModelManager { + static getInstance(): ChatModelManager { if (!ChatModelManager.instance) { - ChatModelManager.instance = new ChatModelManager( - getLangChainParams, - encryptionService, - activeModels - ); + ChatModelManager.instance = new ChatModelManager(); } return ChatModelManager.instance; } private getModelConfig(customModel: CustomModel): ModelConfig { - const decrypt = (key: string) => this.encryptionService.getDecryptedKey(key); - const params = this.getLangChainParams(); + const settings = getSettings(); const baseConfig: ModelConfig = { modelName: customModel.name, - temperature: params.temperature, + temperature: settings.temperature, streaming: true, maxRetries: 3, maxConcurrency: 3, @@ -93,23 +79,21 @@ export default class ChatModelManager { }; const providerConfig: { - [K in keyof ChatProviderConstructMap]: ConstructorParameters< - ChatProviderConstructMap[K] - >[0] /*& Record;*/; + [K in keyof ChatProviderConstructMap]: ConstructorParameters[0]; } = { [ChatModelProviders.OPENAI]: { modelName: customModel.name, - openAIApiKey: decrypt(customModel.apiKey || params.openAIApiKey), + openAIApiKey: getDecryptedKey(customModel.apiKey || settings.openAIApiKey), // @ts-ignore - openAIOrgId: decrypt(params.openAIOrgId), - maxTokens: params.maxTokens, + openAIOrgId: getDecryptedKey(settings.openAIOrgId), + maxTokens: settings.maxTokens, configuration: { baseURL: customModel.baseUrl, fetch: customModel.enableCors ? safeFetch : undefined, }, }, [ChatModelProviders.ANTHROPIC]: { - anthropicApiKey: decrypt(customModel.apiKey || params.anthropicApiKey), + anthropicApiKey: getDecryptedKey(customModel.apiKey || settings.anthropicApiKey), modelName: customModel.name, anthropicApiUrl: customModel.baseUrl, clientOptions: { @@ -119,22 +103,22 @@ export default class ChatModelManager { }, }, [ChatModelProviders.AZURE_OPENAI]: { - maxTokens: params.maxTokens, - azureOpenAIApiKey: decrypt(customModel.apiKey || params.azureOpenAIApiKey), - azureOpenAIApiInstanceName: params.azureOpenAIApiInstanceName, - azureOpenAIApiDeploymentName: params.azureOpenAIApiDeploymentName, - azureOpenAIApiVersion: params.azureOpenAIApiVersion, + maxTokens: settings.maxTokens, + azureOpenAIApiKey: getDecryptedKey(customModel.apiKey || settings.azureOpenAIApiKey), + azureOpenAIApiInstanceName: settings.azureOpenAIApiInstanceName, + azureOpenAIApiDeploymentName: settings.azureOpenAIApiDeploymentName, + azureOpenAIApiVersion: settings.azureOpenAIApiVersion, configuration: { baseURL: customModel.baseUrl, fetch: customModel.enableCors ? safeFetch : undefined, }, }, [ChatModelProviders.COHEREAI]: { - apiKey: decrypt(customModel.apiKey || params.cohereApiKey), + apiKey: getDecryptedKey(customModel.apiKey || settings.cohereApiKey), model: customModel.name, }, [ChatModelProviders.GOOGLE]: { - apiKey: decrypt(customModel.apiKey || params.googleApiKey), + apiKey: getDecryptedKey(customModel.apiKey || settings.googleApiKey), model: customModel.name, safetySettings: [ { @@ -158,14 +142,14 @@ export default class ChatModelManager { }, [ChatModelProviders.OPENROUTERAI]: { modelName: customModel.name, - openAIApiKey: decrypt(customModel.apiKey || params.openRouterAiApiKey), + openAIApiKey: getDecryptedKey(customModel.apiKey || settings.openRouterAiApiKey), configuration: { baseURL: customModel.baseUrl || "https://openrouter.ai/api/v1", fetch: customModel.enableCors ? safeFetch : undefined, }, }, [ChatModelProviders.GROQ]: { - apiKey: decrypt(customModel.apiKey || params.groqApiKey), + apiKey: getDecryptedKey(customModel.apiKey || settings.groqApiKey), modelName: customModel.name, }, [ChatModelProviders.OLLAMA]: { @@ -186,8 +170,8 @@ export default class ChatModelManager { }, [ChatModelProviders.OPENAI_FORMAT]: { modelName: customModel.name, - openAIApiKey: decrypt(customModel.apiKey || "default-key"), - maxTokens: params.maxTokens, + openAIApiKey: getDecryptedKey(customModel.apiKey || settings.openAIApiKey), + maxTokens: settings.maxTokens, configuration: { baseURL: customModel.baseUrl, fetch: customModel.enableCors ? safeFetch : undefined, @@ -203,7 +187,8 @@ export default class ChatModelManager { } // Build a map of modelKey to model config - public buildModelMap(activeModels: CustomModel[]) { + public buildModelMap() { + const activeModels = getSettings().activeModels; ChatModelManager.modelMap = {}; const modelMap = ChatModelManager.modelMap; @@ -261,9 +246,7 @@ export default class ChatModelManager { const modelConfig = this.getModelConfig(model); - // MUST update it since chatModelManager is a singleton. - this.getLangChainParams().modelKey = modelKey; - new Notice(`Setting model: ${modelConfig.modelName}`); + setModelKey(`${model.name}|${model.provider}`); try { const newModelInstance = new selectedModel.AIConstructor({ ...modelConfig, diff --git a/src/LLMProviders/embeddingManager.ts b/src/LLMProviders/embeddingManager.ts index ba50af6f..8f43a03b 100644 --- a/src/LLMProviders/embeddingManager.ts +++ b/src/LLMProviders/embeddingManager.ts @@ -1,9 +1,10 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ -import { CustomModel, LangChainParams } from "@/aiParams"; +import { CustomModel } from "@/aiParams"; import { EmbeddingModelProviders } from "@/constants"; -import EncryptionService from "@/encryptionService"; +import { getDecryptedKey } from "@/encryptionService"; import { CustomError } from "@/error"; import { safeFetch } from "@/utils"; +import { getSettings, subscribeToSettingsChange } from "@/settings/model"; import { CohereEmbeddings } from "@langchain/cohere"; import { Embeddings } from "@langchain/core/embeddings"; import { GoogleGenerativeAIEmbeddings } from "@langchain/google-genai"; @@ -24,7 +25,6 @@ const EMBEDDING_PROVIDER_CONSTRUCTORS = { type EmbeddingProviderConstructorMap = typeof EMBEDDING_PROVIDER_CONSTRUCTORS; export default class EmbeddingManager { - private encryptionService: EncryptionService; private activeEmbeddingModels: CustomModel[]; private static instance: EmbeddingManager; private static embeddingModel: Embeddings; @@ -38,35 +38,28 @@ export default class EmbeddingManager { >; private readonly providerApiKeyMap: Record string> = { - [EmbeddingModelProviders.OPENAI]: () => this.getLangChainParams().openAIApiKey, - [EmbeddingModelProviders.COHEREAI]: () => this.getLangChainParams().cohereApiKey, - [EmbeddingModelProviders.GOOGLE]: () => this.getLangChainParams().googleApiKey, - [EmbeddingModelProviders.AZURE_OPENAI]: () => this.getLangChainParams().azureOpenAIApiKey, + [EmbeddingModelProviders.OPENAI]: () => getSettings().openAIApiKey, + [EmbeddingModelProviders.COHEREAI]: () => getSettings().cohereApiKey, + [EmbeddingModelProviders.GOOGLE]: () => getSettings().googleApiKey, + [EmbeddingModelProviders.AZURE_OPENAI]: () => getSettings().azureOpenAIApiKey, [EmbeddingModelProviders.OLLAMA]: () => "default-key", [EmbeddingModelProviders.OPENAI_FORMAT]: () => "", }; - private constructor( - private getLangChainParams: () => LangChainParams, - encryptionService: EncryptionService, - activeEmbeddingModels: CustomModel[] - ) { - this.encryptionService = encryptionService; + private constructor() { + this.initialize(); + subscribeToSettingsChange(() => this.initialize()); + } + + private initialize() { + const activeEmbeddingModels = getSettings().activeEmbeddingModels; this.activeEmbeddingModels = activeEmbeddingModels; this.buildModelMap(activeEmbeddingModels); } - static getInstance( - getLangChainParams: () => LangChainParams, - encryptionService: EncryptionService, - activeEmbeddingModels: CustomModel[] - ): EmbeddingManager { + static getInstance(): EmbeddingManager { if (!EmbeddingManager.instance) { - EmbeddingManager.instance = new EmbeddingManager( - getLangChainParams, - encryptionService, - activeEmbeddingModels - ); + EmbeddingManager.instance = new EmbeddingManager(); } return EmbeddingManager.instance; } @@ -131,7 +124,7 @@ export default class EmbeddingManager { } getEmbeddingsAPI(): Embeddings | undefined { - const { embeddingModelKey } = this.getLangChainParams(); + const { embeddingModelKey } = getSettings(); if (!EmbeddingManager.modelMap.hasOwnProperty(embeddingModelKey)) { throw new CustomError(`No embedding model found for: ${embeddingModelKey}`); @@ -157,9 +150,8 @@ export default class EmbeddingManager { } } - private getEmbeddingConfig(customModel: CustomModel) { - const decrypt = (key: string) => this.encryptionService.getDecryptedKey(key); - const params = this.getLangChainParams(); + private getEmbeddingConfig(customModel: CustomModel): any { + const settings = getSettings(); const modelName = customModel.name; const baseConfig = { @@ -174,7 +166,7 @@ export default class EmbeddingManager { } = { [EmbeddingModelProviders.OPENAI]: { modelName, - openAIApiKey: decrypt(customModel.apiKey || params.openAIApiKey), + openAIApiKey: getDecryptedKey(customModel.apiKey || settings.openAIApiKey), timeout: 10000, configuration: { baseURL: customModel.baseUrl, @@ -183,17 +175,17 @@ export default class EmbeddingManager { }, [EmbeddingModelProviders.COHEREAI]: { model: modelName, - apiKey: decrypt(customModel.apiKey || params.cohereApiKey), + apiKey: getDecryptedKey(customModel.apiKey || settings.cohereApiKey), }, [EmbeddingModelProviders.GOOGLE]: { modelName: modelName, - apiKey: decrypt(params.googleApiKey), + apiKey: getDecryptedKey(settings.googleApiKey), }, [EmbeddingModelProviders.AZURE_OPENAI]: { - azureOpenAIApiKey: decrypt(customModel.apiKey || params.azureOpenAIApiKey), - azureOpenAIApiInstanceName: params.azureOpenAIApiInstanceName, - azureOpenAIApiDeploymentName: params.azureOpenAIApiEmbeddingDeploymentName, - azureOpenAIApiVersion: params.azureOpenAIApiVersion, + azureOpenAIApiKey: getDecryptedKey(customModel.apiKey || settings.azureOpenAIApiKey), + azureOpenAIApiInstanceName: settings.azureOpenAIApiInstanceName, + azureOpenAIApiDeploymentName: settings.azureOpenAIApiEmbeddingDeploymentName, + azureOpenAIApiVersion: settings.azureOpenAIApiVersion, configuration: { baseURL: customModel.baseUrl, fetch: customModel.enableCors ? safeFetch : undefined, @@ -206,7 +198,7 @@ export default class EmbeddingManager { }, [EmbeddingModelProviders.OPENAI_FORMAT]: { modelName, - openAIApiKey: decrypt(customModel.apiKey || ""), + openAIApiKey: getDecryptedKey(customModel.apiKey || ""), configuration: { baseURL: customModel.baseUrl, fetch: customModel.enableCors ? safeFetch : undefined, diff --git a/src/LLMProviders/memoryManager.ts b/src/LLMProviders/memoryManager.ts index a8c3ee54..276239dd 100644 --- a/src/LLMProviders/memoryManager.ts +++ b/src/LLMProviders/memoryManager.ts @@ -1,4 +1,4 @@ -import { LangChainParams } from "@/aiParams"; +import { getSettings, subscribeToSettingsChange } from "@/settings/model"; import { BaseChatMemory, BufferWindowMemory } from "langchain/memory"; export default class MemoryManager { @@ -6,30 +6,29 @@ export default class MemoryManager { private memory: BaseChatMemory; private debug: boolean; - private constructor( - private langChainParams: LangChainParams, - debug = false - ) { - this.debug = debug; + private constructor() { this.initMemory(); + subscribeToSettingsChange(() => this.initMemory()); } - static getInstance(langChainParams: LangChainParams, debug = false): MemoryManager { + static getInstance(): MemoryManager { if (!MemoryManager.instance) { - MemoryManager.instance = new MemoryManager(langChainParams, debug); + MemoryManager.instance = new MemoryManager(); } return MemoryManager.instance; } private initMemory(): void { + const chatContextTurns = getSettings().contextTurns; this.memory = new BufferWindowMemory({ - k: this.langChainParams.chatContextTurns * 2, + k: chatContextTurns * 2, memoryKey: "history", inputKey: "input", returnMessages: true, }); - if (this.debug) - console.log("Memory initialized with context turns:", this.langChainParams.chatContextTurns); + if (this.debug) { + console.log("Memory initialized with context turns:", chatContextTurns); + } } getMemory(): BaseChatMemory { diff --git a/src/LLMProviders/promptManager.ts b/src/LLMProviders/promptManager.ts index 2c90063d..74daf21e 100644 --- a/src/LLMProviders/promptManager.ts +++ b/src/LLMProviders/promptManager.ts @@ -1,4 +1,4 @@ -import { LangChainParams } from "@/aiParams"; +import { getSettings, subscribeToSettingsChange } from "@/settings/model"; import { ChatPromptTemplate, HumanMessagePromptTemplate, @@ -11,21 +11,26 @@ export default class PromptManager { private chatPrompt: ChatPromptTemplate; private qaPrompt: ChatPromptTemplate; - private constructor(private langChainParams: LangChainParams) { + private constructor() { this.initChatPrompt(); this.initQAPrompt(); + + subscribeToSettingsChange(() => { + this.initChatPrompt(); + this.initQAPrompt(); + }); } - static getInstance(langChainParams: LangChainParams): PromptManager { + static getInstance(): PromptManager { if (!PromptManager.instance) { - PromptManager.instance = new PromptManager(langChainParams); + PromptManager.instance = new PromptManager(); } return PromptManager.instance; } private initChatPrompt(): void { // Escape curly braces in the system message - const escapedSystemMessage = this.escapeTemplateString(this.langChainParams.systemMessage); + const escapedSystemMessage = this.escapeTemplateString(getSettings().userSystemPrompt); this.chatPrompt = ChatPromptTemplate.fromMessages([ SystemMessagePromptTemplate.fromTemplate(escapedSystemMessage), diff --git a/src/VectorStoreManager.ts b/src/VectorStoreManager.ts index 48dd012f..bcf3c25b 100644 --- a/src/VectorStoreManager.ts +++ b/src/VectorStoreManager.ts @@ -1,25 +1,21 @@ -import EncryptionService from "@/encryptionService"; import { CustomError } from "@/error"; import EmbeddingsManager from "@/LLMProviders/embeddingManager"; -import { CopilotSettings } from "@/settings/SettingsPage"; +import { getSettings } from "@/settings/model"; import { areEmbeddingModelsSame, getFilePathsFromPatterns } from "@/utils"; import VectorDBManager from "@/vectorDBManager"; import { Embeddings } from "@langchain/core/embeddings"; import { create, load, Orama, remove, removeMultiple, save, search } from "@orama/orama"; import { MD5 } from "crypto-js"; import { App, Notice, Platform, TAbstractFile, TFile, Vault } from "obsidian"; -import { LangChainParams } from "./aiParams"; import { ChainType } from "./chainFactory"; import { VAULT_VECTOR_STORE_STRATEGY } from "./constants"; +import { getChainType } from "./aiParams"; class VectorStoreManager { private app: App; - private settings: CopilotSettings; - private encryptionService: EncryptionService; private oramaDb: Orama | undefined; private dbPath: string; private embeddingsManager: EmbeddingsManager; - private getLangChainParams: () => LangChainParams; private isIndexingPaused = false; private isIndexingCancelled = false; @@ -37,23 +33,11 @@ class VectorStoreManager { private saveDBDelay = 30000; // Save full DB every 30 seconds private hasUnsavedChanges = false; - constructor( - app: App, - settings: CopilotSettings, - encryptionService: EncryptionService, - getLangChainParams: () => LangChainParams - ) { + constructor(app: App) { this.app = app; - this.settings = settings; - this.encryptionService = encryptionService; - this.getLangChainParams = getLangChainParams; this.dbPath = this.getDbPath(); - this.embeddingsManager = EmbeddingsManager.getInstance( - this.getLangChainParams, - this.encryptionService, - this.settings.activeEmbeddingModels - ); + this.embeddingsManager = EmbeddingsManager.getInstance(); // Initialize the database asynchronously this.initializationPromise = this.initializeDB() @@ -72,12 +56,6 @@ class VectorStoreManager { console.error("Failed to initialize Copilot database:", error); }); - // Initialize the rate limiter - VectorDBManager.initialize({ - getEmbeddingRequestsPerSecond: () => this.settings.embeddingRequestsPerSecond, - debug: this.settings.debug, - }); - this.updateExcludedFiles(); // Initialize periodic save @@ -101,7 +79,7 @@ class VectorStoreManager { private async performPostInitializationTasks() { // Optionally index the vault on startup - if (this.settings.indexVaultToVectorStore === VAULT_VECTOR_STORE_STRATEGY.ON_STARTUP) { + if (getSettings().indexVaultToVectorStore === VAULT_VECTOR_STORE_STRATEGY.ON_STARTUP) { try { await this.indexVaultToVectorStore(); } catch (err) { @@ -133,7 +111,7 @@ class VectorStoreManager { private async initializeDB(): Promise | undefined> { // Check if we should skip index loading on mobile - if (Platform.isMobile && this.settings.disableIndexOnMobile) { + if (Platform.isMobile && getSettings().disableIndexOnMobile) { console.log("Index loading disabled on mobile device"); this.isIndexLoaded = false; this.oramaDb = undefined; @@ -170,7 +148,7 @@ class VectorStoreManager { } } catch (error) { console.error(`Error initializing Orama database:`, error); - if (Platform.isMobile && this.settings.disableIndexOnMobile) { + if (Platform.isMobile && getSettings().disableIndexOnMobile) { return; } return await this.createNewDb(); @@ -178,7 +156,7 @@ class VectorStoreManager { } public async getIsIndexLoaded(): Promise { - await this.waitForInitialization(); + await this.initializationPromise; return this.isIndexLoaded; } @@ -218,10 +196,6 @@ class VectorStoreManager { return this.app.vault; } - public getSettings(): CopilotSettings { - return this.settings; - } - private async getVectorLength(embeddingInstance: Embeddings): Promise { try { const sampleText = "Sample text for embedding"; @@ -256,7 +230,7 @@ class VectorStoreManager { } private async saveDB() { - if (Platform.isMobile && this.settings.disableIndexOnMobile) { + if (Platform.isMobile && getSettings().disableIndexOnMobile) { return; } @@ -274,7 +248,7 @@ class VectorStoreManager { const saveOperation = async () => { try { await this.app.vault.adapter.write(this.dbPath, JSON.stringify(dataToSave)); - if (this.settings.debug) { + if (getSettings().debug) { console.log(`Saved Orama database to ${this.dbPath}.`); } } catch (error) { @@ -320,9 +294,9 @@ class VectorStoreManager { const status = this.isIndexingPaused ? " (Paused)" : ""; const folders = this.extractAppIgnoreSettings(); - const filterType = this.settings.qaInclusions - ? `Inclusions: ${this.settings.qaInclusions}` - : `Exclusions: ${folders.join(",") + (folders.length ? ", " : "") + this.settings.qaExclusions || "None"}`; + const filterType = getSettings().qaInclusions + ? `Inclusions: ${getSettings().qaInclusions}` + : `Exclusions: ${folders.join(",") + (folders.length ? ", " : "") + getSettings().qaExclusions || "None"}`; this.indexNoticeMessage.textContent = `Copilot is indexing your vault...\n` + @@ -366,14 +340,20 @@ class VectorStoreManager { exclusions.push(...this.extractAppIgnoreSettings()); - if (this.settings.qaExclusions) { - exclusions.push(...this.settings.qaExclusions.split(",").map((item) => item.trim())); + if (getSettings().qaExclusions) { + exclusions.push( + ...getSettings() + .qaExclusions.split(",") + .map((item) => item.trim()) + ); } const excludedFilePaths = await getFilePathsFromPatterns(exclusions, this.app.vault); excludedFilePaths.forEach((filePath) => targetFiles.add(filePath)); - } else if (filterType === "inclusions" && this.settings.qaInclusions) { - const inclusions = this.settings.qaInclusions.split(",").map((item) => item.trim()); + } else if (filterType === "inclusions" && getSettings().qaInclusions) { + const inclusions = getSettings() + .qaInclusions.split(",") + .map((item) => item.trim()); const includedFilePaths = await getFilePathsFromPatterns(inclusions, this.app.vault); includedFilePaths.forEach((filePath) => targetFiles.add(filePath)); } @@ -467,12 +447,12 @@ class VectorStoreManager { public async indexVaultToVectorStore(overwrite?: boolean): Promise { // Add check at the start of the method - if ((Platform.isMobile && this.settings.disableIndexOnMobile) || !this.oramaDb) { + await this.waitForInitialization(); + if ((Platform.isMobile && getSettings().disableIndexOnMobile) || !this.oramaDb) { new Notice("Indexing is disabled on mobile devices"); return 0; } - await this.waitForInitialization(); let rateLimitNoticeShown = false; try { @@ -712,7 +692,7 @@ class VectorStoreManager { searchResult.hits.map((hit) => hit.id), 500 ); - if (this.settings.debug) { + if (getSettings().debug) { console.log(`Deleted document from local Copilot index: ${filePath}`); } } @@ -739,8 +719,8 @@ class VectorStoreManager { return result.hits[0]?.document; } - public async initializeEventListeners() { - if (this.settings.debug) { + public initializeEventListeners() { + if (getSettings().debug) { console.log("Copilot Plus: Initializing event listeners"); } this.app.vault.on("modify", this.handleFileModify); @@ -752,7 +732,7 @@ class VectorStoreManager { window.clearTimeout(this.debounceTimer); } this.debounceTimer = window.setTimeout(() => { - if (this.settings.debug) { + if (getSettings().debug) { console.log("Copilot Plus: Triggering reindex for file ", file.path); } this.reindexFile(file); @@ -762,7 +742,7 @@ class VectorStoreManager { private handleFileModify = async (file: TAbstractFile) => { await this.updateExcludedFiles(); - const currentChainType = this.getLangChainParams().chainType; + const currentChainType = getChainType(); if ( file instanceof TFile && file.extension === "md" && @@ -822,7 +802,7 @@ class VectorStoreManager { // Mark that we have unsaved changes instead of saving immediately this.hasUnsavedChanges = true; - if (this.settings.debug) { + if (getSettings().debug) { console.log(`Reindexed file: ${file.path}`); } } catch (error) { diff --git a/src/aiParams.ts b/src/aiParams.ts index b258da69..efdb6a2b 100644 --- a/src/aiParams.ts +++ b/src/aiParams.ts @@ -2,6 +2,37 @@ import { ChainType } from "@/chainFactory"; import { BaseChatModel } from "@langchain/core/language_models/chat_models"; import { ChatPromptTemplate } from "@langchain/core/prompts"; +import { atom, getDefaultStore, useAtom } from "jotai"; +import { settingsAtom } from "@/settings/model"; + +const userModelKeyAtom = atom(null); +const modelKeyAtom = atom( + (get) => { + const userValue = get(userModelKeyAtom); + if (userValue !== null) { + return userValue; + } + return get(settingsAtom).defaultModelKey; + }, + (get, set, newValue) => { + set(userModelKeyAtom, newValue); + } +); + +const userChainTypeAtom = atom(null); +const chainTypeAtom = atom( + (get) => { + const userValue = get(userChainTypeAtom); + if (userValue !== null) { + return userValue; + } + return get(settingsAtom).defaultChainType; + }, + (get, set, newValue) => { + set(userChainTypeAtom, newValue); + } +); + export interface ModelConfig { modelName: string; temperature: number; @@ -24,42 +55,11 @@ export interface ModelConfig { enableCors?: boolean; } -export interface LangChainParams { - modelKey: string; // name | provider, e.g. "gpt-4o|openai" - openAIApiKey: string; - openAIOrgId: string; - huggingfaceApiKey: string; - cohereApiKey: string; - anthropicApiKey: string; - azureOpenAIApiKey: string; - azureOpenAIApiInstanceName: string; - azureOpenAIApiDeploymentName: string; - azureOpenAIApiVersion: string; - azureOpenAIApiEmbeddingDeploymentName: string; - googleApiKey: string; - openRouterAiApiKey: string; - embeddingModelKey: string; // name | provider, e.g. "text-embedding-3-large|openai" - temperature: number; - maxTokens: number; - systemMessage: string; - chatContextTurns: number; - chainType: ChainType; // Default ChainType is set in main.ts getLangChainParams - options: SetChainOptions; - openAIProxyBaseUrl?: string; - enableCors?: boolean; - openAIProxyModelName?: string; - openAIEmbeddingProxyBaseUrl?: string; - openAIEmbeddingProxyModelName?: string; - groqApiKey: string; -} - export interface SetChainOptions { prompt?: ChatPromptTemplate; chatModel?: BaseChatModel; noteFile?: any; - forceNewCreation?: boolean; abortController?: AbortController; - debug?: boolean; refreshIndex?: boolean; } @@ -74,3 +74,35 @@ export interface CustomModel { enableCors?: boolean; core?: boolean; } + +export function setModelKey(modelKey: string) { + getDefaultStore().set(modelKeyAtom, modelKey); +} + +export function getModelKey(): string { + return getDefaultStore().get(modelKeyAtom); +} + +export function subscribeToModelKeyChange(callback: () => void): () => void { + return getDefaultStore().sub(modelKeyAtom, callback); +} + +export function useModelKey() { + return useAtom(modelKeyAtom); +} + +export function getChainType(): ChainType { + return getDefaultStore().get(chainTypeAtom); +} + +export function setChainType(chainType: ChainType) { + getDefaultStore().set(chainTypeAtom, chainType); +} + +export function subscribeToChainTypeChange(callback: () => void): () => void { + return getDefaultStore().sub(chainTypeAtom, callback); +} + +export function useChainType() { + return useAtom(chainTypeAtom); +} diff --git a/src/aiState.ts b/src/aiState.ts deleted file mode 100644 index 15cc4152..00000000 --- a/src/aiState.ts +++ /dev/null @@ -1,40 +0,0 @@ -import ChainManager from "@/LLMProviders/chainManager"; -import { SetChainOptions } from "@/aiParams"; -import { ChainType } from "@/chainFactory"; -import { BaseChatMemory } from "langchain/memory"; -import { useState } from "react"; - -/** - * React hook to manage state related to model, chain and memory in Chat component. - */ -export function useAIState( - chainManager: ChainManager -): [ - string, - (model: string) => void, - ChainType, - (chain: ChainType, options?: SetChainOptions) => void, - () => void, -] { - const langChainParams = chainManager.getLangChainParams(); - const [currentModelKey, setCurrentModelKey] = useState(langChainParams.modelKey); - const [currentChain, setCurrentChain] = useState(langChainParams.chainType); - const [, setChatMemory] = useState(chainManager.memoryManager.getMemory()); - - const clearChatMemory = () => { - chainManager.memoryManager.clearChatMemory(); - setChatMemory(chainManager.memoryManager.getMemory()); - }; - - const setModelKey = (newModelKey: string) => { - chainManager.createChainWithNewModel(newModelKey); - setCurrentModelKey(newModelKey); - }; - - const setChain = (newChain: ChainType, options?: SetChainOptions) => { - chainManager.setChain(newChain, options); - setCurrentChain(newChain); - }; - - return [currentModelKey, setModelKey, currentChain, setChain, clearChatMemory]; -} diff --git a/src/commands.ts b/src/commands.ts index 1c041320..1b33be3c 100644 --- a/src/commands.ts +++ b/src/commands.ts @@ -3,10 +3,18 @@ import { ToneModal } from "@/components/modals/ToneModal"; import CopilotPlugin from "@/main"; import { Editor, Notice } from "obsidian"; import { COMMAND_IDS } from "./constants"; +import { getSettings } from "@/settings/model"; export function registerBuiltInCommands(plugin: CopilotPlugin) { + // Remove all built in commands first + Object.values(COMMAND_IDS).forEach((id) => { + // removeCommand is not available in TypeScript for some reasons + // https://docs.obsidian.md/Reference/TypeScript+API/Plugin/removeCommand + (plugin as any).removeCommand(id); + }); + const addCommandIfEnabled = (id: string, callback: (editor: Editor) => void) => { - const commandSettings = plugin.settings.enabledCommands[id]; + const commandSettings = getSettings().enabledCommands[id]; if (commandSettings && commandSettings.enabled) { plugin.addCommand({ id, @@ -89,7 +97,7 @@ export function registerBuiltInCommands(plugin: CopilotPlugin) { }); plugin.addCommand({ - id: "count-tokens", + id: COMMAND_IDS.COUNT_TOKENS, name: "Count words and tokens in selection", editorCallback: (editor: Editor) => { plugin.processSelection(editor, "countTokensSelection"); @@ -97,7 +105,7 @@ export function registerBuiltInCommands(plugin: CopilotPlugin) { }); plugin.addCommand({ - id: "count-total-vault-tokens", + id: COMMAND_IDS.COUNT_TOTAL_VAULT_TOKENS, name: "Count total tokens in your vault", callback: async () => { const totalTokens = await plugin.countTotalTokens(); diff --git a/src/components/Chat.tsx b/src/components/Chat.tsx index 3ac3ebb5..fa599507 100644 --- a/src/components/Chat.tsx +++ b/src/components/Chat.tsx @@ -1,6 +1,6 @@ -import { useAIState } from "@/aiState"; -import { ChainType } from "@/chainFactory"; +import { useChainType, useModelKey } from "@/aiParams"; import { updateChatMemory } from "@/chatUtils"; +import { ChainType } from "@/chainFactory"; import ChatInput from "@/components/chat-components/ChatInput"; import ChatMessages from "@/components/chat-components/ChatMessages"; import { ABORT_REASON, AI_SENDER, EVENT_NAMES, LOADING_MESSAGES, USER_SENDER } from "@/constants"; @@ -11,7 +11,7 @@ import { getAIResponse } from "@/langchainStream"; import ChainManager from "@/LLMProviders/chainManager"; import CopilotPlugin from "@/main"; import { Mention } from "@/mentions/Mention"; -import { useSettingsValueContext } from "@/settings/contexts/SettingsValueContext"; +import { useSettingsValue } from "@/settings/model"; import SharedState, { ChatMessage, useSharedState } from "@/sharedState"; import { FileParserManager } from "@/tools/FileParserManager"; import { @@ -45,28 +45,25 @@ interface ChatProps { sharedState: SharedState; chainManager: ChainManager; emitter: EventTarget; - defaultSaveFolder: string; onSaveChat: (saveAsNote: () => Promise) => void; updateUserMessageHistory: (newMessage: string) => void; fileParserManager: FileParserManager; plugin: CopilotPlugin; - debug: boolean; } const Chat: React.FC = ({ sharedState, chainManager, emitter, - defaultSaveFolder, onSaveChat, updateUserMessageHistory, fileParserManager, plugin, - debug, }) => { + const settings = useSettingsValue(); const [chatHistory, addMessage, clearMessages] = useSharedState(sharedState); - const [currentModelKey, setModelKey, currentChain, setChain, clearChatMemory] = - useAIState(chainManager); + const [currentModelKey] = useModelKey(); + const [currentChain] = useChainType(); const [currentAiMessage, setCurrentAiMessage] = useState(""); const [inputMessage, setInputMessage] = useState(""); const [abortController, setAbortController] = useState(null); @@ -77,11 +74,10 @@ const Chat: React.FC = ({ const [includeActiveNote, setIncludeActiveNote] = useState(false); const [selectedImages, setSelectedImages] = useState([]); - const mention = Mention.getInstance(plugin.settings.plusLicenseKey); + const mention = Mention.getInstance(); const contextProcessor = ContextProcessor.getInstance(); const inputRef = useRef(null); - const settings = useSettingsValueContext(); useEffect(() => { const handleChatVisibility = () => { @@ -162,7 +158,7 @@ const Chat: React.FC = ({ setLoadingMessage(LOADING_MESSAGES.DEFAULT); // First, process the original user message for custom prompts - const customPromptProcessor = CustomPromptProcessor.getInstance(app.vault, settings); + const customPromptProcessor = CustomPromptProcessor.getInstance(app.vault); let processedUserMessage = await customPromptProcessor.processCustomPrompt( inputMessage || "", "", @@ -211,7 +207,7 @@ const Chat: React.FC = ({ addMessage, setCurrentAiMessage, setAbortController, - { debug, updateLoadingMessage: setLoadingMessage } + { debug: settings.debug, updateLoadingMessage: setLoadingMessage } ); setLoading(false); setLoadingMessage(LOADING_MESSAGES.DEFAULT); @@ -256,9 +252,9 @@ const Chat: React.FC = ({ try { // Check if the default folder exists or create it - const folder = app.vault.getAbstractFileByPath(defaultSaveFolder); + const folder = app.vault.getAbstractFileByPath(settings.defaultSaveFolder); if (!folder) { - await app.vault.createFolder(defaultSaveFolder); + await app.vault.createFolder(settings.defaultSaveFolder); } const { fileName: timestampFileName } = formatDateTime(new Date(firstMessageEpoch)); @@ -281,7 +277,7 @@ const Chat: React.FC = ({ /\s+/g, "_" ); - const noteFileName = `${defaultSaveFolder}/${sanitizedFileName}.md`; + const noteFileName = `${settings.defaultSaveFolder}/${sanitizedFileName}.md`; // Add the timestamp and model properties to the note content const noteContentWithTimestamp = `--- @@ -382,9 +378,9 @@ ${chatContent}`; new AbortController(), setCurrentAiMessage, addMessage, - { debug } + { debug: settings.debug } ); - if (regeneratedResponse && debug) { + if (regeneratedResponse && settings.debug) { console.log("Message regenerated successfully"); } } catch (error) { @@ -485,7 +481,7 @@ ${chatContent}`; setCurrentAiMessage, setAbortController, { - debug, + debug: settings.debug, ignoreSystemMessage, } ); @@ -535,7 +531,7 @@ ${chatContent}`; [] ); - const customPromptProcessor = CustomPromptProcessor.getInstance(app.vault, settings); + const customPromptProcessor = CustomPromptProcessor.getInstance(app.vault); useEffect( createEffect( "applyCustomPrompt", @@ -549,7 +545,7 @@ ${chatContent}`; app.workspace.getActiveFile() as TFile | undefined ); }, - { isVisible: debug, ignoreSystemMessage: true, custom_temperature: 0.1 } + { isVisible: settings.debug, ignoreSystemMessage: true, custom_temperature: 0.1 } ), [] ); @@ -567,7 +563,7 @@ ${chatContent}`; app.workspace.getActiveFile() as TFile | undefined ); }, - { isVisible: debug, ignoreSystemMessage: true, custom_temperature: 0.1 } + { isVisible: settings.debug, ignoreSystemMessage: true, custom_temperature: 0.1 } ), [] ); @@ -615,7 +611,6 @@ ${chatContent}`; return (
handleStopGenerating(ABORT_REASON.USER_STOPPED)} app={app} navigateHistory={navigateHistory} - currentModelKey={currentModelKey} - setCurrentModelKey={setModelKey} - currentChain={currentChain} - setCurrentChain={setChain} onNewChat={async (openNote: boolean) => { handleStopGenerating(ABORT_REASON.NEW_CHAT); if (settings.autosaveChat && chatHistory.length > 0) { await handleSaveAsNote(openNote); } clearMessages(); - clearChatMemory(); + chainManager.memoryManager.clearChatMemory(); clearCurrentAiMessage(); }} onSaveAsNote={() => handleSaveAsNote(true)} @@ -664,7 +655,6 @@ ${chatContent}`; onAddImage={(files: File[]) => setSelectedImages((prev) => [...prev, ...files])} setSelectedImages={setSelectedImages} chatHistory={chatHistory} - debug={debug} />
diff --git a/src/components/CopilotView.tsx b/src/components/CopilotView.tsx index 84fc5ba2..2c60f443 100644 --- a/src/components/CopilotView.tsx +++ b/src/components/CopilotView.tsx @@ -3,26 +3,20 @@ import Chat from "@/components/Chat"; import { CHAT_VIEWTYPE } from "@/constants"; import { AppContext } from "@/context"; import CopilotPlugin from "@/main"; -import { CopilotSettings } from "@/settings/SettingsPage"; import SharedState from "@/sharedState"; import { FileParserManager } from "@/tools/FileParserManager"; import * as Tooltip from "@radix-ui/react-tooltip"; import { ItemView, WorkspaceLeaf } from "obsidian"; import * as React from "react"; import { Root, createRoot } from "react-dom/client"; -import { SettingsValueProvider } from "@/settings/contexts/SettingsValueContext"; export default class CopilotView extends ItemView { private chainManager: ChainManager; private fileParserManager: FileParserManager; private root: Root | null = null; - private settings: CopilotSettings; - private defaultSaveFolder: string; private handleSaveAsNote: (() => Promise) | null = null; - private debug = false; sharedState: SharedState; emitter: EventTarget; - userSystemPrompt = ""; constructor( leaf: WorkspaceLeaf, @@ -30,15 +24,11 @@ export default class CopilotView extends ItemView { ) { super(leaf); this.sharedState = plugin.sharedState; - this.settings = plugin.settings; this.app = plugin.app; this.chainManager = plugin.chainManager; this.fileParserManager = plugin.fileParserManager; - this.debug = plugin.settings.debug; this.emitter = new EventTarget(); - this.userSystemPrompt = plugin.settings.userSystemPrompt; this.plugin = plugin; - this.defaultSaveFolder = plugin.settings.defaultSaveFolder; } getViewType(): string { @@ -65,23 +55,19 @@ export default class CopilotView extends ItemView { - - { - this.plugin.updateUserMessageHistory(newMessage); - }} - fileParserManager={this.fileParserManager} - plugin={this.plugin} - debug={this.debug} - onSaveChat={(saveFunction) => { - this.handleSaveAsNote = saveFunction; - }} - /> - + { + this.plugin.updateUserMessageHistory(newMessage); + }} + fileParserManager={this.fileParserManager} + plugin={this.plugin} + onSaveChat={(saveFunction) => { + this.handleSaveAsNote = saveFunction; + }} + /> diff --git a/src/components/chat-components/ChatControls.tsx b/src/components/chat-components/ChatControls.tsx index 5f98a416..95a39f55 100644 --- a/src/components/chat-components/ChatControls.tsx +++ b/src/components/chat-components/ChatControls.tsx @@ -1,13 +1,10 @@ -import { SetChainOptions } from "@/aiParams"; -import { VAULT_VECTOR_STORE_STRATEGY } from "@/constants"; -import { CustomError } from "@/error"; -import { App, Notice } from "obsidian"; +import { useChainType } from "@/aiParams"; +import { App } from "obsidian"; import React, { useEffect, useState } from "react"; import { ChainType } from "@/chainFactory"; import { TooltipActionButton } from "@/components/chat-components/TooltipActionButton"; import { AddContextNoteModal } from "@/components/modals/AddContextNoteModal"; -import { useSettingsValueContext } from "@/settings/contexts/SettingsValueContext"; import { stringToChainType } from "@/utils"; import * as DropdownMenu from "@radix-ui/react-dropdown-menu"; import { ChevronDown, Download, MessageCirclePlus, Puzzle } from "lucide-react"; @@ -16,10 +13,9 @@ import { NewChatConfirmModal } from "@/components/modals/NewChatConfirmModal"; import { ChatMessage } from "@/sharedState"; import { TFile } from "obsidian"; import { ChatContextMenu } from "./ChatContextMenu"; +import { useSettingsValue } from "@/settings/model"; interface ChatControlsProps { - currentChain: ChainType; - setCurrentChain: (chain: ChainType, options?: SetChainOptions) => void; onNewChat: (openNote: boolean) => void; onSaveAsNote: () => void; onRefreshVaultContext: () => void; @@ -32,12 +28,9 @@ interface ChatControlsProps { contextUrls: string[]; onRemoveUrl: (url: string) => void; chatHistory: ChatMessage[]; - debug?: boolean; } const ChatControls: React.FC = ({ - currentChain, - setCurrentChain, onNewChat, onSaveAsNote, onRefreshVaultContext, @@ -50,9 +43,8 @@ const ChatControls: React.FC = ({ contextUrls, onRemoveUrl, chatHistory, - debug, }) => { - const [selectedChain, setSelectedChain] = useState(currentChain); + const [selectedChain, setSelectedChain] = useChainType(); const [isIndexLoaded, setIsIndexLoaded] = useState(false); const activeNote = app.workspace.getActiveFile(); @@ -61,8 +53,7 @@ const ChatControls: React.FC = ({ setIsIndexLoaded(loaded); }); }, [isIndexLoadedPromise]); - const settings = useSettingsValueContext(); - const indexVaultToVectorStore = settings.indexVaultToVectorStore; + const settings = useSettingsValue(); const handleChainChange = async ({ value }: { value: string }) => { const newChain = stringToChainType(value); @@ -76,30 +67,7 @@ const ChatControls: React.FC = ({ return; } - try { - if ( - (selectedChain === ChainType.VAULT_QA_CHAIN || - selectedChain === ChainType.COPILOT_PLUS_CHAIN) && - indexVaultToVectorStore === VAULT_VECTOR_STORE_STRATEGY.ON_MODE_SWITCH - ) { - await setCurrentChain(selectedChain, { - debug, - refreshIndex: true, - }); - } else { - await setCurrentChain(selectedChain, { debug }); - } - } catch (error) { - if (error instanceof CustomError) { - console.error("Error setting chain:", error.msg); - new Notice(`Error: ${error.msg}. Please check your embedding model settings.`); - } else { - console.error("Unexpected error setting chain:", error); - new Notice( - "An unexpected error occurred while setting up the chain. Please check the console for details." - ); - } - } + setSelectedChain(selectedChain); }; handleChainSelection(); @@ -150,7 +118,7 @@ const ChatControls: React.FC = ({ return (
- {currentChain === ChainType.COPILOT_PLUS_CHAIN && ( + {selectedChain === ChainType.COPILOT_PLUS_CHAIN && ( = ({
- {currentChain === "llm_chain" && "chat"} - {currentChain === "vault_qa" && "vault QA (basic)"} - {currentChain === "copilot_plus" && "copilot plus (alpha)"} + {selectedChain === ChainType.LLM_CHAIN && "chat"} + {selectedChain === ChainType.VAULT_QA_CHAIN && "vault QA (basic)"} + {selectedChain === ChainType.COPILOT_PLUS_CHAIN && "copilot plus (alpha)"} diff --git a/src/components/chat-components/ChatInput.tsx b/src/components/chat-components/ChatInput.tsx index a4ebf28e..01d6079e 100644 --- a/src/components/chat-components/ChatInput.tsx +++ b/src/components/chat-components/ChatInput.tsx @@ -1,5 +1,6 @@ -import { CustomModel, SetChainOptions } from "@/aiParams"; +import { CustomModel, useChainType, useModelKey } from "@/aiParams"; import { ChainType } from "@/chainFactory"; +import { useSettingsValue } from "@/settings/model"; import { AddImageModal } from "@/components/modals/AddImageModal"; import { ListPromptModal } from "@/components/modals/ListPromptModal"; import { NoteTitleModal } from "@/components/modals/NoteTitleModal"; @@ -7,7 +8,6 @@ import { ContextProcessor } from "@/contextProcessor"; import { CustomPromptProcessor } from "@/customPromptProcessor"; import { COPILOT_TOOL_NAMES } from "@/LLMProviders/intentAnalyzer"; import { Mention } from "@/mentions/Mention"; -import { useSettingsValueContext } from "@/settings/contexts/SettingsValueContext"; import { ChatMessage } from "@/sharedState"; import { getToolDescription } from "@/tools/toolManager"; import { extractNoteTitles } from "@/utils"; @@ -26,10 +26,6 @@ interface ChatInputProps { onStopGenerating: () => void; app: App; navigateHistory: (direction: "up" | "down") => string; - currentModelKey: string; - setCurrentModelKey: (modelKey: string) => void; - currentChain: ChainType; - setCurrentChain: (chain: ChainType, options?: SetChainOptions) => void; onNewChat: (openNote: boolean) => void; onSaveAsNote: () => void; onRefreshVaultContext: () => void; @@ -43,7 +39,6 @@ interface ChatInputProps { onAddImage: (files: File[]) => void; setSelectedImages: React.Dispatch>; chatHistory: ChatMessage[]; - debug?: boolean; } const getModelKey = (model: CustomModel) => `${model.name}|${model.provider}`; @@ -58,10 +53,6 @@ const ChatInput = forwardRef<{ focus: () => void }, ChatInputProps>( onStopGenerating, app, navigateHistory, - currentModelKey, - setCurrentModelKey, - currentChain, - setCurrentChain, onNewChat, onSaveAsNote, onRefreshVaultContext, @@ -75,7 +66,6 @@ const ChatInput = forwardRef<{ focus: () => void }, ChatInputProps>( onAddImage, setSelectedImages, chatHistory, - debug, }, ref ) => { @@ -85,7 +75,9 @@ const ChatInput = forwardRef<{ focus: () => void }, ChatInputProps>( const [contextUrls, setContextUrls] = useState([]); const textAreaRef = useRef(null); const containerRef = useRef(null); - const settings = useSettingsValueContext(); + const [currentModelKey, setCurrentModelKey] = useModelKey(); + const [currentChain] = useChainType(); + const settings = useSettingsValue(); useImperativeHandle(ref, () => ({ focus: () => { @@ -222,14 +214,14 @@ const ChatInput = forwardRef<{ focus: () => void }, ChatInputProps>( }; const showCustomPromptModal = async () => { - const customPromptProcessor = CustomPromptProcessor.getInstance(app.vault, settings); + const customPromptProcessor = CustomPromptProcessor.getInstance(app.vault); const prompts = await customPromptProcessor.getAllPrompts(); const promptTitles = prompts.map((prompt) => prompt.title); new ListPromptModal(app, promptTitles, async (promptTitle: string) => { const selectedPrompt = prompts.find((prompt) => prompt.title === promptTitle); if (selectedPrompt) { - await customPromptProcessor.recordPromptUsage(selectedPrompt.title); + customPromptProcessor.recordPromptUsage(selectedPrompt.title); setInputMessage(selectedPrompt.content); } }).open(); @@ -382,8 +374,6 @@ const ChatInput = forwardRef<{ focus: () => void }, ChatInputProps>( return (
void }, ChatInputProps>( contextUrls={contextUrls} onRemoveUrl={(url: string) => setContextUrls((prev) => prev.filter((u) => u !== url))} chatHistory={chatHistory} - debug={debug} /> {selectedImages.length > 0 && ( diff --git a/src/components/chat-components/ChatMessages.tsx b/src/components/chat-components/ChatMessages.tsx index bdd106c7..6dfad2c9 100644 --- a/src/components/chat-components/ChatMessages.tsx +++ b/src/components/chat-components/ChatMessages.tsx @@ -1,4 +1,3 @@ -import { ChainType } from "@/chainFactory"; import ChatSingleMessage from "@/components/chat-components/ChatSingleMessage"; import { SuggestedPrompts } from "@/components/chat-components/SuggestedPrompts"; import { ChatMessage } from "@/sharedState"; @@ -11,7 +10,6 @@ interface ChatMessagesProps { loading?: boolean; loadingMessage?: string; app: App; - currentChain: ChainType; onInsertAtCursor: (message: string) => void; onRegenerate: (messageIndex: number) => void; onEdit: (messageIndex: number, newMessage: string) => void; @@ -23,7 +21,6 @@ const ChatMessages: React.FC = ({ chatHistory, currentAiMessage, loading, - currentChain, loadingMessage, app, onInsertAtCursor, @@ -62,7 +59,7 @@ const ChatMessages: React.FC = ({ if (!chatHistory.filter((message) => message.isVisible).length && !currentAiMessage) { return (
- +
); } diff --git a/src/components/chat-components/SuggestedPrompts.tsx b/src/components/chat-components/SuggestedPrompts.tsx index f8662a4f..375b793d 100644 --- a/src/components/chat-components/SuggestedPrompts.tsx +++ b/src/components/chat-components/SuggestedPrompts.tsx @@ -1,6 +1,7 @@ +import { useChainType } from "@/aiParams"; import { ChainType } from "@/chainFactory"; import { VAULT_VECTOR_STORE_STRATEGY } from "@/constants"; -import { useSettingsValueContext } from "@/settings/contexts/SettingsValueContext"; +import { useSettingsValue } from "@/settings/model"; import React, { useMemo } from "react"; interface NotePrompt { @@ -80,13 +81,13 @@ function getRandomPrompt(chainType: ChainType = ChainType.LLM_CHAIN) { } interface SuggestedPromptsProps { - chainType: ChainType; onClick: (text: string) => void; } -export const SuggestedPrompts: React.FC = ({ chainType, onClick }) => { +export const SuggestedPrompts: React.FC = ({ onClick }) => { + const [chainType] = useChainType(); const prompts = useMemo(() => getRandomPrompt(chainType), [chainType]); - const settings = useSettingsValueContext(); + const settings = useSettingsValue(); const indexVaultToVectorStore = settings.indexVaultToVectorStore as VAULT_VECTOR_STORE_STRATEGY; const showSuggestedPrompts = settings.showSuggestedPrompts; diff --git a/src/constants.ts b/src/constants.ts index 43489adc..1024086f 100644 --- a/src/constants.ts +++ b/src/constants.ts @@ -1,5 +1,5 @@ import { CustomModel } from "@/aiParams"; -import { CopilotSettings } from "@/settings/SettingsPage"; +import { type CopilotSettings } from "@/settings/model"; import { ChainType } from "./chainFactory"; export const BREVILABS_API_BASE_URL = "https://api.brevilabs.com/v1"; diff --git a/src/customPromptProcessor.ts b/src/customPromptProcessor.ts index 23cc6c49..e859fe93 100644 --- a/src/customPromptProcessor.ts +++ b/src/customPromptProcessor.ts @@ -1,6 +1,6 @@ import { CustomError } from "@/error"; -import { PromptUsageStrategy } from "@/promptUsageStrategy"; -import { CopilotSettings } from "@/settings/SettingsPage"; +import { TimestampUsageStrategy } from "@/promptUsageStrategy"; +import { getSettings } from "@/settings/model"; import { extractNoteTitles, getFileContent, @@ -19,33 +19,29 @@ export interface CustomPrompt { export class CustomPromptProcessor { private static instance: CustomPromptProcessor; + private usageStrategy: TimestampUsageStrategy; - private constructor( - private vault: Vault, - private settings: CopilotSettings, - private usageStrategy?: PromptUsageStrategy - ) {} - - static getInstance( - vault: Vault, - settings: CopilotSettings, - usageStrategy?: PromptUsageStrategy - ): CustomPromptProcessor { + private constructor(private vault: Vault) { + this.usageStrategy = new TimestampUsageStrategy(); + } + + get customPromptsFolder(): string { + return getSettings().customPromptsFolder; + } + + static getInstance(vault: Vault): CustomPromptProcessor { if (!CustomPromptProcessor.instance) { - if (!usageStrategy) { - console.warn("PromptUsageStrategy not initialize"); - } - CustomPromptProcessor.instance = new CustomPromptProcessor(vault, settings, usageStrategy); + CustomPromptProcessor.instance = new CustomPromptProcessor(vault); } return CustomPromptProcessor.instance; } - async recordPromptUsage(title: string) { - return this.usageStrategy?.recordUsage(title).save(); + recordPromptUsage(title: string) { + this.usageStrategy.recordUsage(title); } async getAllPrompts(): Promise { - const folder = this.settings.customPromptsFolder; + const folder = this.customPromptsFolder; const files = this.vault .getFiles() .filter((file) => file.path.startsWith(folder) && file.extension === "md"); @@ -60,13 +56,13 @@ export class CustomPromptProcessor { } // Clean up promptUsageTimestamps - this.usageStrategy?.removeUnusedPrompts(prompts.map((prompt) => prompt.title)).save(); + this.usageStrategy.removeUnusedPrompts(prompts.map((prompt) => prompt.title)); - return prompts.sort((a, b) => this.usageStrategy?.compare(b.title, a.title) || 0); + return prompts.sort((a, b) => this.usageStrategy.compare(b.title, a.title) || 0); } async getPrompt(title: string): Promise { - const filePath = `${this.settings.customPromptsFolder}/${title}.md`; + const filePath = `${this.customPromptsFolder}/${title}.md`; const file = this.vault.getAbstractFileByPath(filePath); if (file instanceof TFile) { const content = await this.vault.read(file); @@ -76,7 +72,7 @@ export class CustomPromptProcessor { } async savePrompt(title: string, content: string): Promise { - const folderPath = normalizePath(this.settings.customPromptsFolder); + const folderPath = normalizePath(this.customPromptsFolder); const filePath = `${folderPath}/${title}.md`; // Check if the folder exists and create it if it doesn't @@ -90,12 +86,12 @@ export class CustomPromptProcessor { } async updatePrompt(originTitle: string, newTitle: string, content: string): Promise { - const filePath = `${this.settings.customPromptsFolder}/${originTitle}.md`; + const filePath = `${this.customPromptsFolder}/${originTitle}.md`; const file = this.vault.getAbstractFileByPath(filePath); if (file instanceof TFile) { if (originTitle !== newTitle) { - const newFilePath = `${this.settings.customPromptsFolder}/${newTitle}.md`; + const newFilePath = `${this.customPromptsFolder}/${newTitle}.md`; const newFileExists = this.vault.getAbstractFileByPath(newFilePath); if (newFileExists) { @@ -104,23 +100,19 @@ export class CustomPromptProcessor { ); } - await Promise.all([ - this.usageStrategy?.updateUsage(originTitle, newTitle).save(), - this.vault.rename(file, newFilePath), - ]); + this.usageStrategy.updateUsage(originTitle, newTitle); + await this.vault.rename(file, newFilePath); } await this.vault.modify(file, content); } } async deletePrompt(title: string): Promise { - const filePath = `${this.settings.customPromptsFolder}/${title}.md`; + const filePath = `${this.customPromptsFolder}/${title}.md`; const file = this.vault.getAbstractFileByPath(filePath); if (file instanceof TFile) { - await Promise.all([ - this.usageStrategy?.removeUnusedPrompts([title]).save(), - this.vault.delete(file), - ]); + this.usageStrategy.removeUnusedPrompts([title]); + await this.vault.delete(file); } } diff --git a/src/encryptionService.ts b/src/encryptionService.ts index efc413e7..b0ea3fdb 100644 --- a/src/encryptionService.ts +++ b/src/encryptionService.ts @@ -1,119 +1,115 @@ -import { CopilotSettings } from "@/settings/SettingsPage"; +import { type CopilotSettings } from "@/settings/model"; import { Platform } from "obsidian"; -// Dynamically import electron to access safeStorage // @ts-ignore -let safeStorage: Electron.SafeStorage | null = null; +let safeStorageInternal: Electron.SafeStorage | null = null; -if (Platform.isDesktop) { +function getSafeStorage() { + if (Platform.isDesktop && safeStorageInternal) { + return safeStorageInternal; + } + // Dynamically import electron to access safeStorage // eslint-disable-next-line @typescript-eslint/no-var-requires - safeStorage = require("electron")?.remote?.safeStorage; + safeStorageInternal = require("electron")?.remote?.safeStorage; + return safeStorageInternal; } -export default class EncryptionService { - private settings: CopilotSettings; - private static ENCRYPTION_PREFIX = "enc_"; - private static DECRYPTION_PREFIX = "dec_"; +const ENCRYPTION_PREFIX = "enc_"; +const DECRYPTION_PREFIX = "dec_"; - constructor(settings: CopilotSettings) { - this.settings = settings; +export function encryptAllKeys(settings: Readonly): Readonly { + if (!settings.enableEncryption) { + return settings; } - - private isPlainText(key: string): boolean { - return ( - !key.startsWith(EncryptionService.ENCRYPTION_PREFIX) && - !key.startsWith(EncryptionService.DECRYPTION_PREFIX) - ); + const newSettings = { ...settings }; + const keysToEncrypt = Object.keys(settings).filter( + (key) => key.toLowerCase().includes("apikey") || key === "plusLicenseKey" + ); + + for (const key of keysToEncrypt) { + const apiKey = settings[key as keyof CopilotSettings] as string; + (newSettings[key as keyof CopilotSettings] as any) = getEncryptedKey(apiKey); } - private isDecrypted(keyBuffer: string): boolean { - return keyBuffer.startsWith(EncryptionService.DECRYPTION_PREFIX); + if (Array.isArray(settings.activeModels)) { + newSettings.activeModels = settings.activeModels.map((model) => ({ + ...model, + apiKey: getEncryptedKey(model.apiKey || ""), + })); } - public encryptAllKeys(): void { - const keysToEncrypt = Object.keys(this.settings).filter( - (key) => key.toLowerCase().includes("apikey") || key === "plusLicenseKey" - ); + return newSettings; +} - for (const key of keysToEncrypt) { - const apiKey = this.settings[key as keyof CopilotSettings] as string; - (this.settings[key as keyof CopilotSettings] as any) = this.getEncryptedKey(apiKey); - } +export function getEncryptedKey(apiKey: string): string { + if (!apiKey || apiKey.startsWith(ENCRYPTION_PREFIX)) { + return apiKey; + } - if (Array.isArray(this.settings.activeModels)) { - this.settings.activeModels = this.settings.activeModels.map((model) => ({ - ...model, - apiKey: this.getEncryptedKey(model.apiKey || ""), - })); - } + if (isDecrypted(apiKey)) { + apiKey = apiKey.replace(DECRYPTION_PREFIX, ""); } - public getEncryptedKey(apiKey: string): string { - if ( - !apiKey || - !this.settings.enableEncryption || - apiKey.startsWith(EncryptionService.ENCRYPTION_PREFIX) - ) { - return apiKey; - } + if (getSafeStorage() && getSafeStorage().isEncryptionAvailable()) { + // Convert the encrypted buffer to a Base64 string and prepend the prefix + const encryptedBuffer = getSafeStorage().encryptString(apiKey) as Buffer; + // Convert the encrypted buffer to a Base64 string and prepend the prefix + return ENCRYPTION_PREFIX + encryptedBuffer.toString("base64"); + } else { + // Simple fallback for mobile (just for demonstration) + const encoder = new TextEncoder(); + const data = encoder.encode(apiKey); + return ENCRYPTION_PREFIX + arrayBufferToBase64(data); + } +} - if (this.isDecrypted(apiKey)) { - apiKey = apiKey.replace(EncryptionService.DECRYPTION_PREFIX, ""); - } +export function getDecryptedKey(apiKey: string): string { + if (!apiKey || isPlainText(apiKey)) { + return apiKey; + } + if (isDecrypted(apiKey)) { + return apiKey.replace(DECRYPTION_PREFIX, ""); + } - if (safeStorage && safeStorage.isEncryptionAvailable()) { - // Convert the encrypted buffer to a Base64 string and prepend the prefix - const encryptedBuffer = safeStorage.encryptString(apiKey) as Buffer; - // Convert the encrypted buffer to a Base64 string and prepend the prefix - return EncryptionService.ENCRYPTION_PREFIX + encryptedBuffer.toString("base64"); + const base64Data = apiKey.replace(ENCRYPTION_PREFIX, ""); + try { + if (getSafeStorage() && getSafeStorage().isEncryptionAvailable()) { + const buffer = Buffer.from(base64Data, "base64"); + return getSafeStorage().decryptString(buffer) as string; } else { // Simple fallback for mobile (just for demonstration) - const encoder = new TextEncoder(); - const data = encoder.encode(apiKey); - return EncryptionService.ENCRYPTION_PREFIX + this.arrayBufferToBase64(data); + const data = base64ToArrayBuffer(base64Data); + const decoder = new TextDecoder(); + return decoder.decode(data); } + } catch (err) { + console.error("Decryption failed:", err); + return "Copilot failed to decrypt API keys!"; } +} - public getDecryptedKey(apiKey: string): string { - if (!apiKey || this.isPlainText(apiKey)) { - return apiKey; - } - if (this.isDecrypted(apiKey)) { - return apiKey.replace(EncryptionService.DECRYPTION_PREFIX, ""); - } +function isPlainText(key: string): boolean { + return !key.startsWith(ENCRYPTION_PREFIX) && !key.startsWith(DECRYPTION_PREFIX); +} - const base64Data = apiKey.replace(EncryptionService.ENCRYPTION_PREFIX, ""); - try { - if (safeStorage && safeStorage.isEncryptionAvailable()) { - const buffer = Buffer.from(base64Data, "base64"); - return safeStorage.decryptString(buffer) as string; - } else { - // Simple fallback for mobile (just for demonstration) - const data = this.base64ToArrayBuffer(base64Data); - const decoder = new TextDecoder(); - return decoder.decode(data); - } - } catch (err) { - console.error("Decryption failed:", err); - return "Copilot failed to decrypt API keys!"; - } - } +function isDecrypted(keyBuffer: string): boolean { + return keyBuffer.startsWith(DECRYPTION_PREFIX); +} - private arrayBufferToBase64(buffer: ArrayBuffer): string { - const bytes = new Uint8Array(buffer); - let binary = ""; - for (let i = 0; i < bytes.byteLength; i++) { - binary += String.fromCharCode(bytes[i]); - } - return window.btoa(binary); +function arrayBufferToBase64(buffer: ArrayBuffer): string { + const bytes = new Uint8Array(buffer); + let binary = ""; + for (let i = 0; i < bytes.byteLength; i++) { + binary += String.fromCharCode(bytes[i]); } + return window.btoa(binary); +} - private base64ToArrayBuffer(base64: string): ArrayBuffer { - const binaryString = window.atob(base64); - const bytes = new Uint8Array(binaryString.length); - for (let i = 0; i < binaryString.length; i++) { - bytes[i] = binaryString.charCodeAt(i); - } - return bytes.buffer; +function base64ToArrayBuffer(base64: string): ArrayBuffer { + const binaryString = window.atob(base64); + const bytes = new Uint8Array(binaryString.length); + for (let i = 0; i < binaryString.length; i++) { + bytes[i] = binaryString.charCodeAt(i); } + return bytes.buffer; } diff --git a/src/main.ts b/src/main.ts index d164749b..77d36eab 100644 --- a/src/main.ts +++ b/src/main.ts @@ -1,8 +1,8 @@ import { BrevilabsClient } from "@/LLMProviders/brevilabsClient"; +import { encryptAllKeys } from "@/encryptionService"; import ChainManager from "@/LLMProviders/chainManager"; import VectorStoreManager from "@/VectorStoreManager"; -import { CustomModel, LangChainParams, SetChainOptions } from "@/aiParams"; -import { ChainType } from "@/chainFactory"; +import { CustomModel } from "@/aiParams"; import { parseChatContent, updateChatMemory } from "@/chatUtils"; import { registerBuiltInCommands } from "@/commands"; import CopilotView from "@/components/CopilotView"; @@ -13,27 +13,19 @@ import { ListPromptModal } from "@/components/modals/ListPromptModal"; import { LoadChatHistoryModal } from "@/components/modals/LoadChatHistoryModal"; import { OramaSearchModal } from "@/components/modals/OramaSearchModal"; import { SimilarNotesModal } from "@/components/modals/SimilarNotesModal"; -import { - BUILTIN_CHAT_MODELS, - BUILTIN_EMBEDDING_MODELS, - CHAT_VIEWTYPE, - CHUNK_SIZE, - DEFAULT_OPEN_AREA, - DEFAULT_SETTINGS, - DEFAULT_SYSTEM_PROMPT, - EVENT_NAMES, - VAULT_VECTOR_STORE_STRATEGY, -} from "@/constants"; +import { CHAT_VIEWTYPE, CHUNK_SIZE, DEFAULT_OPEN_AREA, EVENT_NAMES } from "@/constants"; import { CustomPromptProcessor } from "@/customPromptProcessor"; -import EncryptionService from "@/encryptionService"; import { CustomError } from "@/error"; -import { TimestampUsageStrategy } from "@/promptUsageStrategy"; import { HybridRetriever } from "@/search/hybridRetriever"; -import { CopilotSettings, CopilotSettingTab } from "@/settings/SettingsPage"; +import { CopilotSettingTab } from "@/settings/SettingsPage"; +import { + getSettings, + sanitizeSettings, + setSettings, + subscribeToSettingsChange, +} from "@/settings/model"; import SharedState from "@/sharedState"; import { FileParserManager } from "@/tools/FileParserManager"; -import { sanitizeSettings } from "@/utils"; -import VectorDBManager from "@/vectorDBManager"; import { Embeddings } from "@langchain/core/embeddings"; import { search } from "@orama/orama"; import { @@ -48,61 +40,40 @@ import { } from "obsidian"; export default class CopilotPlugin extends Plugin { - settings: CopilotSettings; // A chat history that stores the messages sent and received // Only reset when the user explicitly clicks "New Chat" sharedState: SharedState; chainManager: ChainManager; brevilabsClient: BrevilabsClient; - encryptionService: EncryptionService; userMessageHistory: string[] = []; vectorStoreManager: VectorStoreManager; - langChainParams: LangChainParams; fileParserManager: FileParserManager; + settingsUnsubscriber?: () => void; async onload(): Promise { await this.loadSettings(); + this.settingsUnsubscriber = subscribeToSettingsChange(() => { + const settings = getSettings(); + if (settings.enableEncryption) { + this.saveData(encryptAllKeys(settings)); + } else { + this.saveData(settings); + } + registerBuiltInCommands(this); + }); this.addSettingTab(new CopilotSettingTab(this.app, this)); // Always have one instance of sharedState and chainManager in the plugin this.sharedState = new SharedState(); - this.langChainParams = this.getLangChainParams(); - - this.encryptionService = new EncryptionService(this.settings); - this.vectorStoreManager = new VectorStoreManager( - this.app, - this.settings, - this.encryptionService, - () => this.langChainParams - ); - - // Initialize event listeners for the VectorStoreManager, e.g. onModify triggers reindexing - await this.vectorStoreManager.initializeEventListeners(); - if (this.settings.enableEncryption) { - await this.saveSettings(); - } + this.vectorStoreManager = new VectorStoreManager(this.app); - // Initialize the rate limiter - VectorDBManager.initialize({ - getEmbeddingRequestsPerSecond: () => this.settings.embeddingRequestsPerSecond, - debug: this.settings.debug, - }); + // Initialize event listeners for the VectorStoreManager, e.g. onModify triggers reindexing + this.vectorStoreManager.initializeEventListeners(); // Initialize BrevilabsClient - this.brevilabsClient = BrevilabsClient.getInstance(this.settings.plusLicenseKey, { - debug: this.settings.debug, - }); + this.brevilabsClient = BrevilabsClient.getInstance(); - // Ensure activeModels always includes core models - this.mergeAllActiveModelsWithCoreModels(); - this.chainManager = new ChainManager( - this.app, - () => this.langChainParams, - this.encryptionService, - this.settings, - this.vectorStoreManager, - this.brevilabsClient - ); + this.chainManager = new ChainManager(this.app, this.vectorStoreManager, this.brevilabsClient); // Initialize FileParserManager early with other core services this.fileParserManager = new FileParserManager(this.brevilabsClient); @@ -133,11 +104,7 @@ export default class CopilotPlugin extends Plugin { registerBuiltInCommands(this); - const promptProcessor = CustomPromptProcessor.getInstance( - this.app.vault, - this.settings, - new TimestampUsageStrategy(this.settings, () => this.saveSettings()) - ); + const promptProcessor = CustomPromptProcessor.getInstance(this.app.vault); this.addCommand({ id: "add-custom-prompt", @@ -379,17 +346,6 @@ export default class CopilotPlugin extends Plugin { }, }); - // Index vault to Copilot index on startup and after loading all commands - // This can take a while, so we don't want to block the startup process - if (this.settings.indexVaultToVectorStore === VAULT_VECTOR_STORE_STRATEGY.ON_STARTUP) { - try { - await this.vectorStoreManager.indexVaultToVectorStore(); - } catch (err) { - console.error("Error saving vault to Copilot index:", err); - new Notice("An error occurred while saving vault to Copilot index."); - } - } - this.registerEvent(this.app.workspace.on("editor-menu", this.handleContextMenu)); } @@ -398,6 +354,7 @@ export default class CopilotPlugin extends Plugin { if (this.vectorStoreManager) { this.vectorStoreManager.onunload(); } + this.settingsUnsubscriber?.(); console.log("Copilot plugin unloaded"); } @@ -407,7 +364,7 @@ export default class CopilotPlugin extends Plugin { } async autosaveCurrentChat() { - if (this.settings.autosaveChat) { + if (getSettings().autosaveChat) { const chatView = this.app.workspace.getLeavesOfType(CHAT_VIEWTYPE)[0]?.view as CopilotView; if (chatView && chatView.sharedState.chatHistory.length > 0) { await chatView.saveChat(); @@ -496,7 +453,7 @@ export default class CopilotPlugin extends Plugin { async activateView(): Promise { const leaves = this.app.workspace.getLeavesOfType(CHAT_VIEWTYPE); if (leaves.length === 0) { - if (this.settings.defaultOpenArea === DEFAULT_OPEN_AREA.VIEW) { + if (getSettings().defaultOpenArea === DEFAULT_OPEN_AREA.VIEW) { await this.app.workspace.getRightLeaf(false).setViewState({ type: CHAT_VIEWTYPE, active: true, @@ -507,8 +464,9 @@ export default class CopilotPlugin extends Plugin { active: true, }); } + } else { + this.app.workspace.revealLeaf(leaves[0]); } - this.app.workspace.revealLeaf(leaves[0]); this.emitChatIsVisible(); } @@ -517,10 +475,9 @@ export default class CopilotPlugin extends Plugin { } async loadSettings() { - this.settings = Object.assign({}, DEFAULT_SETTINGS, await this.loadData()); - - // Ensure activeModels always includes core models - this.mergeAllActiveModelsWithCoreModels(); + const savedSettings = await this.loadData(); + const sanitizedSettings = sanitizeSettings(savedSettings); + setSettings(sanitizedSettings); } mergeActiveModels( @@ -550,27 +507,6 @@ export default class CopilotPlugin extends Plugin { return Array.from(modelMap.values()); } - mergeAllActiveModelsWithCoreModels(): void { - this.settings.activeModels = this.mergeActiveModels( - this.settings.activeModels, - BUILTIN_CHAT_MODELS - ); - this.settings.activeEmbeddingModels = this.mergeActiveModels( - this.settings.activeEmbeddingModels, - BUILTIN_EMBEDDING_MODELS - ); - } - - async saveSettings(): Promise { - if (this.settings.enableEncryption) { - // Encrypt all API keys before saving - this.encryptionService.encryptAllKeys(); - } - // Ensure activeModels always includes core models - this.mergeAllActiveModelsWithCoreModels(); - await this.saveData(this.settings); - } - async countTotalTokens(): Promise { try { const allContent = await this.vectorStoreManager.getAllQAMarkdownContent(); @@ -597,61 +533,6 @@ export default class CopilotPlugin extends Plugin { }); }; - getLangChainParams(): LangChainParams { - if (!this.settings) { - throw new Error("Settings are not loaded"); - } - - const { - openAIApiKey, - openAIOrgId, - huggingfaceApiKey, - cohereApiKey, - anthropicApiKey, - azureOpenAIApiKey, - azureOpenAIApiInstanceName, - azureOpenAIApiDeploymentName, - azureOpenAIApiVersion, - azureOpenAIApiEmbeddingDeploymentName, - googleApiKey, - openRouterAiApiKey, - embeddingModelKey, - temperature, - maxTokens, - contextTurns, - groqApiKey, - } = sanitizeSettings(this.settings); - return { - openAIApiKey, - openAIOrgId, - huggingfaceApiKey, - cohereApiKey, - anthropicApiKey, - groqApiKey, - azureOpenAIApiKey, - azureOpenAIApiInstanceName, - azureOpenAIApiDeploymentName, - azureOpenAIApiVersion, - azureOpenAIApiEmbeddingDeploymentName, - googleApiKey, - openRouterAiApiKey, - modelKey: this.settings.defaultModelKey, - embeddingModelKey: embeddingModelKey || DEFAULT_SETTINGS.embeddingModelKey, - temperature: Number(temperature), - maxTokens: Number(maxTokens), - systemMessage: this.settings.userSystemPrompt || DEFAULT_SYSTEM_PROMPT, - chatContextTurns: Number(contextTurns), - chainType: this.settings.defaultChainType || ChainType.LLM_CHAIN, - options: { forceNewCreation: true, debug: this.settings.debug } as SetChainOptions, - openAIProxyBaseUrl: this.settings.openAIProxyBaseUrl, - openAIEmbeddingProxyBaseUrl: this.settings.openAIEmbeddingProxyBaseUrl, - }; - } - - getEncryptionService(): EncryptionService { - return this.encryptionService; - } - async loadCopilotChatHistory() { const chatFiles = await this.getChatHistoryFiles(); if (chatFiles.length === 0) { @@ -662,7 +543,7 @@ export default class CopilotPlugin extends Plugin { } async getChatHistoryFiles(): Promise { - const folder = this.app.vault.getAbstractFileByPath(this.settings.defaultSaveFolder); + const folder = this.app.vault.getAbstractFileByPath(getSettings().defaultSaveFolder); if (!(folder instanceof TFolder)) { return []; } @@ -723,7 +604,7 @@ export default class CopilotPlugin extends Plugin { maxK: 20, salientTerms: [], }, - this.settings.debug + getSettings().debug ); const truncatedContent = content.length > CHUNK_SIZE ? content.slice(0, CHUNK_SIZE) : content; @@ -757,7 +638,7 @@ export default class CopilotPlugin extends Plugin { salientTerms: salientTerms, textWeight: textWeight, }, - this.settings.debug + getSettings().debug ); const results = await hybridRetriever.getOramaChunks(query, salientTerms); diff --git a/src/mentions/Mention.ts b/src/mentions/Mention.ts index 7293a617..59650a8f 100644 --- a/src/mentions/Mention.ts +++ b/src/mentions/Mention.ts @@ -12,14 +12,14 @@ export class Mention { private mentions: Map; private brevilabsClient: BrevilabsClient; - private constructor(licenseKey: string) { + private constructor() { this.mentions = new Map(); - this.brevilabsClient = BrevilabsClient.getInstance(licenseKey); + this.brevilabsClient = BrevilabsClient.getInstance(); } - static getInstance(licenseKey: string): Mention { + static getInstance(): Mention { if (!Mention.instance) { - Mention.instance = new Mention(licenseKey); + Mention.instance = new Mention(); } return Mention.instance; } diff --git a/src/promptUsageStrategy.ts b/src/promptUsageStrategy.ts index c4f1fee5..dad2884d 100644 --- a/src/promptUsageStrategy.ts +++ b/src/promptUsageStrategy.ts @@ -1,53 +1,42 @@ -import { CopilotSettings } from "@/settings/SettingsPage"; +import { getSettings, updateSetting } from "@/settings/model"; export interface PromptUsageStrategy { - recordUsage: (promptTitle: string) => PromptUsageStrategy; + recordUsage: (promptTitle: string) => void; - updateUsage: (oldTitle: string, newTitle: string) => PromptUsageStrategy; + updateUsage: (oldTitle: string, newTitle: string) => void; - removeUnusedPrompts: (existingPromptTitles: Array) => PromptUsageStrategy; + removeUnusedPrompts: (existingPromptTitles: Array) => void; compare: (aKey: string, bKey: string) => number; - - save: () => Promise; } export class TimestampUsageStrategy implements PromptUsageStrategy { - private usageData: Record = {}; - - constructor( - private settings: CopilotSettings, - private saveSettings: () => Promise - ) { - this.usageData = { ...settings.promptUsageTimestamps }; + get usageData(): Readonly> { + return getSettings().promptUsageTimestamps; } - recordUsage(promptTitle: string): PromptUsageStrategy { - this.usageData[promptTitle] = Date.now(); - return this; + recordUsage(promptTitle: string) { + updateSetting("promptUsageTimestamps", { ...this.usageData, [promptTitle]: Date.now() }); } - updateUsage(oldTitle: string, newTitle: string): PromptUsageStrategy { - this.usageData[newTitle] = this.usageData[oldTitle]; - delete this.usageData[oldTitle]; - return this; + updateUsage(oldTitle: string, newTitle: string) { + const newUsageData = { ...this.usageData }; + newUsageData[newTitle] = newUsageData[oldTitle]; + delete newUsageData[oldTitle]; + updateSetting("promptUsageTimestamps", newUsageData); } - removeUnusedPrompts(existingPromptTitles: Array): PromptUsageStrategy { - for (const key in this.usageData) { + removeUnusedPrompts(existingPromptTitles: Array) { + const newUsageData = { ...this.usageData }; + for (const key of Object.keys(newUsageData)) { if (!existingPromptTitles.includes(key)) { - delete this.usageData[key]; + delete newUsageData[key]; } } - return this; + updateSetting("promptUsageTimestamps", newUsageData); } compare(aKey: string, bKey: string): number { return (this.usageData[aKey] || 0) - (this.usageData[bKey] || 0); } - - async save(): Promise { - this.settings.promptUsageTimestamps = { ...this.usageData }; - await this.saveSettings(); - } } diff --git a/src/settings/SettingsPage.tsx b/src/settings/SettingsPage.tsx index c0e0020b..57b4dff7 100644 --- a/src/settings/SettingsPage.tsx +++ b/src/settings/SettingsPage.tsx @@ -1,60 +1,11 @@ -import { CustomModel } from "@/aiParams"; -import { ChainType } from "@/chainFactory"; import CopilotView from "@/components/CopilotView"; -import { CHAT_VIEWTYPE, DEFAULT_OPEN_AREA } from "@/constants"; +import { CHAT_VIEWTYPE } from "@/constants"; import CopilotPlugin from "@/main"; import { App, Notice, PluginSettingTab, Setting } from "obsidian"; import React from "react"; import { createRoot } from "react-dom/client"; import SettingsMain from "./components/SettingsMain"; -import { SettingsProvider } from "./contexts/SettingsContext"; - -export interface CopilotSettings { - plusLicenseKey: string; - openAIApiKey: string; - openAIOrgId: string; - huggingfaceApiKey: string; - cohereApiKey: string; - anthropicApiKey: string; - azureOpenAIApiKey: string; - azureOpenAIApiInstanceName: string; - azureOpenAIApiDeploymentName: string; - azureOpenAIApiVersion: string; - azureOpenAIApiEmbeddingDeploymentName: string; - googleApiKey: string; - openRouterAiApiKey: string; - defaultChainType: ChainType; - defaultModelKey: string; - embeddingModelKey: string; - temperature: number; - maxTokens: number; - contextTurns: number; - userSystemPrompt: string; - openAIProxyBaseUrl: string; - openAIEmbeddingProxyBaseUrl: string; - stream: boolean; - defaultSaveFolder: string; - defaultConversationTag: string; - autosaveChat: boolean; - customPromptsFolder: string; - indexVaultToVectorStore: string; - chatNoteContextPath: string; - chatNoteContextTags: string[]; - debug: boolean; - enableEncryption: boolean; - maxSourceChunks: number; - qaExclusions: string; - qaInclusions: string; - groqApiKey: string; - enabledCommands: Record; - activeModels: Array; - activeEmbeddingModels: Array; - promptUsageTimestamps: Record; - embeddingRequestsPerSecond: number; - defaultOpenArea: DEFAULT_OPEN_AREA; - disableIndexOnMobile: boolean; - showSuggestedPrompts: boolean; -} +import { getSettings, updateSetting } from "@/settings/model"; export class CopilotSettingTab extends PluginSettingTab { plugin: CopilotPlugin; @@ -66,12 +17,9 @@ export class CopilotSettingTab extends PluginSettingTab { async reloadPlugin() { try { - // Save the settings before reloading - await this.plugin.saveSettings(); - // Autosave the current chat before reloading const chatView = this.app.workspace.getLeavesOfType(CHAT_VIEWTYPE)[0]?.view as CopilotView; - if (chatView && this.plugin.settings.autosaveChat) { + if (chatView && getSettings().autosaveChat) { await this.plugin.autosaveCurrentChat(); } @@ -96,11 +44,7 @@ export class CopilotSettingTab extends PluginSettingTab { const div = containerEl.createDiv("div"); const sections = createRoot(div); - sections.render( - - - - ); + sections.render(); const devModeHeader = containerEl.createEl("h1", { text: "Additional Settings" }); devModeHeader.style.marginTop = "40px"; @@ -113,9 +57,8 @@ export class CopilotSettingTab extends PluginSettingTab { }) ) .addToggle((toggle) => - toggle.setValue(this.plugin.settings.enableEncryption).onChange(async (value) => { - this.plugin.settings.enableEncryption = value; - await this.plugin.saveSettings(); + toggle.setValue(getSettings().enableEncryption).onChange(async (value) => { + updateSetting("enableEncryption", value); }) ); @@ -127,9 +70,8 @@ export class CopilotSettingTab extends PluginSettingTab { }) ) .addToggle((toggle) => - toggle.setValue(this.plugin.settings.debug).onChange(async (value) => { - this.plugin.settings.debug = value; - await this.plugin.saveSettings(); + toggle.setValue(getSettings().debug).onChange(async (value) => { + updateSetting("debug", value); }) ); } diff --git a/src/settings/components/AdvancedSettings.tsx b/src/settings/components/AdvancedSettings.tsx index 8f6b687b..096c54d4 100644 --- a/src/settings/components/AdvancedSettings.tsx +++ b/src/settings/components/AdvancedSettings.tsx @@ -1,27 +1,19 @@ import { DEFAULT_SYSTEM_PROMPT } from "@/constants"; import React from "react"; import { TextAreaComponent } from "./SettingBlocks"; +import { updateSetting, useSettingsValue } from "@/settings/model"; -interface AdvancedSettingsProps { - userSystemPrompt: string; - setUserSystemPrompt: (value: string) => void; -} - -const AdvancedSettings: React.FC = ({ - userSystemPrompt, - setUserSystemPrompt, -}) => { +const AdvancedSettings: React.FC = () => { + const settings = useSettingsValue(); return (
-
-

Advanced Settings

updateSetting("userSystemPrompt", value)} + placeholder={settings.userSystemPrompt || "Default: " + DEFAULT_SYSTEM_PROMPT} rows={10} />
diff --git a/src/settings/components/ApiSettings.tsx b/src/settings/components/ApiSettings.tsx index c0c2de82..29dad747 100644 --- a/src/settings/components/ApiSettings.tsx +++ b/src/settings/components/ApiSettings.tsx @@ -1,64 +1,12 @@ import React from "react"; import ApiSetting from "./ApiSetting"; import Collapsible from "./Collapsible"; +import { updateSetting, useSettingsValue } from "@/settings/model"; -interface ApiSettingsProps { - openAIApiKey: string; - setOpenAIApiKey: (value: string) => void; - openAIOrgId: string; - setOpenAIOrgId: (value: string) => void; - googleApiKey: string; - setGoogleApiKey: (value: string) => void; - anthropicApiKey: string; - setAnthropicApiKey: (value: string) => void; - openRouterAiApiKey: string; - setOpenRouterAiApiKey: (value: string) => void; - azureOpenAIApiKey: string; - setAzureOpenAIApiKey: (value: string) => void; - azureOpenAIApiInstanceName: string; - setAzureOpenAIApiInstanceName: (value: string) => void; - azureOpenAIApiDeploymentName: string; - setAzureOpenAIApiDeploymentName: (value: string) => void; - azureOpenAIApiVersion: string; - setAzureOpenAIApiVersion: (value: string) => void; - azureOpenAIApiEmbeddingDeploymentName: string; - setAzureOpenAIApiEmbeddingDeploymentName: (value: string) => void; - groqApiKey: string; - setGroqApiKey: (value: string) => void; - cohereApiKey: string; - setCohereApiKey: (value: string) => void; -} - -const ApiSettings: React.FC = ({ - openAIApiKey, - setOpenAIApiKey, - openAIOrgId, - setOpenAIOrgId, - googleApiKey, - setGoogleApiKey, - anthropicApiKey, - setAnthropicApiKey, - openRouterAiApiKey, - setOpenRouterAiApiKey, - azureOpenAIApiKey, - setAzureOpenAIApiKey, - azureOpenAIApiInstanceName, - setAzureOpenAIApiInstanceName, - azureOpenAIApiDeploymentName, - setAzureOpenAIApiDeploymentName, - azureOpenAIApiVersion, - setAzureOpenAIApiVersion, - azureOpenAIApiEmbeddingDeploymentName, - setAzureOpenAIApiEmbeddingDeploymentName, - groqApiKey, - setGroqApiKey, - cohereApiKey, - setCohereApiKey, -}) => { +const ApiSettings: React.FC = () => { + const settings = useSettingsValue(); return (
-
-

API Settings

All your API keys are stored locally.

@@ -71,8 +19,8 @@ const ApiSettings: React.FC = ({
updateSetting("openAIApiKey", value)} placeholder="Enter OpenAI API Key" />

@@ -87,8 +35,8 @@ const ApiSettings: React.FC = ({

updateSetting("openAIOrgId", value)} placeholder="Enter OpenAI Organization ID if applicable" />
@@ -109,8 +57,8 @@ const ApiSettings: React.FC = ({
updateSetting("googleApiKey", value)} placeholder="Enter Google API Key" />

@@ -133,8 +81,8 @@ const ApiSettings: React.FC = ({

updateSetting("anthropicApiKey", value)} placeholder="Enter Anthropic API Key" />

@@ -158,8 +106,8 @@ const ApiSettings: React.FC = ({

updateSetting("openRouterAiApiKey", value)} placeholder="Enter OpenRouter AI API Key" />

@@ -182,37 +130,37 @@ const ApiSettings: React.FC = ({

updateSetting("azureOpenAIApiKey", value)} placeholder="Enter Azure OpenAI API Key" /> updateSetting("azureOpenAIApiInstanceName", value)} placeholder="Enter Azure OpenAI API Instance Name" type="text" /> updateSetting("azureOpenAIApiDeploymentName", value)} placeholder="Enter Azure OpenAI API Deployment Name" type="text" /> updateSetting("azureOpenAIApiVersion", value)} placeholder="Enter Azure OpenAI API Version" type="text" /> updateSetting("azureOpenAIApiEmbeddingDeploymentName", value)} placeholder="Enter Azure OpenAI API Embedding Deployment Name" type="text" /> @@ -223,8 +171,8 @@ const ApiSettings: React.FC = ({
updateSetting("groqApiKey", value)} placeholder="Enter Groq API Key" />

@@ -242,8 +190,8 @@ const ApiSettings: React.FC = ({ updateSetting("cohereApiKey", value)} placeholder="Enter Cohere API Key" />

diff --git a/src/settings/components/CopilotPlusSettings.tsx b/src/settings/components/CopilotPlusSettings.tsx index 09a7a1f6..72e0e3ee 100644 --- a/src/settings/components/CopilotPlusSettings.tsx +++ b/src/settings/components/CopilotPlusSettings.tsx @@ -1,14 +1,12 @@ import React from "react"; -import { useSettingsContext } from "../contexts/SettingsContext"; import ApiSetting from "./ApiSetting"; +import { updateSetting, useSettingsValue } from "@/settings/model"; const CopilotPlusSettings: React.FC = () => { - const { settings, updateSettings } = useSettingsContext(); + const settings = useSettingsValue(); return (

-
-

Copilot Plus (Alpha)

Copilot Plus brings powerful AI agent capabilities to Obsidian. Alpha access is limited to @@ -21,7 +19,7 @@ const CopilotPlusSettings: React.FC = () => { title="License Key" description="Enter your Copilot Plus license key" value={settings.plusLicenseKey} - setValue={(value) => updateSettings({ plusLicenseKey: value })} + setValue={(value) => updateSetting("plusLicenseKey", value)} placeholder="Enter your license key" />

diff --git a/src/settings/components/GeneralSettings.tsx b/src/settings/components/GeneralSettings.tsx index ef3f5be4..c7ebd4c4 100644 --- a/src/settings/components/GeneralSettings.tsx +++ b/src/settings/components/GeneralSettings.tsx @@ -1,9 +1,7 @@ -import { CustomModel, LangChainParams } from "@/aiParams"; +import { CustomModel } from "@/aiParams"; import { ChainType } from "@/chainFactory"; import { ChatModelProviders, DEFAULT_OPEN_AREA } from "@/constants"; -import EncryptionService from "@/encryptionService"; import React from "react"; -import { useSettingsContext } from "../contexts/SettingsContext"; import CommandToggleSettings from "./CommandToggleSettings"; import { ModelSettingsComponent, @@ -11,17 +9,10 @@ import { TextComponent, ToggleComponent, } from "./SettingBlocks"; +import { updateSetting, setSettings, useSettingsValue } from "@/settings/model"; -interface GeneralSettingsProps { - getLangChainParams: () => LangChainParams; - encryptionService: EncryptionService; -} - -const GeneralSettings: React.FC = ({ - getLangChainParams, - encryptionService, -}) => { - const { settings, updateSettings } = useSettingsContext(); +const GeneralSettings: React.FC = () => { + const settings = useSettingsValue(); const handleUpdateModels = (models: Array) => { const updatedActiveModels = models.map((model) => ({ @@ -29,12 +20,12 @@ const GeneralSettings: React.FC = ({ baseUrl: model.baseUrl || "", apiKey: model.apiKey || "", })); - updateSettings({ activeModels: updatedActiveModels }); + updateSetting("activeModels", updatedActiveModels); }; // modelKey is name | provider, e.g. "gpt-4o|openai" const onSetDefaultModelKey = (modelKey: string) => { - updateSettings({ defaultModelKey: modelKey }); + updateSetting("defaultModelKey", modelKey); }; const onDeleteModel = (modelKey: string) => { @@ -54,8 +45,7 @@ const GeneralSettings: React.FC = ({ } } - // Update both activeModels and defaultModelKey in a single operation - updateSettings({ + setSettings({ activeModels: updatedActiveModels, defaultModelKey: newDefaultModelKey, }); @@ -80,7 +70,7 @@ const GeneralSettings: React.FC = ({ id="defaultChainSelect" className="default-chain-selection" value={settings.defaultChainType} - onChange={(e) => updateSettings({ defaultChainType: e.target.value as ChainType })} + onChange={(e) => updateSetting("defaultChainType", e.target.value as ChainType)} > @@ -93,26 +83,26 @@ const GeneralSettings: React.FC = ({ description="The default folder name where chat conversations will be saved. Default is 'copilot-conversations'" placeholder="copilot-conversations" value={settings.defaultSaveFolder} - onChange={(value) => updateSettings({ defaultSaveFolder: value })} + onChange={(value) => updateSetting("defaultSaveFolder", value)} /> updateSettings({ defaultConversationTag: value })} + onChange={(value) => updateSetting("defaultConversationTag", value)} /> updateSettings({ autosaveChat: value })} + onChange={(value) => updateSetting("autosaveChat", value)} /> updateSettings({ showSuggestedPrompts: value })} + onChange={(value) => updateSetting("showSuggestedPrompts", value)} />

Open Plugin In

@@ -120,9 +110,7 @@ const GeneralSettings: React.FC = ({