From dfb88466fd93ccd2556518eb2e085cf38d554697 Mon Sep 17 00:00:00 2001 From: Tomas Slusny Date: Wed, 20 Nov 2024 15:24:43 +0100 Subject: [PATCH] refactor: move utility functions to utils module Moves uuid, machine_id and quick_hash functions from copilot.lua to utils.lua module. Also improves type annotations and documentation for the moved functions. Removes unused table_equals function and renames blend_color_with_neovim_bg to blend_color for better clarity. --- lua/CopilotChat/copilot.lua | 58 +++++++++++------------------------ lua/CopilotChat/diff.lua | 18 ++--------- lua/CopilotChat/utils.lua | 61 ++++++++++++++++++++++--------------- 3 files changed, 57 insertions(+), 80 deletions(-) diff --git a/lua/CopilotChat/copilot.lua b/lua/CopilotChat/copilot.lua index e4858fb2..cf64849f 100644 --- a/lua/CopilotChat/copilot.lua +++ b/lua/CopilotChat/copilot.lua @@ -36,6 +36,8 @@ local tiktoken = require('CopilotChat.tiktoken') local utils = require('CopilotChat.utils') local class = utils.class local temp_file = utils.temp_file + +--- Constants local context_format = '[#file:%s](#file:%s-context)\n' local big_file_threshold = 2000 local timeout = 30000 @@ -81,31 +83,8 @@ local tiktoken_load = async.wrap(function(tokenizer, callback) tiktoken.load(tokenizer, callback) end, 2) -local function uuid() - local template = 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx' - return ( - string.gsub(template, '[xy]', function(c) - local v = (c == 'x') and math.random(0, 0xf) or math.random(8, 0xb) - return string.format('%x', v) - end) - ) -end - -local function machine_id() - local length = 65 - local hex_chars = '0123456789abcdef' - local hex = '' - for _ = 1, length do - local index = math.random(1, #hex_chars) - hex = hex .. hex_chars:sub(index, index) - end - return hex -end - -local function quick_hash(str) - return #str .. str:sub(1, 32) .. str:sub(-32) -end - +--- Get the github oauth cached token +---@return string|nil local function get_cached_token() -- loading token from the environment only in GitHub Codespaces local token = os.getenv('GITHUB_TOKEN') @@ -140,6 +119,10 @@ local function get_cached_token() return nil end +--- Generate line numbers for the given content +---@param content string: The content to generate line numbers for +---@param start_line number|nil: The starting line number +---@return string local function generate_line_numbers(content, start_line) local lines = vim.split(content, '\n') local truncated = false @@ -348,21 +331,13 @@ local function generate_embedding_request(inputs, model) } end -local function count_history_tokens(history) - local count = 0 - for _, msg in ipairs(history) do - count = count + tiktoken.count(msg.content) - end - return count -end - local Copilot = class(function(self, proxy, allow_insecure) self.history = {} self.embedding_cache = {} self.github_token = nil self.token = nil self.sessionid = nil - self.machineid = machine_id() + self.machineid = utils.machine_id() self.models = nil self.agents = nil self.claude_enabled = false @@ -404,7 +379,7 @@ function Copilot:authenticate() if not self.token or (self.token.expires_at and self.token.expires_at <= math.floor(os.time())) then - local sessionid = uuid() .. tostring(math.floor(os.time() * 1000)) + local sessionid = utils.uuid() .. tostring(math.floor(os.time() * 1000)) local headers = { ['authorization'] = 'token ' .. self.github_token, ['accept'] = 'application/json', @@ -434,7 +409,7 @@ function Copilot:authenticate() local headers = { ['authorization'] = 'Bearer ' .. self.token.token, - ['x-request-id'] = uuid(), + ['x-request-id'] = utils.uuid(), ['vscode-sessionid'] = self.sessionid, ['vscode-machineid'] = self.machineid, ['copilot-integration-id'] = 'vscode-chat', @@ -570,7 +545,7 @@ function Copilot:ask(prompt, opts) local temperature = opts.temperature or 0.1 local no_history = opts.no_history or false local on_progress = opts.on_progress - local job_id = uuid() + local job_id = utils.uuid() self.current_job = job_id log.trace('System prompt: ' .. system_prompt) @@ -622,7 +597,10 @@ function Copilot:ask(prompt, opts) -- Calculate how many tokens we can use for history local history_limit = max_tokens - required_tokens - reserved_tokens - local history_tokens = count_history_tokens(history) + local history_tokens = 0 + for _, msg in ipairs(history) do + history_tokens = history_tokens + tiktoken.count(msg.content) + end -- If we're over history limit, truncate history from the beginning while history_tokens > history_limit and #history > 0 do @@ -914,7 +892,7 @@ function Copilot:embed(inputs, opts) embed.filetype = embed.filetype or 'text' if embed.content then - local key = embed.filename .. quick_hash(embed.content) + local key = embed.filename .. utils.quick_hash(embed.content) if self.embedding_cache[key] then table.insert(cached_embeddings, self.embedding_cache[key]) else @@ -977,7 +955,7 @@ function Copilot:embed(inputs, opts) -- Cache embeddings for _, embedding in ipairs(out) do if embedding.content then - local key = embedding.filename .. quick_hash(embedding.content) + local key = embedding.filename .. utils.quick_hash(embedding.content) self.embedding_cache[key] = embedding end end diff --git a/lua/CopilotChat/diff.lua b/lua/CopilotChat/diff.lua index 6a39204e..f8ed47fe 100644 --- a/lua/CopilotChat/diff.lua +++ b/lua/CopilotChat/diff.lua @@ -20,21 +20,9 @@ local class = utils.class local Diff = class(function(self, help, on_buf_create) self.hl_ns = vim.api.nvim_create_namespace('copilot-chat-highlights') - vim.api.nvim_set_hl( - self.hl_ns, - '@diff.plus', - { bg = utils.blend_color_with_neovim_bg('DiffAdd', 20) } - ) - vim.api.nvim_set_hl( - self.hl_ns, - '@diff.minus', - { bg = utils.blend_color_with_neovim_bg('DiffDelete', 20) } - ) - vim.api.nvim_set_hl( - self.hl_ns, - '@diff.delta', - { bg = utils.blend_color_with_neovim_bg('DiffChange', 20) } - ) + vim.api.nvim_set_hl(self.hl_ns, '@diff.plus', { bg = utils.blend_color('DiffAdd', 20) }) + vim.api.nvim_set_hl(self.hl_ns, '@diff.minus', { bg = utils.blend_color('DiffDelete', 20) }) + vim.api.nvim_set_hl(self.hl_ns, '@diff.delta', { bg = utils.blend_color('DiffChange', 20) }) self.name = 'copilot-diff' self.help = help diff --git a/lua/CopilotChat/utils.lua b/lua/CopilotChat/utils.lua index e2379df0..b1e921de 100644 --- a/lua/CopilotChat/utils.lua +++ b/lua/CopilotChat/utils.lua @@ -69,32 +69,11 @@ function M.config_path() end end ---- Check if a table is equal to another table ----@param a table The first table ----@param b table The second table ----@return boolean -function M.table_equals(a, b) - if type(a) ~= type(b) then - return false - end - if type(a) ~= 'table' then - return a == b - end - for k, v in pairs(a) do - if not M.table_equals(v, b[k]) then - return false - end - end - for k, v in pairs(b) do - if not M.table_equals(v, a[k]) then - return false - end - end - return true -end - --- Blend a color with the neovim background -function M.blend_color_with_neovim_bg(color_name, blend) +---@param color_name string The color name +---@param blend number The blend percentage +---@return string? +function M.blend_color(color_name, blend) local color_int = vim.api.nvim_get_hl(0, { name = color_name }).fg local bg_int = vim.api.nvim_get_hl(0, { name = 'Normal' }).bg @@ -207,4 +186,36 @@ function M.filename_same(file1, file2) return vim.fn.fnamemodify(file1, ':p') == vim.fn.fnamemodify(file2, ':p') end +--- Generate a UUID +---@return string +function M.uuid() + local template = 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx' + return ( + string.gsub(template, '[xy]', function(c) + local v = (c == 'x') and math.random(0, 0xf) or math.random(8, 0xb) + return string.format('%x', v) + end) + ) +end + +--- Generate machine id +---@return string +function M.machine_id() + local length = 65 + local hex_chars = '0123456789abcdef' + local hex = '' + for _ = 1, length do + local index = math.random(1, #hex_chars) + hex = hex .. hex_chars:sub(index, index) + end + return hex +end + +--- Generate a quick hash +---@param str string The string to hash +---@return string +function M.quick_hash(str) + return #str .. str:sub(1, 32) .. str:sub(-32) +end + return M