From 36615e8417972cd55ff5cd713d263d35eb8ed934 Mon Sep 17 00:00:00 2001 From: JinmingYang <2214962083@qq.com> Date: Tue, 17 Dec 2024 23:01:34 +0800 Subject: [PATCH] feat: refactor mention plugin system --- .vscode/settings.json | 3 +- .../ai/model-providers/helpers/factory.ts | 23 +- .../registers/controller-register.ts | 46 +++ src/extension/registers/index.ts | 2 + .../chat-context-processor/index.ts | 2 +- .../strategies/base/base-agent.ts | 62 ++++ .../strategies/base/base-node.ts | 250 +++++++++++++++++ .../nodes/state.ts => base/base-state.ts} | 39 ++- .../strategies/{ => base}/base-strategy.ts | 0 .../strategies/base/index.ts | 4 + .../strategies/chat-strategy/chat-workflow.ts | 28 +- .../strategies/chat-strategy/index.ts | 7 +- .../chat-messages-constructor.ts | 5 +- .../conversation-message-constructor.ts | 2 +- .../chat-strategy/nodes/agent-node.ts | 120 ++++---- .../chat-strategy/nodes/generate-node.ts | 29 +- .../strategies/chat-strategy/state.ts | 37 +++ .../utils/conversation-utils.ts | 36 +++ .../webview-api/controllers/index.ts | 4 +- .../controllers/mention-controller.ts | 112 ++++++++ src/extension/webview-api/index.ts | 36 +-- src/shared/entities/conversation-entity.ts | 48 +++- src/shared/plugins/agents/agent-names.ts | 5 + .../plugins/agents/codebase-search-agent.ts | 78 ++++++ .../plugins/agents/doc-retriever-agent.ts | 96 +++++++ src/shared/plugins/agents/fs-visit-agent.ts | 43 +++ src/shared/plugins/agents/web-search-agent.ts | 154 ++++++++++ src/shared/plugins/agents/web-visit-agent.ts | 62 ++++ src/shared/plugins/base/base-to-state.ts | 120 ++++++++ .../base/client/client-plugin-types.ts | 25 +- .../base/server/create-provider-manager.ts | 23 +- src/shared/plugins/base/strategies.ts | 2 + .../doc-plugin/client/doc-client-plugin.tsx | 29 +- .../doc-plugin/client/doc-log-preview.tsx | 43 ++- src/shared/plugins/doc-plugin/doc-to-state.ts | 20 ++ .../doc-chat-strategy-provider.ts | 75 +++-- .../chat-strategy/doc-retriever-node.ts | 206 +++----------- .../server/doc-mention-utils-provider.ts | 31 ++ .../doc-plugin/server/doc-server-plugin.ts | 6 + src/shared/plugins/doc-plugin/types.ts | 22 +- .../fs-plugin/client/fs-client-plugin.tsx | 114 ++------ .../fs-plugin/client/fs-log-preview.tsx | 52 +++- src/shared/plugins/fs-plugin/fs-to-state.ts | 30 ++ .../chat-strategy/codebase-search-node.ts | 180 +++--------- .../fs-chat-strategy-provider.ts | 161 +++++------ .../server/chat-strategy/fs-visit-node.ts | 133 ++------- .../server/fs-mention-utils-provider.ts | 73 +++++ .../fs-plugin/server/fs-server-plugin.ts | 6 + src/shared/plugins/fs-plugin/types.ts | 55 ++-- .../git-plugin/client/git-client-plugin.tsx | 43 +-- src/shared/plugins/git-plugin/git-to-state.ts | 20 ++ .../git-chat-strategy-provider.ts | 61 ++-- .../server/git-mention-utils-provider.ts | 34 +++ .../git-plugin/server/git-server-plugin.ts | 6 + src/shared/plugins/git-plugin/types.ts | 22 +- .../client/terminal-client-plugin.tsx | 25 +- .../terminal-chat-strategy-provider.ts | 40 ++- .../server/terminal-mention-utils-provider.ts | 32 +++ .../server/terminal-server-plugin.ts | 6 + .../terminal-mentions-to-state.ts | 14 + src/shared/plugins/terminal-plugin/types.ts | 16 +- .../web-plugin/client/web-client-plugin.tsx | 46 +-- .../web-plugin/client/web-log-preview.tsx | 72 +++-- .../web-chat-strategy-provider.ts | 81 ++++-- .../server/chat-strategy/web-search-node.ts | 265 +++--------------- .../server/chat-strategy/web-visit-node.ts | 152 +++------- .../server/web-mention-utils-provider.ts | 22 ++ .../web-plugin/server/web-server-plugin.ts | 6 + src/shared/plugins/web-plugin/types.ts | 24 +- src/shared/plugins/web-plugin/web-to-state.ts | 26 ++ .../convert-to-langchain-message-contents.ts | 11 +- .../utils/merge-langchain-message-contents.ts | 25 +- .../components/chat/editor/chat-editor.tsx | 22 ++ .../components/chat/editor/chat-input.tsx | 182 ++++++++---- .../chat/editor/file-attachments.tsx | 64 ++++- .../chat/messages/roles/chat-ai-message.tsx | 2 +- .../chat/messages/roles/chat-log-preview.tsx | 12 +- .../chat/messages/toolbars/base-toolbar.tsx | 2 +- .../chat/selectors/context-selector.tsx | 19 +- .../mention-selector/mention-selector.tsx | 6 +- .../components/content-preview-popover.tsx | 2 +- .../global-search/global-search.tsx | 4 +- src/webview/components/ui/command.tsx | 6 +- src/webview/components/ui/popover.tsx | 2 +- src/webview/components/ui/tabs.tsx | 7 +- .../hooks/chat/use-plugin-providers.tsx | 32 +-- src/webview/lexical/hooks/use-drop-handler.ts | 92 ++++++ .../lexical/hooks/use-paste-handler.ts | 151 ++++++++++ src/webview/lexical/nodes/mention-node.tsx | 80 ++---- .../lexical/plugins/mention-plugin.tsx | 7 +- src/webview/styles/global.css | 7 + src/webview/types/chat.ts | 2 - src/webview/utils/plugin-states.ts | 78 ++++-- 93 files changed, 2953 insertions(+), 1614 deletions(-) create mode 100644 src/extension/registers/controller-register.ts create mode 100644 src/extension/webview-api/chat-context-processor/strategies/base/base-agent.ts create mode 100644 src/extension/webview-api/chat-context-processor/strategies/base/base-node.ts rename src/extension/webview-api/chat-context-processor/strategies/{chat-strategy/nodes/state.ts => base/base-state.ts} (53%) rename src/extension/webview-api/chat-context-processor/strategies/{ => base}/base-strategy.ts (100%) create mode 100644 src/extension/webview-api/chat-context-processor/strategies/base/index.ts create mode 100644 src/extension/webview-api/chat-context-processor/strategies/chat-strategy/state.ts create mode 100644 src/extension/webview-api/chat-context-processor/utils/conversation-utils.ts create mode 100644 src/extension/webview-api/controllers/mention-controller.ts create mode 100644 src/shared/plugins/agents/agent-names.ts create mode 100644 src/shared/plugins/agents/codebase-search-agent.ts create mode 100644 src/shared/plugins/agents/doc-retriever-agent.ts create mode 100644 src/shared/plugins/agents/fs-visit-agent.ts create mode 100644 src/shared/plugins/agents/web-search-agent.ts create mode 100644 src/shared/plugins/agents/web-visit-agent.ts create mode 100644 src/shared/plugins/base/base-to-state.ts create mode 100644 src/shared/plugins/base/strategies.ts create mode 100644 src/shared/plugins/doc-plugin/doc-to-state.ts create mode 100644 src/shared/plugins/doc-plugin/server/doc-mention-utils-provider.ts create mode 100644 src/shared/plugins/fs-plugin/fs-to-state.ts create mode 100644 src/shared/plugins/fs-plugin/server/fs-mention-utils-provider.ts create mode 100644 src/shared/plugins/git-plugin/git-to-state.ts create mode 100644 src/shared/plugins/git-plugin/server/git-mention-utils-provider.ts create mode 100644 src/shared/plugins/terminal-plugin/server/terminal-mention-utils-provider.ts create mode 100644 src/shared/plugins/terminal-plugin/terminal-mentions-to-state.ts create mode 100644 src/shared/plugins/web-plugin/server/web-mention-utils-provider.ts create mode 100644 src/shared/plugins/web-plugin/web-to-state.ts create mode 100644 src/webview/lexical/hooks/use-drop-handler.ts create mode 100644 src/webview/lexical/hooks/use-paste-handler.ts diff --git a/.vscode/settings.json b/.vscode/settings.json index 2239d5e..d9e61df 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -134,5 +134,6 @@ "xsss", "Zhipu", "zustand" - ] + ], + "testing.automaticallyOpenTestResults": "neverOpen" } diff --git a/src/extension/ai/model-providers/helpers/factory.ts b/src/extension/ai/model-providers/helpers/factory.ts index ce8c8b0..f7589e1 100644 --- a/src/extension/ai/model-providers/helpers/factory.ts +++ b/src/extension/ai/model-providers/helpers/factory.ts @@ -3,10 +3,12 @@ import { aiModelDB } from '@extension/webview-api/lowdb/ai-model-db' import { aiProviderDB } from '@extension/webview-api/lowdb/ai-provider-db' import { globalSettingsDB } from '@extension/webview-api/lowdb/settings-db' import type { MessageContent } from '@langchain/core/messages' +import type { ChatContext } from '@shared/entities' import type { AIModel } from '@shared/entities/ai-model-entity' import { AIProvider, AIProviderType, + chatContextTypeModelSettingKeyMap, FeatureModelSettingKey, FeatureModelSettingValue } from '@shared/entities/ai-provider-entity' @@ -23,6 +25,7 @@ export class ModelProviderFactory { const provider = (await aiProviderDB.getAll()).find( p => p.id === providerId ) + if (!provider) { throw new Error(`Provider not found: ${providerId}`) } @@ -90,6 +93,12 @@ export class ModelProviderFactory { return await this.create(setting) } + static async getModelProviderForChatContext(chatContext: ChatContext) { + const chatContextType = chatContext.type + const modelSettingKey = chatContextTypeModelSettingKeyMap[chatContextType] + return await this.getModelProvider(modelSettingKey) + } + static async getModelSettingForFeature( key: FeatureModelSettingKey, useDefault = true @@ -99,9 +108,19 @@ export class ModelProviderFactory { FeatureModelSettingValue > | null = await globalSettingsDB.getSetting('models') + const isExtendsDefault = + !settings?.[key]?.providerId && !settings?.[key]?.modelName + + const defaultSetting = settings?.[FeatureModelSettingKey.Default] + + if (isExtendsDefault && useDefault && !defaultSetting) { + throw new Error( + 'You forgot to set provider or model in your settings, please check your settings.' + ) + } + const setting = - settings?.[key] || - (useDefault ? settings?.[FeatureModelSettingKey.Default] : undefined) + isExtendsDefault && useDefault ? defaultSetting : settings?.[key] return setting } diff --git a/src/extension/registers/controller-register.ts b/src/extension/registers/controller-register.ts new file mode 100644 index 0000000..279bf1e --- /dev/null +++ b/src/extension/registers/controller-register.ts @@ -0,0 +1,46 @@ +import type { CommandManager } from '@extension/commands/command-manager' +import { + controllers, + type Controllers +} from '@extension/webview-api/controllers' +import type { Controller } from '@extension/webview-api/types' +import * as vscode from 'vscode' + +import { BaseRegister } from './base-register' +import type { RegisterManager } from './register-manager' + +export class ControllerRegister extends BaseRegister { + public controllers: Map = new Map() + + constructor( + protected context: vscode.ExtensionContext, + protected registerManager: RegisterManager, + protected commandManager: CommandManager + ) { + super(context, registerManager, commandManager) + } + + async register(): Promise { + for (const ControllerClass of controllers) { + const controller = new ControllerClass( + this.registerManager, + this.commandManager + ) + this.controllers.set(controller.name, controller) + } + } + + api['name']>( + apiName: T + ): Extract, { name: T }> { + // Type assertion needed since Map.get doesn't preserve the exact type + return this.controllers.get(apiName) as Extract< + InstanceType, + { name: T } + > + } + + dispose(): void { + this.controllers.clear() + } +} diff --git a/src/extension/registers/index.ts b/src/extension/registers/index.ts index 28147e8..42d27e1 100644 --- a/src/extension/registers/index.ts +++ b/src/extension/registers/index.ts @@ -2,6 +2,7 @@ import { AideKeyUsageStatusBarRegister } from './aide-key-usage-statusbar-regist import { AutoOpenCorrespondingFilesRegister } from './auto-open-corresponding-files-register' import { BaseRegister } from './base-register' import { CodebaseWatcherRegister } from './codebase-watcher-register' +import { ControllerRegister } from './controller-register' import { InlineDiffRegister } from './inline-diff-register' import { ModelRegister } from './model-register' import { RegisterManager } from './register-manager' @@ -22,6 +23,7 @@ export const setupRegisters = async (registerManager: RegisterManager) => { InlineDiffRegister, TerminalWatcherRegister, ServerPluginRegister, + ControllerRegister, WebviewRegister, ModelRegister, CodebaseWatcherRegister diff --git a/src/extension/webview-api/chat-context-processor/index.ts b/src/extension/webview-api/chat-context-processor/index.ts index 2ca3240..b8b4a96 100644 --- a/src/extension/webview-api/chat-context-processor/index.ts +++ b/src/extension/webview-api/chat-context-processor/index.ts @@ -10,7 +10,7 @@ import { AutoTaskStrategy } from './strategies/auto-task-strategy' import type { BaseStrategy, BaseStrategyOptions -} from './strategies/base-strategy' +} from './strategies/base/base-strategy' import { ChatStrategy } from './strategies/chat-strategy' import { ComposerStrategy } from './strategies/composer-strategy' import { UIDesignerStrategy } from './strategies/ui-designer-strategy' diff --git a/src/extension/webview-api/chat-context-processor/strategies/base/base-agent.ts b/src/extension/webview-api/chat-context-processor/strategies/base/base-agent.ts new file mode 100644 index 0000000..3222b8c --- /dev/null +++ b/src/extension/webview-api/chat-context-processor/strategies/base/base-agent.ts @@ -0,0 +1,62 @@ +import type { BaseGraphState } from '@extension/webview-api/chat-context-processor/strategies/base/base-state' +import type { BaseStrategyOptions } from '@extension/webview-api/chat-context-processor/strategies/base/base-strategy' +import { DynamicStructuredTool } from '@langchain/core/tools' +import type { Agent } from '@shared/entities' +import { z } from 'zod' + +export interface AgentContext< + State extends BaseGraphState = BaseGraphState, + CreateToolOptions extends Record = Record, + StrategyOptions extends BaseStrategyOptions = BaseStrategyOptions +> { + state: State + createToolOptions: CreateToolOptions + strategyOptions: StrategyOptions +} + +export abstract class BaseAgent< + State extends BaseGraphState = BaseGraphState, + CreateToolOptions extends Record = Record, + StrategyOptions extends BaseStrategyOptions = BaseStrategyOptions, + TInput extends z.ZodType = z.ZodType, + TOutput extends z.ZodType = z.ZodType +> { + abstract inputSchema: TInput + + abstract outputSchema: TOutput + + abstract name: string + + abstract logTitle: string + + abstract description: string + + constructor( + public context: AgentContext + ) {} + + // Abstract method that needs to be implemented by concrete agents + abstract execute(input: z.infer): Promise> + + // Create the Langchain tool + public async createTool(): Promise { + return new DynamicStructuredTool({ + name: this.name, + description: this.description, + schema: this.inputSchema as any, + func: async (input: z.infer) => { + const result = await this.execute(input) + return this.outputSchema.parse(result) + } + }) + } +} + +export type GetAgentInput = z.infer + +export type GetAgentOutput = z.infer + +export type GetAgent = Agent< + GetAgentInput, + GetAgentOutput +> diff --git a/src/extension/webview-api/chat-context-processor/strategies/base/base-node.ts b/src/extension/webview-api/chat-context-processor/strategies/base/base-node.ts new file mode 100644 index 0000000..9919c07 --- /dev/null +++ b/src/extension/webview-api/chat-context-processor/strategies/base/base-node.ts @@ -0,0 +1,250 @@ +/* eslint-disable new-cap */ +import type { + BaseGraphNode, + BaseGraphState +} from '@extension/webview-api/chat-context-processor/strategies/base/base-state' +import type { BaseStrategyOptions } from '@extension/webview-api/chat-context-processor/strategies/base/base-strategy' +import { findCurrentToolsCallParams } from '@extension/webview-api/chat-context-processor/utils/find-current-tools-call-params' +import type { ZodObjectAny } from '@langchain/core/dist/types/zod' +import type { ToolMessage } from '@langchain/core/messages' +import type { DynamicStructuredTool } from '@langchain/core/tools' +import type { Agent, Conversation, ConversationLog } from '@shared/entities' +import { settledPromiseResults } from '@shared/utils/common' +import { produce } from 'immer' +import { v4 as uuidv4 } from 'uuid' + +import type { BaseAgent, GetAgentInput, GetAgentOutput } from './base-agent' + +export interface BaseNodeContext< + StrategyOptions extends BaseStrategyOptions = BaseStrategyOptions +> { + strategyOptions: StrategyOptions +} + +type AgentConstructor = new (...args: any[]) => T + +export type AgentConfig = { + agentClass: AgentConstructor + agentContext?: T['context'] + processAgentOutput?: (agentOutput: GetAgentOutput) => GetAgentOutput +} + +export type AgentsConfig = { + [K: string]: AgentConfig +} + +type ExecuteAgentToolResult = { + agents: Agent, GetAgentOutput>[] + logs: ConversationLog[] +} + +type BuildAgentConfig = ( + state: State +) => AgentConfig + +export abstract class BaseNode< + State extends BaseGraphState = BaseGraphState, + StrategyOptions extends BaseStrategyOptions = BaseStrategyOptions +> { + constructor(protected context: BaseNodeContext) { + this.onInit() + } + + abstract onInit(): void + + protected agentNameBuildAgentConfigMap: Record< + string, + BuildAgentConfig + > = {} + + protected createAgentConfig( + agentConfig: AgentConfig + ): AgentConfig { + return agentConfig + } + + protected registerAgentConfig( + agentName: string, + buildAgentConfig: BuildAgentConfig + ) { + this.agentNameBuildAgentConfigMap[agentName] = + buildAgentConfig as unknown as BuildAgentConfig + } + + protected getAgentsConfig(state: State): AgentsConfig { + return Object.fromEntries( + Object.entries(this.agentNameBuildAgentConfigMap).map( + ([agentName, buildAgentConfig]) => { + const agentConfig = buildAgentConfig(state) + return [agentName, agentConfig] + } + ) + ) + } + + protected getAgentConfig( + agentName: string, + state: State + ): AgentConfig | null { + return ( + (this.agentNameBuildAgentConfigMap[agentName]?.( + state + ) as unknown as AgentConfig) || null + ) + } + + abstract execute(state: State): Promise> + + protected async createAgentToolByName( + agentName: string, + state: State, + overrideAgentContext?: T['context'] + ): Promise<{ + tool: DynamicStructuredTool | null + agentConfig: AgentConfig | null + agentInstance: T | null + }> { + const agentConfig = this.getAgentConfig(agentName, state) + + if (!agentConfig) + return { tool: null, agentConfig: null, agentInstance: null } + + const finalAgentContext = { + ...agentConfig.agentContext, + ...overrideAgentContext + } + + const agentInstance = new agentConfig.agentClass(finalAgentContext) + const tool = await agentInstance.createTool() + return { tool, agentConfig, agentInstance } + } + + // Helper method to execute tool calls + protected async executeAgentTool( + state: State, + props: AgentConfig + ): Promise> { + const { agentClass: AgentClass, agentContext, processAgentOutput } = props + + const results: ExecuteAgentToolResult = { + agents: [], + logs: [] + } + + const { tool, agentConfig, agentInstance } = + await this.createAgentToolByName(AgentClass.name, state, agentContext) + + if (!tool || !agentConfig || !agentInstance) return results + + const messages = agentConfig.agentContext?.state.messages || [] + + if (!messages.length) return results + + const toolCalls = findCurrentToolsCallParams(messages.at(-1), [tool]) + + if (!toolCalls.length) return results + + const toolCallsPromises = toolCalls.map(async toolCall => { + const toolMessage = (await tool.invoke(toolCall)) as ToolMessage + const agentOutput = JSON.parse(toolMessage?.lc_kwargs.content) + + const agent: Agent, GetAgentOutput> = { + id: uuidv4(), + name: tool.name, + input: toolCall.args, + output: processAgentOutput + ? processAgentOutput(agentOutput) + : agentOutput + } + + const log = this.createAgentLog(agentInstance.logTitle, agent.id) + + results.agents.push(agent) + results.logs.push(log) + }) + + await settledPromiseResults(toolCallsPromises) + return results + } + + // Helper method to add agent to conversation + protected addAgentsToConversation( + conversation: Conversation, + agents: Agent, GetAgentOutput>[] + ) { + conversation.agents = produce(conversation.agents, draft => { + draft.push(...agents) + }) + } + + protected addLogsToConversation( + conversation: Conversation, + logs: ConversationLog[] + ) { + conversation.logs = produce(conversation.logs, draft => { + draft.push(...logs) + }) + } + + // Helper method to create agent log + protected createAgentLog(title: string, agentId: string): ConversationLog { + return { + id: uuidv4(), + createdAt: Date.now(), + title, + agentId + } + } + + async createTools( + state: State + ): Promise[]> { + const agentsConfig = this.getAgentsConfig(state) + const tools = await settledPromiseResults( + Object.entries(agentsConfig).map(async ([_, agentConfig]) => { + const agentInstance = await new agentConfig.agentClass( + agentConfig.agentContext + ) + const tool = await agentInstance.createTool() + return tool + }) + ) + + return tools.filter(Boolean) as DynamicStructuredTool[] + } + + createGraphNode(): T { + return ((state: State) => this.execute(state)) as T + } +} + +export const createToolsFromNodes = async < + T extends BaseNode, + StrategyOptions extends BaseStrategyOptions = BaseStrategyOptions, + State extends BaseGraphState = BaseGraphState +>(props: { + nodeClasses: (new (...args: any[]) => T)[] + state: State + strategyOptions: StrategyOptions +}) => + ( + await Promise.all( + props.nodeClasses.map(async NodeClass => { + const nodeInstance = new NodeClass(props.strategyOptions) + return await nodeInstance.createTools(props.state) + }) + ) + ).flat() + +export const createGraphNodeFromNodes = async < + T extends BaseNode, + StrategyOptions extends BaseStrategyOptions = BaseStrategyOptions +>(props: { + nodeClasses: (new (...args: any[]) => T)[] + strategyOptions: StrategyOptions +}) => + await Promise.all( + props.nodeClasses.map(NodeClass => + new NodeClass(props.strategyOptions).createGraphNode() + ) + ) diff --git a/src/extension/webview-api/chat-context-processor/strategies/chat-strategy/nodes/state.ts b/src/extension/webview-api/chat-context-processor/strategies/base/base-state.ts similarity index 53% rename from src/extension/webview-api/chat-context-processor/strategies/chat-strategy/nodes/state.ts rename to src/extension/webview-api/chat-context-processor/strategies/base/base-state.ts index acf34de..1c682a8 100644 --- a/src/extension/webview-api/chat-context-processor/strategies/chat-strategy/nodes/state.ts +++ b/src/extension/webview-api/chat-context-processor/strategies/base/base-state.ts @@ -8,15 +8,9 @@ import { } from '@shared/entities' import { cloneDeep } from 'es-toolkit' -import type { BaseStrategyOptions } from '../../base-strategy' +import type { BaseStrategyOptions } from './base-strategy' -export enum ChatGraphNodeName { - Agent = 'agent', - Tools = 'tools', - Generate = 'generate' -} - -export const chatGraphState = Annotation.Root({ +export const baseGraphStateConfig = { messages: Annotation({ reducer: (x, y) => x.concat(y), default: () => [] @@ -28,28 +22,27 @@ export const chatGraphState = Annotation.Root({ reducer: (x, y) => y ?? x, default: () => [new ConversationEntity({ role: 'ai' }).entity] }), - shouldContinue: Annotation({ - reducer: (x, y) => y ?? x, - default: () => true - }), abortController: Annotation({ reducer: (x, y) => y ?? x, default: () => new AbortController() }) -}) +} + +export const baseGraphState = Annotation.Root(baseGraphStateConfig) -export type ChatGraphState = typeof chatGraphState.State +export type BaseGraphState = typeof baseGraphState.State -export type ChatGraphNode = ( - state: ChatGraphState -) => Promise> +export type BaseGraphNode = ( + state: State +) => Promise> -export type CreateChatGraphNode = ( - options: BaseStrategyOptions -) => ChatGraphNode +export type CreateBaseGraphNode< + StrategyOptions extends BaseStrategyOptions = BaseStrategyOptions, + State extends BaseGraphState = BaseGraphState +> = (strategyOptions: StrategyOptions) => BaseGraphNode -export const chatGraphStateEventName = 'stream-chat-graph-state' -export const dispatchChatGraphState = (state: Partial) => { +export const baseGraphStateEventName = 'stream-base-graph-state' +export const dispatchBaseGraphState = (state: Partial) => { const deepClonedState = cloneDeep(state) - dispatchCustomEvent(chatGraphStateEventName, deepClonedState) + dispatchCustomEvent(baseGraphStateEventName, deepClonedState) } diff --git a/src/extension/webview-api/chat-context-processor/strategies/base-strategy.ts b/src/extension/webview-api/chat-context-processor/strategies/base/base-strategy.ts similarity index 100% rename from src/extension/webview-api/chat-context-processor/strategies/base-strategy.ts rename to src/extension/webview-api/chat-context-processor/strategies/base/base-strategy.ts diff --git a/src/extension/webview-api/chat-context-processor/strategies/base/index.ts b/src/extension/webview-api/chat-context-processor/strategies/base/index.ts new file mode 100644 index 0000000..11342f4 --- /dev/null +++ b/src/extension/webview-api/chat-context-processor/strategies/base/index.ts @@ -0,0 +1,4 @@ +export * from './base-agent' +export * from './base-state' +export * from './base-strategy' +export * from './base-node' diff --git a/src/extension/webview-api/chat-context-processor/strategies/chat-strategy/chat-workflow.ts b/src/extension/webview-api/chat-context-processor/strategies/chat-strategy/chat-workflow.ts index 2d24d6a..018c99e 100644 --- a/src/extension/webview-api/chat-context-processor/strategies/chat-strategy/chat-workflow.ts +++ b/src/extension/webview-api/chat-context-processor/strategies/chat-strategy/chat-workflow.ts @@ -2,14 +2,10 @@ import { ServerPluginRegister } from '@extension/registers/server-plugin-registe import { END, START, StateGraph } from '@langchain/langgraph' import { combineNode } from '../../utils/combine-node' -import type { BaseStrategyOptions } from '../base-strategy' -import { createAgentNode } from './nodes/agent-node' -import { createGenerateNode } from './nodes/generate-node' -import { - ChatGraphNodeName, - chatGraphState, - type ChatGraphState -} from './nodes/state' +import type { BaseStrategyOptions } from '../base/base-strategy' +import { AgentNode } from './nodes/agent-node' +import { GenerateNode } from './nodes/generate-node' +import { ChatGraphNodeName, chatGraphState, type ChatGraphState } from './state' const createSmartRoute = (nextNodeName: ChatGraphNodeName) => (state: ChatGraphState) => { @@ -28,10 +24,20 @@ export const createChatWorkflow = async (options: BaseStrategyOptions) => { const toolNodes = (await chatStrategyProvider?.buildLanggraphToolNodes?.(options)) || [] + const combinedToolsNode = combineNode(toolNodes, chatGraphState) + + const agentNode = new AgentNode({ + strategyOptions: options + }).createGraphNode() + + const generateNode = new GenerateNode({ + strategyOptions: options + }).createGraphNode() + const chatWorkflow = new StateGraph(chatGraphState) - .addNode(ChatGraphNodeName.Agent, createAgentNode(options)) - .addNode(ChatGraphNodeName.Tools, combineNode(toolNodes, chatGraphState)) - .addNode(ChatGraphNodeName.Generate, createGenerateNode(options)) + .addNode(ChatGraphNodeName.Agent, agentNode) + .addNode(ChatGraphNodeName.Tools, combinedToolsNode) + .addNode(ChatGraphNodeName.Generate, generateNode) chatWorkflow .addConditionalEdges(START, createSmartRoute(ChatGraphNodeName.Agent)) diff --git a/src/extension/webview-api/chat-context-processor/strategies/chat-strategy/index.ts b/src/extension/webview-api/chat-context-processor/strategies/chat-strategy/index.ts index 59e59e3..7e0f7fe 100644 --- a/src/extension/webview-api/chat-context-processor/strategies/chat-strategy/index.ts +++ b/src/extension/webview-api/chat-context-processor/strategies/chat-strategy/index.ts @@ -2,9 +2,10 @@ import { type ChatContext, type Conversation } from '@shared/entities' import { UnPromise } from '@shared/types/common' import { produce } from 'immer' -import { BaseStrategy } from '../base-strategy' +import { baseGraphStateEventName } from '../base/base-state' +import { BaseStrategy } from '../base/base-strategy' import { createChatWorkflow } from './chat-workflow' -import { chatGraphStateEventName, type ChatGraphState } from './nodes/state' +import { type ChatGraphState } from './state' export class ChatStrategy extends BaseStrategy { private _chatWorkflow: UnPromise< @@ -40,7 +41,7 @@ export class ChatStrategy extends BaseStrategy { const state: Partial = {} for await (const { event, name, data } of eventStream) { - if (event === 'on_custom_event' && name === chatGraphStateEventName) { + if (event === 'on_custom_event' && name === baseGraphStateEventName) { const returnsState = data as Partial Object.assign(state, returnsState) const currentChatContext = state.chatContext || chatContext diff --git a/src/extension/webview-api/chat-context-processor/strategies/chat-strategy/messages-constructors/chat-messages-constructor.ts b/src/extension/webview-api/chat-context-processor/strategies/chat-strategy/messages-constructors/chat-messages-constructor.ts index 4db7156..d774f0e 100644 --- a/src/extension/webview-api/chat-context-processor/strategies/chat-strategy/messages-constructors/chat-messages-constructor.ts +++ b/src/extension/webview-api/chat-context-processor/strategies/chat-strategy/messages-constructors/chat-messages-constructor.ts @@ -1,11 +1,12 @@ import type { CommandManager } from '@extension/commands/command-manager' import type { RegisterManager } from '@extension/registers/register-manager' import { ServerPluginRegister } from '@extension/registers/server-plugin-register' +import { processConversationsWithAgents } from '@extension/webview-api/chat-context-processor/utils/conversation-utils' import { HumanMessage, SystemMessage } from '@langchain/core/messages' import type { ChatContext, LangchainMessage } from '@shared/entities' import { settledPromiseResults } from '@shared/utils/common' -import type { BaseStrategyOptions } from '../../base-strategy' +import type { BaseStrategyOptions } from '../../base/base-strategy' import { ConversationMessageConstructor } from './conversation-message-constructor' interface ChatMessagesConstructorOptions extends BaseStrategyOptions { @@ -26,7 +27,7 @@ export class ChatMessagesConstructor { } constructor(options: ChatMessagesConstructorOptions) { - this.chatContext = options.chatContext + this.chatContext = processConversationsWithAgents(options.chatContext) this.registerManager = options.registerManager this.commandManager = options.commandManager } diff --git a/src/extension/webview-api/chat-context-processor/strategies/chat-strategy/messages-constructors/conversation-message-constructor.ts b/src/extension/webview-api/chat-context-processor/strategies/chat-strategy/messages-constructors/conversation-message-constructor.ts index 8f8529f..467fdb3 100644 --- a/src/extension/webview-api/chat-context-processor/strategies/chat-strategy/messages-constructors/conversation-message-constructor.ts +++ b/src/extension/webview-api/chat-context-processor/strategies/chat-strategy/messages-constructors/conversation-message-constructor.ts @@ -83,7 +83,7 @@ ${prompt} const imageContents: LangchainMessageContents = imageUrls.map(url => ({ type: 'image_url', - image_url: url + image_url: { url } })) || [] let isEnhanced = false diff --git a/src/extension/webview-api/chat-context-processor/strategies/chat-strategy/nodes/agent-node.ts b/src/extension/webview-api/chat-context-processor/strategies/chat-strategy/nodes/agent-node.ts index 6999421..30051be 100644 --- a/src/extension/webview-api/chat-context-processor/strategies/chat-strategy/nodes/agent-node.ts +++ b/src/extension/webview-api/chat-context-processor/strategies/chat-strategy/nodes/agent-node.ts @@ -2,74 +2,88 @@ import { ModelProviderFactory } from '@extension/ai/model-providers/helpers/fact import { ServerPluginRegister } from '@extension/registers/server-plugin-register' import { getToolCallsFromMessage } from '@extension/webview-api/chat-context-processor/utils/get-tool-calls-from-message' import type { AIMessageChunk } from '@langchain/core/messages' -import { FeatureModelSettingKey, type LangchainTool } from '@shared/entities' +import { type LangchainTool } from '@shared/entities' import { convertToLangchainMessageContents } from '@shared/utils/convert-to-langchain-message-contents' +import { mergeLangchainMessageContents } from '@shared/utils/merge-langchain-message-contents' import { produce } from 'immer' +import { BaseNode } from '../../base/base-node' +import { dispatchBaseGraphState } from '../../base/base-state' import { ChatMessagesConstructor } from '../messages-constructors/chat-messages-constructor' -import { dispatchChatGraphState, type CreateChatGraphNode } from './state' +import { type ChatGraphState } from '../state' -export const createAgentNode: CreateChatGraphNode = options => async state => { - const modelProvider = await ModelProviderFactory.getModelProvider( - FeatureModelSettingKey.Chat - ) - const aiModel = await modelProvider.createLangChainModel() - const chatStrategyProvider = options.registerManager - .getRegister(ServerPluginRegister) - ?.serverPluginRegistry?.providerManagers.chatStrategy.mergeAll() +export class AgentNode extends BaseNode { + onInit() {} - const tools = [ - ...((await chatStrategyProvider?.buildAgentTools?.(options, state)) || []) - ].filter(Boolean) as LangchainTool[] + async execute(state: ChatGraphState) { + const modelProvider = + await ModelProviderFactory.getModelProviderForChatContext( + state.chatContext + ) + const aiModel = await modelProvider.createLangChainModel() + const chatStrategyProvider = this.context.strategyOptions.registerManager + .getRegister(ServerPluginRegister) + ?.serverPluginRegistry?.providerManagers.chatStrategy.mergeAll() - const chatMessagesConstructor = new ChatMessagesConstructor({ - ...options, - chatContext: state.chatContext - }) + const tools = [ + ...((await chatStrategyProvider?.buildAgentTools?.( + this.context.strategyOptions, + state + )) || []) + ].filter(Boolean) as LangchainTool[] - const messagesFromChatContext = - await chatMessagesConstructor.constructMessages() + const chatMessagesConstructor = new ChatMessagesConstructor({ + ...this.context.strategyOptions, + chatContext: state.chatContext + }) - const stream = await aiModel - .bindTools(tools) - .bind({ signal: state.abortController?.signal }) - .stream(messagesFromChatContext) + const messagesFromChatContext = + await chatMessagesConstructor.constructMessages() - let message: AIMessageChunk | undefined - let shouldContinue = true - let { newConversations } = state + const stream = await aiModel + .bindTools(tools) + .bind({ signal: state.abortController?.signal }) + .stream(messagesFromChatContext) - for await (const chunk of stream) { - if (!message) { - message = chunk - } else { - message = message.concat(chunk) - // stream with tool calls not need to concat content - message.content = chunk.content - } + let message: AIMessageChunk | undefined + let shouldContinue = true + let { newConversations } = state - const toolCalls = getToolCallsFromMessage(message) - const contents = convertToLangchainMessageContents(message.content) + for await (const chunk of stream) { + if (!message) { + message = chunk + } else { + message = message.concat(chunk) + // stream with tool calls not need to concat content + message.content = chunk.content + } - if (!toolCalls.length && contents.length) { - // no tool calls - shouldContinue = false - newConversations = produce(newConversations, draft => { - draft.at(-1)!.contents.push(...contents) - }) - } + const toolCalls = getToolCallsFromMessage(message) + const contents = convertToLangchainMessageContents(message.content) + + if (!toolCalls.length && contents.length) { + // no tool calls + shouldContinue = false + newConversations = produce(newConversations, draft => { + draft.at(-1)!.contents = mergeLangchainMessageContents([ + ...draft.at(-1)!.contents, + ...contents + ]) + }) + } - if (contents.length) { - dispatchChatGraphState({ - newConversations, - chatContext: state.chatContext - }) + if (contents.length) { + dispatchBaseGraphState({ + newConversations, + chatContext: state.chatContext + }) + } } - } - return { - shouldContinue, - newConversations, - messages: message ? [message] : undefined + return { + shouldContinue, + newConversations, + messages: message ? [message] : undefined + } } } diff --git a/src/extension/webview-api/chat-context-processor/strategies/chat-strategy/nodes/generate-node.ts b/src/extension/webview-api/chat-context-processor/strategies/chat-strategy/nodes/generate-node.ts index 93dc996..ec38166 100644 --- a/src/extension/webview-api/chat-context-processor/strategies/chat-strategy/nodes/generate-node.ts +++ b/src/extension/webview-api/chat-context-processor/strategies/chat-strategy/nodes/generate-node.ts @@ -1,21 +1,26 @@ import { ModelProviderFactory } from '@extension/ai/model-providers/helpers/factory' import type { AIMessageChunk } from '@langchain/core/messages' -import { FeatureModelSettingKey } from '@shared/entities' import { convertToLangchainMessageContents } from '@shared/utils/convert-to-langchain-message-contents' +import { mergeLangchainMessageContents } from '@shared/utils/merge-langchain-message-contents' import { produce } from 'immer' +import { BaseNode } from '../../base/base-node' +import { dispatchBaseGraphState } from '../../base/base-state' import { ChatMessagesConstructor } from '../messages-constructors/chat-messages-constructor' -import { dispatchChatGraphState, type CreateChatGraphNode } from './state' +import { type ChatGraphState } from '../state' -export const createGenerateNode: CreateChatGraphNode = - options => async state => { - const modelProvider = await ModelProviderFactory.getModelProvider( - FeatureModelSettingKey.Chat - ) +export class GenerateNode extends BaseNode { + onInit() {} + + async execute(state: ChatGraphState) { + const modelProvider = + await ModelProviderFactory.getModelProviderForChatContext( + state.chatContext + ) const aiModel = await modelProvider.createLangChainModel() const chatMessagesConstructor = new ChatMessagesConstructor({ - ...options, + ...this.context.strategyOptions, chatContext: state.chatContext }) @@ -40,11 +45,14 @@ export const createGenerateNode: CreateChatGraphNode = if (contents.length) { newConversations = produce(state.newConversations, draft => { - draft.at(-1)!.contents.push(...contents) + draft.at(-1)!.contents = mergeLangchainMessageContents([ + ...draft.at(-1)!.contents, + ...contents + ]) }) } - dispatchChatGraphState({ + dispatchBaseGraphState({ newConversations, chatContext: state.chatContext }) @@ -55,3 +63,4 @@ export const createGenerateNode: CreateChatGraphNode = newConversations } } +} diff --git a/src/extension/webview-api/chat-context-processor/strategies/chat-strategy/state.ts b/src/extension/webview-api/chat-context-processor/strategies/chat-strategy/state.ts new file mode 100644 index 0000000..9d87ffe --- /dev/null +++ b/src/extension/webview-api/chat-context-processor/strategies/chat-strategy/state.ts @@ -0,0 +1,37 @@ +import { Annotation } from '@langchain/langgraph' + +import { + baseGraphStateConfig, + type BaseGraphNode, + type CreateBaseGraphNode +} from '../base/base-state' +import type { BaseStrategyOptions } from '../base/base-strategy' + +export enum ChatGraphNodeName { + Agent = 'agent', + Tools = 'tools', + Generate = 'generate' +} + +export const chatGraphState = Annotation.Root({ + ...baseGraphStateConfig, + shouldContinue: Annotation({ + reducer: (x, y) => y ?? x, + default: () => true + }) +}) + +export type ChatGraphState = typeof chatGraphState.State + +export type ChatGraphNode = BaseGraphNode + +export type CreateChatGraphNode = CreateBaseGraphNode< + BaseStrategyOptions, + ChatGraphState +> + +// export const chatGraphStateEventName = 'stream-chat-graph-state' +// export const dispatchChatGraphState = (state: Partial) => { +// const deepClonedState = cloneDeep(state) +// dispatchCustomEvent(chatGraphStateEventName, deepClonedState) +// } diff --git a/src/extension/webview-api/chat-context-processor/utils/conversation-utils.ts b/src/extension/webview-api/chat-context-processor/utils/conversation-utils.ts new file mode 100644 index 0000000..35f2c8b --- /dev/null +++ b/src/extension/webview-api/chat-context-processor/utils/conversation-utils.ts @@ -0,0 +1,36 @@ +import type { Agent, ChatContext } from '@shared/entities' + +/** + * Process all conversations in chatContext and collect AI agents for each human conversation + * Returns a new ChatContext with updated conversations + */ +export const processConversationsWithAgents = ( + chatContext: ChatContext +): ChatContext => { + const { conversations } = chatContext + const updatedConversations = conversations.map((conv, index) => { + // Return early if not human conversation + if (conv.role !== 'human') return conv + + // Collect AI agents until next human message + const aiAgents: Agent[] = [] + let i = index + 1 + + while (i < conversations.length && conversations[i]!.role !== 'human') { + if (conversations[i]!.role === 'ai') { + aiAgents.push(...(conversations[i]!.agents || [])) + } + i++ + } + + return { + ...conv, + agents: aiAgents + } + }) + + return { + ...chatContext, + conversations: updatedConversations + } +} diff --git a/src/extension/webview-api/controllers/index.ts b/src/extension/webview-api/controllers/index.ts index 2d2da4e..fad4133 100644 --- a/src/extension/webview-api/controllers/index.ts +++ b/src/extension/webview-api/controllers/index.ts @@ -8,6 +8,7 @@ import { CodebaseController } from './codebase-controller' import { DocController } from './doc-controller' import { FileController } from './file-controller' import { GitController } from './git-controller' +import { MentionController } from './mention-controller' import { SettingsController } from './settings-controller' import { SystemController } from './system-controller' import { TerminalController } from './terminal-controller' @@ -24,6 +25,7 @@ export const controllers = [ ApplyController, SettingsController, AIProviderController, - AIModelController + AIModelController, + MentionController ] as const satisfies (typeof Controller)[] export type Controllers = typeof controllers diff --git a/src/extension/webview-api/controllers/mention-controller.ts b/src/extension/webview-api/controllers/mention-controller.ts new file mode 100644 index 0000000..b45b7c0 --- /dev/null +++ b/src/extension/webview-api/controllers/mention-controller.ts @@ -0,0 +1,112 @@ +import { ControllerRegister } from '@extension/registers/controller-register' +import { ServerPluginRegister } from '@extension/registers/server-plugin-register' +import { tryParseJSON } from '@extension/utils' +import type { Conversation, Mention } from '@shared/entities' +import type { RefreshMentionFn } from '@shared/plugins/base/server/create-provider-manager' +import { settledPromiseResults } from '@shared/utils/common' + +import { Controller } from '../types' + +export class MentionController extends Controller { + readonly name = 'mention' + + private async createCompositeRefreshFunction(): Promise { + // Get mention utils providers + const serverPluginRegister = + this.registerManager.getRegister(ServerPluginRegister) + const mentionUtilsProviders = + serverPluginRegister?.serverPluginRegistry?.providerManagers.mentionUtils.getValues() + + if (!mentionUtilsProviders) { + throw new Error('MentionUtilsProviders not found') + } + + // Get controller register + const controllerRegister = + this.registerManager.getRegister(ControllerRegister) + if (!controllerRegister) { + throw new Error('ControllerRegister not found') + } + + // Create refresh functions from all providers + const refreshFunctions = await settledPromiseResults( + mentionUtilsProviders.map( + async provider => + await provider.createRefreshMentionFn(controllerRegister) + ) + ) + + return (mention: Mention): Mention => + refreshFunctions.reduce( + (updatedMention, refreshFn) => refreshFn(updatedMention), + mention + ) + } + + async refreshConversationMentions( + conversation: Conversation + ): Promise { + // Get and compose refresh functions + const compositeRefreshFn = await this.createCompositeRefreshFunction() + + // Parse and process the lexical editor tree + const editorState = tryParseJSON(conversation.richText || '{}') as { + root: LexicalNode + } + const collectedMentions: Mention[] = [] + + const updatedEditorStateRoot = traverseLexicalMentionNode( + editorState?.root, + mention => { + const updatedMention = compositeRefreshFn(mention) + collectedMentions.push(updatedMention) + return updatedMention + } + ) + const updatedEditorState = { + ...editorState, + root: updatedEditorStateRoot + } + + // Update conversation with processed mentions + return { + ...conversation, + richText: JSON.stringify(updatedEditorState), + mentions: collectedMentions + } + } +} + +interface LexicalNode { + type: string | 'mention' + version: number + children?: LexicalNode[] + mention?: Mention +} + +/** + * Traverses a Lexical editor tree and processes mention nodes + * @param node Current node in the tree + * @param processMention Function to process each mention + * @returns Updated node + */ +const traverseLexicalMentionNode = ( + node: LexicalNode, + processMention: (mention: Mention) => Mention +): LexicalNode => { + // Process current node if it's a mention + if (node.type === 'mention' && node.mention) { + return { + ...node, + mention: processMention(node.mention) + } + } + + // Process children recursively + return { + ...node, + children: node.children?.map(child => + traverseLexicalMentionNode(child, processMention) + ) + } +} diff --git a/src/extension/webview-api/index.ts b/src/extension/webview-api/index.ts index c3d4fc7..8318303 100644 --- a/src/extension/webview-api/index.ts +++ b/src/extension/webview-api/index.ts @@ -1,20 +1,22 @@ import type { CommandManager } from '@extension/commands/command-manager' +import { ControllerRegister } from '@extension/registers/controller-register' import type { RegisterManager } from '@extension/registers/register-manager' import { getErrorMsg, isAbortError } from '@shared/utils/common' import findFreePorts from 'find-free-ports' import { Server } from 'socket.io' import * as vscode from 'vscode' -import { controllers } from './controllers' -import type { - Controller, - ControllerClass, - ControllerMethod, - WebviewPanel -} from './types' +import type { ControllerMethod, WebviewPanel } from './types' class APIManager { - private controllers: Map = new Map() + private get controllers() { + const controller = + this.registerManager.getRegister(ControllerRegister)?.controllers + + if (!controller) throw new Error('ControllerRegister not found') + + return controller + } private io!: Server @@ -28,12 +30,8 @@ class APIManager { private commandManager: CommandManager ) {} - public async initialize( - panel: WebviewPanel, - controllerClasses: ControllerClass[] - ) { + public async initialize(panel: WebviewPanel) { await this.initializeServer() - this.registerControllers(controllerClasses) const listenerDispose = panel.webview.onDidReceiveMessage(e => { if (e.type === 'getVSCodeSocketPort') { @@ -64,16 +62,6 @@ class APIManager { }) } - private registerControllers(controllerClasses: ControllerClass[]) { - for (const ControllerClass of controllerClasses) { - const controller = new ControllerClass( - this.registerManager, - this.commandManager - ) - this.controllers.set(controller.name, controller) - } - } - private async handleMessage(socket: any, message: any) { const { id, controller: controllerName, method, data } = message const controller = this.controllers.get(controllerName) @@ -138,7 +126,7 @@ export const setupWebviewAPIManager = async ( ): Promise => { const apiManager = new APIManager(context, registerManager, commandManager) - await apiManager.initialize(panel, controllers as any as ControllerClass[]) + await apiManager.initialize(panel) return { dispose: () => { diff --git a/src/shared/entities/conversation-entity.ts b/src/shared/entities/conversation-entity.ts index b390dae..dced2bf 100644 --- a/src/shared/entities/conversation-entity.ts +++ b/src/shared/entities/conversation-entity.ts @@ -1,8 +1,10 @@ +import type { FileInfo } from '@extension/file-utils/traverse-fs' import type { AIMessage, ChatMessage, FunctionMessage, HumanMessage, + ImageDetail, MessageType, SystemMessage, ToolMessage @@ -14,13 +16,27 @@ import { v4 as uuidv4 } from 'uuid' import { BaseEntity, type IBaseEntity } from './base-entity' +export interface ImageInfo { + url: string + name?: string +} + +export interface ConversationState { + selectedFilesFromFileSelector: FileInfo[] + currentFilesFromVSCode: FileInfo[] + selectedImagesFromOutsideUrl: ImageInfo[] +} + export interface Conversation extends IBaseEntity { createdAt: number role: MessageType contents: LangchainMessageContents richText?: string // JSON stringified pluginStates: Record + mentions: Mention[] + agents: Agent[] logs: ConversationLog[] + state: ConversationState } export class ConversationEntity extends BaseEntity { @@ -31,22 +47,37 @@ export class ConversationEntity extends BaseEntity { role: 'human', contents: [], pluginStates: {}, + mentions: [], + agents: [], logs: [], + state: { + selectedFilesFromFileSelector: [], + currentFilesFromVSCode: [], + selectedImagesFromOutsideUrl: [] + }, ...data } } } -export type BaseConversationLog = { +export interface Mention { + type: Type + data: Data +} + +export interface Agent { + id: string + name: string + input: Input + output: Output +} + +export type ConversationLog = { id: string - pluginId?: string createdAt: number title: string content?: string -} - -export type ConversationLog = BaseConversationLog & { - [key: string]: any + agentId?: string } export type LangchainMessage = @@ -64,7 +95,10 @@ export type LangchainMessageContents = ( } | { type: 'image_url' - image_url: string + image_url: { + url: string + detail?: ImageDetail + } } )[] diff --git a/src/shared/plugins/agents/agent-names.ts b/src/shared/plugins/agents/agent-names.ts new file mode 100644 index 0000000..d06efd7 --- /dev/null +++ b/src/shared/plugins/agents/agent-names.ts @@ -0,0 +1,5 @@ +export const codebaseSearchAgentName = 'codebaseSearch' +export const docRetrieverAgentName = 'docRetriever' +export const fsVisitAgentName = 'fsVisit' +export const webSearchAgentName = 'webSearch' +export const webVisitAgentName = 'webVisit' diff --git a/src/shared/plugins/agents/codebase-search-agent.ts b/src/shared/plugins/agents/codebase-search-agent.ts new file mode 100644 index 0000000..c9b68a6 --- /dev/null +++ b/src/shared/plugins/agents/codebase-search-agent.ts @@ -0,0 +1,78 @@ +import { CodebaseWatcherRegister } from '@extension/registers/codebase-watcher-register' +import { BaseAgent } from '@extension/webview-api/chat-context-processor/strategies/base/base-agent' +import type { BaseGraphState } from '@extension/webview-api/chat-context-processor/strategies/base/base-state' +import { settledPromiseResults } from '@shared/utils/common' +import { z } from 'zod' + +import { mergeCodeSnippets } from '../fs-plugin/server/merge-code-snippets' +import type { CodeSnippet } from '../fs-plugin/types' +import { codebaseSearchAgentName } from './agent-names' + +export class CodebaseSearchAgent extends BaseAgent< + BaseGraphState, + { enableCodebaseAgent: boolean } +> { + static name = codebaseSearchAgentName + + name = CodebaseSearchAgent.name + + logTitle = 'Search Codebase' + + description = 'Search for relevant code in the codebase.' + + inputSchema = z.object({ + queryParts: z + .array(z.string()) + .describe('List of search terms to find relevant code in the codebase') + }) + + outputSchema = z.object({ + codeSnippets: z.array( + z.object({ + fileHash: z.string(), + relativePath: z.string(), + fullPath: z.string(), + startLine: z.number(), + startCharacter: z.number(), + endLine: z.number(), + endCharacter: z.number(), + code: z.string() + }) satisfies z.ZodType + ) + }) + + async execute(input: z.infer) { + const { enableCodebaseAgent } = this.context.createToolOptions + + if (!enableCodebaseAgent) { + return { codeSnippets: [] } + } + + const indexer = this.context.strategyOptions.registerManager.getRegister( + CodebaseWatcherRegister + )?.indexer + + if (!indexer) { + return { codeSnippets: [] } + } + + const searchResults = await settledPromiseResults( + input.queryParts?.map(query => indexer.searchSimilarRow(query)) || [] + ) + + const searchCodeSnippets = searchResults + .flat() + .slice(0, 8) + .map(row => { + // eslint-disable-next-line unused-imports/no-unused-vars + const { embedding, ...others } = row + return { ...others, code: '' } + }) + + const codeSnippets = await mergeCodeSnippets(searchCodeSnippets, { + mode: 'expanded' + }) + + return { codeSnippets } + } +} diff --git a/src/shared/plugins/agents/doc-retriever-agent.ts b/src/shared/plugins/agents/doc-retriever-agent.ts new file mode 100644 index 0000000..e6e420e --- /dev/null +++ b/src/shared/plugins/agents/doc-retriever-agent.ts @@ -0,0 +1,96 @@ +import { aidePaths } from '@extension/file-utils/paths' +import { BaseAgent } from '@extension/webview-api/chat-context-processor/strategies/base/base-agent' +import type { BaseGraphState } from '@extension/webview-api/chat-context-processor/strategies/base/base-state' +import { DocCrawler } from '@extension/webview-api/chat-context-processor/utils/doc-crawler' +import { DocIndexer } from '@extension/webview-api/chat-context-processor/vectordb/doc-indexer' +import { docSitesDB } from '@extension/webview-api/lowdb/doc-sites-db' +import { removeDuplicates, settledPromiseResults } from '@shared/utils/common' +import { z } from 'zod' + +import type { DocInfo } from '../doc-plugin/types' +import { docRetrieverAgentName } from './agent-names' + +export class DocRetrieverAgent extends BaseAgent< + BaseGraphState, + { allowSearchDocSiteNames: string[] } +> { + static name = docRetrieverAgentName + + name = DocRetrieverAgent.name + + logTitle = 'Search documentation' + + description = + 'Search for relevant information in specified documentation sites.' + + inputSchema = z.object({ + queryParts: z + .array( + z.object({ + siteName: z + .string() + .describe('The name of the documentation site to search'), + keywords: z + .array(z.string()) + .describe( + 'List of keywords to search for in the specified doc site' + ) + }) + ) + .describe( + "The AI should break down the user's query into multiple parts, each targeting a specific doc site with relevant keywords. This allows for a more comprehensive search across multiple documentation sources." + ) + }) + + outputSchema = z.object({ + relevantDocs: z.array( + z.object({ + content: z.string(), + path: z.string() + }) satisfies z.ZodType + ) + }) + + async execute(input: z.infer) { + const { allowSearchDocSiteNames } = this.context.createToolOptions + const docSites = await docSitesDB.getAll() + + const docPromises = input.queryParts.map(async ({ siteName, keywords }) => { + const docSite = docSites.find(site => site.name === siteName) + + if (!docSite?.isIndexed || !allowSearchDocSiteNames.includes(siteName)) { + return [] + } + + const docIndexer = new DocIndexer( + DocCrawler.getDocCrawlerFolderPath(docSite.url), + aidePaths.getGlobalLanceDbPath() + ) + + await docIndexer.initialize() + + const searchResults = await settledPromiseResults( + keywords.map(keyword => docIndexer.searchSimilarRow(keyword)) + ) + + const searchRows = removeDuplicates( + searchResults.flatMap(result => result), + ['fullPath'] + ).slice(0, 3) + + const docInfoResults = await settledPromiseResults( + searchRows.map(async row => ({ + content: await docIndexer.getRowFileContent(row), + path: docSite.url + })) + ) + + return docInfoResults + }) + + const results = await settledPromiseResults(docPromises) + return { + relevantDocs: results.flatMap(result => result) + } + } +} diff --git a/src/shared/plugins/agents/fs-visit-agent.ts b/src/shared/plugins/agents/fs-visit-agent.ts new file mode 100644 index 0000000..b0daeb9 --- /dev/null +++ b/src/shared/plugins/agents/fs-visit-agent.ts @@ -0,0 +1,43 @@ +import { getValidFiles } from '@extension/file-utils/get-valid-files' +import type { FileInfo } from '@extension/file-utils/traverse-fs' +import { BaseAgent } from '@extension/webview-api/chat-context-processor/strategies/base/base-agent' +import type { BaseGraphState } from '@extension/webview-api/chat-context-processor/strategies/base/base-state' +import { z } from 'zod' + +import { fsVisitAgentName } from './agent-names' + +export class FsVisitAgent extends BaseAgent { + static name = fsVisitAgentName + + name = FsVisitAgent.name + + logTitle = 'Visit Files' + + description = 'Access specific files in the workspace.' + + inputSchema = z.object({ + relativePaths: z + .array(z.string()) + .describe( + 'An array of relative file paths to read from the workspace root' + ) + }) + + outputSchema = z.object({ + files: z.array( + z.object({ + type: z.literal('file'), + fullPath: z.string(), + relativePath: z.string(), + content: z.string() + }) satisfies z.ZodType + ) + }) + + async execute(input: z.infer) { + const files = await getValidFiles(input.relativePaths, { + isGetFileContent: false + }) + return { files } + } +} diff --git a/src/shared/plugins/agents/web-search-agent.ts b/src/shared/plugins/agents/web-search-agent.ts new file mode 100644 index 0000000..eca2959 --- /dev/null +++ b/src/shared/plugins/agents/web-search-agent.ts @@ -0,0 +1,154 @@ +import { ModelProviderFactory } from '@extension/ai/model-providers/helpers/factory' +import { logger } from '@extension/logger' +import { BaseAgent } from '@extension/webview-api/chat-context-processor/strategies/base/base-agent' +import type { BaseGraphState } from '@extension/webview-api/chat-context-processor/strategies/base/base-state' +import { ChatMessagesConstructor } from '@extension/webview-api/chat-context-processor/strategies/chat-strategy/messages-constructors/chat-messages-constructor' +import { searxngSearch } from '@extension/webview-api/chat-context-processor/utils/searxng-search' +import { CheerioWebBaseLoader } from '@langchain/community/document_loaders/web/cheerio' +import type { Document } from '@langchain/core/documents' +import { HumanMessage } from '@langchain/core/messages' +import { settledPromiseResults } from '@shared/utils/common' +import { z } from 'zod' + +import { webSearchAgentName } from './agent-names' + +const MAX_CONTENT_LENGTH = 16 * 1000 + +export class WebSearchAgent extends BaseAgent< + BaseGraphState, + { enableWebSearchAgent: boolean } +> { + static name = webSearchAgentName + + name = WebSearchAgent.name + + logTitle = 'Search web' + + description = + 'IMPORTANT: Proactively use this web search tool whenever you:\n' + + '1. Need to verify or update your knowledge about recent developments, versions, or current facts\n' + + '2. Are unsure about specific technical details or best practices\n' + + '3. Need real-world examples or implementation details\n' + + '4. Encounter questions about:\n' + + ' - Current events or recent updates\n' + + ' - Latest software versions or features\n' + + ' - Modern best practices or trends\n' + + ' - Specific technical implementations\n' + + '5. Want to provide evidence-based recommendations\n\n' + + 'DO NOT rely solely on your training data when users ask about:\n' + + '- Recent technologies or updates\n' + + '- Current best practices\n' + + '- Specific implementation details\n' + + '- Version-specific features or APIs\n' + + 'Instead, use this tool to get up-to-date information.' + + inputSchema = z.object({ + keywords: z.string().describe('Keywords to search web') + }) + + outputSchema = z.object({ + relevantContent: z.string(), + webSearchResults: z.array( + z.object({ + content: z.string(), + url: z.string() + }) + ) + }) + + async execute(input: z.infer) { + const { enableWebSearchAgent } = this.context.createToolOptions + + if (!enableWebSearchAgent) { + return { relevantContent: '', webSearchResults: [] } + } + + const searxngSearchResult = await searxngSearch(input.keywords, { + abortController: this.context.state.abortController + }) + const urls = searxngSearchResult.results.map(result => result.url) + + const docsLoadResult = await settledPromiseResults( + urls.map(url => new CheerioWebBaseLoader(url).load()) + ) + + const docs: Document>[] = docsLoadResult.flat() + + const docsContent = docs + .map(doc => doc.pageContent) + .join('\n') + .slice(0, MAX_CONTENT_LENGTH) + + if (!docsContent) { + logger.warn('No content found in web search results', { + keywords: input.keywords, + docs + }) + return { relevantContent: '', webSearchResults: [] } + } + + const chatMessagesConstructor = new ChatMessagesConstructor({ + ...this.context.strategyOptions, + chatContext: this.context.state.chatContext + }) + const messagesFromChatContext = + await chatMessagesConstructor.constructMessages() + + const modelProvider = + await ModelProviderFactory.getModelProviderForChatContext( + this.context.state.chatContext + ) + const aiModel = await modelProvider.createLangChainModel() + + const response = await aiModel + .bind({ signal: this.context.state.abortController?.signal }) + .invoke([ + ...messagesFromChatContext.slice(-2), + new HumanMessage({ + content: ` +You are an expert information analyst. Your task is to process web search results and create a high-quality, focused summary that will be used in a subsequent AI conversation. Follow these critical guidelines: + +1. RELEVANCE & FOCUS +- Identify and extract ONLY information that directly addresses the user's query +- Eliminate tangential or loosely related content +- Preserve technical details and specific examples when relevant + +2. INFORMATION QUALITY +- Prioritize factual, verifiable information +- Include specific technical details, numbers, or metrics when present +- Maintain technical accuracy in specialized topics + +3. STRUCTURE & CLARITY +- Present information in a logical, well-structured format +- Use clear, precise language +- Preserve important technical terms and concepts + +4. BALANCED PERSPECTIVE +- Include multiple viewpoints when present +- Note any significant disagreements or contradictions +- Indicate if information seems incomplete or uncertain + +5. CONTEXT PRESERVATION +- Maintain crucial context that affects meaning +- Include relevant dates or version information for technical content +- Preserve attribution for significant claims or findings + +Here's the content to analyze: + +""" +${docsContent} +""" + +Provide a focused, technical summary that will serve as high-quality context for the next phase of AI conversation.` + }) + ]) + + return { + relevantContent: + typeof response.content === 'string' + ? response.content + : JSON.stringify(response.content), + webSearchResults: searxngSearchResult.results + } + } +} diff --git a/src/shared/plugins/agents/web-visit-agent.ts b/src/shared/plugins/agents/web-visit-agent.ts new file mode 100644 index 0000000..096a868 --- /dev/null +++ b/src/shared/plugins/agents/web-visit-agent.ts @@ -0,0 +1,62 @@ +import { BaseAgent } from '@extension/webview-api/chat-context-processor/strategies/base/base-agent' +import type { BaseGraphState } from '@extension/webview-api/chat-context-processor/strategies/base/base-state' +import { DocCrawler } from '@extension/webview-api/chat-context-processor/utils/doc-crawler' +import { settledPromiseResults } from '@shared/utils/common' +import { z } from 'zod' + +import { webVisitAgentName } from './agent-names' + +export class WebVisitAgent extends BaseAgent< + BaseGraphState, + { enableWebVisitAgent: boolean } +> { + static name = webVisitAgentName + + name = WebVisitAgent.name + + logTitle = 'Visit web' + + description = + 'A tool for visiting and extracting content from web pages. Use this tool when you need to:\n' + + '1. Analyze specific webpage content in detail\n' + + '2. Extract information from known URLs\n' + + '3. Compare content across multiple web pages\n' + + '4. Verify or fact-check information from web sources\n' + + 'Note: Only use this for specific URLs you want to analyze, not for general web searches.' + + inputSchema = z.object({ + urls: z + .array(z.string().url()) + .describe( + 'An array of URLs to visit and retrieve content from. Each URL should be a valid web address.' + ) + }) + + outputSchema = z.object({ + contents: z.array( + z.object({ + content: z.string(), + url: z.string() + }) + ) + }) + + async execute(input: z.infer) { + const { enableWebVisitAgent } = this.context.createToolOptions + + if (!enableWebVisitAgent) { + return { contents: [] } + } + + const docCrawler = new DocCrawler(input.urls[0]!) + const contents = await settledPromiseResults( + input.urls.map(async url => ({ + url, + content: + (await docCrawler.getPageContent(url)) || 'Failed to retrieve content' + })) + ) + + return { contents } + } +} diff --git a/src/shared/plugins/base/base-to-state.ts b/src/shared/plugins/base/base-to-state.ts new file mode 100644 index 0000000..2d15f1f --- /dev/null +++ b/src/shared/plugins/base/base-to-state.ts @@ -0,0 +1,120 @@ +import type { + BaseAgent, + GetAgentOutput +} from '@extension/webview-api/chat-context-processor/strategies/base/base-agent' +import type { + Agent, + Conversation, + ConversationLog, + Mention +} from '@shared/entities' + +export type LogWithAgent = ConversationLog & { + agent?: A +} + +export abstract class BaseToState { + mentions?: M[] + + agents?: Agent[] + + conversation?: Conversation + + constructor(conversation?: Conversation) { + this.mentions = (conversation?.mentions || []) as M[] + this.agents = (conversation?.agents || []) as Agent[] + this.conversation = conversation + } + + abstract toMentionsState(): unknown + + abstract toAgentsState(): unknown + + toLogWithAgent(): LogWithAgent[] { + return toLogWithAgent(this.conversation) + } + + getMentionDataByType( + type: T + ): Extract['data'][] { + if (!this.mentions?.length) return [] + const data: Mention['data'][] = [] + + this.mentions?.forEach(mention => { + if (mention.type === type && mention.data) { + data.push(mention.data) + } + }) + + return data + } + + isMentionExit(type: T): boolean { + if (!this.mentions?.length) return false + return this.mentions?.some(mention => mention.type === type) || false + } + + getAgentOutputs( + agentName: T['name'] + ): GetAgentOutput[] { + const outputs: GetAgentOutput[] = [] + + this.agents?.forEach(agent => { + if (agent.name === agentName && agent.output) { + outputs.push(agent.output) + } + }) + + return outputs + } + + getAgentOutputsByKey>( + agentName: T['name'], + key: K + ): GetAgentOutput[K][] { + const outputs: GetAgentOutput[K][] = [] + + this.agents?.forEach(agent => { + if ( + agent.name === agentName && + agent.output && + key in agent.output && + agent.output[key] + ) { + outputs.push(agent.output[key]) + } + }) + + return outputs + } +} + +export const toLogWithAgent = ( + conversation: Conversation | undefined +): LogWithAgent[] => { + if (!conversation) return [] + + const idAgentMap = new Map() + conversation.agents?.forEach(agent => { + idAgentMap.set(agent.id, agent) + }) + + return ( + conversation.logs?.map(log => { + if (!log.agentId) return log + + return { + ...log, + agent: idAgentMap.get(log.agentId) + } + }) || [] + ) +} + +export type GetMentionState> = ReturnType< + T['toMentionsState'] +> + +export type GetAgentState> = ReturnType< + T['toAgentsState'] +> diff --git a/src/shared/plugins/base/client/client-plugin-types.ts b/src/shared/plugins/base/client/client-plugin-types.ts index 9d3f634..5b6b1c0 100644 --- a/src/shared/plugins/base/client/client-plugin-types.ts +++ b/src/shared/plugins/base/client/client-plugin-types.ts @@ -1,30 +1,15 @@ import type { FC } from 'react' -import type { BaseConversationLog, ConversationLog } from '@shared/entities' -import type { ImageInfo } from '@shared/plugins/fs-plugin/types' -import type { FileInfo, MentionOption } from '@webview/types/chat' +import type { MentionOption } from '@webview/types/chat' -export type UseMentionOptionsReturns = MentionOption[] - -export type UseSelectedFilesReturns = { - selectedFiles: FileInfo[] - setSelectedFiles: (files: FileInfo[]) => void -} +import type { LogWithAgent } from '../base-to-state' -export type UseSelectedImagesReturns = { - selectedImages: ImageInfo[] - addSelectedImage: (image: ImageInfo) => void - removeSelectedImage: (image: ImageInfo) => void -} +export type UseMentionOptionsReturns = MentionOption[] -export type CustomRenderLogPreviewProps< - T extends BaseConversationLog = ConversationLog -> = { - log: T +export type CustomRenderLogPreviewProps = { + log: LogWithAgent } export type ClientPluginProviderMap = { useMentionOptions: () => UseMentionOptionsReturns - useSelectedFiles: () => UseSelectedFilesReturns - useSelectedImages: () => UseSelectedImagesReturns CustomRenderLogPreview: FC } diff --git a/src/shared/plugins/base/server/create-provider-manager.ts b/src/shared/plugins/base/server/create-provider-manager.ts index 554c647..2bc733a 100644 --- a/src/shared/plugins/base/server/create-provider-manager.ts +++ b/src/shared/plugins/base/server/create-provider-manager.ts @@ -1,10 +1,11 @@ -import type { BaseStrategyOptions } from '@extension/webview-api/chat-context-processor/strategies/base-strategy' +import type { ControllerRegister } from '@extension/registers/controller-register' +import type { StructuredTool } from '@langchain/core/tools' +import type { ChatContext, Conversation, Mention } from '@shared/entities' import type { + BaseStrategyOptions, ChatGraphNode, ChatGraphState -} from '@extension/webview-api/chat-context-processor/strategies/chat-strategy/nodes/state' -import type { StructuredTool } from '@langchain/core/tools' -import type { ChatContext, Conversation } from '@shared/entities' +} from '@shared/plugins/base/strategies' import { ProviderManager } from '../provider-manager' @@ -27,15 +28,23 @@ export interface ChatStrategyProvider { chatContext: ChatContext ) => Promise buildAgentTools?: ( - options: BaseStrategyOptions, + strategyOptions: BaseStrategyOptions, graphState: ChatGraphState ) => Promise buildLanggraphToolNodes?: ( - options: BaseStrategyOptions + strategyOptions: BaseStrategyOptions ) => Promise } +export type RefreshMentionFn = (mention: Mention) => Mention +export interface MentionUtilsProvider { + createRefreshMentionFn: ( + controllerRegister: ControllerRegister + ) => Promise +} + export const createProviderManagers = () => ({ - chatStrategy: new ProviderManager() + chatStrategy: new ProviderManager(), + mentionUtils: new ProviderManager() }) as const satisfies Record> diff --git a/src/shared/plugins/base/strategies.ts b/src/shared/plugins/base/strategies.ts new file mode 100644 index 0000000..7c0fb4d --- /dev/null +++ b/src/shared/plugins/base/strategies.ts @@ -0,0 +1,2 @@ +export * from '@extension/webview-api/chat-context-processor/strategies/base' +export * from '@extension/webview-api/chat-context-processor/strategies/chat-strategy/state' diff --git a/src/shared/plugins/doc-plugin/client/doc-client-plugin.tsx b/src/shared/plugins/doc-plugin/client/doc-client-plugin.tsx index 469e212..e489d30 100644 --- a/src/shared/plugins/doc-plugin/client/doc-client-plugin.tsx +++ b/src/shared/plugins/doc-plugin/client/doc-client-plugin.tsx @@ -11,7 +11,7 @@ import { api } from '@webview/services/api-client' import { type MentionOption } from '@webview/types/chat' import { useNavigate } from 'react-router' -import type { DocPluginState } from '../types' +import { DocMentionType, type DocPluginState } from '../types' import { DocLogPreview } from './doc-log-preview' export const DocClientPlugin = createClientPlugin({ @@ -19,10 +19,7 @@ export const DocClientPlugin = createClientPlugin({ version: pkg.version, getInitialState() { - return { - allowSearchDocSiteNamesFromEditor: [], - relevantDocsFromAgent: [] - } + return {} }, setup(props) { @@ -35,7 +32,6 @@ export const DocClientPlugin = createClientPlugin({ const createUseMentionOptions = (props: SetupProps) => (): UseMentionOptionsReturns => { - const { setState } = props const navigate = useNavigate() const { data: docSites = [] } = useQuery({ @@ -43,9 +39,9 @@ const createUseMentionOptions = queryFn: () => api.doc.getDocSites({}) }) - const docSiteNamesSettingMentionOption: MentionOption = { - id: `${PluginId.Doc}#doc#setting`, - type: `${PluginId.Doc}#doc`, + const docSiteNamesSettingMentionOption: MentionOption = { + id: DocMentionType.DocSetting, + type: DocMentionType.DocSetting, label: 'docs setting', disableAddToEditor: true, onSelect: () => { @@ -59,18 +55,13 @@ const createUseMentionOptions = } } - const docSiteNamesMentionOptions: MentionOption[] = docSites.map( + const docSiteNamesMentionOptions: MentionOption[] = docSites.map( site => ({ - id: `${PluginId.Doc}#doc#${site.id}`, - type: `${PluginId.Doc}#doc`, + id: `${DocMentionType.Doc}#${site.id}`, + type: DocMentionType.Doc, label: site.name, data: site.name, - onUpdatePluginState: dataArr => { - setState(draft => { - draft.allowSearchDocSiteNamesFromEditor = dataArr - }) - }, searchKeywords: [site.name, site.url], itemLayoutProps: { icon: , @@ -82,8 +73,8 @@ const createUseMentionOptions = return [ { - id: `${PluginId.Doc}#docs`, - type: `${PluginId.Doc}#docs`, + id: DocMentionType.Docs, + type: DocMentionType.Docs, label: 'Docs', topLevelSort: 4, searchKeywords: ['docs'], diff --git a/src/shared/plugins/doc-plugin/client/doc-log-preview.tsx b/src/shared/plugins/doc-plugin/client/doc-log-preview.tsx index a7945a4..a9328bd 100644 --- a/src/shared/plugins/doc-plugin/client/doc-log-preview.tsx +++ b/src/shared/plugins/doc-plugin/client/doc-log-preview.tsx @@ -1,32 +1,45 @@ -import { FC } from 'react' +import { FC, type ReactNode } from 'react' import { FileTextIcon } from '@radix-ui/react-icons' +import { docRetrieverAgentName } from '@shared/plugins/agents/agent-names' +import type { DocRetrieverAgent } from '@shared/plugins/agents/doc-retriever-agent' import type { CustomRenderLogPreviewProps } from '@shared/plugins/base/client/client-plugin-types' -import { PluginId } from '@shared/plugins/base/types' +import type { GetAgent } from '@shared/plugins/base/strategies' import { ChatLogPreview } from '@webview/components/chat/messages/roles/chat-log-preview' import type { PreviewContent } from '@webview/components/content-preview' import { ContentPreviewPopover } from '@webview/components/content-preview-popover' import { api } from '@webview/services/api-client' import { cn } from '@webview/utils/common' -import type { DocInfo, DocPluginLog } from '../types' +import type { DocInfo } from '../types' export const DocLogPreview: FC = props => { - if (props.log.pluginId !== PluginId.Doc) return null - const log = props.log as DocPluginLog + const { log } = props + const { agent } = log - return ( + const renderWrapper = (children: ReactNode) => ( -
- {log.relevantDocsFromAgent?.map((doc, index) => ( - - ))} -
+
{children}
) + + if (!agent) return null + + switch (agent.name) { + case docRetrieverAgentName: + return renderWrapper( + (agent as GetAgent).output.relevantDocs?.map( + (doc, index) => ( + + ) + ) + ) + default: + return null + } } interface DocItemProps { diff --git a/src/shared/plugins/doc-plugin/doc-to-state.ts b/src/shared/plugins/doc-plugin/doc-to-state.ts new file mode 100644 index 0000000..0e58dd6 --- /dev/null +++ b/src/shared/plugins/doc-plugin/doc-to-state.ts @@ -0,0 +1,20 @@ +import { DocRetrieverAgent } from '../agents/doc-retriever-agent' +import { BaseToState } from '../base/base-to-state' +import { DocMentionType, type DocMention } from './types' + +export class DocToState extends BaseToState { + toMentionsState() { + return { + allowSearchDocSiteNames: this.getMentionDataByType(DocMentionType.Doc) + } + } + + toAgentsState() { + return { + relevantDocs: this.getAgentOutputsByKey< + DocRetrieverAgent, + 'relevantDocs' + >(DocRetrieverAgent.name, 'relevantDocs').flat() + } + } +} diff --git a/src/shared/plugins/doc-plugin/server/chat-strategy/doc-chat-strategy-provider.ts b/src/shared/plugins/doc-plugin/server/chat-strategy/doc-chat-strategy-provider.ts index 11843be..f6904c5 100644 --- a/src/shared/plugins/doc-plugin/server/chat-strategy/doc-chat-strategy-provider.ts +++ b/src/shared/plugins/doc-plugin/server/chat-strategy/doc-chat-strategy-provider.ts @@ -1,56 +1,77 @@ -import type { BaseStrategyOptions } from '@extension/webview-api/chat-context-processor/strategies/base-strategy' -import type { - ChatGraphNode, - ChatGraphState -} from '@extension/webview-api/chat-context-processor/strategies/chat-strategy/nodes/state' import type { StructuredTool } from '@langchain/core/tools' import type { Conversation } from '@shared/entities' +import type { + GetAgentState, + GetMentionState +} from '@shared/plugins/base/base-to-state' import type { ChatStrategyProvider } from '@shared/plugins/base/server/create-provider-manager' -import { PluginId } from '@shared/plugins/base/types' - -import type { DocPluginState } from '../../types' import { - createDocRetrieverNode, - createDocRetrieverTool -} from './doc-retriever-node' + createGraphNodeFromNodes, + createToolsFromNodes +} from '@shared/plugins/base/strategies' +import type { + BaseStrategyOptions, + ChatGraphNode, + ChatGraphState +} from '@shared/plugins/base/strategies' -export class DocChatStrategyProvider implements ChatStrategyProvider { - async buildContextMessagePrompt(conversation: Conversation): Promise { - const state = conversation.pluginStates?.[PluginId.Doc] as - | Partial - | undefined +import { DocToState } from '../../doc-to-state' +import { DocRetrieverNode } from './doc-retriever-node' - if (!state) return '' +interface ConversationWithStateProps { + conversation: Conversation + mentionState: GetMentionState + agentState: GetAgentState +} - const relevantDocsPrompt = this.buildRelevantDocsPrompt(state) +export class DocChatStrategyProvider implements ChatStrategyProvider { + private createConversationWithStateProps( + conversation: Conversation + ): ConversationWithStateProps { + const docToState = new DocToState(conversation) + const mentionState = docToState.toMentionsState() + const agentState = docToState.toAgentsState() + + return { conversation, mentionState, agentState } + } + async buildContextMessagePrompt(conversation: Conversation): Promise { + const props = this.createConversationWithStateProps(conversation) + const relevantDocsPrompt = this.buildRelevantDocsPrompt(props) const prompts = [relevantDocsPrompt].filter(Boolean) return prompts.join('\n\n') } async buildAgentTools( - options: BaseStrategyOptions, + strategyOptions: BaseStrategyOptions, state: ChatGraphState ): Promise { - const tools = await Promise.all([createDocRetrieverTool(options, state)]) - return tools.filter(Boolean) as StructuredTool[] + return await createToolsFromNodes({ + nodeClasses: [DocRetrieverNode], + strategyOptions, + state + }) } async buildLanggraphToolNodes( - options: BaseStrategyOptions + strategyOptions: BaseStrategyOptions ): Promise { - return [createDocRetrieverNode(options)] + return await createGraphNodeFromNodes({ + nodeClasses: [DocRetrieverNode], + strategyOptions + }) } - private buildRelevantDocsPrompt(state: Partial): string { - const { relevantDocsFromAgent: relevantDocsFromDocAgent } = state + private buildRelevantDocsPrompt(props: ConversationWithStateProps): string { + const { agentState } = props + const { relevantDocs } = agentState - if (!relevantDocsFromDocAgent?.length) return '' + if (!relevantDocs?.length) return '' let docsContent = '' - relevantDocsFromDocAgent.forEach(doc => { + relevantDocs.forEach(doc => { docsContent += ` Source Path: ${doc.path} Content: ${doc.content} diff --git a/src/shared/plugins/doc-plugin/server/chat-strategy/doc-retriever-node.ts b/src/shared/plugins/doc-plugin/server/chat-strategy/doc-retriever-node.ts index f6c62af..5705b91 100644 --- a/src/shared/plugins/doc-plugin/server/chat-strategy/doc-retriever-node.ts +++ b/src/shared/plugins/doc-plugin/server/chat-strategy/doc-retriever-node.ts @@ -1,180 +1,54 @@ -import { aidePaths } from '@extension/file-utils/paths' -import type { BaseStrategyOptions } from '@extension/webview-api/chat-context-processor/strategies/base-strategy' +import { DocRetrieverAgent } from '@shared/plugins/agents/doc-retriever-agent' import { - dispatchChatGraphState, - type ChatGraphState, - type CreateChatGraphNode -} from '@extension/webview-api/chat-context-processor/strategies/chat-strategy/nodes/state' -import { DocCrawler } from '@extension/webview-api/chat-context-processor/utils/doc-crawler' -import { findCurrentToolsCallParams } from '@extension/webview-api/chat-context-processor/utils/find-current-tools-call-params' -import { DocIndexer } from '@extension/webview-api/chat-context-processor/vectordb/doc-indexer' -import { docSitesDB } from '@extension/webview-api/lowdb/doc-sites-db' -import type { ToolMessage } from '@langchain/core/messages' -import { DynamicStructuredTool } from '@langchain/core/tools' -import type { ConversationLog, LangchainTool } from '@shared/entities' -import { PluginId } from '@shared/plugins/base/types' -import { removeDuplicates, settledPromiseResults } from '@shared/utils/common' -import { produce } from 'immer' -import { v4 as uuidv4 } from 'uuid' -import { z } from 'zod' - -import type { DocInfo, DocPluginLog, DocPluginState } from '../../types' - -interface DocRetrieverToolResult { - relevantDocs: DocInfo[] -} - -export const createDocRetrieverTool = async ( - options: BaseStrategyOptions, - state: ChatGraphState -) => { - const lastConversation = state.chatContext.conversations.at(-1) - - const docPluginState = lastConversation?.pluginStates?.[PluginId.Doc] as - | Partial - | undefined - - if (!docPluginState?.allowSearchDocSiteNamesFromEditor?.length) return null - - const siteNames = removeDuplicates( - docPluginState.allowSearchDocSiteNamesFromEditor - ) - - const getRelevantDocs = async ( - queryParts: { siteName: string; keywords: string[] }[] - ): Promise => { - const docSites = await docSitesDB.getAll() - - const docPromises = queryParts.map(async ({ siteName, keywords }) => { - const docSite = docSites.find(site => site.name === siteName) - - if (!docSite?.isIndexed || !siteNames.includes(siteName)) { - return [] - } - - const docIndexer = new DocIndexer( - DocCrawler.getDocCrawlerFolderPath(docSite.url), - aidePaths.getGlobalLanceDbPath() - ) - - await docIndexer.initialize() - - const searchResults = await settledPromiseResults( - keywords.map(keyword => docIndexer.searchSimilarRow(keyword)) - ) - - const searchRows = removeDuplicates( - searchResults.flatMap(result => result), - ['fullPath'] - ).slice(0, 3) - - const docInfoResults = await settledPromiseResults( - searchRows.map(async row => ({ - content: await docIndexer.getRowFileContent(row), - path: docSite.url - })) - ) - - return docInfoResults - }) - - const results = await settledPromiseResults(docPromises) - return results.flatMap(result => result) - } - - return new DynamicStructuredTool({ - name: 'docRetriever', - description: - 'Search for relevant information in specified documentation sites. This tool can search across multiple doc sites, with multiple keywords for each site. Use this tool to find documentation on specific topics or understand how certain features are described in the documentation.', - func: async ({ queryParts }): Promise => ({ - relevantDocs: await getRelevantDocs(queryParts) - }), - schema: z.object({ - queryParts: z - .array( - z.object({ - siteName: z - .enum(siteNames as unknown as [string, ...string[]]) - .describe('The name of the documentation site to search'), - keywords: z - .array(z.string()) - .describe( - 'List of keywords to search for in the specified doc site' - ) - }) - ) - .describe( - "The AI should break down the user's query into multiple parts, each targeting a specific doc site with relevant keywords. This allows for a more comprehensive search across multiple documentation sources." - ) - }) - }) -} - -export const createDocRetrieverNode: CreateChatGraphNode = - options => async state => { - const lastConversation = state.chatContext.conversations.at(-1) - const logs: ConversationLog[] = [] - - const docPluginState = lastConversation?.pluginStates?.[PluginId.Doc] as - | Partial - | undefined - - if (!docPluginState?.allowSearchDocSiteNamesFromEditor?.length) return {} - - const docRetrieverTool = await createDocRetrieverTool(options, state) - - if (!docRetrieverTool) return {} - - const tools: LangchainTool[] = [docRetrieverTool] - const lastMessage = state.messages.at(-1) - const toolCalls = findCurrentToolsCallParams(lastMessage, tools) - - if (!toolCalls.length) return {} - - const toolCallsPromises = toolCalls.map(async toolCall => { - const toolMessage = (await docRetrieverTool.invoke( - toolCall - )) as ToolMessage - - const result = JSON.parse( - toolMessage?.lc_kwargs.content - ) as DocRetrieverToolResult - - lastConversation!.pluginStates![PluginId.Doc] = produce( - lastConversation!.pluginStates![ - PluginId.Doc - ] as Partial, - (draft: Partial) => { - if (!draft.relevantDocsFromAgent) { - draft.relevantDocsFromAgent = [] + BaseNode, + ChatGraphState, + dispatchBaseGraphState, + type BaseStrategyOptions +} from '@shared/plugins/base/strategies' + +import { DocToState } from '../../doc-to-state' + +export class DocRetrieverNode extends BaseNode< + ChatGraphState, + BaseStrategyOptions +> { + onInit() { + this.registerAgentConfig(DocRetrieverAgent.name, state => { + const lastConversation = state.chatContext.conversations.at(-1) + const mentionState = new DocToState(lastConversation).toMentionsState() + + return this.createAgentConfig({ + agentClass: DocRetrieverAgent, + agentContext: { + state, + strategyOptions: this.context.strategyOptions, + createToolOptions: { + allowSearchDocSiteNames: mentionState.allowSearchDocSiteNames } - - draft.relevantDocsFromAgent.push(...result.relevantDocs) } - ) + }) + }) + } - logs.push({ - id: uuidv4(), - createdAt: Date.now(), - pluginId: PluginId.Doc, - title: 'Search documentation', - relevantDocsFromAgent: result.relevantDocs - } satisfies DocPluginLog) + async execute(state: ChatGraphState) { + const toolCallsResults = await this.executeAgentTool(state, { + agentClass: DocRetrieverAgent }) - await settledPromiseResults(toolCallsPromises) + if (!toolCallsResults.agents.length) return {} - const newConversations = produce(state.newConversations, draft => { - draft.at(-1)!.logs.push(...logs) - }) + const newConversation = state.newConversations.at(-1)! + this.addAgentsToConversation(newConversation, toolCallsResults.agents) + this.addLogsToConversation(newConversation, toolCallsResults.logs) - dispatchChatGraphState({ - newConversations, - chatContext: state.chatContext + dispatchBaseGraphState({ + chatContext: state.chatContext, + newConversations: state.newConversations }) return { chatContext: state.chatContext, - newConversations + newConversations: state.newConversations } } +} diff --git a/src/shared/plugins/doc-plugin/server/doc-mention-utils-provider.ts b/src/shared/plugins/doc-plugin/server/doc-mention-utils-provider.ts new file mode 100644 index 0000000..5eb57ee --- /dev/null +++ b/src/shared/plugins/doc-plugin/server/doc-mention-utils-provider.ts @@ -0,0 +1,31 @@ +import type { ControllerRegister } from '@extension/registers/controller-register' +import type { Mention } from '@shared/entities' +import type { MentionUtilsProvider } from '@shared/plugins/base/server/create-provider-manager' + +import { DocMentionType } from '../types' + +export class DocMentionUtilsProvider implements MentionUtilsProvider { + async createRefreshMentionFn(controllerRegister: ControllerRegister) { + const docSites = await controllerRegister.api('doc').getDocSites() + + // Create a map of doc site names for quick lookup + const docSiteMap = new Map() + for (const site of docSites) { + docSiteMap.set(site.name, site.name) + } + + return (_mention: Mention) => { + const mention = { ..._mention } as Mention + switch (mention.type) { + case DocMentionType.Doc: + const siteName = docSiteMap.get(mention.data) + if (siteName) mention.data = siteName + break + default: + break + } + + return mention + } + } +} diff --git a/src/shared/plugins/doc-plugin/server/doc-server-plugin.ts b/src/shared/plugins/doc-plugin/server/doc-server-plugin.ts index 7eac628..8ebd41a 100644 --- a/src/shared/plugins/doc-plugin/server/doc-server-plugin.ts +++ b/src/shared/plugins/doc-plugin/server/doc-server-plugin.ts @@ -7,6 +7,7 @@ import { pkg } from '@shared/utils/pkg' import type { DocPluginState } from '../types' import { DocChatStrategyProvider } from './chat-strategy/doc-chat-strategy-provider' +import { DocMentionUtilsProvider } from './doc-mention-utils-provider' export class DocServerPlugin implements ServerPlugin { id = PluginId.Doc @@ -22,6 +23,11 @@ export class DocServerPlugin implements ServerPlugin { 'chatStrategy', () => new DocChatStrategyProvider() ) + + this.context.registerProvider( + 'mentionUtils', + () => new DocMentionUtilsProvider() + ) } deactivate(): void { diff --git a/src/shared/plugins/doc-plugin/types.ts b/src/shared/plugins/doc-plugin/types.ts index 4d21dc5..b5f0adf 100644 --- a/src/shared/plugins/doc-plugin/types.ts +++ b/src/shared/plugins/doc-plugin/types.ts @@ -1,18 +1,18 @@ -import type { BaseConversationLog } from '@shared/entities' +import type { Mention } from '@shared/entities' -import type { PluginId } from '../base/types' +import { PluginId } from '../base/types' + +export enum DocMentionType { + Docs = `${PluginId.Doc}#docs`, + Doc = `${PluginId.Doc}#doc`, + DocSetting = `${PluginId.Doc}#doc-setting` +} + +export type DocMention = Mention export interface DocInfo { content: string path: string // file path or url } -export interface DocPluginState { - allowSearchDocSiteNamesFromEditor: string[] - relevantDocsFromAgent: DocInfo[] -} - -export interface DocPluginLog extends BaseConversationLog { - pluginId: PluginId.Doc - relevantDocsFromAgent?: DocInfo[] -} +export interface DocPluginState {} diff --git a/src/shared/plugins/fs-plugin/client/fs-client-plugin.tsx b/src/shared/plugins/fs-plugin/client/fs-client-plugin.tsx index 1fe85e0..1d9ac26 100644 --- a/src/shared/plugins/fs-plugin/client/fs-client-plugin.tsx +++ b/src/shared/plugins/fs-plugin/client/fs-client-plugin.tsx @@ -4,11 +4,7 @@ import { CubeIcon, ExclamationTriangleIcon } from '@radix-ui/react-icons' -import type { - UseMentionOptionsReturns, - UseSelectedFilesReturns, - UseSelectedImagesReturns -} from '@shared/plugins/base/client/client-plugin-types' +import type { UseMentionOptionsReturns } from '@shared/plugins/base/client/client-plugin-types' import { createClientPlugin, type SetupProps @@ -22,7 +18,7 @@ import { SearchSortStrategy, type MentionOption } from '@webview/types/chat' import { getFileNameFromPath } from '@webview/utils/path' import { ChevronRightIcon, FileIcon, FolderTreeIcon } from 'lucide-react' -import type { FsPluginState, TreeInfo } from '../types' +import { FsMentionType, type FsPluginState, type TreeInfo } from '../types' import { FsLogPreview } from './fs-log-preview' import { MentionFilePreview } from './mention-file-preview' import { MentionFolderPreview } from './mention-folder-preview' @@ -33,59 +29,19 @@ export const FsClientPlugin = createClientPlugin({ version: pkg.version, getInitialState() { - return { - selectedFilesFromFileSelector: [], - selectedFilesFromEditor: [], - selectedFilesFromAgent: [], - currentFilesFromVSCode: [], - selectedFoldersFromEditor: [], - selectedImagesFromOutsideUrl: [], - codeChunksFromEditor: [], - codeSnippetFromAgent: [], - enableCodebaseAgent: false, - editorErrors: [], - selectedTreesFromEditor: [] - } + return {} }, setup(props) { const { registerProvider } = props registerProvider('useMentionOptions', () => createUseMentionOptions(props)) - registerProvider('useSelectedFiles', () => createUseSelectedFiles(props)) - registerProvider('useSelectedImages', () => createUseSelectedImages(props)) registerProvider('CustomRenderLogPreview', () => FsLogPreview) } }) -const createUseSelectedFiles = - (props: SetupProps) => (): UseSelectedFilesReturns => ({ - selectedFiles: props.state.selectedFilesFromFileSelector || [], - setSelectedFiles: files => - props.setState(draft => { - draft.selectedFilesFromFileSelector = files - }) - }) - -const createUseSelectedImages = - (props: SetupProps) => (): UseSelectedImagesReturns => ({ - selectedImages: props.state.selectedImagesFromOutsideUrl || [], - addSelectedImage: image => { - props.setState(draft => { - draft.selectedImagesFromOutsideUrl.push(image) - }) - }, - removeSelectedImage: image => { - props.setState(draft => { - draft.selectedImagesFromOutsideUrl = - draft.selectedImagesFromOutsideUrl.filter(i => i.url !== image.url) - }) - } - }) - const createUseMentionOptions = (props: SetupProps) => (): UseMentionOptionsReturns => { - const { setState } = props const { data: files = [] } = useQuery({ queryKey: ['realtime', 'files'], queryFn: () => api.file.traverseWorkspaceFiles({ filesOrFolders: ['./'] }) @@ -110,16 +66,10 @@ const createUseMentionOptions = const label = getFileNameFromPath(file.fullPath) return { - id: `${PluginId.Fs}#file#${file.fullPath}`, - type: `${PluginId.Fs}#file`, + id: `${FsMentionType.File}#${file.fullPath}`, + type: FsMentionType.File, label, data: file, - onUpdatePluginState: dataArr => { - setState(draft => { - draft.selectedFilesFromEditor = dataArr - }) - }, - searchKeywords: [file.relativePath, label], searchSortStrategy: SearchSortStrategy.EndMatch, itemLayoutProps: { @@ -137,16 +87,10 @@ const createUseMentionOptions = const label = getFileNameFromPath(folder.fullPath) return { - id: `${PluginId.Fs}#folder#${folder.fullPath}`, - type: `${PluginId.Fs}#folder`, + id: `${FsMentionType.Folder}#${folder.fullPath}`, + type: FsMentionType.Folder, label, data: folder, - onUpdatePluginState: dataArr => { - setState(draft => { - draft.selectedFoldersFromEditor = dataArr - }) - }, - searchKeywords: [folder.relativePath, label], searchSortStrategy: SearchSortStrategy.EndMatch, itemLayoutProps: { @@ -172,15 +116,10 @@ const createUseMentionOptions = const label = getFileNameFromPath(treeInfo.fullPath) return { - id: `${PluginId.Fs}#tree#${treeInfo.fullPath}`, - type: `${PluginId.Fs}#tree`, + id: `${FsMentionType.Tree}#${treeInfo.fullPath}`, + type: FsMentionType.Tree, label, data: treeInfo, - onUpdatePluginState: dataArr => { - setState(draft => { - draft.selectedTreesFromEditor = dataArr - }) - }, searchKeywords: [treeInfo.relativePath, label], searchSortStrategy: SearchSortStrategy.EndMatch, itemLayoutProps: { @@ -194,8 +133,8 @@ const createUseMentionOptions = return [ { - id: `${PluginId.Fs}#files`, - type: `${PluginId.Fs}#files`, + id: FsMentionType.Files, + type: FsMentionType.Files, label: 'Files', topLevelSort: 0, searchKeywords: ['files'], @@ -206,8 +145,8 @@ const createUseMentionOptions = } }, { - id: `${PluginId.Fs}#folders`, - type: `${PluginId.Fs}#folders`, + id: FsMentionType.Folders, + type: FsMentionType.Folders, label: 'Folders', topLevelSort: 1, searchKeywords: ['folders'], @@ -218,8 +157,8 @@ const createUseMentionOptions = } }, { - id: `${PluginId.Fs}#tree`, - type: `${PluginId.Fs}#tree`, + id: FsMentionType.Trees, + type: FsMentionType.Trees, label: 'Tree', topLevelSort: 2, searchKeywords: ['tree', 'structure'], @@ -230,8 +169,8 @@ const createUseMentionOptions = } }, // { - // id: `${PluginId.Fs}#code`, - // type: `${PluginId.Fs}#code`, + // id: FsMentionType.Code, + // type: FsMentionType.Code, // label: 'Code', // topLevelSort: 2, // searchKeywords: ['code'], @@ -241,16 +180,10 @@ const createUseMentionOptions = // } // }, { - id: `${PluginId.Fs}#codebase`, - type: `${PluginId.Fs}#codebase`, + id: FsMentionType.Codebase, + type: FsMentionType.Codebase, label: 'Codebase', data: true, - onUpdatePluginState: (dataArr: true[]) => { - setState(draft => { - draft.enableCodebaseAgent = dataArr.length > 0 - draft.codeSnippetFromAgent = [] - }) - }, topLevelSort: 6, searchKeywords: ['codebase'], itemLayoutProps: { @@ -259,15 +192,10 @@ const createUseMentionOptions = } }, { - id: `${PluginId.Fs}#errors`, - type: `${PluginId.Fs}#errors`, + id: FsMentionType.Errors, + type: FsMentionType.Errors, label: 'Errors', data: editorErrors, - onUpdatePluginState: dataArr => { - setState(draft => { - draft.editorErrors = dataArr?.[0] ?? [] - }) - }, topLevelSort: 7, searchKeywords: ['errors', 'warnings', 'diagnostics'], itemLayoutProps: { diff --git a/src/shared/plugins/fs-plugin/client/fs-log-preview.tsx b/src/shared/plugins/fs-plugin/client/fs-log-preview.tsx index 03f5937..f85db32 100644 --- a/src/shared/plugins/fs-plugin/client/fs-log-preview.tsx +++ b/src/shared/plugins/fs-plugin/client/fs-log-preview.tsx @@ -1,32 +1,54 @@ -import { FC } from 'react' -import type { FileInfo } from '@extension/file-utils/traverse-fs' +import { FC, type ReactNode } from 'react' +import { + codebaseSearchAgentName, + fsVisitAgentName +} from '@shared/plugins/agents/agent-names' +import type { CodebaseSearchAgent } from '@shared/plugins/agents/codebase-search-agent' +import type { FsVisitAgent } from '@shared/plugins/agents/fs-visit-agent' import type { CustomRenderLogPreviewProps } from '@shared/plugins/base/client/client-plugin-types' -import { PluginId } from '@shared/plugins/base/types' +import type { GetAgent } from '@shared/plugins/base/strategies' import { ChatLogPreview } from '@webview/components/chat/messages/roles/chat-log-preview' import { FileIcon } from '@webview/components/file-icon' import { TruncateStart } from '@webview/components/truncate-start' import { api } from '@webview/services/api-client' +import type { FileInfo } from '@webview/types/chat' import { cn } from '@webview/utils/common' import { getFileNameFromPath } from '@webview/utils/path' -import type { CodeSnippet, FsPluginLog } from '../types' +import type { CodeSnippet } from '../types' export const FsLogPreview: FC = props => { - if (props.log.pluginId !== PluginId.Fs) return null - const log = props.log as FsPluginLog + const { log } = props + const { agent } = log - return ( + const renderWrapper = (children: ReactNode) => ( -
- {log.codeSnippets?.map((snippet, index) => ( - - ))} - {log.selectedFilesFromAgent?.map((file, index) => ( - - ))} -
+
{children}
) + + if (!agent) return null + + switch (agent.name) { + case codebaseSearchAgentName: + return renderWrapper( +
+ {(agent as GetAgent).output.codeSnippets?.map( + (snippet, index) => + )} +
+ ) + case fsVisitAgentName: + return renderWrapper( +
+ {(agent as GetAgent).output.files?.map( + (file, index) => + )} +
+ ) + default: + return null + } } interface FileSnippetItemProps { diff --git a/src/shared/plugins/fs-plugin/fs-to-state.ts b/src/shared/plugins/fs-plugin/fs-to-state.ts new file mode 100644 index 0000000..81d75e1 --- /dev/null +++ b/src/shared/plugins/fs-plugin/fs-to-state.ts @@ -0,0 +1,30 @@ +import { CodebaseSearchAgent } from '../agents/codebase-search-agent' +import { FsVisitAgent } from '../agents/fs-visit-agent' +import { BaseToState } from '../base/base-to-state' +import { FsMentionType, type FsMention } from './types' + +export class FsToState extends BaseToState { + toMentionsState() { + return { + selectedFiles: this.getMentionDataByType(FsMentionType.File), + selectedFolders: this.getMentionDataByType(FsMentionType.Folder), + selectedTrees: this.getMentionDataByType(FsMentionType.Tree), + codeChunks: this.getMentionDataByType(FsMentionType.Code), + enableCodebaseAgent: this.isMentionExit(FsMentionType.Codebase), + editorErrors: this.getMentionDataByType(FsMentionType.Errors).flat() + } + } + + toAgentsState() { + return { + codeSnippets: this.getAgentOutputsByKey< + CodebaseSearchAgent, + 'codeSnippets' + >('codebaseSearch', 'codeSnippets').flat(), + visitedFiles: this.getAgentOutputsByKey( + 'fsVisit', + 'files' + ).flat() + } + } +} diff --git a/src/shared/plugins/fs-plugin/server/chat-strategy/codebase-search-node.ts b/src/shared/plugins/fs-plugin/server/chat-strategy/codebase-search-node.ts index 5cadf94..535ae12 100644 --- a/src/shared/plugins/fs-plugin/server/chat-strategy/codebase-search-node.ts +++ b/src/shared/plugins/fs-plugin/server/chat-strategy/codebase-search-node.ts @@ -1,160 +1,50 @@ -import { CodebaseWatcherRegister } from '@extension/registers/codebase-watcher-register' -import type { BaseStrategyOptions } from '@extension/webview-api/chat-context-processor/strategies/base-strategy' +import { CodebaseSearchAgent } from '@shared/plugins/agents/codebase-search-agent' import { - dispatchChatGraphState, - type ChatGraphState, - type CreateChatGraphNode -} from '@extension/webview-api/chat-context-processor/strategies/chat-strategy/nodes/state' -import { findCurrentToolsCallParams } from '@extension/webview-api/chat-context-processor/utils/find-current-tools-call-params' -import type { ToolMessage } from '@langchain/core/messages' -import { DynamicStructuredTool } from '@langchain/core/tools' -import type { ConversationLog, LangchainTool } from '@shared/entities' -import { PluginId } from '@shared/plugins/base/types' -import { mergeCodeSnippets } from '@shared/plugins/fs-plugin/server/merge-code-snippets' -import { settledPromiseResults } from '@shared/utils/common' -import { produce } from 'immer' -import { v4 as uuidv4 } from 'uuid' -import { z } from 'zod' - -import type { CodeSnippet, FsPluginLog, FsPluginState } from '../../types' - -interface CodebaseSearchToolResult { - codeSnippets: CodeSnippet[] -} - -export const createCodebaseSearchTool = async ( - options: BaseStrategyOptions, - state: ChatGraphState -) => { - const { conversations } = state.chatContext - const lastConversation = conversations.at(-1) - const fsPluginState = lastConversation?.pluginStates?.[PluginId.Fs] as - | Partial - | undefined - - if (!fsPluginState?.enableCodebaseAgent) return null - - const getSearchResults = async ( - state: ChatGraphState, - queryParts?: string[] - ): Promise => { - const indexer = options.registerManager.getRegister( - CodebaseWatcherRegister - )?.indexer - const searchResults: CodebaseSearchToolResult = { - codeSnippets: [] - } - - if (!indexer) return searchResults - - const searchPromisesResult = await settledPromiseResults( - queryParts?.map(queryPart => indexer.searchSimilarRow(queryPart)) || [] - ) - - const searchCodeSnippets: CodeSnippet[] = searchPromisesResult - .flat() - .slice(0, 8) - .map(row => { - // eslint-disable-next-line unused-imports/no-unused-vars - const { embedding, ...others } = row - return { ...others, code: '' } + BaseNode, + dispatchBaseGraphState, + type ChatGraphState +} from '@shared/plugins/base/strategies' + +import { FsToState } from '../../fs-to-state' + +export class CodebaseSearchNode extends BaseNode { + onInit() { + this.registerAgentConfig(CodebaseSearchAgent.name, state => { + const lastConversation = state.chatContext.conversations.at(-1) + const mentionState = new FsToState(lastConversation).toMentionsState() + + return this.createAgentConfig({ + agentClass: CodebaseSearchAgent, + agentContext: { + state, + strategyOptions: this.context.strategyOptions, + createToolOptions: { + enableCodebaseAgent: mentionState.enableCodebaseAgent + } + } }) - - const mergedCodeSnippets = await mergeCodeSnippets(searchCodeSnippets, { - mode: 'expanded' }) - - return { - ...searchResults, - codeSnippets: mergedCodeSnippets - } } - return new DynamicStructuredTool({ - name: 'codebaseSearch', - description: - 'Search the codebase using vector embeddings. This tool breaks down the query into parts and finds relevant code snippets for each part. Use this when you need to find specific code implementations or understand how certain features are coded in the project.', - func: async ({ queryParts }): Promise => - await getSearchResults(state, queryParts), - schema: z.object({ - queryParts: z - .array( - z - .string() - .describe( - 'A list of code snippets or questions to search for in the codebase. Each item will be used to find relevant code through vector search.' - ) - ) - .describe( - "The AI should break down the user's query into multiple parts, each focusing on a specific aspect or concept. This allows for a more comprehensive search across the codebase." - ) - }) - }) -} - -export const createCodebaseSearchNode: CreateChatGraphNode = - options => async state => { - const { conversations } = state.chatContext - const lastConversation = conversations.at(-1) - const fsPluginState = lastConversation?.pluginStates?.[PluginId.Fs] as - | Partial - | undefined - const logs: ConversationLog[] = [] - - if (!fsPluginState?.enableCodebaseAgent) return {} - - const codebaseSearchTool = await createCodebaseSearchTool(options, state) - - if (!codebaseSearchTool) return {} - - const tools: LangchainTool[] = [codebaseSearchTool] - const lastMessage = state.messages.at(-1) - const toolCalls = findCurrentToolsCallParams(lastMessage, tools) - - if (!toolCalls.length) return {} - - const toolCallsPromises = toolCalls.map(async toolCall => { - const toolMessage = (await codebaseSearchTool.invoke( - toolCall - )) as ToolMessage - - const result = JSON.parse( - toolMessage?.lc_kwargs.content - ) as CodebaseSearchToolResult - - lastConversation!.pluginStates![PluginId.Fs] = produce( - lastConversation!.pluginStates![PluginId.Fs] as Partial, - (draft: Partial) => { - if (!draft.codeSnippetFromAgent) { - draft.codeSnippetFromAgent = [] - } - - draft.codeSnippetFromAgent.push(...result.codeSnippets) - } - ) - - logs.push({ - id: uuidv4(), - createdAt: Date.now(), - pluginId: PluginId.Fs, - title: 'Search codebase', - codeSnippets: result.codeSnippets - } satisfies FsPluginLog) + async execute(state: ChatGraphState) { + const toolCallsResults = await this.executeAgentTool(state, { + agentClass: CodebaseSearchAgent }) - await settledPromiseResults(toolCallsPromises) + if (!toolCallsResults.agents.length) return {} - const newConversations = produce(state.newConversations, draft => { - draft.at(-1)!.logs.push(...logs) - }) + const newConversation = state.newConversations.at(-1)! + this.addAgentsToConversation(newConversation, toolCallsResults.agents) + this.addLogsToConversation(newConversation, toolCallsResults.logs) - dispatchChatGraphState({ - newConversations, - chatContext: state.chatContext + dispatchBaseGraphState({ + chatContext: state.chatContext, + newConversations: state.newConversations }) return { chatContext: state.chatContext, - newConversations + newConversations: state.newConversations } } +} diff --git a/src/shared/plugins/fs-plugin/server/chat-strategy/fs-chat-strategy-provider.ts b/src/shared/plugins/fs-plugin/server/chat-strategy/fs-chat-strategy-provider.ts index d48049a..e2c8c49 100644 --- a/src/shared/plugins/fs-plugin/server/chat-strategy/fs-chat-strategy-provider.ts +++ b/src/shared/plugins/fs-plugin/server/chat-strategy/fs-chat-strategy-provider.ts @@ -10,26 +10,29 @@ import { import { VsCodeFS } from '@extension/file-utils/vscode-fs' import { logger } from '@extension/logger' import { getWorkspaceFolder } from '@extension/utils' -import type { BaseStrategyOptions } from '@extension/webview-api/chat-context-processor/strategies/base-strategy' -import type { - ChatGraphNode, - ChatGraphState -} from '@extension/webview-api/chat-context-processor/strategies/chat-strategy/nodes/state' import { formatCodeSnippet } from '@extension/webview-api/chat-context-processor/utils/code-snippet-formatter' import { getFileContent } from '@extension/webview-api/chat-context-processor/utils/get-file-content' import type { StructuredTool } from '@langchain/core/tools' import type { ChatContext, Conversation } from '@shared/entities' +import type { + GetAgentState, + GetMentionState +} from '@shared/plugins/base/base-to-state' import type { ChatStrategyProvider } from '@shared/plugins/base/server/create-provider-manager' -import { PluginId } from '@shared/plugins/base/types' +import { + createGraphNodeFromNodes, + createToolsFromNodes, + type BaseStrategyOptions, + type ChatGraphNode, + type ChatGraphState +} from '@shared/plugins/base/strategies' import { mergeCodeSnippets } from '@shared/plugins/fs-plugin/server/merge-code-snippets' import { removeDuplicates } from '@shared/utils/common' -import type { EditorError, FsPluginState } from '../../types' -import { - createCodebaseSearchNode, - createCodebaseSearchTool -} from './codebase-search-node' -import { createFsVisitNode, createFsVisitTool } from './fs-visit-node' +import { FsToState } from '../../fs-to-state' +import { type EditorError } from '../../types' +import { CodebaseSearchNode } from './codebase-search-node' +import { FsVisitNode } from './fs-visit-node' interface BuildFilePromptsResult { selectedFilesPrompt: string @@ -38,7 +41,23 @@ interface BuildFilePromptsResult { treePrompt: string } +interface ConversationWithStateProps { + conversation: Conversation + mentionState: GetMentionState + agentState: GetAgentState +} + export class FsChatStrategyProvider implements ChatStrategyProvider { + private createConversationWithStateProps( + conversation: Conversation + ): ConversationWithStateProps { + const fsToState = new FsToState(conversation) + const mentionState = fsToState.toMentionsState() + const agentState = fsToState.toAgentsState() + + return { conversation, mentionState, agentState } + } + async buildSystemMessagePrompt(chatContext: ChatContext): Promise { const hasAttachedFiles = this.checkForAttachedFiles(chatContext) @@ -48,21 +67,17 @@ export class FsChatStrategyProvider implements ChatStrategyProvider { } async buildContextMessagePrompt(conversation: Conversation): Promise { - const state = conversation.pluginStates?.[PluginId.Fs] as - | Partial - | undefined + const props = this.createConversationWithStateProps(conversation) - if (!state) return '' - - const codeSnippetsPrompt = await this.buildCodeSnippetsPrompt(state) + const codeSnippetsPrompt = await this.buildCodeSnippetsPrompt(props) codeSnippetsPrompt && logger.dev.verbose('codeSnippetsPrompt', codeSnippetsPrompt) - const editorErrorsPrompt = this.buildEditorErrorsPrompt(state) + const editorErrorsPrompt = this.buildEditorErrorsPrompt(props) editorErrorsPrompt && logger.dev.verbose('editorErrorsPrompt', editorErrorsPrompt) - const { currentFilesPrompt } = await this.buildFilePrompts(state) + const { currentFilesPrompt } = await this.buildFilePrompts(props) const prompts = [ codeSnippetsPrompt, currentFilesPrompt, @@ -73,15 +88,10 @@ export class FsChatStrategyProvider implements ChatStrategyProvider { } async buildHumanMessagePrompt(conversation: Conversation): Promise { - const state = conversation.pluginStates?.[PluginId.Fs] as - | Partial - | undefined - - if (!state) return '' - + const props = this.createConversationWithStateProps(conversation) const { selectedFilesPrompt, treePrompt } = - await this.buildFilePrompts(state) - const codeChunksPrompt = this.buildCodeChunksPrompt(state) + await this.buildFilePrompts(props) + const codeChunksPrompt = this.buildCodeChunksPrompt(props) return ` ${treePrompt} @@ -102,14 +112,12 @@ ${codeChunksPrompt}` async buildHumanMessageImageUrls( conversation: Conversation ): Promise { - const state = conversation.pluginStates?.[PluginId.Fs] as - | Partial - | undefined + const { selectedImagesFromOutsideUrl } = conversation.state - if (!state) return [] + if (!selectedImagesFromOutsideUrl) return [] - const { selectedImagesFromOutsideUrl } = state - const { imageBase64Urls } = await this.buildFilePrompts(state) + const props = this.createConversationWithStateProps(conversation) + const { imageBase64Urls } = await this.buildFilePrompts(props) return removeDuplicates([ ...(selectedImagesFromOutsideUrl?.map(image => image.url) || []), ...imageBase64Urls @@ -117,43 +125,45 @@ ${codeChunksPrompt}` } async buildAgentTools( - options: BaseStrategyOptions, + strategyOptions: BaseStrategyOptions, state: ChatGraphState ): Promise { - const tools = await Promise.all([ - createCodebaseSearchTool(options, state), - createFsVisitTool(options, state) - ]) - return tools.filter(Boolean) as StructuredTool[] + return await createToolsFromNodes({ + nodeClasses: [CodebaseSearchNode, FsVisitNode], + strategyOptions, + state + }) } async buildLanggraphToolNodes( - options: BaseStrategyOptions + strategyOptions: BaseStrategyOptions ): Promise { - return [createCodebaseSearchNode(options), createFsVisitNode(options)] + return await createGraphNodeFromNodes({ + nodeClasses: [CodebaseSearchNode, FsVisitNode], + strategyOptions + }) } private checkForAttachedFiles(chatContext: ChatContext): boolean { return chatContext.conversations.some(conversation => { - const { - selectedFilesFromEditor = [], - selectedFoldersFromEditor = [], - codeSnippetFromAgent = [] - } = (conversation.pluginStates?.[PluginId.Fs] as - | Partial - | undefined) || {} + const { mentionState, agentState } = + this.createConversationWithStateProps(conversation) + const codeSnippet = agentState?.codeSnippets || [] + + const { selectedFiles = [], selectedFolders = [] } = mentionState || {} return ( - selectedFilesFromEditor.length > 0 || - selectedFoldersFromEditor.length > 0 || - codeSnippetFromAgent.length > 0 + selectedFiles.length > 0 || + selectedFolders.length > 0 || + codeSnippet.length > 0 ) }) } private async buildFilePrompts( - state: Partial + props: ConversationWithStateProps ): Promise { + const { conversation, mentionState, agentState } = props const result: BuildFilePromptsResult = { selectedFilesPrompt: '', currentFilesPrompt: '', @@ -163,16 +173,16 @@ ${codeChunksPrompt}` const workspacePath = getWorkspaceFolder().uri.fsPath const currentFilePaths = new Set( - state.currentFilesFromVSCode?.map(file => file.fullPath) + conversation.state.currentFilesFromVSCode?.map(file => file.fullPath) ) const processedFiles = new Set() const filesOrFolders = removeDuplicates( [ - ...(state.selectedFilesFromEditor || []), - ...(state.selectedFilesFromFileSelector || []), - ...(state.selectedFoldersFromEditor || []), - ...(state.selectedFilesFromAgent || []) + ...(mentionState?.selectedFiles || []), + ...(mentionState?.selectedFolders || []), + ...(conversation?.state?.selectedFilesFromFileSelector || []), + ...(agentState?.codeSnippets || []) ], ['fullPath'] ).map(file => file.fullPath) @@ -218,11 +228,11 @@ ${codeChunksPrompt}` } }) - if (state.selectedTreesFromEditor?.length) { + if (mentionState?.selectedTrees?.length) { result.treePrompt = ` ## Some Project Structure -${state.selectedTreesFromEditor.map(tree => tree.listString).join('\n')} +${mentionState?.selectedTrees?.map(tree => tree.listString).join('\n')} ` } @@ -239,16 +249,13 @@ ${result.currentFilesPrompt} } private async buildCodeSnippetsPrompt( - state: Partial + props: ConversationWithStateProps ): Promise { - const { - enableCodebaseAgent, - codeSnippetFromAgent: codeSnippetFromCodebaseAgent - } = state - - if (!enableCodebaseAgent || !codeSnippetFromCodebaseAgent?.length) return '' + const { mentionState, agentState } = props + if (!mentionState?.enableCodebaseAgent || !agentState.codeSnippets?.length) + return '' - const mergedSnippets = await mergeCodeSnippets(codeSnippetFromCodebaseAgent) + const mergedSnippets = await mergeCodeSnippets(agentState.codeSnippets) const snippetsContent = mergedSnippets .map(snippet => @@ -272,12 +279,11 @@ ${CONTENT_SEPARATOR} : '' } - private buildCodeChunksPrompt(state: Partial): string { - const { codeChunksFromEditor } = state + private buildCodeChunksPrompt(props: ConversationWithStateProps): string { + const { mentionState } = props + if (!mentionState?.codeChunks?.length) return '' - if (!codeChunksFromEditor?.length) return '' - - const chunksContent = removeDuplicates(codeChunksFromEditor, [ + const chunksContent = removeDuplicates(mentionState.codeChunks, [ 'relativePath', 'code' ]) @@ -296,13 +302,12 @@ ${CONTENT_SEPARATOR} return chunksContent } - private buildEditorErrorsPrompt(state: Partial): string { - const { editorErrors } = state - - if (!editorErrors?.length) return '' + private buildEditorErrorsPrompt(props: ConversationWithStateProps): string { + const { mentionState } = props + if (!mentionState?.editorErrors?.length) return '' // Group errors by file - const errorsByFile = editorErrors.reduce( + const errorsByFile = mentionState.editorErrors.reduce( (acc, error) => { if (!acc[error.file]) { acc[error.file] = [] diff --git a/src/shared/plugins/fs-plugin/server/chat-strategy/fs-visit-node.ts b/src/shared/plugins/fs-plugin/server/chat-strategy/fs-visit-node.ts index d0ca135..58e2120 100644 --- a/src/shared/plugins/fs-plugin/server/chat-strategy/fs-visit-node.ts +++ b/src/shared/plugins/fs-plugin/server/chat-strategy/fs-visit-node.ts @@ -1,116 +1,43 @@ -/* eslint-disable unused-imports/no-unused-vars */ -import { getValidFiles } from '@extension/file-utils/get-valid-files' -import type { FileInfo } from '@extension/file-utils/traverse-fs' -import type { BaseStrategyOptions } from '@extension/webview-api/chat-context-processor/strategies/base-strategy' +import { FsVisitAgent } from '@shared/plugins/agents/fs-visit-agent' import { - dispatchChatGraphState, - type ChatGraphState, - type CreateChatGraphNode -} from '@extension/webview-api/chat-context-processor/strategies/chat-strategy/nodes/state' -import { findCurrentToolsCallParams } from '@extension/webview-api/chat-context-processor/utils/find-current-tools-call-params' -import type { ToolMessage } from '@langchain/core/messages' -import { DynamicStructuredTool } from '@langchain/core/tools' -import type { ConversationLog, LangchainTool } from '@shared/entities' -import { PluginId } from '@shared/plugins/base/types' -import { settledPromiseResults } from '@shared/utils/common' -import { produce } from 'immer' -import { v4 as uuidv4 } from 'uuid' -import { z } from 'zod' - -import type { FsPluginLog, FsPluginState } from '../../types' - -interface FsVisitToolResult { - files: FileInfo[] -} - -export const createFsVisitTool = async ( - options: BaseStrategyOptions, - state: ChatGraphState -) => { - const getFileContents = async ( - relativePaths: string[] - ): Promise => { - const files = await getValidFiles(relativePaths, { - isGetFileContent: false - }) - return { - files - } - } - - return new DynamicStructuredTool({ - name: 'fsVisit', - description: - 'A tool for directly accessing and reading specific files in the codebase.', - func: async ({ relativePaths }): Promise => { - const result = await getFileContents(relativePaths) - return result - }, - schema: z.object({ - relativePaths: z - .array(z.string()) - .describe( - 'An array of relative file paths to read from the workspace root' - ) - }) - }) -} - -export const createFsVisitNode: CreateChatGraphNode = - options => async state => { - const { conversations } = state.chatContext - const lastConversation = conversations.at(-1) - const logs: ConversationLog[] = [] - const fsVisitTool = await createFsVisitTool(options, state) - - if (!fsVisitTool) return {} - - const tools: LangchainTool[] = [fsVisitTool] - const lastMessage = state.messages.at(-1) - const toolCalls = findCurrentToolsCallParams(lastMessage, tools) - - if (!toolCalls.length) return {} - - const toolCallsPromises = toolCalls.map(async toolCall => { - const toolMessage = (await fsVisitTool.invoke(toolCall)) as ToolMessage - - const result = JSON.parse( - toolMessage?.lc_kwargs.content - ) as FsVisitToolResult - - lastConversation!.pluginStates![PluginId.Fs] = produce( - lastConversation!.pluginStates![PluginId.Fs] as Partial, - (draft: Partial) => { - if (!draft.selectedFilesFromAgent) { - draft.selectedFilesFromAgent = [] - } - - draft.selectedFilesFromAgent.push(...result.files) + BaseNode, + dispatchBaseGraphState, + type ChatGraphState +} from '@shared/plugins/base/strategies' + +export class FsVisitNode extends BaseNode { + onInit() { + this.registerAgentConfig(FsVisitAgent.name, state => + this.createAgentConfig({ + agentClass: FsVisitAgent, + agentContext: { + state, + strategyOptions: this.context.strategyOptions, + createToolOptions: {} } - ) + }) + ) + } - logs.push({ - id: uuidv4(), - createdAt: Date.now(), - pluginId: PluginId.Fs, - title: 'Auto visit files', - selectedFilesFromAgent: result.files - } satisfies FsPluginLog) + async execute(state: ChatGraphState) { + const toolCallsResults = await this.executeAgentTool(state, { + agentClass: FsVisitAgent }) - await settledPromiseResults(toolCallsPromises) + if (!toolCallsResults.agents.length) return {} - const newConversations = produce(state.newConversations, draft => { - draft.at(-1)!.logs.push(...logs) - }) + const newConversation = state.newConversations.at(-1)! + this.addAgentsToConversation(newConversation, toolCallsResults.agents) + this.addLogsToConversation(newConversation, toolCallsResults.logs) - dispatchChatGraphState({ - newConversations, - chatContext: state.chatContext + dispatchBaseGraphState({ + chatContext: state.chatContext, + newConversations: state.newConversations }) return { chatContext: state.chatContext, - newConversations + newConversations: state.newConversations } } +} diff --git a/src/shared/plugins/fs-plugin/server/fs-mention-utils-provider.ts b/src/shared/plugins/fs-plugin/server/fs-mention-utils-provider.ts new file mode 100644 index 0000000..c0b6f79 --- /dev/null +++ b/src/shared/plugins/fs-plugin/server/fs-mention-utils-provider.ts @@ -0,0 +1,73 @@ +import type { FileInfo, FolderInfo } from '@extension/file-utils/traverse-fs' +import type { ControllerRegister } from '@extension/registers/controller-register' +import type { Mention } from '@shared/entities' +import type { MentionUtilsProvider } from '@shared/plugins/base/server/create-provider-manager' + +import { FsMentionType, type TreeInfo } from '../types' + +export class FsMentionUtilsProvider implements MentionUtilsProvider { + async createRefreshMentionFn(controllerRegister: ControllerRegister) { + const files = await controllerRegister + .api('file') + .traverseWorkspaceFiles({ filesOrFolders: ['./'] }) + + const folders = await controllerRegister + .api('file') + .traverseWorkspaceFolders({ folders: ['./'] }) + + const editorErrors = await controllerRegister + .api('file') + .getCurrentEditorErrors() + + const treesInfo = await controllerRegister + .api('file') + .getWorkspaceTreesInfo({ depth: 5 }) + + const filePathMapFile = new Map() + + for (const file of files) { + filePathMapFile.set(file.fullPath, file) + } + + const filePathMapFolder = new Map() + + for (const folder of folders) { + filePathMapFolder.set(folder.fullPath, folder) + } + + const filePathMapTree = new Map() + + for (const tree of treesInfo) { + filePathMapTree.set(tree.fullPath, tree) + } + + return (_mention: Mention) => { + const mention = { ..._mention } as Mention + switch (mention.type) { + case FsMentionType.File: + const file = filePathMapFile.get(mention.data.fullPath) + if (file) mention.data = file + break + + case FsMentionType.Folder: + const folder = filePathMapFolder.get(mention.data.fullPath) + if (folder) mention.data = folder + break + + case FsMentionType.Tree: + const tree = filePathMapTree.get(mention.data.fullPath) + if (tree) mention.data = tree + break + + case FsMentionType.Errors: + mention.data = editorErrors + break + + default: + break + } + + return mention + } + } +} diff --git a/src/shared/plugins/fs-plugin/server/fs-server-plugin.ts b/src/shared/plugins/fs-plugin/server/fs-server-plugin.ts index 7a280aa..6318a4c 100644 --- a/src/shared/plugins/fs-plugin/server/fs-server-plugin.ts +++ b/src/shared/plugins/fs-plugin/server/fs-server-plugin.ts @@ -7,6 +7,7 @@ import { pkg } from '@shared/utils/pkg' import type { FsPluginState } from '../types' import { FsChatStrategyProvider } from './chat-strategy/fs-chat-strategy-provider' +import { FsMentionUtilsProvider } from './fs-mention-utils-provider' export class FsServerPlugin implements ServerPlugin { id = PluginId.Fs @@ -22,6 +23,11 @@ export class FsServerPlugin implements ServerPlugin { 'chatStrategy', () => new FsChatStrategyProvider() ) + + this.context.registerProvider( + 'mentionUtils', + () => new FsMentionUtilsProvider() + ) } deactivate(): void { diff --git a/src/shared/plugins/fs-plugin/types.ts b/src/shared/plugins/fs-plugin/types.ts index 86cc148..c0396b0 100644 --- a/src/shared/plugins/fs-plugin/types.ts +++ b/src/shared/plugins/fs-plugin/types.ts @@ -1,7 +1,34 @@ import type { FileInfo, FolderInfo } from '@extension/file-utils/traverse-fs' -import type { BaseConversationLog } from '@shared/entities' +import type { Mention } from '@shared/entities' -import type { PluginId } from '../base/types' +import { PluginId } from '../base/types' + +export enum FsMentionType { + Files = `${PluginId.Fs}#files`, + File = `${PluginId.Fs}#file`, + Folders = `${PluginId.Fs}#folders`, + Folder = `${PluginId.Fs}#folder`, + Trees = `${PluginId.Fs}#trees`, + Tree = `${PluginId.Fs}#tree`, + Code = `${PluginId.Fs}#code`, + Codebase = `${PluginId.Fs}#codebase`, + Errors = `${PluginId.Fs}#errors` +} + +export type FileMention = Mention +export type FolderMention = Mention +export type TreeMention = Mention +export type CodeMention = Mention +export type CodebaseMention = Mention +export type ErrorMention = Mention + +export type FsMention = + | FileMention + | FolderMention + | TreeMention + | CodeMention + | CodebaseMention + | ErrorMention export interface CodeSnippet { fileHash: string @@ -23,10 +50,6 @@ export interface CodeChunk { endLine?: number } -export interface ImageInfo { - url: string -} - export interface EditorError { message: string code?: string @@ -44,22 +67,4 @@ export interface TreeInfo { listString: string // markdown list string, for ai reading } -export interface FsPluginState { - selectedFilesFromFileSelector: FileInfo[] - selectedFilesFromEditor: FileInfo[] - selectedFilesFromAgent: FileInfo[] - currentFilesFromVSCode: FileInfo[] - selectedFoldersFromEditor: FolderInfo[] - selectedImagesFromOutsideUrl: ImageInfo[] - codeChunksFromEditor: CodeChunk[] - codeSnippetFromAgent: CodeSnippet[] - enableCodebaseAgent: boolean - editorErrors: EditorError[] - selectedTreesFromEditor: TreeInfo[] -} - -export interface FsPluginLog extends BaseConversationLog { - pluginId: PluginId.Fs - codeSnippets?: CodeSnippet[] - selectedFilesFromAgent?: FileInfo[] -} +export interface FsPluginState {} diff --git a/src/shared/plugins/git-plugin/client/git-client-plugin.tsx b/src/shared/plugins/git-plugin/client/git-client-plugin.tsx index fbbca80..8188d02 100644 --- a/src/shared/plugins/git-plugin/client/git-client-plugin.tsx +++ b/src/shared/plugins/git-plugin/client/git-client-plugin.tsx @@ -10,18 +10,14 @@ import { useQuery } from '@tanstack/react-query' import { api } from '@webview/services/api-client' import { type MentionOption } from '@webview/types/chat' -import type { GitCommit, GitPluginState } from '../types' +import { GitCommit, GitMentionType, GitPluginState } from '../types' export const GitClientPlugin = createClientPlugin({ id: PluginId.Git, version: pkg.version, getInitialState() { - return { - gitCommitsFromEditor: [], - gitDiffWithMainBranchFromEditor: null, - gitDiffOfWorkingStateFromEditor: null - } + return {} }, setup(props) { @@ -33,7 +29,6 @@ export const GitClientPlugin = createClientPlugin({ const createUseMentionOptions = (props: SetupProps) => (): UseMentionOptionsReturns => { - const { setState } = props const { data: gitCommits = [] } = useQuery({ queryKey: ['realtime', 'git-commits'], queryFn: () => @@ -45,16 +40,10 @@ const createUseMentionOptions = const gitCommitsMentionOptions: MentionOption[] = gitCommits.map( commit => ({ - id: `${PluginId.Git}#git-commit#${commit.sha}`, - type: `${PluginId.Git}#git-commit`, + id: `${GitMentionType.GitCommit}#${commit.sha}`, + type: GitMentionType.GitCommit, label: commit.message, data: commit, - onUpdatePluginState: dataArr => { - setState(draft => { - draft.gitCommitsFromEditor = dataArr - }) - }, - searchKeywords: [commit.sha, commit.message], itemLayoutProps: { icon: , @@ -66,8 +55,8 @@ const createUseMentionOptions = return [ { - id: `${PluginId.Git}#git`, - type: `${PluginId.Git}#git`, + id: GitMentionType.Git, + type: GitMentionType.Git, label: 'Git', topLevelSort: 5, searchKeywords: ['git'], @@ -77,14 +66,10 @@ const createUseMentionOptions = }, children: [ { - id: `${PluginId.Git}#git-diff`, - type: `${PluginId.Git}#git-diff`, + id: GitMentionType.GitDiff, + type: GitMentionType.GitDiff, label: 'Diff (Diff of Working State)', - onUpdatePluginState: dataArr => { - setState(draft => { - draft.gitDiffOfWorkingStateFromEditor = dataArr.at(-1) - }) - }, + data: null, // TODO: add diff of working state searchKeywords: ['diff'], itemLayoutProps: { icon: , @@ -92,14 +77,10 @@ const createUseMentionOptions = } }, { - id: `${PluginId.Git}#git-pr`, - type: `${PluginId.Git}#git-pr`, + id: GitMentionType.GitPR, + type: GitMentionType.GitPR, label: 'PR (Diff with Main Branch)', - onUpdatePluginState: dataArr => { - setState(draft => { - draft.gitDiffWithMainBranchFromEditor = dataArr.at(-1) - }) - }, + data: null, // TODO: add diff with main branch searchKeywords: ['pull request', 'pr', 'diff'], itemLayoutProps: { icon: , diff --git a/src/shared/plugins/git-plugin/git-to-state.ts b/src/shared/plugins/git-plugin/git-to-state.ts new file mode 100644 index 0000000..e5038a8 --- /dev/null +++ b/src/shared/plugins/git-plugin/git-to-state.ts @@ -0,0 +1,20 @@ +import { BaseToState } from '../base/base-to-state' +import { GitMentionType, type GitMention } from './types' + +export class GitToState extends BaseToState { + toMentionsState() { + return { + gitCommits: this.getMentionDataByType(GitMentionType.GitCommit), + + gitDiffOfWorkingState: this.getMentionDataByType( + GitMentionType.GitDiff + )?.[0], + + gitDiffWithMain: this.getMentionDataByType(GitMentionType.GitDiff)?.[0] + } + } + + toAgentsState() { + return {} + } +} diff --git a/src/shared/plugins/git-plugin/server/chat-strategy/git-chat-strategy-provider.ts b/src/shared/plugins/git-plugin/server/chat-strategy/git-chat-strategy-provider.ts index 5f1bd66..b0b8bda 100644 --- a/src/shared/plugins/git-plugin/server/chat-strategy/git-chat-strategy-provider.ts +++ b/src/shared/plugins/git-plugin/server/chat-strategy/git-chat-strategy-provider.ts @@ -1,23 +1,38 @@ import type { Conversation } from '@shared/entities' +import type { + GetAgentState, + GetMentionState +} from '@shared/plugins/base/base-to-state' import type { ChatStrategyProvider } from '@shared/plugins/base/server/create-provider-manager' -import { PluginId } from '@shared/plugins/base/types' import { removeDuplicates } from '@shared/utils/common' -import type { GitDiff, GitPluginState } from '../../types' +import { GitToState } from '../../git-to-state' +import type { GitDiff } from '../../types' -export class GitChatStrategyProvider implements ChatStrategyProvider { - async buildContextMessagePrompt(conversation: Conversation): Promise { - const state = conversation.pluginStates?.[PluginId.Git] as - | Partial - | undefined +interface ConversationWithStateProps { + conversation: Conversation + mentionState: GetMentionState + agentState: GetAgentState +} - if (!state) return '' +export class GitChatStrategyProvider implements ChatStrategyProvider { + private createConversationWithStateProps( + conversation: Conversation + ): ConversationWithStateProps { + const gitToState = new GitToState(conversation) + const mentionState = gitToState.toMentionsState() + const agentState = gitToState.toAgentsState() + + return { conversation, mentionState, agentState } + } + async buildContextMessagePrompt(conversation: Conversation): Promise { + const props = this.createConversationWithStateProps(conversation) const diffWithMainBranchPrompt = - this.buildGitDiffWithMainBranchPrompt(state) + this.buildGitDiffWithMainBranchPrompt(props) const diffOfWorkingStatePrompt = - this.buildGitDiffOfWorkingStatePrompt(state) - const commitPrompt = this.buildGitCommitPrompt(state) + this.buildGitDiffOfWorkingStatePrompt(props) + const commitPrompt = this.buildGitCommitPrompt(props) const prompts = [ diffWithMainBranchPrompt, @@ -28,16 +43,16 @@ export class GitChatStrategyProvider implements ChatStrategyProvider { return prompts.join('\n\n') } - private buildGitCommitPrompt(state: Partial): string { - const { gitCommitsFromEditor = [] } = state + private buildGitCommitPrompt(props: ConversationWithStateProps): string { + const { mentionState } = props - if (!gitCommitsFromEditor.length) return '' + if (!mentionState?.gitCommits.length) return '' let gitCommitContent = ` ## Git Commits ` - removeDuplicates(gitCommitsFromEditor, ['sha']).forEach(commit => { + removeDuplicates(mentionState.gitCommits, ['sha']).forEach(commit => { gitCommitContent += ` Commit: ${commit.sha} Message: ${commit.message} @@ -52,28 +67,28 @@ ${this.buildGitDiffsPrompt(commit.diff)} } private buildGitDiffWithMainBranchPrompt( - state: Partial + props: ConversationWithStateProps ): string { - const { gitDiffWithMainBranchFromEditor } = state + const { mentionState } = props - if (!gitDiffWithMainBranchFromEditor) return '' + if (!mentionState?.gitDiffWithMain) return '' return ` ## Git Diff with Main Branch -${this.buildGitDiffsPrompt([gitDiffWithMainBranchFromEditor])} +${this.buildGitDiffsPrompt([mentionState.gitDiffWithMain])} ` } private buildGitDiffOfWorkingStatePrompt( - state: Partial + props: ConversationWithStateProps ): string { - const { gitDiffOfWorkingStateFromEditor } = state + const { mentionState } = props - if (!gitDiffOfWorkingStateFromEditor) return '' + if (!mentionState?.gitDiffOfWorkingState) return '' return ` ## Git Diff of Working State -${this.buildGitDiffsPrompt([gitDiffOfWorkingStateFromEditor])} +${this.buildGitDiffsPrompt([mentionState.gitDiffOfWorkingState])} ` } diff --git a/src/shared/plugins/git-plugin/server/git-mention-utils-provider.ts b/src/shared/plugins/git-plugin/server/git-mention-utils-provider.ts new file mode 100644 index 0000000..7731763 --- /dev/null +++ b/src/shared/plugins/git-plugin/server/git-mention-utils-provider.ts @@ -0,0 +1,34 @@ +import type { ControllerRegister } from '@extension/registers/controller-register' +import type { Mention } from '@shared/entities' +import type { MentionUtilsProvider } from '@shared/plugins/base/server/create-provider-manager' + +import { GitMentionType } from '../types' + +export class GitMentionUtilsProvider implements MentionUtilsProvider { + async createRefreshMentionFn(controllerRegister: ControllerRegister) { + const commits = await controllerRegister.api('git').getHistoryCommits({ + maxCount: 50 + }) + + // Create a map of commit SHAs for quick lookup + const commitMap = new Map(commits.map(commit => [commit.sha, commit])) + + return (_mention: Mention) => { + const mention = { ..._mention } as Mention + switch (mention.type) { + case GitMentionType.GitCommit: + const commit = commitMap.get(mention.data.sha) + if (commit) mention.data = commit + break + case GitMentionType.GitDiff: + break + case GitMentionType.GitPR: + break + default: + break + } + + return mention + } + } +} diff --git a/src/shared/plugins/git-plugin/server/git-server-plugin.ts b/src/shared/plugins/git-plugin/server/git-server-plugin.ts index b311054..fc79c80 100644 --- a/src/shared/plugins/git-plugin/server/git-server-plugin.ts +++ b/src/shared/plugins/git-plugin/server/git-server-plugin.ts @@ -7,6 +7,7 @@ import { pkg } from '@shared/utils/pkg' import type { GitPluginState } from '../types' import { GitChatStrategyProvider } from './chat-strategy/git-chat-strategy-provider' +import { GitMentionUtilsProvider } from './git-mention-utils-provider' export class GitServerPlugin implements ServerPlugin { id = PluginId.Git @@ -22,6 +23,11 @@ export class GitServerPlugin implements ServerPlugin { 'chatStrategy', () => new GitChatStrategyProvider() ) + + this.context.registerProvider( + 'mentionUtils', + () => new GitMentionUtilsProvider() + ) } deactivate(): void { diff --git a/src/shared/plugins/git-plugin/types.ts b/src/shared/plugins/git-plugin/types.ts index 262a8b9..642e8a9 100644 --- a/src/shared/plugins/git-plugin/types.ts +++ b/src/shared/plugins/git-plugin/types.ts @@ -1,3 +1,19 @@ +import type { Mention } from '@shared/entities' + +import { PluginId } from '../base/types' + +export enum GitMentionType { + Git = `${PluginId.Git}#git`, + GitCommit = `${PluginId.Git}#git-commit`, + GitDiff = `${PluginId.Git}#git-diff`, + GitPR = `${PluginId.Git}#git-pr` +} + +export type GitCommitMention = Mention +export type GitDiffMention = Mention +export type GitPRMention = Mention +export type GitMention = GitCommitMention | GitDiffMention | GitPRMention + export interface GitDiff { /** * @example '.github/workflows/ci.yml' @@ -45,8 +61,4 @@ export interface GitCommit { date: string } -export interface GitPluginState { - gitCommitsFromEditor: GitCommit[] - gitDiffWithMainBranchFromEditor: GitDiff | null - gitDiffOfWorkingStateFromEditor: GitDiff | null -} +export interface GitPluginState {} diff --git a/src/shared/plugins/terminal-plugin/client/terminal-client-plugin.tsx b/src/shared/plugins/terminal-plugin/client/terminal-client-plugin.tsx index 9a8cf1d..97c5306 100644 --- a/src/shared/plugins/terminal-plugin/client/terminal-client-plugin.tsx +++ b/src/shared/plugins/terminal-plugin/client/terminal-client-plugin.tsx @@ -10,7 +10,11 @@ import { api } from '@webview/services/api-client' import { type MentionOption } from '@webview/types/chat' import { SquareTerminalIcon } from 'lucide-react' -import type { TerminalInfo, TerminalPluginState } from '../types' +import { + TerminalInfo, + TerminalMentionType, + TerminalPluginState +} from '../types' import { MentionTerminalPreview } from './mention-terminal-preview' export const TerminalClientPlugin = createClientPlugin({ @@ -18,10 +22,7 @@ export const TerminalClientPlugin = createClientPlugin({ version: pkg.version, getInitialState() { - return { - selectedTerminalsFromEditor: [], - terminalLogsFromAgent: [] - } + return {} }, setup(props) { @@ -33,22 +34,16 @@ export const TerminalClientPlugin = createClientPlugin({ const createUseMentionOptions = (props: SetupProps) => (): UseMentionOptionsReturns => { - const { setState } = props const { data: terminals = [] } = useQuery({ queryKey: ['realtime', 'terminals'], queryFn: () => api.terminal.getTerminalsForMention({}) }) const terminalMentionOptions: MentionOption[] = terminals.map(terminal => ({ - id: `${PluginId.Terminal}#terminal#${terminal.processId}`, - type: `${PluginId.Terminal}#terminal`, + id: `${TerminalMentionType.Terminal}#${terminal.processId}`, + type: TerminalMentionType.Terminal, label: terminal.name, data: terminal, - onUpdatePluginState: dataArr => { - setState(draft => { - draft.selectedTerminalsFromEditor = dataArr - }) - }, searchKeywords: [terminal.name], itemLayoutProps: { icon: , @@ -60,8 +55,8 @@ const createUseMentionOptions = return [ { - id: `${PluginId.Terminal}#terminals`, - type: `${PluginId.Terminal}#terminals`, + id: TerminalMentionType.Terminals, + type: TerminalMentionType.Terminals, label: 'Terminals', topLevelSort: 6, searchKeywords: ['terminal', 'shell', 'command'], diff --git a/src/shared/plugins/terminal-plugin/server/chat-strategy/terminal-chat-strategy-provider.ts b/src/shared/plugins/terminal-plugin/server/chat-strategy/terminal-chat-strategy-provider.ts index c486907..620ac55 100644 --- a/src/shared/plugins/terminal-plugin/server/chat-strategy/terminal-chat-strategy-provider.ts +++ b/src/shared/plugins/terminal-plugin/server/chat-strategy/terminal-chat-strategy-provider.ts @@ -1,34 +1,50 @@ import type { Conversation } from '@shared/entities' +import type { + GetAgentState, + GetMentionState +} from '@shared/plugins/base/base-to-state' import type { ChatStrategyProvider } from '@shared/plugins/base/server/create-provider-manager' -import { PluginId } from '@shared/plugins/base/types' -import type { TerminalCommand, TerminalPluginState } from '../../types' +import { TerminalToState } from '../../terminal-mentions-to-state' +import type { TerminalCommand } from '../../types' + +interface ConversationWithStateProps { + conversation: Conversation + mentionState: GetMentionState + agentState: GetAgentState +} export class TerminalChatStrategyProvider implements ChatStrategyProvider { - async buildContextMessagePrompt(conversation: Conversation): Promise { - const state = conversation.pluginStates?.[PluginId.Terminal] as - | Partial - | undefined + private createConversationWithStateProps( + conversation: Conversation + ): ConversationWithStateProps { + const terminalToState = new TerminalToState(conversation) + const mentionState = terminalToState.toMentionsState() + const agentState = terminalToState.toAgentsState() - if (!state) return '' + return { conversation, mentionState, agentState } + } + + async buildContextMessagePrompt(conversation: Conversation): Promise { + const props = this.createConversationWithStateProps(conversation) - const terminalLogsPrompt = this.buildTerminalLogsPrompt(state) + const terminalLogsPrompt = this.buildTerminalLogsPrompt(props) const prompts = [terminalLogsPrompt].filter(Boolean) return prompts.join('\n\n') } - private buildTerminalLogsPrompt(state: Partial): string { - const { selectedTerminalsFromEditor = [] } = state + private buildTerminalLogsPrompt(props: ConversationWithStateProps): string { + const { mentionState } = props - if (!selectedTerminalsFromEditor.length) return '' + if (!mentionState?.selectedTerminals.length) return '' let terminalContent = ` ## Terminal Logs ` - selectedTerminalsFromEditor.forEach(terminal => { + mentionState.selectedTerminals.forEach(terminal => { terminalContent += ` Terminal: ${terminal.name} ${terminal.commands.map(cmd => this.buildTerminalCommandPrompt(cmd)).join('\n')} diff --git a/src/shared/plugins/terminal-plugin/server/terminal-mention-utils-provider.ts b/src/shared/plugins/terminal-plugin/server/terminal-mention-utils-provider.ts new file mode 100644 index 0000000..f9b3d07 --- /dev/null +++ b/src/shared/plugins/terminal-plugin/server/terminal-mention-utils-provider.ts @@ -0,0 +1,32 @@ +import type { ControllerRegister } from '@extension/registers/controller-register' +import type { Mention } from '@shared/entities' +import type { MentionUtilsProvider } from '@shared/plugins/base/server/create-provider-manager' + +import { TerminalMentionType } from '../types' + +export class TerminalMentionUtilsProvider implements MentionUtilsProvider { + async createRefreshMentionFn(controllerRegister: ControllerRegister) { + const terminals = await controllerRegister + .api('terminal') + .getTerminalsForMention() + + // Create a map of terminal processIds for quick lookup + const terminalMap = new Map( + terminals.map(terminal => [terminal.processId, terminal]) + ) + + return (_mention: Mention) => { + const mention = { ..._mention } as Mention + switch (mention.type) { + case TerminalMentionType.Terminal: + const terminal = terminalMap.get(mention.data.processId) + if (terminal) mention.data = terminal + break + default: + break + } + + return mention + } + } +} diff --git a/src/shared/plugins/terminal-plugin/server/terminal-server-plugin.ts b/src/shared/plugins/terminal-plugin/server/terminal-server-plugin.ts index 8e1d371..792c2b3 100644 --- a/src/shared/plugins/terminal-plugin/server/terminal-server-plugin.ts +++ b/src/shared/plugins/terminal-plugin/server/terminal-server-plugin.ts @@ -7,6 +7,7 @@ import { pkg } from '@shared/utils/pkg' import type { TerminalPluginState } from '../types' import { TerminalChatStrategyProvider } from './chat-strategy/terminal-chat-strategy-provider' +import { TerminalMentionUtilsProvider } from './terminal-mention-utils-provider' export class TerminalServerPlugin implements ServerPlugin { id = PluginId.Terminal @@ -24,6 +25,11 @@ export class TerminalServerPlugin implements ServerPlugin { 'chatStrategy', () => new TerminalChatStrategyProvider() ) + + this.context.registerProvider( + 'mentionUtils', + () => new TerminalMentionUtilsProvider() + ) } deactivate(): void { diff --git a/src/shared/plugins/terminal-plugin/terminal-mentions-to-state.ts b/src/shared/plugins/terminal-plugin/terminal-mentions-to-state.ts new file mode 100644 index 0000000..a368eea --- /dev/null +++ b/src/shared/plugins/terminal-plugin/terminal-mentions-to-state.ts @@ -0,0 +1,14 @@ +import { BaseToState } from '../base/base-to-state' +import { TerminalMentionType, type TerminalMention } from './types' + +export class TerminalToState extends BaseToState { + toMentionsState() { + return { + selectedTerminals: this.getMentionDataByType(TerminalMentionType.Terminal) + } + } + + toAgentsState() { + return {} + } +} diff --git a/src/shared/plugins/terminal-plugin/types.ts b/src/shared/plugins/terminal-plugin/types.ts index 26a2cb1..ab596de 100644 --- a/src/shared/plugins/terminal-plugin/types.ts +++ b/src/shared/plugins/terminal-plugin/types.ts @@ -1,11 +1,21 @@ import type { TerminalInfo } from '@extension/registers/terminal-watcher-register' +import type { Mention } from '@shared/entities' + +import { PluginId } from '../base/types' export type { TerminalInfo, TerminalCommand } from '@extension/registers/terminal-watcher-register' -export interface TerminalPluginState { - selectedTerminalsFromEditor: TerminalInfo[] - terminalLogsFromAgent: TerminalInfo[] +export enum TerminalMentionType { + Terminals = `${PluginId.Terminal}#terminals`, + Terminal = `${PluginId.Terminal}#terminal` } + +export type TerminalMention = Mention< + TerminalMentionType.Terminal, + TerminalInfo +> + +export interface TerminalPluginState {} diff --git a/src/shared/plugins/web-plugin/client/web-client-plugin.tsx b/src/shared/plugins/web-plugin/client/web-client-plugin.tsx index 59a3945..3a0abac 100644 --- a/src/shared/plugins/web-plugin/client/web-client-plugin.tsx +++ b/src/shared/plugins/web-plugin/client/web-client-plugin.tsx @@ -7,7 +7,7 @@ import { import { PluginId } from '@shared/plugins/base/types' import { pkg } from '@shared/utils/pkg' -import type { WebPluginState } from '../types' +import { WebMentionType, WebPluginState } from '../types' import { WebLogPreview } from './web-log-preview' export const WebClientPlugin = createClientPlugin({ @@ -15,13 +15,7 @@ export const WebClientPlugin = createClientPlugin({ version: pkg.version, getInitialState() { - return { - enableWebSearchAgent: false, - webSearchResultsFromAgent: [], - webSearchAsDocFromAgent: [], - enableWebVisitAgent: false, - webVisitResultsFromAgent: [] - } + return {} }, setup(props) { @@ -33,27 +27,17 @@ export const WebClientPlugin = createClientPlugin({ }) const createUseMentionOptions = - (props: SetupProps) => (): UseMentionOptionsReturns => { - const { setState } = props - - return [ - { - id: `${PluginId.Web}#web`, - type: `${PluginId.Web}#web`, - label: 'Web', - data: true, - onUpdatePluginState: (dataArr: true[]) => { - setState(draft => { - draft.enableWebVisitAgent = dataArr.length > 0 - draft.enableWebSearchAgent = dataArr.length > 0 - }) - }, - topLevelSort: 3, - searchKeywords: ['web', 'search'], - itemLayoutProps: { - icon: , - label: 'Web' - } + (props: SetupProps) => (): UseMentionOptionsReturns => [ + { + id: WebMentionType.Web, + type: WebMentionType.Web, + label: 'Web', + data: true, + topLevelSort: 3, + searchKeywords: ['web', 'search'], + itemLayoutProps: { + icon: , + label: 'Web' } - ] - } + } + ] diff --git a/src/shared/plugins/web-plugin/client/web-log-preview.tsx b/src/shared/plugins/web-plugin/client/web-log-preview.tsx index 56ecb6b..5944ebc 100644 --- a/src/shared/plugins/web-plugin/client/web-log-preview.tsx +++ b/src/shared/plugins/web-plugin/client/web-log-preview.tsx @@ -1,44 +1,60 @@ -import { FC } from 'react' +import { FC, type ReactNode } from 'react' import { GlobeIcon } from '@radix-ui/react-icons' +import { + webSearchAgentName, + webVisitAgentName +} from '@shared/plugins/agents/agent-names' +import type { WebSearchAgent } from '@shared/plugins/agents/web-search-agent' +import type { WebVisitAgent } from '@shared/plugins/agents/web-visit-agent' import type { CustomRenderLogPreviewProps } from '@shared/plugins/base/client/client-plugin-types' -import { PluginId } from '@shared/plugins/base/types' +import type { GetAgent } from '@shared/plugins/base/strategies' import { ChatLogPreview } from '@webview/components/chat/messages/roles/chat-log-preview' import type { PreviewContent } from '@webview/components/content-preview' import { ContentPreviewPopover } from '@webview/components/content-preview-popover' import { cn } from '@webview/utils/common' -import type { WebDocInfo, WebPluginLog } from '../types' +import type { WebDocInfo } from '../types' export const WebLogPreview: FC = props => { - if (props.log.pluginId !== PluginId.Web) return null - const log = props.log as WebPluginLog + const { log } = props + const { agent } = log - return ( + const renderWrapper = (children: ReactNode) => ( -
- {/* Search Results Section */} - {log.webSearchResultsFromAgent?.map((doc, index) => ( - - ))} - - {/* Visit Results Section */} - {log.webVisitResultsFromAgent?.map((doc, index) => ( - - ))} -
+
{children}
) + + if (!agent) return null + + switch (agent.name) { + case webSearchAgentName: + return renderWrapper( + (agent as GetAgent).output.webSearchResults?.map( + (doc, index) => ( + + ) + ) + ) + case webVisitAgentName: + return renderWrapper( + (agent as GetAgent).output.contents?.map( + (doc, index) => ( + + ) + ) + ) + default: + return null + } } interface WebDocItemProps { diff --git a/src/shared/plugins/web-plugin/server/chat-strategy/web-chat-strategy-provider.ts b/src/shared/plugins/web-plugin/server/chat-strategy/web-chat-strategy-provider.ts index adee023..c8c9c3d 100644 --- a/src/shared/plugins/web-plugin/server/chat-strategy/web-chat-strategy-provider.ts +++ b/src/shared/plugins/web-plugin/server/chat-strategy/web-chat-strategy-provider.ts @@ -1,27 +1,46 @@ -import type { BaseStrategyOptions } from '@extension/webview-api/chat-context-processor/strategies/base-strategy' -import type { - ChatGraphNode, - ChatGraphState -} from '@extension/webview-api/chat-context-processor/strategies/chat-strategy/nodes/state' import type { StructuredTool } from '@langchain/core/tools' import type { Conversation } from '@shared/entities' +import type { + GetAgentState, + GetMentionState +} from '@shared/plugins/base/base-to-state' import type { ChatStrategyProvider } from '@shared/plugins/base/server/create-provider-manager' -import { PluginId } from '@shared/plugins/base/types' +import { + createGraphNodeFromNodes, + createToolsFromNodes +} from '@shared/plugins/base/strategies' +import type { + BaseStrategyOptions, + ChatGraphNode, + ChatGraphState +} from '@shared/plugins/base/strategies' import { removeDuplicates } from '@shared/utils/common' -import type { WebPluginState } from '../../types' -import { createWebSearchNode, createWebSearchTool } from './web-search-node' -import { createWebVisitNode, createWebVisitTool } from './web-visit-node' +import { WebToState } from '../../web-to-state' +import { WebSearchNode } from './web-search-node' +import { WebVisitNode } from './web-visit-node' + +interface ConversationWithStateProps { + conversation: Conversation + mentionState: GetMentionState + agentState: GetAgentState +} export class WebChatStrategyProvider implements ChatStrategyProvider { - async buildContextMessagePrompt(conversation: Conversation): Promise { - const state = conversation.pluginStates?.[PluginId.Web] as - | Partial - | undefined + private createConversationWithStateProps( + conversation: Conversation + ): ConversationWithStateProps { + const webToState = new WebToState(conversation) + const mentionState = webToState.toMentionsState() + const agentState = webToState.toAgentsState() + + return { conversation, mentionState, agentState } + } - if (!state) return '' + async buildContextMessagePrompt(conversation: Conversation): Promise { + const props = this.createConversationWithStateProps(conversation) - const relevantWebsPrompt = this.buildRelevantWebsPrompt(state) + const relevantWebsPrompt = this.buildRelevantWebsPrompt(props) const prompts = [relevantWebsPrompt].filter(Boolean) @@ -29,29 +48,35 @@ export class WebChatStrategyProvider implements ChatStrategyProvider { } async buildAgentTools( - options: BaseStrategyOptions, + strategyOptions: BaseStrategyOptions, state: ChatGraphState ): Promise { - const tools = await Promise.all([ - createWebSearchTool(options, state), - createWebVisitTool(options, state) - ]) - return tools.filter(Boolean) as StructuredTool[] + return await createToolsFromNodes({ + nodeClasses: [WebSearchNode, WebVisitNode], + strategyOptions, + state + }) } async buildLanggraphToolNodes( - options: BaseStrategyOptions + strategyOptions: BaseStrategyOptions ): Promise { - return [createWebSearchNode(options), createWebVisitNode(options)] + return await createGraphNodeFromNodes({ + nodeClasses: [WebSearchNode, WebVisitNode], + strategyOptions + }) } - private buildRelevantWebsPrompt(state: Partial): string { - const { webSearchAsDocFromAgent = [], webVisitResultsFromAgent = [] } = - state + private buildRelevantWebsPrompt(props: ConversationWithStateProps): string { + const { agentState } = props + const { webSearchRelevantContent = [], webVisitContents = [] } = agentState const webDocs = [ - ...webSearchAsDocFromAgent, - ...removeDuplicates(webVisitResultsFromAgent, ['url']) + ...webSearchRelevantContent.map(content => ({ + url: '', + content + })), + ...removeDuplicates(webVisitContents, ['url']) ] if (!webDocs.length) return '' diff --git a/src/shared/plugins/web-plugin/server/chat-strategy/web-search-node.ts b/src/shared/plugins/web-plugin/server/chat-strategy/web-search-node.ts index 5b3ce9f..3384137 100644 --- a/src/shared/plugins/web-plugin/server/chat-strategy/web-search-node.ts +++ b/src/shared/plugins/web-plugin/server/chat-strategy/web-search-node.ts @@ -1,246 +1,51 @@ -import { ModelProviderFactory } from '@extension/ai/model-providers/helpers/factory' -import { logger } from '@extension/logger' -import type { BaseStrategyOptions } from '@extension/webview-api/chat-context-processor/strategies/base-strategy' -import { ChatMessagesConstructor } from '@extension/webview-api/chat-context-processor/strategies/chat-strategy/messages-constructors/chat-messages-constructor' +import { WebSearchAgent } from '@shared/plugins/agents/web-search-agent' import { - dispatchChatGraphState, - type ChatGraphState, - type CreateChatGraphNode -} from '@extension/webview-api/chat-context-processor/strategies/chat-strategy/nodes/state' -import { findCurrentToolsCallParams } from '@extension/webview-api/chat-context-processor/utils/find-current-tools-call-params' -import { searxngSearch } from '@extension/webview-api/chat-context-processor/utils/searxng-search' -import { CheerioWebBaseLoader } from '@langchain/community/document_loaders/web/cheerio' -import type { Document } from '@langchain/core/documents' -import { HumanMessage, type ToolMessage } from '@langchain/core/messages' -import { DynamicStructuredTool } from '@langchain/core/tools' -import { FeatureModelSettingKey } from '@shared/entities' -import type { ConversationLog, LangchainTool } from '@shared/entities' -import { PluginId } from '@shared/plugins/base/types' -import { settledPromiseResults } from '@shared/utils/common' -import { produce } from 'immer' -import { v4 as uuidv4 } from 'uuid' -import { z } from 'zod' - -import type { WebDocInfo, WebPluginLog, WebPluginState } from '../../types' - -interface WebSearchToolResult { - relevantContent: string - webSearchResults: WebDocInfo[] -} - -const MAX_CONTENT_LENGTH = 16 * 1000 - -export const createWebSearchTool = async ( - options: BaseStrategyOptions, - state: ChatGraphState -) => { - const { conversations } = state.chatContext - const lastConversation = conversations.at(-1) - - const webPluginState = lastConversation?.pluginStates?.[PluginId.Web] as - | Partial - | undefined - - if (!webPluginState?.enableWebSearchAgent) return null - - const getRelevantContentAndSearchResults = async ( - state: ChatGraphState, - keywords: string - ) => { - const searxngSearchResult = await searxngSearch(keywords, { - abortController: state.abortController - }) - const urls = searxngSearchResult.results.map(result => result.url) - - const docsLoadResult = await settledPromiseResults( - urls.map(url => new CheerioWebBaseLoader(url).load()) - ) - - const docs: Document>[] = docsLoadResult.flat() - - const docsContent = docs - .map(doc => doc.pageContent) - .join('\n') - .slice(0, MAX_CONTENT_LENGTH) - - if (!docsContent) { - logger.warn('No content found in web search results', { - keywords, - docs + BaseNode, + dispatchBaseGraphState, + type ChatGraphState +} from '@shared/plugins/base/strategies' + +import { WebToState } from '../../web-to-state' + +export class WebSearchNode extends BaseNode { + onInit() { + this.registerAgentConfig(WebSearchAgent.name, state => { + const lastConversation = state.chatContext.conversations.at(-1) + const mentionState = new WebToState(lastConversation).toMentionsState() + + return this.createAgentConfig({ + agentClass: WebSearchAgent, + agentContext: { + state, + strategyOptions: this.context.strategyOptions, + createToolOptions: { + enableWebSearchAgent: mentionState.enableWebSearchAgent + } + } }) - return { relevantContent: '', webSearchResults: [] } - } - const chatMessagesConstructor = new ChatMessagesConstructor({ - ...options, - chatContext: state.chatContext }) - const messagesFromChatContext = - await chatMessagesConstructor.constructMessages() - - const modelProvider = await ModelProviderFactory.getModelProvider( - FeatureModelSettingKey.Chat - ) - const aiModel = await modelProvider.createLangChainModel() - - const response = await aiModel - .bind({ signal: state.abortController?.signal }) - .invoke([ - ...messagesFromChatContext.slice(-2), - new HumanMessage({ - content: ` -You are an expert information analyst. Your task is to process web search results and create a high-quality, focused summary that will be used in a subsequent AI conversation. Follow these critical guidelines: - -1. RELEVANCE & FOCUS -- Identify and extract ONLY information that directly addresses the user's query -- Eliminate tangential or loosely related content -- Preserve technical details and specific examples when relevant - -2. INFORMATION QUALITY -- Prioritize factual, verifiable information -- Include specific technical details, numbers, or metrics when present -- Maintain technical accuracy in specialized topics - -3. STRUCTURE & CLARITY -- Present information in a logical, well-structured format -- Use clear, precise language -- Preserve important technical terms and concepts - -4. BALANCED PERSPECTIVE -- Include multiple viewpoints when present -- Note any significant disagreements or contradictions -- Indicate if information seems incomplete or uncertain - -5. CONTEXT PRESERVATION -- Maintain crucial context that affects meaning -- Include relevant dates or version information for technical content -- Preserve attribution for significant claims or findings - -Here's the content to analyze: - -""" -${docsContent} -""" - -Provide a focused, technical summary that will serve as high-quality context for the next phase of AI conversation.` - }) - ]) - - return { - relevantContent: - typeof response.content === 'string' - ? response.content - : JSON.stringify(response.content), - webSearchResults: searxngSearchResult.results - } } - return new DynamicStructuredTool({ - name: 'webSearch', - description: - 'IMPORTANT: Proactively use this web search tool whenever you:\n' + - '1. Need to verify or update your knowledge about recent developments, versions, or current facts\n' + - '2. Are unsure about specific technical details or best practices\n' + - '3. Need real-world examples or implementation details\n' + - '4. Encounter questions about:\n' + - ' - Current events or recent updates\n' + - ' - Latest software versions or features\n' + - ' - Modern best practices or trends\n' + - ' - Specific technical implementations\n' + - '5. Want to provide evidence-based recommendations\n\n' + - 'DO NOT rely solely on your training data when users ask about:\n' + - '- Recent technologies or updates\n' + - '- Current best practices\n' + - '- Specific implementation details\n' + - '- Version-specific features or APIs\n' + - 'Instead, use this tool to get up-to-date information.', - func: async ({ keywords }): Promise => { - const { relevantContent, webSearchResults } = - await getRelevantContentAndSearchResults(state, keywords) - - return { - relevantContent, - webSearchResults - } - }, - schema: z.object({ - keywords: z.string().describe('Keywords to search web') + async execute(state: ChatGraphState) { + const toolCallsResults = await this.executeAgentTool(state, { + agentClass: WebSearchAgent }) - }) -} - -export const createWebSearchNode: CreateChatGraphNode = - options => async state => { - const { conversations } = state.chatContext - const lastConversation = conversations.at(-1) - const webPluginState = lastConversation?.pluginStates?.[PluginId.Web] as - | Partial - | undefined - const logs: ConversationLog[] = [] - - if (!webPluginState?.enableWebSearchAgent) return {} - - const webRetrieverTool = await createWebSearchTool(options, state) - - if (!webRetrieverTool) return {} - - const tools: LangchainTool[] = [webRetrieverTool] - const lastMessage = state.messages.at(-1) - const toolCalls = findCurrentToolsCallParams(lastMessage, tools) - - if (!toolCalls.length) return {} - const toolCallsPromises = toolCalls.map(async toolCall => { - const toolMessage = (await webRetrieverTool.invoke( - toolCall - )) as ToolMessage - const result = JSON.parse( - toolMessage?.lc_kwargs.content - ) as WebSearchToolResult + if (!toolCallsResults.agents.length) return {} - lastConversation!.pluginStates![PluginId.Web] = produce( - lastConversation!.pluginStates![ - PluginId.Web - ] as Partial, - (draft: Partial) => { - if (!draft.webSearchAsDocFromAgent) { - draft.webSearchAsDocFromAgent = [] - } - - if (!draft.webSearchResultsFromAgent) { - draft.webSearchResultsFromAgent = [] - } - - draft.webSearchAsDocFromAgent.push({ - url: '', - content: result.relevantContent - }) - - draft.webSearchResultsFromAgent.push(...result.webSearchResults) - } - ) + const newConversation = state.newConversations.at(-1)! - logs.push({ - id: uuidv4(), - createdAt: Date.now(), - pluginId: PluginId.Web, - title: 'Search web', - webSearchResultsFromAgent: result.webSearchResults - } satisfies WebPluginLog) - }) - - await settledPromiseResults(toolCallsPromises) - - const newConversations = produce(state.newConversations, draft => { - draft.at(-1)!.logs.push(...logs) - }) + this.addAgentsToConversation(newConversation, toolCallsResults.agents) + this.addLogsToConversation(newConversation, toolCallsResults.logs) - dispatchChatGraphState({ - newConversations, - chatContext: state.chatContext + dispatchBaseGraphState({ + chatContext: state.chatContext, + newConversations: state.newConversations }) return { chatContext: state.chatContext, - newConversations + newConversations: state.newConversations } } +} diff --git a/src/shared/plugins/web-plugin/server/chat-strategy/web-visit-node.ts b/src/shared/plugins/web-plugin/server/chat-strategy/web-visit-node.ts index b50d9a6..f1062b5 100644 --- a/src/shared/plugins/web-plugin/server/chat-strategy/web-visit-node.ts +++ b/src/shared/plugins/web-plugin/server/chat-strategy/web-visit-node.ts @@ -1,129 +1,51 @@ -/* eslint-disable unused-imports/no-unused-vars */ -import type { BaseStrategyOptions } from '@extension/webview-api/chat-context-processor/strategies/base-strategy' +import { WebVisitAgent } from '@shared/plugins/agents/web-visit-agent' import { - dispatchChatGraphState, - type ChatGraphState, - type CreateChatGraphNode -} from '@extension/webview-api/chat-context-processor/strategies/chat-strategy/nodes/state' -import { DocCrawler } from '@extension/webview-api/chat-context-processor/utils/doc-crawler' -import { findCurrentToolsCallParams } from '@extension/webview-api/chat-context-processor/utils/find-current-tools-call-params' -import type { ToolMessage } from '@langchain/core/messages' -import { DynamicStructuredTool } from '@langchain/core/tools' -import type { ConversationLog, LangchainTool } from '@shared/entities' -import { PluginId } from '@shared/plugins/base/types' -import { settledPromiseResults } from '@shared/utils/common' -import { produce } from 'immer' -import { v4 as uuidv4 } from 'uuid' -import { z } from 'zod' - -import type { WebDocInfo, WebPluginLog, WebPluginState } from '../../types' - -interface WebVisitToolResult { - contents: WebDocInfo[] -} - -export const createWebVisitTool = async ( - options: BaseStrategyOptions, - state: ChatGraphState -) => { - const getPageContents = async (urls: string[]): Promise => { - const docCrawler = new DocCrawler(urls[0]!) - const contents = await settledPromiseResults( - urls.map(async url => ({ - url, - content: - (await docCrawler.getPageContent(url)) || 'Failed to retrieve content' - })) - ) - return contents - } - - return new DynamicStructuredTool({ - name: 'webVisit', - description: - 'A tool for visiting and extracting content from web pages. Use this tool when you need to:\n' + - '1. Analyze specific webpage content in detail\n' + - '2. Extract information from known URLs\n' + - '3. Compare content across multiple web pages\n' + - '4. Verify or fact-check information from web sources\n' + - 'Note: Only use this for specific URLs you want to analyze, not for general web searches.', - func: async ({ urls }): Promise => { - const contents = await getPageContents(urls) - return { contents } - }, - schema: z.object({ - urls: z - .array(z.string().url()) - .describe( - 'An array of URLs to visit and retrieve content from. Each URL should be a valid web address.' - ) - }) - }) -} - -export const createWebVisitNode: CreateChatGraphNode = - options => async state => { - const { conversations } = state.chatContext - const lastConversation = conversations.at(-1) - const webPluginState = lastConversation?.pluginStates?.[PluginId.Web] as - | Partial - | undefined - const logs: ConversationLog[] = [] - - if (!webPluginState?.enableWebVisitAgent) return {} - - const webVisitTool = await createWebVisitTool(options, state) - - if (!webVisitTool) return {} - - const tools: LangchainTool[] = [webVisitTool] - const lastMessage = state.messages.at(-1) - const toolCalls = findCurrentToolsCallParams(lastMessage, tools) - - if (!toolCalls.length) return {} - - const toolCallsPromises = toolCalls.map(async toolCall => { - const toolMessage = (await webVisitTool.invoke(toolCall)) as ToolMessage - - const result = JSON.parse( - toolMessage?.lc_kwargs.content - ) as WebVisitToolResult - - lastConversation!.pluginStates![PluginId.Web] = produce( - lastConversation!.pluginStates![ - PluginId.Web - ] as Partial, - (draft: Partial) => { - if (!draft.webVisitResultsFromAgent) { - draft.webVisitResultsFromAgent = [] + BaseNode, + dispatchBaseGraphState, + type ChatGraphState +} from '@shared/plugins/base/strategies' + +import { WebToState } from '../../web-to-state' + +export class WebVisitNode extends BaseNode { + onInit() { + this.registerAgentConfig(WebVisitAgent.name, state => { + const lastConversation = state.chatContext.conversations.at(-1) + const mentionState = new WebToState(lastConversation).toMentionsState() + + return this.createAgentConfig({ + agentClass: WebVisitAgent, + agentContext: { + state, + strategyOptions: this.context.strategyOptions, + createToolOptions: { + enableWebVisitAgent: mentionState.enableWebVisitAgent } - - draft.webVisitResultsFromAgent.push(...result.contents) } - ) + }) + }) + } - logs.push({ - id: uuidv4(), - createdAt: Date.now(), - pluginId: PluginId.Web, - title: 'Visit web', - webVisitResultsFromAgent: result.contents - } satisfies WebPluginLog) + async execute(state: ChatGraphState) { + const toolCallsResults = await this.executeAgentTool(state, { + agentClass: WebVisitAgent }) - await settledPromiseResults(toolCallsPromises) + if (!toolCallsResults.agents.length) return {} - const newConversations = produce(state.newConversations, draft => { - draft.at(-1)!.logs.push(...logs) - }) + const newConversation = state.newConversations.at(-1)! - dispatchChatGraphState({ - newConversations, - chatContext: state.chatContext + this.addAgentsToConversation(newConversation, toolCallsResults.agents) + this.addLogsToConversation(newConversation, toolCallsResults.logs) + + dispatchBaseGraphState({ + chatContext: state.chatContext, + newConversations: state.newConversations }) return { chatContext: state.chatContext, - newConversations + newConversations: state.newConversations } } +} diff --git a/src/shared/plugins/web-plugin/server/web-mention-utils-provider.ts b/src/shared/plugins/web-plugin/server/web-mention-utils-provider.ts new file mode 100644 index 0000000..1d4afc5 --- /dev/null +++ b/src/shared/plugins/web-plugin/server/web-mention-utils-provider.ts @@ -0,0 +1,22 @@ +import type { ControllerRegister } from '@extension/registers/controller-register' +import type { Mention } from '@shared/entities' +import type { MentionUtilsProvider } from '@shared/plugins/base/server/create-provider-manager' + +import { WebMentionType } from '../types' + +export class WebMentionUtilsProvider implements MentionUtilsProvider { + // eslint-disable-next-line unused-imports/no-unused-vars + async createRefreshMentionFn(controllerRegister: ControllerRegister) { + return (_mention: Mention) => { + const mention = { ..._mention } as Mention + switch (mention.type) { + case WebMentionType.Web: + // Web mention is just a boolean flag, no need to refresh + break + default: + break + } + return mention + } + } +} diff --git a/src/shared/plugins/web-plugin/server/web-server-plugin.ts b/src/shared/plugins/web-plugin/server/web-server-plugin.ts index 3e55453..b50c8f3 100644 --- a/src/shared/plugins/web-plugin/server/web-server-plugin.ts +++ b/src/shared/plugins/web-plugin/server/web-server-plugin.ts @@ -7,6 +7,7 @@ import { pkg } from '@shared/utils/pkg' import type { WebPluginState } from '../types' import { WebChatStrategyProvider } from './chat-strategy/web-chat-strategy-provider' +import { WebMentionUtilsProvider } from './web-mention-utils-provider' export class WebServerPlugin implements ServerPlugin { id = PluginId.Web @@ -22,6 +23,11 @@ export class WebServerPlugin implements ServerPlugin { 'chatStrategy', () => new WebChatStrategyProvider() ) + + this.context.registerProvider( + 'mentionUtils', + () => new WebMentionUtilsProvider() + ) } deactivate(): void { diff --git a/src/shared/plugins/web-plugin/types.ts b/src/shared/plugins/web-plugin/types.ts index aeffa95..4133c84 100644 --- a/src/shared/plugins/web-plugin/types.ts +++ b/src/shared/plugins/web-plugin/types.ts @@ -1,22 +1,16 @@ -import type { BaseConversationLog } from '@shared/entities' +import type { Mention } from '@shared/entities' -import type { PluginId } from '../base/types' +import { PluginId } from '../base/types' + +export enum WebMentionType { + Web = `${PluginId.Web}#web` +} + +export type WebMention = Mention export interface WebDocInfo { content: string url: string } -export interface WebPluginState { - enableWebSearchAgent: boolean - webSearchResultsFromAgent: WebDocInfo[] - webSearchAsDocFromAgent: WebDocInfo[] - enableWebVisitAgent: boolean - webVisitResultsFromAgent: WebDocInfo[] -} - -export interface WebPluginLog extends BaseConversationLog { - pluginId: PluginId.Web - webSearchResultsFromAgent?: WebDocInfo[] - webVisitResultsFromAgent?: WebDocInfo[] -} +export interface WebPluginState {} diff --git a/src/shared/plugins/web-plugin/web-to-state.ts b/src/shared/plugins/web-plugin/web-to-state.ts new file mode 100644 index 0000000..62fb82e --- /dev/null +++ b/src/shared/plugins/web-plugin/web-to-state.ts @@ -0,0 +1,26 @@ +import { WebSearchAgent } from '../agents/web-search-agent' +import type { WebVisitAgent } from '../agents/web-visit-agent' +import { BaseToState } from '../base/base-to-state' +import { WebMentionType, type WebMention } from './types' + +export class WebToState extends BaseToState { + toMentionsState() { + return { + enableWebSearchAgent: this.isMentionExit(WebMentionType.Web), + enableWebVisitAgent: this.isMentionExit(WebMentionType.Web) + } + } + + toAgentsState() { + return { + webSearchRelevantContent: this.getAgentOutputsByKey< + WebSearchAgent, + 'relevantContent' + >(WebSearchAgent.name, 'relevantContent').flat(), + webVisitContents: this.getAgentOutputsByKey( + 'webVisit', + 'contents' + ).flat() + } + } +} diff --git a/src/shared/utils/convert-to-langchain-message-contents.ts b/src/shared/utils/convert-to-langchain-message-contents.ts index 95edc65..77bb5cc 100644 --- a/src/shared/utils/convert-to-langchain-message-contents.ts +++ b/src/shared/utils/convert-to-langchain-message-contents.ts @@ -40,10 +40,13 @@ export const convertToLangchainMessageContents = ( return [ { type: 'image_url', - image_url: - typeof _content.image_url === 'string' - ? _content.image_url - : _content.image_url?.url + image_url: { + url: + typeof _content.image_url === 'string' + ? _content.image_url + : _content.image_url?.url, + detail: _content?.image_url?.detail || undefined + } } ] } diff --git a/src/shared/utils/merge-langchain-message-contents.ts b/src/shared/utils/merge-langchain-message-contents.ts index f3ab738..57b7881 100644 --- a/src/shared/utils/merge-langchain-message-contents.ts +++ b/src/shared/utils/merge-langchain-message-contents.ts @@ -3,22 +3,23 @@ import type { LangchainMessageContents } from '@shared/entities' export const mergeLangchainMessageContents = ( contents: LangchainMessageContents ): LangchainMessageContents => { - let finalText = '' - const otherContents: LangchainMessageContents = [] + const finalContents: LangchainMessageContents = [] contents.forEach(content => { - if (content.type === 'text') { - finalText += content.text + const lastContent = finalContents.at(-1) + + if (!lastContent) { + finalContents.push(content) + return + } + + if (content.type === 'text' && lastContent.type === 'text') { + lastContent.text += content.text + finalContents[finalContents.length - 1] = lastContent } else { - otherContents.push(content) + finalContents.push(content) } }) - return [ - { - type: 'text', - text: finalText - }, - ...otherContents - ] + return finalContents } diff --git a/src/webview/components/chat/editor/chat-editor.tsx b/src/webview/components/chat/editor/chat-editor.tsx index 4e97ece..10102c0 100644 --- a/src/webview/components/chat/editor/chat-editor.tsx +++ b/src/webview/components/chat/editor/chat-editor.tsx @@ -1,4 +1,5 @@ import { useEffect, useId, useImperativeHandle, type FC, type Ref } from 'react' +import type { FileInfo } from '@extension/file-utils/traverse-fs' import { $generateHtmlFromNodes } from '@lexical/html' import { AutoFocusPlugin } from '@lexical/react/LexicalAutoFocusPlugin' import { @@ -12,7 +13,10 @@ import { HistoryPlugin } from '@lexical/react/LexicalHistoryPlugin' import { OnChangePlugin } from '@lexical/react/LexicalOnChangePlugin' import { RichTextPlugin } from '@lexical/react/LexicalRichTextPlugin' import { TabIndentationPlugin } from '@lexical/react/LexicalTabIndentationPlugin' +import { type ImageInfo } from '@shared/entities' import { useQueryClient } from '@tanstack/react-query' +import { useDropHandler } from '@webview/lexical/hooks/use-drop-handler' +import { usePasteHandler } from '@webview/lexical/hooks/use-paste-handler' import { MentionNode } from '@webview/lexical/nodes/mention-node' import { MentionPlugin, @@ -54,6 +58,8 @@ export interface ChatEditorProps editor: LexicalEditor, tags: Set ) => void + onPasteImage?: (image: ImageInfo) => void + onDropFiles?: (files: FileInfo[]) => void } export interface ChatEditorRef { @@ -72,6 +78,8 @@ export const ChatEditor: FC = ({ autoFocus = false, onComplete, onChange, + onPasteImage, + onDropFiles, ...otherProps }) => { const id = useId() @@ -93,6 +101,8 @@ export const ChatEditor: FC = ({ autoFocus={autoFocus} onComplete={onComplete} onChange={onChange} + onPasteImage={onPasteImage} + onDropFiles={onDropFiles} {...otherProps} /> @@ -108,6 +118,8 @@ const ChatEditorInner: FC = ({ autoFocus, onComplete, onChange, + onPasteImage, + onDropFiles, // div props ...otherProps @@ -240,6 +252,16 @@ const ChatEditorInner: FC = ({ [editor, queryClient] ) + usePasteHandler({ + editor, + onPasteImage + }) + + useDropHandler({ + editor, + onDropFiles + }) + return (
= ({ }) => { const editorRef = useRef(null) const { setState: setPluginState, getState: getPluginState } = usePlugin() - const { selectedFiles, setSelectedFiles } = usePluginSelectedFilesProviders() - const mentionOptions = usePluginMentionOptions() // sync conversation plugin states with plugin registry useEffect(() => { @@ -96,10 +103,6 @@ const _ChatInput: FC = ({ ) ) } - }) - - updatePluginStatesFromEditorState(editorState, mentionOptions) - setConversation(draft => { draft.pluginStates = getPluginState() }) } @@ -144,12 +147,12 @@ const _ChatInput: FC = ({ await handleEditorChange(editorState) } - logger.verbose('send conversation', getConversation()) - onSend(getConversation()) - } + // refresh mentions + const newConversation = + await api.mention.refreshConversationMentions(getConversation()) - const handleSelectedFiles = (files: FileInfo[]) => { - setSelectedFiles?.(files) + logger.verbose('send conversation', newConversation) + onSend(newConversation) } const focusOnEditor = () => editorRef.current?.focusOnEditor() @@ -183,6 +186,24 @@ const _ChatInput: FC = ({ } } + const handlePasteImage = (image: ImageInfo) => { + setConversation(draft => { + draft.state.selectedImagesFromOutsideUrl = removeDuplicates( + [...draft.state.selectedImagesFromOutsideUrl, image], + ['url'] + ) + }) + } + + const handleDropFiles = (files: FileInfo[]) => { + setConversation(draft => { + draft.state.selectedFilesFromFileSelector = removeDuplicates( + [...draft.state.selectedFilesFromFileSelector, ...files], + ['fullPath'] + ) + }) + } + return ( = ({ className )} > - - {![ChatInputMode.MessageReadonly].includes(mode) && ( - - !isOpen && focusOnEditor()} - /> - - )} - + = ({ [ChatInputMode.MessageReadonly].includes(mode) && 'min-h-0 min-w-0 h-auto w-auto' )} + onPasteImage={handlePasteImage} + onDropFiles={handleDropFiles} /> @@ -292,7 +295,7 @@ const _ChatInput: FC = ({ variant="outline" disabled={sendButtonDisabled} size="xs" - className="ml-auto" + className="ml-auto rounded-md" onClick={handleSend} tooltip="You can use ⌘↩ to send message" > @@ -310,3 +313,88 @@ const _ChatInput: FC = ({ } export const ChatInput = WithPluginProvider(_ChatInput) + +interface AnimatedFileAttachmentsProps { + mode: ChatInputMode + conversation: Conversation + setConversation: Updater + onFocusEditor: () => void +} + +export const AnimatedFileAttachments: React.FC< + AnimatedFileAttachmentsProps +> = ({ mode, conversation, setConversation, onFocusEditor }) => { + const selectedFiles = conversation?.state?.selectedFilesFromFileSelector || [] + const setSelectedFiles = (files: FileInfo[]) => { + setConversation(draft => { + if (!draft.state) { + draft.state = new ConversationEntity().entity.state + } + draft.state.selectedFilesFromFileSelector = removeDuplicates(files, [ + 'fullPath' + ]) + }) + } + + const selectedOtherItems: FileAttachmentOtherItem[] = + conversation.state?.selectedImagesFromOutsideUrl?.map( + img => + ({ + id: img.url, + label: img.name || 'image', + type: 'image', + icon: , + item: img, + previewConfig: { + type: 'image', + url: img.url + } + }) satisfies FileAttachmentOtherItem + ) + + const setSelectedOtherItems = (items: FileAttachmentOtherItem[]) => { + setConversation(draft => { + const selectedImages: ImageInfo[] = [] + items.forEach(item => { + if (item.type === 'image') { + selectedImages.push(item.item as ImageInfo) + } + }) + + draft.state.selectedImagesFromOutsideUrl = removeDuplicates( + selectedImages, + ['url'] + ) + }) + } + + if ([ChatInputMode.MessageReadonly].includes(mode)) { + return null + } + + return ( + + + !isOpen && onFocusEditor()} + /> + + + ) +} diff --git a/src/webview/components/chat/editor/file-attachments.tsx b/src/webview/components/chat/editor/file-attachments.tsx index b290199..846e51a 100644 --- a/src/webview/components/chat/editor/file-attachments.tsx +++ b/src/webview/components/chat/editor/file-attachments.tsx @@ -7,15 +7,29 @@ import type { FileInfo } from '@webview/types/chat' import { cn } from '@webview/utils/common' import { getFileNameFromPath } from '@webview/utils/path' -import { ContentPreviewPopover } from '../../content-preview-popover' +import { + ContentPreviewPopover, + type ContentPreviewPopoverProps +} from '../../content-preview-popover' import { Popover, PopoverContent, PopoverTrigger } from '../../ui/popover' import { FileSelector } from '../selectors/file-selector' +export interface FileAttachmentOtherItem { + id: string + label: string + type: string + item: Record + previewConfig: ContentPreviewPopoverProps['content'] + icon?: React.ReactNode +} + interface FileAttachmentsProps { className?: string showFileSelector?: boolean selectedFiles: FileInfo[] + selectedOtherItems?: FileAttachmentOtherItem[] onSelectedFilesChange: (files: FileInfo[]) => void + onSelectedOtherItemsChange?: (items: FileAttachmentOtherItem[]) => void onOpenChange?: (isOpen: boolean) => void } @@ -23,7 +37,9 @@ export const FileAttachments: React.FC = ({ className, showFileSelector = true, selectedFiles, + selectedOtherItems, onSelectedFilesChange, + onSelectedOtherItemsChange, onOpenChange }) => { const [showMore, setShowMore] = useState(false) @@ -37,7 +53,7 @@ export const FileAttachments: React.FC = ({ const checkOverflow = () => { setShowMore(container.scrollHeight > container.clientHeight) - const items = container.querySelectorAll('.file-item') + const items = container.querySelectorAll('.file-attachment-item') let count = 0 for (const item of items) { if (item.getBoundingClientRect().top < container.clientHeight) { @@ -60,14 +76,22 @@ export const FileAttachments: React.FC = ({ ) } + const handleRemoveOtherItem = (item: FileAttachmentOtherItem) => { + onSelectedOtherItemsChange?.( + selectedOtherItems?.filter(i => i.id !== item.id) ?? [] + ) + } + const renderFileItem = (file: FileInfo) => ( -
+
-
{getFileNameFromPath(file.fullPath)}
+
+ {getFileNameFromPath(file.fullPath)} +
{ @@ -79,6 +103,24 @@ export const FileAttachments: React.FC = ({ ) + const renderOtherItem = (item: FileAttachmentOtherItem) => ( + +
+ {item.icon} +
+ {item.label} +
+ { + e.stopPropagation() + handleRemoveOtherItem(item) + }} + /> +
+
+ ) + return (
= ({ size="xsss" className="cursor-pointer mt-2 mr-2 self-start" > - ...{selectedFiles.length - visibleCount} more + ... + {selectedFiles.length + + (selectedOtherItems?.length || 0) - + visibleCount}{' '} + more - +
+ {selectedOtherItems?.map(renderOtherItem)} {selectedFiles.map(renderFileItem)}
)} + {selectedOtherItems?.map(renderOtherItem)} {selectedFiles.map(renderFileItem)}
) diff --git a/src/webview/components/chat/messages/roles/chat-ai-message.tsx b/src/webview/components/chat/messages/roles/chat-ai-message.tsx index 164dad9..68943e0 100644 --- a/src/webview/components/chat/messages/roles/chat-ai-message.tsx +++ b/src/webview/components/chat/messages/roles/chat-ai-message.tsx @@ -31,7 +31,7 @@ const _ChatAIMessage: FC = props => {
= ({ conversation }) => { - const logs = conversation.logs || [] + const logs = toLogWithAgent(conversation) + const customRenderLogPreview = usePluginCustomRenderLogPreview() if (logs.length === 0) return null @@ -63,9 +65,11 @@ export const ChatAIMessageLogPreview: FC<{ conversation: Conversation }> = ({
{logs.map((log, index) => (
- {customRenderLogPreview?.({ - log - }) || } + {log.agent && customRenderLogPreview ? ( + customRenderLogPreview({ log }) + ) : ( + + )}
))}
diff --git a/src/webview/components/chat/messages/toolbars/base-toolbar.tsx b/src/webview/components/chat/messages/toolbars/base-toolbar.tsx index 2fcd275..076fbb8 100644 --- a/src/webview/components/chat/messages/toolbars/base-toolbar.tsx +++ b/src/webview/components/chat/messages/toolbars/base-toolbar.tsx @@ -115,7 +115,7 @@ export const BaseToolbar: FC = ({ > {/* blur overlay */}
= ({ showExitEditModeButton, onExitEditMode }) => { - const { addSelectedImage } = usePluginSelectedImagesProviders() + const addSelectedImage = (image: ImageInfo) => { + setConversation(draft => { + draft.state.selectedImagesFromOutsideUrl = removeDuplicates( + [...draft.state.selectedImagesFromOutsideUrl, image], + ['url'] + ) + }) + } const handleSelectImage = () => { const input = document.createElement('input') @@ -45,7 +53,10 @@ export const ContextSelector: React.FC = ({ const reader = new FileReader() reader.onload = e => { const base64Image = e.target?.result as string - addSelectedImage?.({ url: base64Image }) + addSelectedImage?.({ + url: base64Image, + name: file.name + }) } reader.readAsDataURL(file) } diff --git a/src/webview/components/chat/selectors/mention-selector/mention-selector.tsx b/src/webview/components/chat/selectors/mention-selector/mention-selector.tsx index 7669265..5e24068 100644 --- a/src/webview/components/chat/selectors/mention-selector/mention-selector.tsx +++ b/src/webview/components/chat/selectors/mention-selector/mention-selector.tsx @@ -123,7 +123,7 @@ export const MentionSelector: React.FC = ({ {children} div]:flex-col-reverse', + 'min-w-[200px] max-w-[400px] w-screen p-0 rounded-none bg-transparent shadow-none border-none [&[data-side="bottom"]>div]:flex-col-reverse', !isOpen && 'hidden' )} innerClassName="flex flex-col gap-4 overflow-hidden" @@ -137,7 +137,7 @@ export const MentionSelector: React.FC = ({ >
= ({ )}
-
+
void diff --git a/src/webview/components/global-search/global-search.tsx b/src/webview/components/global-search/global-search.tsx index 6022a43..24970c0 100644 --- a/src/webview/components/global-search/global-search.tsx +++ b/src/webview/components/global-search/global-search.tsx @@ -158,7 +158,7 @@ export const GlobalSearch: React.FC = ({ = ({
{isOpen && focusedItem?.renderPreview ? ( -
+
{focusedItem.renderPreview()}
) : null} diff --git a/src/webview/components/ui/command.tsx b/src/webview/components/ui/command.tsx index 7900099..84f5df5 100644 --- a/src/webview/components/ui/command.tsx +++ b/src/webview/components/ui/command.tsx @@ -18,7 +18,7 @@ const Command: React.FC< = ({ ref, ...props }) => ( ) @@ -158,7 +158,7 @@ const CommandItem: React.FC< = ({ align={align} sideOffset={sideOffset} className={cn( - 'z-50 w-72 overflow-x-hidden rounded-md shadow-md border bg-popover p-4 text-popover-foreground outline-none', + 'z-50 w-72 overflow-x-hidden rounded-2xl shadow-md border bg-popover p-4 text-popover-foreground outline-none', className )} onFocusOutside={e => e.preventDefault()} diff --git a/src/webview/components/ui/tabs.tsx b/src/webview/components/ui/tabs.tsx index 43bc019..4c5c3d7 100644 --- a/src/webview/components/ui/tabs.tsx +++ b/src/webview/components/ui/tabs.tsx @@ -22,13 +22,14 @@ const tabListVariants = cva( ) const tabsTriggerVariants = cva( - 'inline-flex items-center justify-center whitespace-nowrap text-sm font-medium ring-offset-background transition-all focus-visible:outline-none disabled:pointer-events-none disabled:opacity-50 data-[state=active]:bg-background data-[state=active]:text-foreground data-[state=active]:shadow px-3 py-1 rounded-md', + 'inline-flex items-center justify-center whitespace-nowrap text-sm font-medium ring-offset-background transition-all focus-visible:outline-none disabled:pointer-events-none disabled:opacity-50 px-3 py-1 rounded-md', { variants: { mode: { - default: '', + default: + 'data-[state=active]:bg-background data-[state=active]:text-foreground data-[state=active]:shadow', underlined: - 'relative text-md h-full rounded-none border-b-2 border-b-transparent bg-transparent px-4 pb-1 pt-1 font-semibold text-foreground/60 shadow-none transition-none data-[state=active]:border-b-primary data-[state=active]:text-foreground data-[state=active]:bg-background data-[state=active]:shadow-none' + 'relative text-md h-full rounded-none border-b-2 border-b-transparent bg-transparent px-4 pb-1 pt-1 font-semibold text-foreground/60 shadow-none transition-none data-[state=active]:border-b-primary' } }, defaultVariants: { diff --git a/src/webview/hooks/chat/use-plugin-providers.tsx b/src/webview/hooks/chat/use-plugin-providers.tsx index 38c5af0..2ff6714 100644 --- a/src/webview/hooks/chat/use-plugin-providers.tsx +++ b/src/webview/hooks/chat/use-plugin-providers.tsx @@ -1,9 +1,7 @@ import { Fragment, type FC } from 'react' import type { CustomRenderLogPreviewProps, - UseMentionOptionsReturns, - UseSelectedFilesReturns, - UseSelectedImagesReturns + UseMentionOptionsReturns } from '@shared/plugins/base/client/client-plugin-types' import { usePlugin } from '@webview/contexts/plugin-context' @@ -17,34 +15,6 @@ export const usePluginCustomRenderLogPreview = () => { return CustomRenderLogPreview } -export const usePluginSelectedFilesProviders = (): UseSelectedFilesReturns => { - const { mergeProviders } = usePlugin() - const useSelectedFiles = mergeProviders('useSelectedFiles') - - return ( - // eslint-disable-next-line react-compiler/react-compiler - useSelectedFiles?.() || { - selectedFiles: [], - setSelectedFiles: () => {} - } - ) -} - -export const usePluginSelectedImagesProviders = - (): UseSelectedImagesReturns => { - const { mergeProviders } = usePlugin() - const useSelectedImages = mergeProviders('useSelectedImages') - - return ( - // eslint-disable-next-line react-compiler/react-compiler - useSelectedImages?.() || { - selectedImages: [], - addSelectedImage: () => {}, - removeSelectedImage: () => {} - } - ) - } - export const usePluginMentionOptions = (): UseMentionOptionsReturns => { const { mergeProviders } = usePlugin() const useMentionOptions = mergeProviders('useMentionOptions') diff --git a/src/webview/lexical/hooks/use-drop-handler.ts b/src/webview/lexical/hooks/use-drop-handler.ts new file mode 100644 index 0000000..a4badc6 --- /dev/null +++ b/src/webview/lexical/hooks/use-drop-handler.ts @@ -0,0 +1,92 @@ +import { useEffect, useRef } from 'react' +import { settledPromiseResults } from '@shared/utils/common' +import { api } from '@webview/services/api-client' +import type { FileInfo } from '@webview/types/chat' +import { noop } from 'es-toolkit' +import { type LexicalEditor } from 'lexical' + +interface UseDropHandlerOptions { + editor: LexicalEditor + onDropFiles?: (files: FileInfo[]) => void +} + +export const useDropHandler = ({ + editor, + onDropFiles +}: UseDropHandlerOptions) => { + // Track if we're currently dragging + const isDraggingRef = useRef(false) + + useEffect(() => { + const rootElement = editor.getRootElement() + if (!rootElement) return + + const handleDragEnter = (e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + isDraggingRef.current = true + rootElement.classList.add('dragging') + } + + const handleDragLeave = (e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + isDraggingRef.current = false + rootElement.classList.remove('dragging') + } + + const handleDragOver = (e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + } + + const handleDrop = async (e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + + isDraggingRef.current = false + rootElement.classList.remove('dragging') + + // Handle VSCode file drops + const fileFullPaths = await settledPromiseResults( + Array.from(e.dataTransfer?.items || []) + .filter( + item => item.kind === 'string' && item.type === 'text/uri-list' + ) + .map( + item => + new Promise(resolve => { + item.getAsString(uri => { + // Convert VSCode URI to file path + const fileFullPath = decodeURIComponent( + uri.replace('file://', '') + ) + resolve(fileFullPath) + }) + }) + ) + ) + + const droppedFiles = await api.file.traverseWorkspaceFiles( + { filesOrFolders: fileFullPaths }, + noop + ) + + if (droppedFiles.length > 0) { + onDropFiles?.(droppedFiles) + } + } + + rootElement.addEventListener('dragenter', handleDragEnter) + rootElement.addEventListener('dragleave', handleDragLeave) + rootElement.addEventListener('dragover', handleDragOver) + rootElement.addEventListener('drop', handleDrop) + + return () => { + rootElement.removeEventListener('dragenter', handleDragEnter) + rootElement.removeEventListener('dragleave', handleDragLeave) + rootElement.removeEventListener('dragover', handleDragOver) + rootElement.removeEventListener('drop', handleDrop) + } + }, [editor, onDropFiles]) +} diff --git a/src/webview/lexical/hooks/use-paste-handler.ts b/src/webview/lexical/hooks/use-paste-handler.ts new file mode 100644 index 0000000..942e0b3 --- /dev/null +++ b/src/webview/lexical/hooks/use-paste-handler.ts @@ -0,0 +1,151 @@ +import { useEffect, useRef } from 'react' +import { type ImageInfo } from '@shared/entities' +import { logger } from '@webview/utils/logger' +import { type LexicalEditor } from 'lexical' + +interface UsePasteHandlerOptions { + editor: LexicalEditor + onPasteImage?: (image: ImageInfo) => void +} + +/** + * Hook for handling image paste events in a Lexical editor + * + * Handles three types of paste scenarios: + * 1. HTML content with embedded images + * 2. Direct file/image pastes + * 3. Image URLs in plain text + */ +export const usePasteHandler = ({ + editor, + onPasteImage +}: UsePasteHandlerOptions) => { + // Use ref to track processed images in current paste event + const processedImagesRef = useRef>(new Set()) + + useEffect(() => { + const handlePaste = async (event: ClipboardEvent) => { + const { clipboardData } = event + if (!clipboardData) return + + // Clear processed images from previous paste event + processedImagesRef.current.clear() + + const html = clipboardData.getData('text/html') + const plainText = clipboardData.getData('text/plain') + + const { items } = clipboardData + + // Helper function to prevent duplicate image processing + const processImage = (image: ImageInfo) => { + const imageKey = String(image.url) + if (!processedImagesRef.current.has(imageKey)) { + processedImagesRef.current.add(imageKey) + onPasteImage?.(image) + } + } + + // Process HTML content with embedded images + if (html) { + const images = imageUtils.extractFromHtml(html) + images.forEach(processImage) + } + + // Process clipboard items (direct file/image pastes) + if (items) { + let hasImage = false + + for (const item of items) { + if (item.type.startsWith('image/')) { + hasImage = true + const file = item.getAsFile() + if (!file) continue + + try { + const image = await imageUtils.processFile(file) + processImage(image) + } catch (error) { + logger.error('Failed to process pasted image:', error) + } + } + } + + // Prevent default only for pure image pastes + if (hasImage && !plainText && !html) { + event.preventDefault() + } + } + + // Process image URLs in plain text + if (plainText && imageUtils.isImageUrl(plainText)) { + processImage(imageUtils.createFromUrl(plainText)) + } + } + + const rootElement = editor.getRootElement() + rootElement?.addEventListener('paste', handlePaste) + + return () => { + rootElement?.removeEventListener('paste', handlePaste) + } + }, [editor, onPasteImage]) +} + +// Constants for image handling +const IMAGE_TYPES = { + EMBEDDED: 'data:image/', + REMOTE: 'http' +} as const + +const IMAGE_NAMES = { + EMBEDDED: 'embedded-image', + WEB: 'web-image', + PASTED: 'pasted-image' +} as const + +const IMAGE_URL_PATTERN = /^https?:\/\/.*\.(png|jpe?g|gif|webp|svg)$/i + +/** + * Utility functions for image handling + */ +const imageUtils = { + isImageUrl: (text: string): boolean => IMAGE_URL_PATTERN.test(text), + + createFromUrl: (url: string): ImageInfo => ({ + url, + name: url.split('/').pop() || IMAGE_NAMES.WEB + }), + + extractFromHtml: (html: string): ImageInfo[] => { + const tempDiv = document.createElement('div') + tempDiv.innerHTML = html + + return Array.from(tempDiv.getElementsByTagName('img')).reduce( + (images, img) => { + const { src } = img + if (!src) return images + + if (src.startsWith(IMAGE_TYPES.EMBEDDED)) { + images.push({ url: src, name: IMAGE_NAMES.EMBEDDED }) + } else if (src.startsWith(IMAGE_TYPES.REMOTE)) { + images.push(imageUtils.createFromUrl(src)) + } + return images + }, + [] + ) + }, + + processFile: async (file: File): Promise => + new Promise((resolve, reject) => { + const reader = new FileReader() + reader.onload = e => { + resolve({ + url: e.target?.result as string, + name: file.name || IMAGE_NAMES.PASTED + }) + } + reader.onerror = reject + reader.readAsDataURL(file) + }) +} diff --git a/src/webview/lexical/nodes/mention-node.tsx b/src/webview/lexical/nodes/mention-node.tsx index 99395a5..90a0069 100644 --- a/src/webview/lexical/nodes/mention-node.tsx +++ b/src/webview/lexical/nodes/mention-node.tsx @@ -1,5 +1,6 @@ /* eslint-disable unused-imports/no-unused-vars */ import React, { useState, type FC } from 'react' +import type { Mention } from '@shared/entities' import { Popover, PopoverContent, @@ -26,9 +27,8 @@ import { export type SerializedMentionNode = Spread< { - mentionType: string - mentionData: any text: string + mention: Mention }, SerializedLexicalNode > @@ -36,27 +36,19 @@ export type SerializedMentionNode = Spread< const convertMentionElement = ( domNode: HTMLElement ): DOMConversionOutput | null => { - const mentionType = domNode.getAttribute( - 'data-lexical-mention-type' - ) as string - const mentionData = domNode.getAttribute('data-lexical-mention-data') + const mentionJson = domNode.getAttribute('data-lexical-mention') as string + const mention = JSON.parse(mentionJson) as Mention const text = domNode.textContent - if (mentionType && text) { - const node = $createMentionNode( - mentionType, - JSON.parse(mentionData || '{}'), - text - ) + if (mention && text) { + const node = $createMentionNode(mention, text) return { node } } return null } export class MentionNode extends DecoratorNode { - __mentionType: string - - __mentionData: any + __mention: Mention __text: string @@ -65,34 +57,18 @@ export class MentionNode extends DecoratorNode { } static clone(node: MentionNode): MentionNode { - return new MentionNode( - node.__mentionType, - node.__mentionData, - node.__text, - node.__key - ) + return new MentionNode(node.__mention, node.__text, node.__key) } - constructor( - mentionType: string, - mentionData: any, - text: string, - key?: NodeKey - ) { + constructor(mention: Mention, text: string, key?: NodeKey) { super(key) - this.__mentionType = mentionType - this.__mentionData = mentionData + this.__mention = mention this.__text = text } createDOM(config: EditorConfig): HTMLElement { const dom = document.createElement('span') - dom.setAttribute('data-lexical-mention', 'true') - dom.setAttribute('data-lexical-mention-type', this.__mentionType) - dom.setAttribute( - 'data-lexical-mention-data', - JSON.stringify(this.__mentionData) - ) + dom.setAttribute('data-lexical-mention', JSON.stringify(this.__mention)) return dom } @@ -116,27 +92,21 @@ export class MentionNode extends DecoratorNode { exportDOM(): DOMExportOutput { const element = document.createElement('span') - element.setAttribute('data-lexical-mention', 'true') - element.setAttribute('data-lexical-mention-type', this.__mentionType) - element.setAttribute( - 'data-lexical-mention-data', - JSON.stringify(this.__mentionData) - ) + element.setAttribute('data-lexical-mention', JSON.stringify(this.__mention)) element.textContent = this.__text return { element } } static importJSON(serializedNode: SerializedMentionNode): MentionNode { - const { mentionType, mentionData, text } = serializedNode - const node = $createMentionNode(mentionType, mentionData, text) + const { mention, text } = serializedNode + const node = $createMentionNode(mention, text) return node } exportJSON(): SerializedMentionNode { return { type: 'mention', - mentionType: this.__mentionType, - mentionData: this.__mentionData, + mention: this.__mention, text: this.__text, version: 1 } @@ -148,10 +118,7 @@ export class MentionNode extends DecoratorNode { decorate(editor: LexicalEditor, config: EditorConfig): React.ReactNode { return ( - + { } export const $createMentionNode = ( - mentionType: string, - mentionData: any, + mention: Mention, text: string -): MentionNode => - $applyNodeReplacement(new MentionNode(mentionType, mentionData, text)) +): MentionNode => $applyNodeReplacement(new MentionNode(mention, text)) export const $isMentionNode = ( node: LexicalNode | null | undefined ): node is MentionNode => node instanceof MentionNode const MentionPreview: FC<{ - mentionType: string - mentionData: any + mention: Mention children: React.ReactNode -}> = ({ mentionType, mentionData, children }) => { +}> = ({ mention, children }) => { const mentionOptions = usePluginMentionOptions() - const option = findMentionOptionByMentionType(mentionOptions, mentionType) + const option = findMentionOptionByMentionType(mentionOptions, mention.type) const currentOption = { ...option, - data: mentionData + data: mention.data } as MentionOption const [isOpen, setIsOpen] = useState(false) diff --git a/src/webview/lexical/plugins/mention-plugin.tsx b/src/webview/lexical/plugins/mention-plugin.tsx index c24926f..832c945 100644 --- a/src/webview/lexical/plugins/mention-plugin.tsx +++ b/src/webview/lexical/plugins/mention-plugin.tsx @@ -1,5 +1,6 @@ import React, { useState, type FC } from 'react' import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext' +import type { Mention } from '@shared/entities' import { MentionSelector } from '@webview/components/chat/selectors/mention-selector/mention-selector' import { usePluginMentionOptions } from '@webview/hooks/chat/use-plugin-providers' import type { MentionOption } from '@webview/types/chat' @@ -99,7 +100,11 @@ const insertMention = ({ // Create and insert the mention node const mentionText = `@${option.label}` - const mentionNode = $createMentionNode(option.type, option.data, mentionText) + const mention: Mention = { + type: option.type, + data: option.data + } + const mentionNode = $createMentionNode(mention, mentionText) selection.insertNodes([mentionNode]) // Insert a space after the mention node diff --git a/src/webview/styles/global.css b/src/webview/styles/global.css index 02ec87e..f877c87 100644 --- a/src/webview/styles/global.css +++ b/src/webview/styles/global.css @@ -40,6 +40,8 @@ body, color: hsl(var(--foreground)); background-color: hsl(var(--background)) !important; + pointer-events: auto !important; + user-select: auto !important; } a:focus, @@ -125,3 +127,8 @@ span.languageLabel { border-radius: 3px; font-family: monospace; } + +.editor-input.dragging { + border: 2px dashed hsl(var(--primary)); + background-color: hsl(var(--primary) / 0.1); +} diff --git a/src/webview/types/chat.ts b/src/webview/types/chat.ts index 1548111..7959a9b 100644 --- a/src/webview/types/chat.ts +++ b/src/webview/types/chat.ts @@ -20,8 +20,6 @@ export interface MentionOption { label: string type?: string onSelect?: (data: T) => void - onUpdatePluginState?: (dataArr: T[]) => void - topLevelSort?: number searchKeywords?: string[] searchSortStrategy?: SearchSortStrategy diff --git a/src/webview/utils/plugin-states.ts b/src/webview/utils/plugin-states.ts index d53e817..9b1e126 100644 --- a/src/webview/utils/plugin-states.ts +++ b/src/webview/utils/plugin-states.ts @@ -1,37 +1,30 @@ -import { $isMentionNode } from '@webview/lexical/nodes/mention-node' import type { MentionOption } from '@webview/types/chat' -import { - $getRoot, - $isElementNode, - type EditorState, - type LexicalNode -} from 'lexical' - -export const updatePluginStatesFromEditorState = ( - editorState: EditorState, - mentionOptions: MentionOption[] -): void => - editorState.read(() => { - const root = $getRoot() - const mentionTypeDataArr: Record = {} - - const traverseNodes = (node: LexicalNode) => { - if ($isMentionNode(node)) { - const { mentionType, mentionData } = node.exportJSON() - mentionTypeDataArr[mentionType] ||= [] - mentionTypeDataArr[mentionType]!.push(mentionData) - } else if ($isElementNode(node)) { - node.getChildren().forEach(traverseNodes) - } - } - traverseNodes(root) +// export const updatePluginStatesFromEditorState = ( +// editorState: EditorState, +// mentionOptions: MentionOption[] +// ): void => +// editorState.read(() => { +// const root = $getRoot() +// const mentionTypeDataArr: Record = {} + +// const traverseNodes = (node: LexicalNode) => { +// if ($isMentionNode(node)) { +// const { mention } = node.exportJSON() +// mentionTypeDataArr[mention.type] ||= [] +// mentionTypeDataArr[mention.type]!.push(mention.data) +// } else if ($isElementNode(node)) { +// node.getChildren().forEach(traverseNodes) +// } +// } - Object.entries(mentionTypeDataArr).forEach(([type, dataArr]) => { - const found = findMentionOptionByMentionType(mentionOptions, type) - found?.onUpdatePluginState?.(dataArr) - }) - }) +// traverseNodes(root) + +// Object.entries(mentionTypeDataArr).forEach(([type, dataArr]) => { +// const found = findMentionOptionByMentionType(mentionOptions, type) +// found?.onUpdatePluginState?.(dataArr) +// }) +// }) export const findMentionOptionByMentionType = ( mentionOptions: MentionOption[], @@ -52,3 +45,26 @@ export const findMentionOptionByMentionType = ( return undefined } + +// export const getMentionsFromEditorState = ( +// editorState: EditorState +// ): Mention[] => { +// const mentions: Mention[] = [] + +// editorState.read(() => { +// const root = $getRoot() + +// const traverseNodes = (node: LexicalNode) => { +// if ($isMentionNode(node)) { +// const { mention } = node.exportJSON() +// mentions.push(mention) +// } else if ($isElementNode(node)) { +// node.getChildren().forEach(traverseNodes) +// } +// } + +// traverseNodes(root) +// }) + +// return mentions +// }