Skip to content

Commit

Permalink
fix: make file operations and contexts async
Browse files Browse the repository at this point in the history
The changes make various file operations and context resolution properly
async by:
- Converting synchronous UV calls to async versions
- Moving context resolution inside async block
- Making tiktoken data loading fully async
- Converting directory scanning to async version
- Adding proper error handling for async UV operations

This improves overall responsiveness of the plugin by preventing blocking
operations from freezing the editor.

Signed-off-by: Tomas Slusny <[email protected]>
  • Loading branch information
deathbeam committed Nov 27, 2024
1 parent 3968c25 commit 3c6b463
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 84 deletions.
3 changes: 2 additions & 1 deletion .luarc.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{
"diagnostics.globals": ["describe", "it"]
"diagnostics.globals": ["describe", "it"],
"diagnostics.disable": ["redefined-local"]
}
4 changes: 4 additions & 0 deletions lua/CopilotChat/context.lua
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
---@class CopilotChat.context.outline : CopilotChat.copilot.embed
---@field symbols table<string, CopilotChat.context.symbol>

local async = require('plenary.async')
local log = require('plenary.log')
local utils = require('CopilotChat.utils')

Expand Down Expand Up @@ -315,6 +316,7 @@ function M.file(filename)
return nil
end

async.util.scheduler()
return {
content = content,
filename = vim.fn.fnamemodify(filename, ':p:.'),
Expand All @@ -326,6 +328,8 @@ end
---@param bufnr number
---@return CopilotChat.copilot.embed?
function M.buffer(bufnr)
async.util.scheduler()

if not utils.buf_valid(bufnr) then
return nil
end
Expand Down
71 changes: 38 additions & 33 deletions lua/CopilotChat/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -233,24 +233,16 @@ end
---@param config CopilotChat.config.shared
---@return table<CopilotChat.copilot.embed>, string
local function resolve_embeddings(prompt, config)
local embeddings = utils.ordered_map()

local contexts = {}
local function parse_context(prompt_context)
local split = vim.split(prompt_context, ':')
local context_name = table.remove(split, 1)
local context_input = vim.trim(table.concat(split, ':'))
local context_value = M.config.contexts[context_name]
if context_input == '' then
---@diagnostic disable-next-line: cast-local-type
context_input = nil
end

if context_value then
for _, embedding in ipairs(context_value.resolve(context_input, state.source)) do
if embedding then
embeddings:set(embedding.filename, embedding)
end
end
if M.config.contexts[context_name] then
table.insert(contexts, {
name = context_name,
input = (context_input ~= '' and context_input or nil),
})

return true
end
Expand All @@ -276,6 +268,16 @@ local function resolve_embeddings(prompt, config)
end
end

local embeddings = utils.ordered_map()
for _, context_data in ipairs(contexts) do
local context_value = M.config.contexts[context_data.name]
for _, embedding in ipairs(context_value.resolve(context_data.input, state.source)) do
if embedding then
embeddings:set(embedding.filename, embedding)
end
end
end

return embeddings:values(), prompt
end

Expand Down Expand Up @@ -678,14 +680,13 @@ function M.ask(prompt, config)
'\n'
))

-- Retrieve embeddings
local embeddings, embedded_prompt = resolve_embeddings(prompt, config)
prompt = embedded_prompt

-- Retrieve the selection
local selection = get_selection(config)

async.run(function()
local embeddings, embedded_prompt = resolve_embeddings(prompt, config)
prompt = embedded_prompt

local agents = vim.tbl_keys(state.copilot:list_agents())
local selected_agent = config.agent
prompt = prompt:gsub('@' .. WORD, function(match)
Expand Down Expand Up @@ -1162,25 +1163,29 @@ function M.setup(config)

map_key('show_user_context', bufnr, function()
local section = state.chat:get_closest_section()
local embeddings = {}
if section and not section.answer then
embeddings = resolve_embeddings(section.content, state.chat.config)
end

local text = ''
for _, embedding in ipairs(embeddings) do
local lines = vim.split(embedding.content, '\n')
local preview = table.concat(vim.list_slice(lines, 1, math.min(10, #lines)), '\n')
local header = string.format('**`%s`** (%s lines)', embedding.filename, #lines)
if #lines > 10 then
header = header .. ' (truncated)'
async.run(function()
local embeddings = {}
if section and not section.answer then
embeddings = resolve_embeddings(section.content, state.chat.config)
end

text = text
.. string.format('%s\n```%s\n%s\n```\n\n', header, embedding.filetype, preview)
end
local text = ''
for _, embedding in ipairs(embeddings) do
local lines = vim.split(embedding.content, '\n')
local preview = table.concat(vim.list_slice(lines, 1, math.min(10, #lines)), '\n')
local header = string.format('**`%s`** (%s lines)', embedding.filename, #lines)
if #lines > 10 then
header = header .. ' (truncated)'
end

state.overlay:show(vim.trim(text) .. '\n', state.chat.winnr, 'markdown')
text = text
.. string.format('%s\n```%s\n%s\n```\n\n', header, embedding.filetype, preview)
end

async.util.scheduler()
state.overlay:show(vim.trim(text) .. '\n', state.chat.winnr, 'markdown')
end)
end)

vim.api.nvim_create_autocmd({ 'BufEnter', 'BufLeave' }, {
Expand Down
46 changes: 19 additions & 27 deletions lua/CopilotChat/tiktoken.lua
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
local async = require('plenary.async')
local curl = require('plenary.curl')
local log = require('plenary.log')
local utils = require('CopilotChat.utils')
local tiktoken_core = nil
Expand All @@ -9,58 +8,51 @@ vim.fn.mkdir(tostring(cache_dir), 'p')

--- Load tiktoken data from cache or download it
---@param tokenizer string The tokenizer to load
---@param on_done fun(path: string) The callback to call when the data is loaded
local function load_tiktoken_data(tokenizer, on_done)
local function load_tiktoken_data(tokenizer)
local tiktoken_url = 'https://openaipublic.blob.core.windows.net/encodings/'
.. tokenizer
.. '.tiktoken'
local cache_path = cache_dir .. '/' .. tiktoken_url:match('.+/(.+)')

if utils.file_exists(cache_path) then
on_done(cache_path)
return
return cache_path
end

log.info('Downloading tiktoken data from ' .. tiktoken_url)
curl.get(tiktoken_url, {
utils.curl_get(tiktoken_url, {
output = cache_path,
callback = function()
on_done(cache_path)
end,
})

return cache_path
end

local M = {}

--- Load the tiktoken module
---@param tokenizer string The tokenizer to load
M.load = async.wrap(function(tokenizer, callback)
M.load = function(tokenizer)
if tokenizer == current_tokenizer then
callback()
return
end

local ok, core = pcall(require, 'tiktoken_core')
if not ok then
callback()
return
end

load_tiktoken_data(tokenizer, function(path)
local special_tokens = {}
special_tokens['<|endoftext|>'] = 100257
special_tokens['<|fim_prefix|>'] = 100258
special_tokens['<|fim_middle|>'] = 100259
special_tokens['<|fim_suffix|>'] = 100260
special_tokens['<|endofprompt|>'] = 100276
local pat_str =
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
core.new(path, special_tokens, pat_str)
tiktoken_core = core
current_tokenizer = tokenizer
callback()
end)
end, 2)
local path = load_tiktoken_data(tokenizer)
local special_tokens = {}
special_tokens['<|endoftext|>'] = 100257
special_tokens['<|fim_prefix|>'] = 100258
special_tokens['<|fim_middle|>'] = 100259
special_tokens['<|fim_suffix|>'] = 100260
special_tokens['<|endofprompt|>'] = 100276
local pat_str =
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
core.new(path, special_tokens, pat_str)
tiktoken_core = core
current_tokenizer = tokenizer
end

--- Encode a prompt
---@param prompt string The prompt to encode
Expand Down
49 changes: 26 additions & 23 deletions lua/CopilotChat/utils.lua
Original file line number Diff line number Diff line change
Expand Up @@ -261,44 +261,47 @@ M.curl_post = async.wrap(function(url, opts, callback)
end, 3)

--- Scan a directory
--- FIXME: Make async
M.scan_dir = scandir.scan_dir

-- M.scan_dir = async.wrap(function(path, opts, callback)
-- scandir.scan_dir_async(path, vim.tbl_deep_extend('force', opts, {
-- on_exit = callback,
-- on_error = function(err)
-- err = err and err.stderr or vim.inspect(err)
-- callback(nil, err)
-- end,
-- }))
-- end, 3)
---@param path string The directory path
---@param opts table The options
M.scan_dir = async.wrap(function(path, opts, callback)
scandir.scan_dir_async(
path,
vim.tbl_deep_extend('force', opts, {
on_exit = callback,
on_error = function(err)
err = err and err.stderr or vim.inspect(err)
callback(nil, err)
end,
})
)
end, 3)

--- Check if a file exists
--- FIXME: Make async
---@param path string The file path
M.file_exists = function(path)
local stat = vim.uv.fs_stat(path)
return stat ~= nil
local err, stat = async.uv.fs_stat(path)
return err == nil and stat ~= nil
end

--- Read a file
--- FIXME: Make async
---@param path string The file path
M.read_file = function(path)
local fd = vim.uv.fs_open(path, 'r', 438)
if not fd then
local err, fd = async.uv.fs_open(path, 'r', 438)
if err or not fd then
return nil
end

local stat = vim.uv.fs_fstat(fd)
if not stat then
vim.uv.fs_close(fd)
local err, stat = async.uv.fs_fstat(fd)
if err or not stat then
async.uv.fs_close(fd)
return nil
end

local data = vim.uv.fs_read(fd, stat.size, 0)
vim.uv.fs_close(fd)
local err, data = async.uv.fs_read(fd, stat.size, 0)
async.uv.fs_close(fd)
if err or not data then
return nil
end
return data
end

Expand Down

0 comments on commit 3c6b463

Please sign in to comment.