Skip to content

Commit

Permalink
feat: add doc crawler and indexer
Browse files Browse the repository at this point in the history
  • Loading branch information
2214962083 committed Oct 7, 2024
1 parent 382726a commit 096da9f
Show file tree
Hide file tree
Showing 28 changed files with 382 additions and 162 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
type ChatGraphState
} from './nodes/state'
import { webSearchNode } from './nodes/web-search-node'
import { webVisitNode } from './nodes/web-visit-node'

const createSmartRoute =
(nextNodeName: ChatGraphNodeName) => (state: ChatGraphState) =>
Expand All @@ -21,7 +22,7 @@ const chatWorkflow = new StateGraph(chatGraphState)
.addNode(
ChatGraphNodeName.Tools,
combineNode(
[codebaseSearchNode, docRetrieverNode, webSearchNode],
[codebaseSearchNode, docRetrieverNode, webSearchNode, webVisitNode],
chatGraphState
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { createCodebaseSearchTool } from './codebase-search-node'
import { createDocRetrieverTool } from './doc-retriever-node'
import type { ChatGraphNode } from './state'
import { createWebSearchTool } from './web-search-node'
import { createWebVisitTool } from './web-visit-node'

export const agentNode: ChatGraphNode = async state => {
const modelProvider = await createModelProvider()
Expand All @@ -21,7 +22,9 @@ export const agentNode: ChatGraphNode = async state => {
// doc
await createDocRetrieverTool(state),
// web search
await createWebSearchTool(state)
await createWebSearchTool(state),
// web visit
await createWebVisitTool(state)
].filter(Boolean) as LangchainTool[]

const chatMessagesConstructor = new ChatMessagesConstructor(state.chatContext)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ export const createCodebaseSearchTool = async (state: ChatGraphState) => {
if (!indexer) return searchResults

const searchPromisesResult = await Promise.allSettled(
queryParts?.map(queryPart => indexer.searchSimilarCode(queryPart)) || []
queryParts?.map(queryPart => indexer.searchSimilarRow(queryPart)) || []
)

const searchCodeSnippets: CodeSnippet[] = searchPromisesResult
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import { aidePaths } from '@extension/file-utils/paths'
import { DocInfo } from '@extension/webview-api/chat-context-processor/types/chat-context/doc-context'
import type { LangchainTool } from '@extension/webview-api/chat-context-processor/types/langchain-message'
import { DocCrawler } from '@extension/webview-api/chat-context-processor/utils/doc-crawler'
import { findCurrentToolsCallParams } from '@extension/webview-api/chat-context-processor/utils/find-current-tools-call-params'
import { CheerioWebBaseLoader } from '@langchain/community/document_loaders/web/cheerio'
import type { DocumentInterface } from '@langchain/core/documents'
import { DocIndexer } from '@extension/webview-api/chat-context-processor/vectordb/doc-indexer'
import { docSitesDB } from '@extension/webview-api/lowdb/doc-sites-db'
import type { ToolMessage } from '@langchain/core/messages'
import { DynamicStructuredTool } from '@langchain/core/tools'
import type { VectorStoreRetriever } from '@langchain/core/vectorstores'
import { OpenAIEmbeddings } from '@langchain/openai'
import { RecursiveCharacterTextSplitter } from '@langchain/textsplitters'
import { MemoryVectorStore } from 'langchain/vectorstores/memory'
import { removeDuplicates } from '@shared/utils/common'
import { z } from 'zod'

import {
Expand All @@ -17,7 +17,7 @@ import {
} from './state'

interface DocRetrieverToolResult {
relevantDocs: DocumentInterface<Record<string, any>>[]
relevantDocs: DocInfo[]
}

export const createDocRetrieverTool = async (state: ChatGraphState) => {
Expand All @@ -28,52 +28,93 @@ export const createDocRetrieverTool = async (state: ChatGraphState) => {

if (!docContext) return null

const { allowSearchDocSiteUrls } = docContext
const { allowSearchDocSiteNames } = docContext

if (!allowSearchDocSiteUrls.length) return null
if (!allowSearchDocSiteNames.length) return null

let _retriever: VectorStoreRetriever<MemoryVectorStore>
const getRelevantDocs = async (
queryParts: { siteName: string; keywords: string[] }[]
): Promise<DocInfo[]> => {
const docSites = await docSitesDB.getAll()

const getRetriever = async () => {
if (_retriever) return _retriever
const docPromises = queryParts.map(async ({ siteName, keywords }) => {
const docSite = docSites.find(site => site.name === siteName)

// TODO: Deep search
const docs = await Promise.all(
allowSearchDocSiteUrls.map(url => new CheerioWebBaseLoader(url).load())
)
const docsList = docs.flat()
if (!docSite?.isIndexed || !allowSearchDocSiteNames.includes(siteName)) {
return []
}

const textSplitter = new RecursiveCharacterTextSplitter({
chunkSize: 500,
chunkOverlap: 50
const docIndexer = new DocIndexer(
DocCrawler.getDocCrawlerFolderPath(docSite.url),
aidePaths.getGlobalLanceDbPath()
)

await docIndexer.initialize()

const searchResults = await Promise.allSettled(
keywords.map(keyword => docIndexer.searchSimilarRow(keyword))
)

const searchRows = removeDuplicates(
searchResults
.filter(
(result): result is PromiseFulfilledResult<any> =>
result.status === 'fulfilled'
)
.flatMap(result => result.value),
['fullPath']
).slice(0, 3)

const docInfoResults = await Promise.allSettled(
searchRows.map(async row => ({
content: await docIndexer.getRowFileContent(row),
path: docSite.url
}))
)

return docInfoResults
.filter(
(result): result is PromiseFulfilledResult<DocInfo> =>
result.status === 'fulfilled'
)
.map(result => result.value)
})
const docSplits = await textSplitter.splitDocuments(docsList)

const vectorStore = await MemoryVectorStore.fromDocuments(
docSplits,
new OpenAIEmbeddings()
)
const results = await Promise.allSettled(docPromises)
const relevantDocs = results
.filter(
(result): result is PromiseFulfilledResult<DocInfo[]> =>
result.status === 'fulfilled'
)
.flatMap(result => result.value)

_retriever = vectorStore.asRetriever()

return _retriever
return relevantDocs
}

return new DynamicStructuredTool({
name: ChatGraphToolName.DocRetriever,
description: 'Search and return information about question in Docs.',
func: async ({ query }, runManager): Promise<DocRetrieverToolResult> => {
const retriever = await getRetriever()

return {
relevantDocs: await retriever.invoke(
query,
runManager?.getChild('retriever')
)
}
},
description:
'Search for relevant information in specified documentation sites. This tool can search across multiple doc sites, with multiple keywords for each site. Use this tool to find documentation on specific topics or understand how certain features are described in the documentation.',
func: async ({ queryParts }): Promise<DocRetrieverToolResult> => ({
relevantDocs: await getRelevantDocs(queryParts)
}),
schema: z.object({
query: z.string().describe('query to look up in retriever')
queryParts: z
.array(
z.object({
siteName: z
.enum(allowSearchDocSiteNames as unknown as [string, ...string[]])
.describe('The name of the documentation site to search'),
keywords: z
.array(z.string())
.describe(
'List of keywords to search for in the specified doc site'
)
})
)
.describe(
"The AI should break down the user's query into multiple parts, each targeting a specific doc site with relevant keywords. This allows for a more comprehensive search across multiple documentation sources."
)
})
})
}
Expand Down Expand Up @@ -105,10 +146,7 @@ export const docRetrieverNode: ChatGraphNode = async state => {

lastConversation.attachments!.docContext.relevantDocs = [
...lastConversation.attachments!.docContext.relevantDocs,
...result.relevantDocs.map(doc => ({
path: doc.metadata?.filePath,
content: doc.pageContent
}))
...result.relevantDocs
]
})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ import { baseState } from '../../base-state'
export enum ChatGraphToolName {
DocRetriever = 'docRetriever',
WebSearch = 'webSearch',
CodebaseSearch = 'codebaseSearch'
CodebaseSearch = 'codebaseSearch',
WebVisit = 'webVisit'
}

export enum ChatGraphNodeName {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import type { LangchainTool } from '@extension/webview-api/chat-context-processor/types/langchain-message'
import { DocCrawler } from '@extension/webview-api/chat-context-processor/utils/doc-crawler'
import { findCurrentToolsCallParams } from '@extension/webview-api/chat-context-processor/utils/find-current-tools-call-params'
import type { ToolMessage } from '@langchain/core/messages'
import { DynamicStructuredTool } from '@langchain/core/tools'
import { z } from 'zod'

import {
ChatGraphToolName,
type ChatGraphNode,
type ChatGraphState
} from './state'

interface WebVisitToolResult {
contents: { url: string; content: string }[]
}

// eslint-disable-next-line unused-imports/no-unused-vars
export const createWebVisitTool = async (state: ChatGraphState) => {
const getPageContents = async (
urls: string[]
): Promise<{ url: string; content: string }[]> => {
const docCrawler = new DocCrawler(urls![0]!)
const promises = await Promise.allSettled(
urls.map(async url => ({
url,
content:
(await docCrawler.getPageContent(url)) || 'Failed to retrieve content'
}))
)
return promises
.filter(promise => promise.status === 'fulfilled')
.map(
promise =>
(promise as PromiseFulfilledResult<{ url: string; content: string }>)
.value
)
}

return new DynamicStructuredTool({
name: ChatGraphToolName.WebVisit,
description:
'Visit specific web pages and retrieve their content. Use this tool when you need to access and analyze the content of one or more web pages.',
func: async ({ urls }): Promise<WebVisitToolResult> => {
const contents = await getPageContents(urls)
return { contents }
},
schema: z.object({
urls: z
.array(z.string().url())
.describe(
'An array of URLs to visit and retrieve content from. Each URL should be a valid web address.'
)
})
})
}

export const webVisitNode: ChatGraphNode = async state => {
const { messages, chatContext } = state
const { conversations } = chatContext
const lastConversation = conversations.at(-1)
const docContext = lastConversation?.attachments?.docContext

if (!docContext) return {}

const webVisitTool = await createWebVisitTool(state)

if (!webVisitTool) return {}

const tools: LangchainTool[] = [webVisitTool]
const lastMessage = messages.at(-1)
const toolCalls = findCurrentToolsCallParams(lastMessage, tools)

if (!toolCalls.length) return {}

const toolCallsPromises = toolCalls.map(async toolCall => {
const toolMessage = (await webVisitTool.invoke(toolCall)) as ToolMessage

const result = JSON.parse(
toolMessage?.lc_kwargs.content
) as WebVisitToolResult

lastConversation.attachments!.docContext.relevantDocs = [
...lastConversation.attachments!.docContext.relevantDocs,
...result.contents.map(item => ({
path: item.url,
content: item.content
}))
]
})

await Promise.allSettled(toolCallsPromises)

return {
chatContext
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import type { BaseToolContext } from './base-tool-context'

export interface DocInfo {
content: string
path: string // file path or url
}

export interface DocContext extends BaseToolContext {
allowSearchDocSiteUrls: string[]
export interface DocContext {
allowSearchDocSiteNames: string[]
relevantDocs: DocInfo[]
}
Loading

0 comments on commit 096da9f

Please sign in to comment.