Skip to content

Commit

Permalink
refactor(context): optimize file loading with caching
Browse files Browse the repository at this point in the history
Refactor context.lua to introduce file caching mechanism and improve code
organization around file handling. The main changes include:

- Add file caching to avoid reprocessing unchanged files
- Move outline generation logic into separate build_outline function
- Consolidate embed type definitions into context.lua
- Remove duplicate type definitions from copilot.lua
- Optimize file loading with new get_file helper function

Signed-off-by: Tomas Slusny <[email protected]>
  • Loading branch information
deathbeam committed Nov 28, 2024
1 parent a0b89f0 commit e0c4ca0
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 109 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ require('CopilotChat').setup({

## Roadmap (Wishlist)

- Caching for contexts
- Improved caching for context (persistence through restarts/smarter caching)
- General QOL improvements

## Development
Expand Down
19 changes: 2 additions & 17 deletions lua/CopilotChat/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,10 @@ local utils = require('CopilotChat.utils')
--- @field bufnr number
--- @field winnr number

---@class CopilotChat.config.selection.diagnostic
---@field content string
---@field start_line number
---@field end_line number
---@field severity string

---@class CopilotChat.config.selection
---@field content string
---@field start_line number
---@field end_line number
---@field filename string
---@field filetype string
---@field bufnr number
---@field diagnostics table<CopilotChat.config.selection.diagnostic>?

---@class CopilotChat.config.context
---@field description string?
---@field input fun(callback: fun(input: string?), source: CopilotChat.config.source)?
---@field resolve fun(input: string?, source: CopilotChat.config.source):table<CopilotChat.copilot.embed>
---@field resolve fun(input: string?, source: CopilotChat.config.source):table<CopilotChat.context.embed>

---@class CopilotChat.config.prompt : CopilotChat.config.shared
---@field prompt string?
Expand Down Expand Up @@ -76,7 +61,7 @@ local utils = require('CopilotChat.utils')
---@field temperature number?
---@field headless boolean?
---@field callback fun(response: string, source: CopilotChat.config.source)?
---@field selection nil|fun(source: CopilotChat.config.source):CopilotChat.config.selection?
---@field selection nil|fun(source: CopilotChat.config.source):CopilotChat.select.selection?
---@field window CopilotChat.config.window?
---@field show_help boolean?
---@field show_folds boolean?
Expand Down
137 changes: 76 additions & 61 deletions lua/CopilotChat/context.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,18 @@
---@field end_row number
---@field end_col number

---@class CopilotChat.context.outline : CopilotChat.copilot.embed
---@field symbols table<string, CopilotChat.context.symbol>
---@class CopilotChat.context.embed
---@field content string
---@field filename string
---@field filetype string
---@field original string?
---@field symbols table<string, CopilotChat.context.symbol>?
---@field embedding table<number>?

local async = require('plenary.async')
local log = require('plenary.log')
local utils = require('CopilotChat.utils')
local file_cache = {}

local M = {}

Expand Down Expand Up @@ -79,10 +85,10 @@ local function spatial_distance_cosine(a, b)
end

--- Rank data by relatedness to the query
---@param query CopilotChat.copilot.embed
---@param data table<CopilotChat.copilot.embed>
---@param query CopilotChat.context.embed
---@param data table<CopilotChat.context.embed>
---@param top_n number
---@return table<CopilotChat.copilot.embed>
---@return table<CopilotChat.context.embed>
local function data_ranked_by_relatedness(query, data, top_n)
data = vim.tbl_map(function(item)
return vim.tbl_extend(
Expand All @@ -101,7 +107,7 @@ end

--- Rank data by symbols
---@param query string
---@param data table<CopilotChat.context.outline>
---@param data table<CopilotChat.context.embed>
---@param top_n number
local function data_ranked_by_symbols(query, data, top_n)
local query_terms = {}
Expand Down Expand Up @@ -193,22 +199,15 @@ end
--- Build an outline and symbols from a string
---@param content string
---@param filename string
---@param ft string?
---@return CopilotChat.context.outline
function M.outline(content, filename, ft)
ft = ft or 'text'

---@param ft string
---@return CopilotChat.context.embed
local function build_outline(content, filename, ft)
local output = {
filename = filename,
filetype = ft,
content = content,
symbols = {},
}

if ft == 'raw' then
return output
end

local lang = vim.treesitter.language.get_lang(ft)
local ok, parser = false, nil
if lang then
Expand All @@ -224,6 +223,7 @@ function M.outline(content, filename, ft)

local root = parser:parse()[1]:root()
local lines = vim.split(content, '\n')
local symbols = {}
local outline_lines = {}
local depth = 0

Expand All @@ -239,7 +239,7 @@ function M.outline(content, filename, ft)
table.insert(outline_lines, string.rep(' ', depth) .. signature_start)

-- Store symbol information
table.insert(output.symbols, {
table.insert(symbols, {
name = name,
signature = signature_start,
type = type,
Expand Down Expand Up @@ -269,15 +269,45 @@ function M.outline(content, filename, ft)
if #outline_lines > 0 then
output.original = content
output.content = table.concat(outline_lines, '\n')
output.symbols = symbols
end

return output
end

--- Get data for a file
---@param filename string
---@param filetype string
---@return CopilotChat.context.embed?
local function get_file(filename, filetype)
local modified = utils.file_mtime(filename)
if not modified then
return nil
end

local cached = file_cache[filename]
if cached and cached.modified >= modified then
return cached.outline
end

local content = utils.read_file(filename)
if content then
local outline = build_outline(content, filename, filetype)
file_cache[filename] = {
outline = outline,
modified = modified,
}

return outline
end

return nil
end

--- Get list of all files in workspace
---@param winnr number?
---@param with_content boolean?
---@return table<CopilotChat.copilot.embed>
---@return table<CopilotChat.context.embed>
function M.files(winnr, with_content)
local cwd = utils.win_cwd(winnr)
local files = utils.scan_dir(cwd, {
Expand All @@ -291,24 +321,22 @@ function M.files(winnr, with_content)
if with_content then
async.util.scheduler()

files = vim.tbl_map(function(file)
return {
name = utils.filepath(file),
ft = utils.filetype(file),
}
end, files)
files = vim.tbl_filter(function(file)
return file.ft ~= nil
end, files)
files = vim.tbl_filter(
function(file)
return file.ft ~= nil
end,
vim.tbl_map(function(file)
return {
name = utils.filepath(file),
ft = utils.filetype(file),
}
end, files)
)

for _, file in ipairs(files) do
local content = utils.read_file(file.name)
if content then
table.insert(out, {
content = content,
filename = file.name,
filetype = file.ft,
})
local file_data = get_file(file.name, file.ft)
if file_data then
table.insert(out, file_data)
end
end

Expand Down Expand Up @@ -338,28 +366,20 @@ end

--- Get the content of a file
---@param filename string
---@return CopilotChat.copilot.embed?
---@return CopilotChat.context.embed?
function M.file(filename)
local content = utils.read_file(filename)
if not content then
return nil
end

async.util.scheduler()
if not utils.filetype(filename) then
local ft = utils.filetype(filename)
if not ft then
return nil
end

return {
content = content,
filename = utils.filepath(filename),
filetype = utils.filetype(filename),
}
return get_file(utils.filepath(filename), ft)
end

--- Get the content of a buffer
---@param bufnr? number
---@return CopilotChat.copilot.embed?
---@return CopilotChat.context.embed?
function M.buffer(bufnr)
async.util.scheduler()
bufnr = bufnr or vim.api.nvim_get_current_buf()
Expand All @@ -373,17 +393,17 @@ function M.buffer(bufnr)
return nil
end

return {
content = table.concat(content, '\n'),
filename = utils.filepath(vim.api.nvim_buf_get_name(bufnr)),
filetype = vim.bo[bufnr].filetype,
}
return build_outline(
table.concat(content, '\n'),
utils.filepath(vim.api.nvim_buf_get_name(bufnr)),
vim.bo[bufnr].filetype
)
end

--- Get current git diff
---@param type string?
---@param winnr number
---@return CopilotChat.copilot.embed?
---@return CopilotChat.context.embed?
function M.gitdiff(type, winnr)
type = type or 'unstaged'
local cwd = utils.win_cwd(winnr)
Expand Down Expand Up @@ -411,7 +431,7 @@ end

--- Return contents of specified register
---@param register string?
---@return CopilotChat.copilot.embed?
---@return CopilotChat.context.embed?
function M.register(register)
register = register or '+'
local lines = vim.fn.getreg(register)
Expand All @@ -429,19 +449,14 @@ end
--- Filter embeddings based on the query
---@param copilot CopilotChat.Copilot
---@param prompt string
---@param embeddings table<CopilotChat.copilot.embed>
---@return table<CopilotChat.copilot.embed>
---@param embeddings table<CopilotChat.context.embed>
---@return table<CopilotChat.context.embed>
function M.filter_embeddings(copilot, prompt, embeddings)
-- If we dont need to embed anything, just return directly
if #embeddings < MULTI_FILE_THRESHOLD then
return embeddings
end

-- Map embeddings to outlines
embeddings = vim.tbl_map(function(embed)
return M.outline(embed.content, embed.filename, embed.filetype)
end, embeddings)

-- Rank embeddings by symbols
embeddings = data_ranked_by_symbols(prompt, embeddings, TOP_SYMBOLS)
log.debug('Ranked data:', #embeddings)
Expand Down
24 changes: 7 additions & 17 deletions lua/CopilotChat/copilot.lua
Original file line number Diff line number Diff line change
@@ -1,23 +1,13 @@
---@class CopilotChat.copilot.embed
---@field content string
---@field filename string
---@field filetype string
---@field embedding table<number>

---@class CopilotChat.copilot.ask.opts
---@field selection CopilotChat.config.selection?
---@field embeddings table<CopilotChat.copilot.embed>?
---@field selection CopilotChat.select.selection?
---@field embeddings table<CopilotChat.context.embed>?
---@field system_prompt string?
---@field model string?
---@field agent string?
---@field temperature number?
---@field no_history boolean?
---@field on_progress nil|fun(response: string):nil

---@class CopilotChat.copilot.embed.opts
---@field model string?
---@field chunk_size number?

local log = require('plenary.log')
local prompts = require('CopilotChat.prompts')
local tiktoken = require('CopilotChat.tiktoken')
Expand Down Expand Up @@ -107,7 +97,7 @@ local function generate_line_numbers(content, start_line)
end

--- Generate messages for the given selection
--- @param selection CopilotChat.config.selection
--- @param selection CopilotChat.select.selection
local function generate_selection_messages(selection)
local filename = selection.filename or 'unknown'
local filetype = selection.filetype or 'text'
Expand Down Expand Up @@ -167,7 +157,7 @@ local function generate_selection_messages(selection)
end

--- Generate messages for the given embeddings
--- @param embeddings table<CopilotChat.copilot.embed>
--- @param embeddings table<CopilotChat.context.embed>
local function generate_embeddings_messages(embeddings)
local files = {}
for _, embedding in ipairs(embeddings) do
Expand Down Expand Up @@ -295,7 +285,7 @@ end

---@class CopilotChat.Copilot : Class
---@field history table
---@field embedding_cache table<CopilotChat.copilot.embed>
---@field embedding_cache table<CopilotChat.context.embed>
---@field policies table<string, boolean>
---@field models table<string, table>?
---@field agents table<string, table>?
Expand Down Expand Up @@ -863,8 +853,8 @@ function Copilot:list_agents()
end

--- Generate embeddings for the given inputs
---@param inputs table<CopilotChat.copilot.embed>: The inputs to embed
---@return table<CopilotChat.copilot.embed>
---@param inputs table<CopilotChat.context.embed>: The inputs to embed
---@return table<CopilotChat.context.embed>
function Copilot:embed(inputs)
if not inputs or #inputs == 0 then
return {}
Expand Down
4 changes: 2 additions & 2 deletions lua/CopilotChat/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ local state = {
}

---@param config CopilotChat.config.shared
---@return CopilotChat.config.selection?
---@return CopilotChat.select.selection?
local function get_selection(config)
local bufnr = state.source and state.source.bufnr
local winnr = state.source and state.source.winnr
Expand Down Expand Up @@ -231,7 +231,7 @@ end

---@param prompt string
---@param config CopilotChat.config.shared
---@return table<CopilotChat.copilot.embed>, string
---@return table<CopilotChat.context.embed>, string
local function resolve_embeddings(prompt, config)
local contexts = {}
local function parse_context(prompt_context)
Expand Down
Loading

0 comments on commit e0c4ca0

Please sign in to comment.