Skip to content

Commit

Permalink
refactor: move utility functions to utils module
Browse files Browse the repository at this point in the history
Moves uuid, machine_id and quick_hash functions from copilot.lua to utils.lua
module. Also improves type annotations and documentation for the moved
functions. Removes unused table_equals function and renames
blend_color_with_neovim_bg to blend_color for better clarity.
  • Loading branch information
deathbeam committed Nov 20, 2024
1 parent f69ee54 commit dfb8846
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 80 deletions.
58 changes: 18 additions & 40 deletions lua/CopilotChat/copilot.lua
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ local tiktoken = require('CopilotChat.tiktoken')
local utils = require('CopilotChat.utils')
local class = utils.class
local temp_file = utils.temp_file

--- Constants
local context_format = '[#file:%s](#file:%s-context)\n'
local big_file_threshold = 2000
local timeout = 30000
Expand Down Expand Up @@ -81,31 +83,8 @@ local tiktoken_load = async.wrap(function(tokenizer, callback)
tiktoken.load(tokenizer, callback)
end, 2)

local function uuid()
local template = 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx'
return (
string.gsub(template, '[xy]', function(c)
local v = (c == 'x') and math.random(0, 0xf) or math.random(8, 0xb)
return string.format('%x', v)
end)
)
end

local function machine_id()
local length = 65
local hex_chars = '0123456789abcdef'
local hex = ''
for _ = 1, length do
local index = math.random(1, #hex_chars)
hex = hex .. hex_chars:sub(index, index)
end
return hex
end

local function quick_hash(str)
return #str .. str:sub(1, 32) .. str:sub(-32)
end

--- Get the github oauth cached token
---@return string|nil
local function get_cached_token()
-- loading token from the environment only in GitHub Codespaces
local token = os.getenv('GITHUB_TOKEN')
Expand Down Expand Up @@ -140,6 +119,10 @@ local function get_cached_token()
return nil
end

--- Generate line numbers for the given content
---@param content string: The content to generate line numbers for
---@param start_line number|nil: The starting line number
---@return string
local function generate_line_numbers(content, start_line)
local lines = vim.split(content, '\n')
local truncated = false
Expand Down Expand Up @@ -348,21 +331,13 @@ local function generate_embedding_request(inputs, model)
}
end

local function count_history_tokens(history)
local count = 0
for _, msg in ipairs(history) do
count = count + tiktoken.count(msg.content)
end
return count
end

local Copilot = class(function(self, proxy, allow_insecure)
self.history = {}
self.embedding_cache = {}
self.github_token = nil
self.token = nil
self.sessionid = nil
self.machineid = machine_id()
self.machineid = utils.machine_id()
self.models = nil
self.agents = nil
self.claude_enabled = false
Expand Down Expand Up @@ -404,7 +379,7 @@ function Copilot:authenticate()
if
not self.token or (self.token.expires_at and self.token.expires_at <= math.floor(os.time()))
then
local sessionid = uuid() .. tostring(math.floor(os.time() * 1000))
local sessionid = utils.uuid() .. tostring(math.floor(os.time() * 1000))
local headers = {
['authorization'] = 'token ' .. self.github_token,
['accept'] = 'application/json',
Expand Down Expand Up @@ -434,7 +409,7 @@ function Copilot:authenticate()

local headers = {
['authorization'] = 'Bearer ' .. self.token.token,
['x-request-id'] = uuid(),
['x-request-id'] = utils.uuid(),
['vscode-sessionid'] = self.sessionid,
['vscode-machineid'] = self.machineid,
['copilot-integration-id'] = 'vscode-chat',
Expand Down Expand Up @@ -570,7 +545,7 @@ function Copilot:ask(prompt, opts)
local temperature = opts.temperature or 0.1
local no_history = opts.no_history or false
local on_progress = opts.on_progress
local job_id = uuid()
local job_id = utils.uuid()
self.current_job = job_id

log.trace('System prompt: ' .. system_prompt)
Expand Down Expand Up @@ -622,7 +597,10 @@ function Copilot:ask(prompt, opts)

-- Calculate how many tokens we can use for history
local history_limit = max_tokens - required_tokens - reserved_tokens
local history_tokens = count_history_tokens(history)
local history_tokens = 0
for _, msg in ipairs(history) do
history_tokens = history_tokens + tiktoken.count(msg.content)
end

-- If we're over history limit, truncate history from the beginning
while history_tokens > history_limit and #history > 0 do
Expand Down Expand Up @@ -914,7 +892,7 @@ function Copilot:embed(inputs, opts)
embed.filetype = embed.filetype or 'text'

if embed.content then
local key = embed.filename .. quick_hash(embed.content)
local key = embed.filename .. utils.quick_hash(embed.content)
if self.embedding_cache[key] then
table.insert(cached_embeddings, self.embedding_cache[key])
else
Expand Down Expand Up @@ -977,7 +955,7 @@ function Copilot:embed(inputs, opts)
-- Cache embeddings
for _, embedding in ipairs(out) do
if embedding.content then
local key = embedding.filename .. quick_hash(embedding.content)
local key = embedding.filename .. utils.quick_hash(embedding.content)
self.embedding_cache[key] = embedding
end
end
Expand Down
18 changes: 3 additions & 15 deletions lua/CopilotChat/diff.lua
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,9 @@ local class = utils.class

local Diff = class(function(self, help, on_buf_create)
self.hl_ns = vim.api.nvim_create_namespace('copilot-chat-highlights')
vim.api.nvim_set_hl(
self.hl_ns,
'@diff.plus',
{ bg = utils.blend_color_with_neovim_bg('DiffAdd', 20) }
)
vim.api.nvim_set_hl(
self.hl_ns,
'@diff.minus',
{ bg = utils.blend_color_with_neovim_bg('DiffDelete', 20) }
)
vim.api.nvim_set_hl(
self.hl_ns,
'@diff.delta',
{ bg = utils.blend_color_with_neovim_bg('DiffChange', 20) }
)
vim.api.nvim_set_hl(self.hl_ns, '@diff.plus', { bg = utils.blend_color('DiffAdd', 20) })
vim.api.nvim_set_hl(self.hl_ns, '@diff.minus', { bg = utils.blend_color('DiffDelete', 20) })
vim.api.nvim_set_hl(self.hl_ns, '@diff.delta', { bg = utils.blend_color('DiffChange', 20) })

self.name = 'copilot-diff'
self.help = help
Expand Down
61 changes: 36 additions & 25 deletions lua/CopilotChat/utils.lua
Original file line number Diff line number Diff line change
Expand Up @@ -69,32 +69,11 @@ function M.config_path()
end
end

--- Check if a table is equal to another table
---@param a table The first table
---@param b table The second table
---@return boolean
function M.table_equals(a, b)
if type(a) ~= type(b) then
return false
end
if type(a) ~= 'table' then
return a == b
end
for k, v in pairs(a) do
if not M.table_equals(v, b[k]) then
return false
end
end
for k, v in pairs(b) do
if not M.table_equals(v, a[k]) then
return false
end
end
return true
end

--- Blend a color with the neovim background
function M.blend_color_with_neovim_bg(color_name, blend)
---@param color_name string The color name
---@param blend number The blend percentage
---@return string?
function M.blend_color(color_name, blend)
local color_int = vim.api.nvim_get_hl(0, { name = color_name }).fg
local bg_int = vim.api.nvim_get_hl(0, { name = 'Normal' }).bg

Expand Down Expand Up @@ -207,4 +186,36 @@ function M.filename_same(file1, file2)
return vim.fn.fnamemodify(file1, ':p') == vim.fn.fnamemodify(file2, ':p')
end

--- Generate a UUID
---@return string
function M.uuid()
local template = 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx'
return (
string.gsub(template, '[xy]', function(c)
local v = (c == 'x') and math.random(0, 0xf) or math.random(8, 0xb)
return string.format('%x', v)
end)
)
end

--- Generate machine id
---@return string
function M.machine_id()
local length = 65
local hex_chars = '0123456789abcdef'
local hex = ''
for _ = 1, length do
local index = math.random(1, #hex_chars)
hex = hex .. hex_chars:sub(index, index)
end
return hex
end

--- Generate a quick hash
---@param str string The string to hash
---@return string
function M.quick_hash(str)
return #str .. str:sub(1, 32) .. str:sub(-32)
end

return M

0 comments on commit dfb8846

Please sign in to comment.