Skip to content

Commit

Permalink
refactor: move embed truncation logic to proper places
Browse files Browse the repository at this point in the history
This commit reorganizes the code to handle truncation of large files and
embeddings in a more consistent way. The changes include:

- Moving ordered_map implementation to top of utils.lua
- Centralizing truncation logic for embeddings with BIG_EMBED_THRESHOLD
- Simplifying outline truncation by moving it to embedding generation
- Updating class type definitions to be more consistent
- Removing duplicate code related to truncation handlers

The main goal is to make the codebase more maintainable by having
truncation logic in appropriate locations rather than scattered across
different files.

Signed-off-by: Tomas Slusny <[email protected]>
  • Loading branch information
deathbeam committed Nov 24, 2024
1 parent 273f43a commit ac7edc4
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 75 deletions.
14 changes: 1 addition & 13 deletions lua/CopilotChat/context.lua
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ local OFF_SIDE_RULE_LANGUAGES = {
'fsharp',
}

local OUTLINE_THRESHOLD = 600
local MULTI_FILE_THRESHOLD = 3

local function spatial_distance_cosine(a, b)
Expand Down Expand Up @@ -339,19 +338,8 @@ function M.filter_embeddings(copilot, prompt, embeddings)
-- Map embeddings by filename
for _, embed in ipairs(embeddings) do
original_map:set(embed.filename, embed)

if embed.filetype ~= 'raw' then
local outline = M.outline(embed.content, embed.filename, embed.filetype)
local outline_lines = vim.split(outline.content, '\n')

-- If outline is too big, truncate it
if #outline_lines > 0 and #outline_lines > OUTLINE_THRESHOLD then
outline_lines = vim.list_slice(outline_lines, 1, OUTLINE_THRESHOLD)
table.insert(outline_lines, '... (truncated)')
end

outline.content = table.concat(outline_lines, '\n')
embedded_map:set(embed.filename, outline)
embedded_map:set(embed.filename, M.outline(embed.content, embed.filename, embed.filetype))
else
embedded_map:set(embed.filename, embed)
end
Expand Down
35 changes: 14 additions & 21 deletions lua/CopilotChat/copilot.lua
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ local temp_file = utils.temp_file
--- Constants
local CONTEXT_FORMAT = '[#file:%s](#file:%s-context)'
local BIG_FILE_THRESHOLD = 2000
local BIG_EMBED_THRESHOLD = 600
local TRUNCATED = '... (truncated)'
local TIMEOUT = 30000
local VERSION_HEADERS = {
['editor-version'] = 'Neovim/'
Expand Down Expand Up @@ -114,12 +116,9 @@ end
---@return string
local function generate_line_numbers(content, start_line)
local lines = vim.split(content, '\n')
local truncated = false

-- If the file is too big, truncate it
if #lines > BIG_FILE_THRESHOLD then
lines = vim.list_slice(lines, 1, BIG_FILE_THRESHOLD)
truncated = true
table.insert(lines, TRUNCATED)
end

local total_lines = #lines
Expand All @@ -129,10 +128,6 @@ local function generate_line_numbers(content, start_line)
lines[i] = formatted_line_number .. ': ' .. line
end

if truncated then
table.insert(lines, '... (truncated)')
end

content = table.concat(lines, '\n')
return content
end
Expand Down Expand Up @@ -301,26 +296,24 @@ local function generate_embedding_request(inputs, model)
return {
dimensions = 512,
input = vim.tbl_map(function(input)
local out = ''
local lines = vim.split(input.content, '\n')
if #lines > BIG_EMBED_THRESHOLD then
lines = vim.list_slice(lines, 1, BIG_EMBED_THRESHOLD)
table.insert(lines, TRUNCATED)
end
local content = table.concat(lines, '\n')

if input.filetype == 'raw' then
out = input.content .. '\n'
return content
else
out = out
.. string.format(
'File: `%s`\n```%s\n%s\n```',
input.filename,
input.filetype,
input.content
)
return string.format('File: `%s`\n```%s\n%s\n```', input.filename, input.filetype, content)
end

return out
end, inputs),
model = model,
}
end

---@class CopilotChat.Copilot : CopilotChat.utils.Class
---@class CopilotChat.Copilot : Class
---@field history table
---@field embedding_cache table<CopilotChat.copilot.embed>
---@field policies table<string, boolean>
Expand Down Expand Up @@ -873,7 +866,7 @@ end

--- Generate embeddings for the given inputs
---@param inputs table<CopilotChat.copilot.embed>: The inputs to embed
---@param opts CopilotChat.copilot.embed.opts: Options for the request
---@param opts CopilotChat.copilot.embed.opts?: Options for the request
---@return table<CopilotChat.copilot.embed>
function Copilot:embed(inputs, opts)
if not inputs or #inputs == 0 then
Expand Down
2 changes: 1 addition & 1 deletion lua/CopilotChat/ui/overlay.lua
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
local utils = require('CopilotChat.utils')
local class = utils.class

---@class CopilotChat.ui.Overlay : CopilotChat.utils.Class
---@class CopilotChat.ui.Overlay : Class
---@field name string
---@field help string
---@field help_ns number
Expand Down
2 changes: 1 addition & 1 deletion lua/CopilotChat/ui/spinner.lua
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ local spinner_frames = {
'',
}

---@class CopilotChat.ui.Spinner : CopilotChat.utils.Class
---@class CopilotChat.ui.Spinner : Class
---@field ns number
---@field bufnr number
---@field timer table
Expand Down
78 changes: 39 additions & 39 deletions lua/CopilotChat/utils.lua
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
local M = {}
M.timers = {}

---@class CopilotChat.utils.Class
---@class Class
---@field new fun(...):table
---@field init fun(self, ...)

--- Create class
---@param fn function The class constructor
---@param parent table? The parent class
---@return CopilotChat.utils.Class
---@return Class
function M.class(fn, parent)
local out = {}
out.__index = out
Expand Down Expand Up @@ -38,6 +38,43 @@ function M.class(fn, parent)
return out
end

---@class OrderedMap
---@field set fun(self:OrderedMap, key:any, value:any)
---@field get fun(self:OrderedMap, key:any):any
---@field keys fun(self:OrderedMap):table
---@field values fun(self:OrderedMap):table

--- Create an ordered map
---@return OrderedMap
function M.ordered_map()
return {
_keys = {},
_data = {},
set = function(self, key, value)
if not self._data[key] then
table.insert(self._keys, key)
end
self._data[key] = value
end,

get = function(self, key)
return self._data[key]
end,

keys = function(self)
return self._keys
end,

values = function(self)
local result = {}
for _, key in ipairs(self._keys) do
table.insert(result, self._data[key])
end
return result
end,
}
end

--- Check if the current version of neovim is stable
---@return boolean
function M.is_stable()
Expand Down Expand Up @@ -187,41 +224,4 @@ function M.win_cwd(winnr)
return dir
end

---@class OrderedMap
---@field set fun(self:OrderedMap, key:any, value:any)
---@field get fun(self:OrderedMap, key:any):any
---@field keys fun(self:OrderedMap):table
---@field values fun(self:OrderedMap):table

--- Create an ordered map
---@return OrderedMap
function M.ordered_map()
return {
_keys = {},
_data = {},
set = function(self, key, value)
if not self._data[key] then
table.insert(self._keys, key)
end
self._data[key] = value
end,

get = function(self, key)
return self._data[key]
end,

keys = function(self)
return self._keys
end,

values = function(self)
local result = {}
for _, key in ipairs(self._keys) do
table.insert(result, self._data[key])
end
return result
end,
}
end

return M

0 comments on commit ac7edc4

Please sign in to comment.