Skip to content

Commit

Permalink
fix: fix batch processor do not create file on windows and smart-past…
Browse files Browse the repository at this point in the history
…e incorrect position bug
  • Loading branch information
2214962083 committed Aug 6, 2024
1 parent 7340dc5 commit 8bb7949
Show file tree
Hide file tree
Showing 14 changed files with 299 additions and 356 deletions.
481 changes: 166 additions & 315 deletions pnpm-lock.yaml

Large diffs are not rendered by default.

18 changes: 14 additions & 4 deletions src/ai/get-reference-file-paths.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { AbortError } from '@/constants'
import { traverseFileOrFolders } from '@/file-utils/traverse-fs'
import { getCurrentWorkspaceFolderEditor } from '@/utils'
import { getCurrentWorkspaceFolderEditor, toPlatformPath } from '@/utils'
import * as vscode from 'vscode'
import { z } from 'zod'

Expand All @@ -11,9 +12,11 @@ export interface ReferenceFilePaths {
}

export const getReferenceFilePaths = async ({
currentFilePath
currentFilePath,
abortController
}: {
currentFilePath: string
abortController?: AbortController
}): Promise<ReferenceFilePaths> => {
const { workspaceFolder } = await getCurrentWorkspaceFolderEditor()
const allRelativePaths: string[] = []
Expand All @@ -31,10 +34,11 @@ export const getReferenceFilePaths = async ({

const modelProvider = await createModelProvider()
const aiRunnable = await modelProvider.createStructuredOutputRunnable({
signal: abortController?.signal,
useHistory: false,
zodSchema: z.object({
referenceFileRelativePaths: z.array(z.string()).min(0).max(3).describe(`
Required! The relative paths of the up to three most useful files related to the currently edited file. This can include 0 to 3 files.
Required! The relative paths array of the up to three most useful files related to the currently edited file. This can include 0 to 3 files.
`),
dependenceFileRelativePath: z.string().describe(`
Required! The relative path of the dependency file for the current file. If the dependency file is not found, return an empty string.
Expand Down Expand Up @@ -69,5 +73,11 @@ Please find and return the dependency file path for the current file and the thr
`
})

return aiRes
if (abortController?.signal.aborted) throw AbortError

return {
referenceFileRelativePaths:
aiRes.referenceFileRelativePaths.map(toPlatformPath),
dependenceFileRelativePath: toPlatformPath(aiRes.dependenceFileRelativePath)
}
}
28 changes: 18 additions & 10 deletions src/ai/model-providers/base.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { MaybePromise } from '@/types'
import { normalizeLineEndings } from '@/utils'
import { InMemoryChatMessageHistory } from '@langchain/core/chat_history'
import type { BaseChatModel } from '@langchain/core/language_models/chat_models'
import {
Expand Down Expand Up @@ -29,21 +30,24 @@ export interface BaseModelProviderCreateStructuredOutputRunnableOptions<
> {
useHistory?: boolean
historyMessages?: BaseMessage[]
signal?: AbortSignal
zodSchema: ZSchema
}

export abstract class BaseModelProvider<Model extends BaseChatModel> {
static sessionIdHistoriesMap: Record<string, InMemoryChatMessageHistory> = {}

static answerContentToText(content: MessageContent): string {
if (typeof content === 'string') return content

return content
.map(c => {
if (c.type === 'text') return c.text
return ''
})
.join('')
if (typeof content === 'string') return normalizeLineEndings(content)

return normalizeLineEndings(
content
.map(c => {
if (c.type === 'text') return c.text
return ''
})
.join('')
)
}

model?: Model
Expand Down Expand Up @@ -118,10 +122,14 @@ export abstract class BaseModelProvider<Model extends BaseChatModel> {
async createStructuredOutputRunnable<
ZSchema extends z.ZodType<any> = z.ZodType<any>
>(options: BaseModelProviderCreateStructuredOutputRunnableOptions<ZSchema>) {
const { useHistory = true, historyMessages, zodSchema } = options
const { useHistory = true, historyMessages, zodSchema, signal } = options
const model = await this.getModel()
const prompt = await this.createPrompt({ useHistory })
const chain = prompt.pipe(model.withStructuredOutput(zodSchema))
const chain = prompt.pipe(
model.withStructuredOutput(zodSchema).bind({
signal
})
)

return useHistory
? await this.createRunnableWithMessageHistory(
Expand Down
47 changes: 38 additions & 9 deletions src/commands/batch-processor/get-pre-process-info.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import path from 'path'
import { createModelProvider } from '@/ai/helpers'
import { AbortError } from '@/constants'
import { traverseFileOrFolders } from '@/file-utils/traverse-fs'
import { getCurrentWorkspaceFolderEditor } from '@/utils'
import { getCurrentWorkspaceFolderEditor, toPlatformPath } from '@/utils'
import { z } from 'zod'

export interface PreProcessInfo {
Expand All @@ -16,10 +17,12 @@ export interface PreProcessInfo {

export const getPreProcessInfo = async ({
prompt,
fileRelativePathsForProcess
fileRelativePathsForProcess,
abortController
}: {
prompt: string
fileRelativePathsForProcess: string[]
abortController?: AbortController
}): Promise<
PreProcessInfo & {
allFileRelativePaths: string[]
Expand All @@ -38,6 +41,7 @@ export const getPreProcessInfo = async ({

const modelProvider = await createModelProvider()
const aiRunnable = await modelProvider.createStructuredOutputRunnable({
signal: abortController?.signal,
useHistory: false,
zodSchema: z.object({
processFilePathInfo: z
Expand Down Expand Up @@ -115,33 +119,58 @@ Please analyze these files and provide the requested information to help streaml
`
})

if (abortController?.signal.aborted) throw AbortError

aiRes.dependenceFileRelativePath = toPlatformPath(
aiRes.dependenceFileRelativePath || ''
)
aiRes.ignoreFileRelativePaths =
aiRes.ignoreFileRelativePaths?.map(toPlatformPath)
aiRes.processFilePathInfo = aiRes.processFilePathInfo.map(info => ({
sourceFileRelativePath: toPlatformPath(info.sourceFileRelativePath),
processedFileRelativePath: toPlatformPath(info.processedFileRelativePath),
referenceFileRelativePaths:
info.referenceFileRelativePaths.map(toPlatformPath)
}))

// data cleaning
// Process and filter the file path information
const finalProcessFilePathInfo: PreProcessInfo['processFilePathInfo'] =
aiRes.processFilePathInfo
.map(info => {
const {
sourceFileRelativePath,
processedFileRelativePath,
referenceFileRelativePaths
} = info
const { ignoreFileRelativePaths } = aiRes

// Extract the base name and extension from the source file path
const sourceBaseName = path.basename(
info.sourceFileRelativePath,
path.extname(info.sourceFileRelativePath)
sourceFileRelativePath,
path.extname(sourceFileRelativePath)
)
// Get the extension from the processed file path
const processedExtName = path.extname(info.processedFileRelativePath)
const processedExtName = path.extname(processedFileRelativePath)
// Construct the full processed file path
const fullProcessedPath = path.join(
path.dirname(info.sourceFileRelativePath),
path.dirname(sourceFileRelativePath),
sourceBaseName + processedExtName
)

// Check if the processed file path should be ignored
const shouldIgnore =
fullProcessedPath === info.sourceFileRelativePath &&
aiRes.ignoreFileRelativePaths?.includes(info.sourceFileRelativePath)
fullProcessedPath === sourceFileRelativePath &&
ignoreFileRelativePaths?.includes(sourceFileRelativePath)

// Return the new info object or null if it should be ignored
return shouldIgnore
? null
: { ...info, processedFileRelativePath: fullProcessedPath }
: {
sourceFileRelativePath,
processedFileRelativePath: toPlatformPath(fullProcessedPath),
referenceFileRelativePaths
}
})
// Filter out any null entries
.filter(
Expand Down
8 changes: 5 additions & 3 deletions src/commands/batch-processor/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { getConfigKey } from '@/config'
import { AbortError } from '@/constants'
import { isTmpFileUri } from '@/file-utils/create-tmp-file'
import { traverseFileOrFolders } from '@/file-utils/traverse-fs'
import { t } from '@/i18n'
Expand Down Expand Up @@ -56,12 +57,13 @@ export const handleBatchProcessor = async (

const preProcessInfo = await getPreProcessInfo({
prompt,
fileRelativePathsForProcess
fileRelativePathsForProcess,
abortController
})

logger.log('handleBatchProcessor', preProcessInfo)

if (abortController.signal.aborted) return
if (abortController?.signal.aborted) throw AbortError

const apiConcurrency = (await getConfigKey('apiConcurrency')) || 1
const limit = pLimit(apiConcurrency)
Expand All @@ -75,7 +77,7 @@ export const handleBatchProcessor = async (
processedFileRelativePath: info.processedFileRelativePath,
dependenceFileRelativePath: preProcessInfo.dependenceFileRelativePath,
abortController
})
}).catch(err => logger.warn('writeAndSaveTmpFile error', err))
)
)

Expand Down
5 changes: 3 additions & 2 deletions src/commands/batch-processor/write-and-save-tmp-file.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import path from 'path'
import { createModelProvider } from '@/ai/helpers'
import { AbortError } from '@/constants'
import { getTmpFileUri } from '@/file-utils/create-tmp-file'
import { tmpFileWriter } from '@/file-utils/tmp-file-writer'
import { VsCodeFS } from '@/file-utils/vscode-fs'
Expand Down Expand Up @@ -32,7 +33,7 @@ export const writeAndSaveTmpFile = async ({
signal: abortController?.signal
})

if (abortController?.signal.aborted) return
if (abortController?.signal.aborted) throw AbortError

const getContentFromRelativePath = async (relativePath: string) => {
if (!relativePath) return ''
Expand Down Expand Up @@ -65,7 +66,7 @@ export const writeAndSaveTmpFile = async ({
dependenceFileRelativePath || ''
)

if (abortController?.signal.aborted) return
if (abortController?.signal.aborted) throw AbortError

await tmpFileWriter({
stopWriteWhenClosed: true,
Expand Down
11 changes: 10 additions & 1 deletion src/commands/rename-variable/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import {
createModelProvider,
getCurrentSessionIdHistoriesMap
} from '@/ai/helpers'
import { AbortError } from '@/constants'
import { t } from '@/i18n'
import { createLoading } from '@/loading'
import { getCurrentWorkspaceFolderEditor } from '@/utils'
Expand Down Expand Up @@ -37,7 +38,9 @@ export const handleRenameVariable = async () => {
const modelProvider = await createModelProvider()
const { showProcessLoading, hideProcessLoading } = createLoading()

const abortController = new AbortController()
const aiRunnable = await modelProvider.createStructuredOutputRunnable({
signal: abortController.signal,
zodSchema: renameSuggestionZodSchema
})
const sessionId = `renameVariable:${variableName}`
Expand All @@ -53,7 +56,11 @@ export const handleRenameVariable = async () => {

let aiRes: any
try {
showProcessLoading()
showProcessLoading({
onCancel: () => {
abortController.abort()
}
})
const prompt = await buildRenameSuggestionPrompt({
contextCode: activeEditor.document.getText(),
variableName,
Expand All @@ -74,6 +81,8 @@ export const handleRenameVariable = async () => {
hideProcessLoading()
}

if (abortController?.signal.aborted) throw AbortError

const suggestionVariableNameOptions = Array.from(
aiRes?.suggestionVariableNameOptions || []
) as RenameSuggestionZodSchema['suggestionVariableNameOptions']
Expand Down
6 changes: 4 additions & 2 deletions src/commands/smart-paste/build-convert-chat-messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ const getClipboardContent = async () => {
export const buildConvertChatMessages = async ({
workspaceFolder,
currentFilePath,
selection
selection,
abortController
}: {
workspaceFolder: vscode.WorkspaceFolder
currentFilePath: string
selection: vscode.Selection
abortController?: AbortController
}): Promise<BaseMessage[]> => {
const { clipboardImg, clipboardContent } = await getClipboardContent()

Expand All @@ -49,7 +51,7 @@ export const buildConvertChatMessages = async ({

// reference file content
const { referenceFileRelativePaths, dependenceFileRelativePath } =
await cacheGetReferenceFilePaths({ currentFilePath })
await cacheGetReferenceFilePaths({ currentFilePath, abortController })
const referencePaths = [
...new Set([dependenceFileRelativePath, ...referenceFileRelativePaths])
]
Expand Down
3 changes: 2 additions & 1 deletion src/commands/smart-paste/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ export const handleSmartPaste = async () => {
const convertMessages = await buildConvertChatMessages({
workspaceFolder,
currentFilePath,
selection: activeEditor.selection
selection: activeEditor.selection,
abortController: aiModelAbortController
})

const history = await modelProvider.getHistory(sessionId)
Expand Down
2 changes: 2 additions & 0 deletions src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,5 @@ export const languageExtIdMap = Object.fromEntries(
exts.map(ext => [ext, id])
)
)

export const AbortError = new Error('AbortError')
11 changes: 9 additions & 2 deletions src/file-utils/create-tmp-file.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,16 @@ export const getTmpFileUri = ({
const originalFileExt = path.parse(filePath).ext
const languageExt = getLanguageIdExt(languageId) || languageId

return vscode.Uri.parse(
`${untitled ? 'untitled:' : ''}${path.join(originalFileDir, `${originalFileName}${originalFileExt}.aide${languageExt ? `.${languageExt}` : ''}`)}`
const finalPath = path.join(
originalFileDir,
`${originalFileName}${originalFileExt}.aide${languageExt ? `.${languageExt}` : ''}`
)

if (!untitled) {
return vscode.Uri.file(finalPath)
}

return vscode.Uri.parse(`untitled:${finalPath}`)
}

const aideTmpFileRegExp = /\.aide(\.[^.]+)?$/
Expand Down
8 changes: 7 additions & 1 deletion src/file-utils/stream-completion-writer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ import { createLoading } from '@/loading'
import {
removeCodeBlockEndSyntax,
removeCodeBlockStartSyntax,
removeCodeBlockSyntax
removeCodeBlockSyntax,
sleep
} from '@/utils'
import type { IterableReadableStream } from '@langchain/core/dist/utils/stream'
import type { AIMessageChunk } from '@langchain/core/messages'
Expand Down Expand Up @@ -43,6 +44,8 @@ export const streamingCompletionWriter = async (
{ undoStopBefore: false, undoStopAfter: false }
)

await sleep(10)

// update current position to the end of the inserted text
currentPosition = editor.document.positionAt(
editor.document.offsetAt(currentPosition) + text.length
Expand Down Expand Up @@ -70,6 +73,8 @@ export const streamingCompletionWriter = async (
{ undoStopBefore: false, undoStopAfter: false }
)

await sleep(10)

// update current position to the end of the inserted text
currentPosition = editor.document.positionAt(
editor.document.offsetAt(originPosition) + text.length
Expand All @@ -95,6 +100,7 @@ export const streamingCompletionWriter = async (

// convert openai answer content to text
const text = ModelProvider.answerContentToText(chunk.content)

if (!text) continue

fullText += text
Expand Down
Loading

0 comments on commit 8bb7949

Please sign in to comment.