Skip to content

Commit

Permalink
feat: replace promise allsettled
Browse files Browse the repository at this point in the history
  • Loading branch information
2214962083 committed Oct 7, 2024
1 parent 096da9f commit 2f459ff
Show file tree
Hide file tree
Showing 12 changed files with 51 additions and 78 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import type { ChatContext } from '@extension/webview-api/chat-context-processor/types/chat-context'
import type { LangchainMessage } from '@extension/webview-api/chat-context-processor/types/langchain-message'
import { HumanMessage, SystemMessage } from '@langchain/core/messages'
import { settledPromiseResults } from '@shared/utils/common'

import { CHAT_WITH_FILE_SYSTEM_PROMPT, COMMON_SYSTEM_PROMPT } from './constants'
import { ConversationMessageConstructor } from './conversation-message-constructor'
Expand Down Expand Up @@ -70,15 +71,7 @@ ${explicitContext}
hasAttachedFiles
).buildMessages()
)
const messageArrays = await Promise.allSettled(messagePromises)
const messages: LangchainMessage[] = []

messageArrays.forEach(result => {
if (result.status === 'fulfilled') {
messages.push(...result.value)
}
})

return messages
const messageArrays = await settledPromiseResults(messagePromises)
return messageArrays.flat()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { findCurrentToolsCallParams } from '@extension/webview-api/chat-context-
import { mergeCodeSnippets } from '@extension/webview-api/chat-context-processor/utils/merge-code-snippets'
import type { ToolMessage } from '@langchain/core/messages'
import { DynamicStructuredTool } from '@langchain/core/tools'
import { settledPromiseResults } from '@shared/utils/common'
import { z } from 'zod'

import {
Expand Down Expand Up @@ -39,13 +40,12 @@ export const createCodebaseSearchTool = async (state: ChatGraphState) => {

if (!indexer) return searchResults

const searchPromisesResult = await Promise.allSettled(
const searchPromisesResult = await settledPromiseResults(
queryParts?.map(queryPart => indexer.searchSimilarRow(queryPart)) || []
)

const searchCodeSnippets: CodeSnippet[] = searchPromisesResult
.filter(result => result.status === 'fulfilled')
.flatMap(result => (result as any).value)
.flat()
.map(row => {
// eslint-disable-next-line unused-imports/no-unused-vars
const { embedding, ...others } = row
Expand Down Expand Up @@ -117,7 +117,7 @@ export const codebaseSearchNode: ChatGraphNode = async state => {
]
})

await Promise.allSettled(toolCallsPromises)
await settledPromiseResults(toolCallsPromises)

return {
chatContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import { DocIndexer } from '@extension/webview-api/chat-context-processor/vector
import { docSitesDB } from '@extension/webview-api/lowdb/doc-sites-db'
import type { ToolMessage } from '@langchain/core/messages'
import { DynamicStructuredTool } from '@langchain/core/tools'
import { removeDuplicates } from '@shared/utils/common'
import { removeDuplicates, settledPromiseResults } from '@shared/utils/common'
import { z } from 'zod'

import {
Expand Down Expand Up @@ -51,44 +51,27 @@ export const createDocRetrieverTool = async (state: ChatGraphState) => {

await docIndexer.initialize()

const searchResults = await Promise.allSettled(
const searchResults = await settledPromiseResults(
keywords.map(keyword => docIndexer.searchSimilarRow(keyword))
)

const searchRows = removeDuplicates(
searchResults
.filter(
(result): result is PromiseFulfilledResult<any> =>
result.status === 'fulfilled'
)
.flatMap(result => result.value),
searchResults.flatMap(result => result),
['fullPath']
).slice(0, 3)

const docInfoResults = await Promise.allSettled(
const docInfoResults = await settledPromiseResults(
searchRows.map(async row => ({
content: await docIndexer.getRowFileContent(row),
path: docSite.url
}))
)

return docInfoResults
.filter(
(result): result is PromiseFulfilledResult<DocInfo> =>
result.status === 'fulfilled'
)
.map(result => result.value)
})

const results = await Promise.allSettled(docPromises)
const relevantDocs = results
.filter(
(result): result is PromiseFulfilledResult<DocInfo[]> =>
result.status === 'fulfilled'
)
.flatMap(result => result.value)

return relevantDocs
const results = await settledPromiseResults(docPromises)
return results.flatMap(result => result)
}

return new DynamicStructuredTool({
Expand Down Expand Up @@ -150,7 +133,7 @@ export const docRetrieverNode: ChatGraphNode = async state => {
]
})

await Promise.allSettled(toolCallsPromises)
await settledPromiseResults(toolCallsPromises)

return {
chatContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { CheerioWebBaseLoader } from '@langchain/community/document_loaders/web/
import type { Document } from '@langchain/core/documents'
import { HumanMessage, type ToolMessage } from '@langchain/core/messages'
import { DynamicStructuredTool } from '@langchain/core/tools'
import { settledPromiseResults } from '@shared/utils/common'
import { z } from 'zod'

import { ChatMessagesConstructor } from '../messages-constructors/chat-messages-constructor'
Expand Down Expand Up @@ -38,17 +39,11 @@ export const createWebSearchTool = async (state: ChatGraphState) => {
const searxngSearchResult = await searxngSearch(keywords)
const urls = searxngSearchResult.results.map(result => result.url)

const docsLoadResult = await Promise.allSettled(
const docsLoadResult = await settledPromiseResults(
urls.map(url => new CheerioWebBaseLoader(url).load())
)

const docs: Document<Record<string, any>>[] = []

docsLoadResult.forEach(result => {
if (result.status === 'fulfilled') {
docs.push(...result.value)
}
})
const docs: Document<Record<string, any>>[] = docsLoadResult.flat()

const docsContent = docs
.map(doc => doc.pageContent)
Expand Down Expand Up @@ -156,7 +151,7 @@ export const webSearchNode: ChatGraphNode = async state => {
]
})

await Promise.allSettled(toolCallsPromises)
await settledPromiseResults(toolCallsPromises)

return {
chatContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { DocCrawler } from '@extension/webview-api/chat-context-processor/utils/
import { findCurrentToolsCallParams } from '@extension/webview-api/chat-context-processor/utils/find-current-tools-call-params'
import type { ToolMessage } from '@langchain/core/messages'
import { DynamicStructuredTool } from '@langchain/core/tools'
import { settledPromiseResults } from '@shared/utils/common'
import { z } from 'zod'

import {
Expand All @@ -20,21 +21,15 @@ export const createWebVisitTool = async (state: ChatGraphState) => {
const getPageContents = async (
urls: string[]
): Promise<{ url: string; content: string }[]> => {
const docCrawler = new DocCrawler(urls![0]!)
const promises = await Promise.allSettled(
const docCrawler = new DocCrawler(urls[0]!)
const contents = await settledPromiseResults(
urls.map(async url => ({
url,
content:
(await docCrawler.getPageContent(url)) || 'Failed to retrieve content'
}))
)
return promises
.filter(promise => promise.status === 'fulfilled')
.map(
promise =>
(promise as PromiseFulfilledResult<{ url: string; content: string }>)
.value
)
return contents
}

return new DynamicStructuredTool({
Expand Down Expand Up @@ -89,7 +84,7 @@ export const webVisitNode: ChatGraphNode = async state => {
]
})

await Promise.allSettled(toolCallsPromises)
await settledPromiseResults(toolCallsPromises)

return {
chatContext
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/* eslint-disable prefer-destructuring */
import { logger } from '@extension/logger'
import { BinaryOperatorAggregate } from '@langchain/langgraph'
import { settledPromiseResults } from '@shared/utils/common'

import type { CreateAnnotationRoot } from '../types/langgraph'

Expand All @@ -14,25 +14,18 @@ export const combineNode = <GraphState extends Record<string, any>>(
): GraphNode<GraphState> => {
const combined: GraphNode<GraphState> = async state => {
const promises = nodes.map(async node => await node(state))
const promisesResults = await Promise.allSettled(promises)
const states = await settledPromiseResults(promises)
const keys = new Set<keyof GraphState>()
const states: Partial<GraphState>[] = []

promisesResults.forEach((result, index) => {
if (result.status === 'fulfilled') {
const partialState = result.value as Partial<GraphState>
Object.keys(partialState).forEach(key =>
keys.add(key as keyof GraphState)
)
states.push(partialState)
} else {
logger.warn(`Error in node ${index}:`, result.reason)
}

states.forEach(partialState => {
Object.keys(partialState).forEach(key =>
keys.add(key as keyof GraphState)
)
})

const combinedResult = {} as Partial<GraphState>

for (const _key in keys) {
for (const _key of keys) {
const key = _key as keyof GraphState
const annotation = stateDefinition.spec[key]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ export class DocCrawler {
while (this.queue.length > 0 && this.visited.size < this.options.maxPages) {
const batch = this.queue.splice(0, this.options.concurrency)
const promises = batch.map(item => this.crawlPage(item.url, item.depth))
await Promise.all(promises)
await Promise.allSettled(promises)
await new Promise(resolve => setTimeout(resolve, this.options.delay))
this.progressReporter.setProcessedItems(this.visited.size)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ export abstract class BaseIndexer<T extends IndexRow> {
this.progressReporter.reset()
const filePaths = await this.getAllIndexedFilePaths()
const filePathsNeedReindex: string[] = []
const tasks = filePaths.map(async filePath => {
const tasksPromises = filePaths.map(async filePath => {
try {
const currentHash = await this.generateFileHash(filePath)
const existingRows = await this.getFileRows(filePath)
Expand All @@ -215,7 +215,7 @@ export abstract class BaseIndexer<T extends IndexRow> {
}
})

await Promise.allSettled(tasks)
await Promise.allSettled(tasksPromises)

this.totalFiles = filePathsNeedReindex.length
this.progressReporter.setTotalItems(this.totalFiles)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { getExt, getSemanticHashName } from '@extension/file-utils/paths'
import { traverseFileOrFolders } from '@extension/file-utils/traverse-fs'
import { VsCodeFS } from '@extension/file-utils/vscode-fs'
import { logger } from '@extension/logger'
import { settledPromiseResults } from '@shared/utils/common'
import { languageIdExts } from '@shared/utils/vscode-lang'
import * as vscode from 'vscode'

Expand Down Expand Up @@ -61,7 +62,7 @@ export class CodebaseIndexer extends BaseIndexer<CodeChunkRow> {
}
})

return Promise.all(chunkRowsPromises)
return settledPromiseResults(chunkRowsPromises)
}

private async chunkCodeFile(filePath: string): Promise<TextChunk[]> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { getSemanticHashName } from '@extension/file-utils/paths'
import { traverseFileOrFolders } from '@extension/file-utils/traverse-fs'
import { VsCodeFS } from '@extension/file-utils/vscode-fs'
import { logger } from '@extension/logger'
import { settledPromiseResults } from '@shared/utils/common'

import { CodeChunkerManager, type TextChunk } from '../tree-sitter/code-chunker'
import { ProgressReporter } from '../utils/process-reporter'
Expand Down Expand Up @@ -59,7 +60,7 @@ export class DocIndexer extends BaseIndexer<DocChunkRow> {
}
})

return Promise.all(chunkRowsPromises)
return settledPromiseResults(chunkRowsPromises)
}

private async chunkCodeFile(filePath: string): Promise<TextChunk[]> {
Expand Down
3 changes: 2 additions & 1 deletion src/extension/webview-api/controllers/git.controller.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import type { CommandManager } from '@extension/commands/command-manager'
import type { RegisterManager } from '@extension/registers/register-manager'
import { getWorkspaceFolder } from '@extension/utils'
import { settledPromiseResults } from '@shared/utils/common'
import simpleGit, { SimpleGit } from 'simple-git'

import type {
Expand All @@ -27,7 +28,7 @@ export class GitController extends Controller {
const { maxCount = 50 } = req
const log = await this.git.log({ maxCount })

const commits: GitCommit[] = await Promise.all(
const commits: GitCommit[] = await settledPromiseResults(
log.all.map(async commit => {
const diff = await this.git.diff([`${commit.hash}^`, commit.hash])
return {
Expand Down
11 changes: 11 additions & 0 deletions src/shared/utils/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,14 @@ export const tryStringifyJSON = (obj: any) => {
return null
}
}

export async function settledPromiseResults<T>(
promises: Promise<T>[]
): Promise<T[]> {
const results = await Promise.allSettled(promises)
return results
.map((result, index) => ({ result, index }))
.filter(item => item.result.status === 'fulfilled')
.sort((a, b) => a.index - b.index)
.map(item => (item.result as PromiseFulfilledResult<T>).value)
}

0 comments on commit 2f459ff

Please sign in to comment.