From c98301c5cbba79b70beffa9699ce89c75691c093 Mon Sep 17 00:00:00 2001 From: Tomas Slusny Date: Fri, 15 Nov 2024 19:25:26 +0100 Subject: [PATCH] Add support for copilot extension agents https://docs.github.com/en/copilot/building-copilot-extensions/about-building-copilot-extensions - Change @buffer and @buffers to #buffer and #buffers - Add support for @agent agent selection - Add support for config.agent for specifying default agent - Add :CopilotChatAgents for listing agents (and showing selected agent) - Remove :CopilotChatModel, instead show which model is selected in :CopilotChatModels - Remove early errors from curl so we can actually get response body for the error - Add info to README about models, agents and contexts Closes #466 Signed-off-by: Tomas Slusny --- README.md | 42 ++++- lua/CopilotChat/config.lua | 10 +- lua/CopilotChat/copilot.lua | 116 +++++++++++++- lua/CopilotChat/init.lua | 221 +++++++++++++++++---------- lua/CopilotChat/integrations/cmp.lua | 36 ++--- 5 files changed, 308 insertions(+), 117 deletions(-) diff --git a/README.md b/README.md index 5bd899e4..c1b8e2aa 100644 --- a/README.md +++ b/README.md @@ -110,7 +110,7 @@ Verify "[Copilot chat in the IDE](https://github.com/settings/copilot)" is enabl - `:CopilotChatLoad ?` - Load chat history from file - `:CopilotChatDebugInfo` - Show debug information - `:CopilotChatModels` - View and select available models. This is reset when a new instance is made. Please set your model in `init.lua` for persistence. -- `:CopilotChatModel` - View the currently selected model. +- `:CopilotChatAgents` - View and select available agents. This is reset when a new instance is made. Please set your agent in `init.lua` for persistence. #### Commands coming from default prompts @@ -122,6 +122,39 @@ Verify "[Copilot chat in the IDE](https://github.com/settings/copilot)" is enabl - `:CopilotChatTests` - Please generate tests for my code - `:CopilotChatCommit` - Write commit message for the change with commitizen convention +### Models, Agents and Contexts + +#### Models + +You can list available models with `:CopilotChatModels` command. Model determines the AI model used for the chat. +Default models are: + +- `gpt-4o` - This is the default Copilot Chat model. It is a versatile, multimodal model that excels in both text and image processing and is designed to provide fast, reliable responses. It also has superior performance in non-English languages. Gpt-4o is hosted on Azure. +- `claude-3.5-sonnet` - This model excels at coding tasks across the entire software development lifecycle, from initial design to bug fixes, maintenance to optimizations. GitHub Copilot uses Claude 3.5 Sonnet hosted on Amazon Web Services. +- `o1-preview` - This model is focused on advanced reasoning and solving complex problems, in particular in math and science. It responds more slowly than the gpt-4o model. You can make 10 requests to this model per day. o1-preview is hosted on Azure. +- `o1-mini` - This is the faster version of the o1-preview model, balancing the use of complex reasoning with the need for faster responses. It is best suited for code generation and small context operations. You can make 50 requests to this model per day. o1-mini is hosted on Azure. + +For more information about models, see [here](https://docs.github.com/en/copilot/using-github-copilot/asking-github-copilot-questions-in-your-ide#ai-models-for-copilot-chat) +You can use more models from [here](https://github.com/marketplace/models) by using `@models` agent from [here](https://github.com/marketplace/models-github) (example: `@models Using Mistral-small, what is 1 + 11`) + +#### Agents + +Agents are used to determine the AI agent used for the chat. You can list available agents with `:CopilotChatAgents` command. +You can set the agent in the prompt by using `@` followed by the agent name. +Default "noop" agent is `copilot`. + +For more information about extension agents, see [here](https://docs.github.com/en/copilot/using-github-copilot/using-extensions-to-integrate-external-tools-with-copilot-chat) +You can install more agents from [here](https://github.com/marketplace?type=apps&copilot_app=true) + +#### Contexts + +Contexts are used to determine the context of the chat. +You can set the context in the prompt by using `#` followed by the context name. +Supported contexts are: + +- `buffers` - Includes all open buffers in chat context +- `buffer` - Includes only the current buffer in chat context + ### API ```lua @@ -202,8 +235,10 @@ Also see [here](/lua/CopilotChat/config.lua): allow_insecure = false, -- Allow insecure server connections system_prompt = prompts.COPILOT_INSTRUCTIONS, -- System prompt to use - model = 'gpt-4o', -- GPT model to use, see ':CopilotChatModels' for available models - temperature = 0.1, -- GPT temperature + model = 'gpt-4o', -- Default model to use, see ':CopilotChatModels' for available models + agent = 'copilot', -- Default agent to use, see ':CopilotChatAgents' for available agents (can be specified manually in prompt via @). + context = nil, -- Default context to use, 'buffers', 'buffer' or none (can be specified manually in prompt via #). + temperature = 0.1, -- GPT result temperature question_header = '## User ', -- Header to use for user questions answer_header = '## Copilot ', -- Header to use for AI answers @@ -218,7 +253,6 @@ Also see [here](/lua/CopilotChat/config.lua): clear_chat_on_new_prompt = false, -- Clears chat on every new prompt highlight_selection = true, -- Highlight selection in the source buffer when in the chat window - context = nil, -- Default context to use, 'buffers', 'buffer' or none (can be specified manually in prompt via @). history_path = vim.fn.stdpath('data') .. '/copilotchat_history', -- Default path to stored history callback = nil, -- Callback to use when ask response is received diff --git a/lua/CopilotChat/config.lua b/lua/CopilotChat/config.lua index aa17fc3b..7d736f10 100644 --- a/lua/CopilotChat/config.lua +++ b/lua/CopilotChat/config.lua @@ -69,6 +69,8 @@ local select = require('CopilotChat.select') ---@field allow_insecure boolean? ---@field system_prompt string? ---@field model string? +---@field agent string? +---@field context string? ---@field temperature number? ---@field question_header string? ---@field answer_header string? @@ -80,7 +82,6 @@ local select = require('CopilotChat.select') ---@field auto_insert_mode boolean? ---@field clear_chat_on_new_prompt boolean? ---@field highlight_selection boolean? ----@field context string? ---@field history_path string? ---@field callback fun(response: string, source: CopilotChat.config.source)? ---@field selection nil|fun(source: CopilotChat.config.source):CopilotChat.config.selection? @@ -94,8 +95,10 @@ return { allow_insecure = false, -- Allow insecure server connections system_prompt = prompts.COPILOT_INSTRUCTIONS, -- System prompt to use - model = 'gpt-4o', -- GPT model to use, see ':CopilotChatModels' for available models - temperature = 0.1, -- GPT temperature + model = 'gpt-4o', -- Default model to use, see ':CopilotChatModels' for available models + agent = 'copilot', -- Default agent to use, see ':CopilotChatAgents' for available agents (can be specified manually in prompt via @). + context = nil, -- Default context to use, 'buffers', 'buffer' or none (can be specified manually in prompt via #). + temperature = 0.1, -- GPT result temperature question_header = '## User ', -- Header to use for user questions answer_header = '## Copilot ', -- Header to use for AI answers @@ -110,7 +113,6 @@ return { clear_chat_on_new_prompt = false, -- Clears chat on every new prompt highlight_selection = true, -- Highlight selection - context = nil, -- Default context to use, 'buffers', 'buffer' or none (can be specified manually in prompt via @). history_path = vim.fn.stdpath('data') .. '/copilotchat_history', -- Default path to stored history callback = nil, -- Callback to use when ask response is received diff --git a/lua/CopilotChat/copilot.lua b/lua/CopilotChat/copilot.lua index 3153604f..733d3119 100644 --- a/lua/CopilotChat/copilot.lua +++ b/lua/CopilotChat/copilot.lua @@ -13,6 +13,7 @@ ---@field end_row number? ---@field system_prompt string? ---@field model string? +---@field agent string? ---@field temperature number? ---@field on_progress nil|fun(response: string):nil @@ -29,6 +30,7 @@ ---@field load fun(self: CopilotChat.Copilot, name: string, path: string):table ---@field running fun(self: CopilotChat.Copilot):boolean ---@field list_models fun(self: CopilotChat.Copilot):table +---@field list_agents fun(self: CopilotChat.Copilot):table local async = require('plenary.async') local log = require('plenary.log') @@ -340,6 +342,7 @@ local Copilot = class(function(self, proxy, allow_insecure) self.sessionid = nil self.machineid = machine_id() self.models = nil + self.agents = nil self.claude_enabled = false self.current_job = nil self.request_args = { @@ -362,9 +365,6 @@ local Copilot = class(function(self, proxy, allow_insecure) '--no-keepalive', -- Don't reuse connections '--tcp-nodelay', -- Disable Nagle's algorithm for faster streaming '--no-buffer', -- Disable output buffering for streaming - '--fail', -- Return error on HTTP errors (4xx, 5xx) - '--silent', -- Don't show progress meter - '--show-error', -- Show errors even when silent }, } end) @@ -461,6 +461,39 @@ function Copilot:fetch_models() return out end +function Copilot:fetch_agents() + if self.agents then + return self.agents + end + + local response, err = curl_get( + 'https://api.githubcopilot.com/agents', + vim.tbl_extend('force', self.request_args, { + headers = self:authenticate(), + }) + ) + + if err then + error(err) + end + + if response.status ~= 200 then + error('Failed to fetch agents: ' .. tostring(response.status)) + end + + local agents = vim.json.decode(response.body)['agents'] + local out = {} + for _, agent in ipairs(agents) do + out[agent['slug']] = agent + end + + out['copilot'] = { name = 'Copilot', default = true } + + log.info('Agents fetched') + self.agents = out + return out +end + function Copilot:enable_claude() if self.claude_enabled then return true @@ -510,6 +543,7 @@ function Copilot:ask(prompt, opts) local selection = opts.selection or {} local system_prompt = opts.system_prompt or prompts.COPILOT_INSTRUCTIONS local model = opts.model or 'gpt-4o-2024-05-13' + local agent = opts.agent or 'copilot' local temperature = opts.temperature or 0.1 local on_progress = opts.on_progress local job_id = uuid() @@ -522,10 +556,21 @@ function Copilot:ask(prompt, opts) log.debug('Filename: ' .. filename) log.debug('Filetype: ' .. filetype) log.debug('Model: ' .. model) + log.debug('Agent: ' .. agent) log.debug('Temperature: ' .. temperature) local models = self:fetch_models() - local capabilities = models[model] and models[model].capabilities + local agents = self:fetch_agents() + local agent_config = agents[agent] + if not agent_config then + error('Agent not found: ' .. agent) + end + local model_config = models[model] + if not model_config then + error('Model not found: ' .. model) + end + + local capabilities = model_config.capabilities local max_tokens = capabilities.limits.max_prompt_tokens -- FIXME: Is max_prompt_tokens the right limit? local max_output_tokens = capabilities.limits.max_output_tokens local tokenizer = capabilities.tokenizer @@ -582,6 +627,7 @@ function Copilot:ask(prompt, opts) local errored = false local finished = false local full_response = '' + local full_references = '' local function finish_stream(err, job) if err then @@ -631,6 +677,22 @@ function Copilot:ask(prompt, opts) return end + if content.copilot_references then + for _, reference in ipairs(content.copilot_references) do + local metadata = reference.metadata + if metadata and metadata.display_name and metadata.display_url then + full_references = full_references + .. '\n' + .. '[' + .. metadata.display_name + .. ']' + .. '(' + .. metadata.display_url + .. ')' + end + end + end + if not content.choices or #content.choices == 0 then return end @@ -668,8 +730,13 @@ function Copilot:ask(prompt, opts) self:enable_claude() end + local url = 'https://api.githubcopilot.com/chat/completions' + if not agent_config.default then + url = 'https://api.githubcopilot.com/agents/' .. agent .. '?chat' + end + local response, err = curl_post( - 'https://api.githubcopilot.com/chat/completions', + url, vim.tbl_extend('force', self.request_args, { headers = self:authenticate(), body = temp_file(body), @@ -694,6 +761,25 @@ function Copilot:ask(prompt, opts) end if response.status ~= 200 then + if response.status == 401 then + local ok, content = pcall(vim.json.decode, response.body, { + luanil = { + object = true, + array = true, + }, + }) + + if ok and content.authorize_url then + error( + 'Failed to authenticate. Visit following url to authorize ' + .. content.slug + .. ':\n' + .. content.authorize_url + ) + return + end + end + error('Failed to get response: ' .. tostring(response.status) .. '\n' .. response.body) return end @@ -708,6 +794,14 @@ function Copilot:ask(prompt, opts) return end + if full_references ~= '' then + full_references = '\n\n**`References:`**' .. full_references + full_response = full_response .. full_references + if on_progress then + on_progress(full_references) + end + end + log.trace('Full response: ' .. full_response) log.debug('Last message: ' .. vim.inspect(last_message)) @@ -727,10 +821,10 @@ function Copilot:ask(prompt, opts) end --- List available models +---@return table function Copilot:list_models() local models = self:fetch_models() - -- Group models by version and shortest ID local version_map = {} for id, model in pairs(models) do local version = model.version @@ -739,10 +833,18 @@ function Copilot:list_models() end end - -- Map to IDs and sort local result = vim.tbl_values(version_map) table.sort(result) + return result +end +--- List available agents +---@return table +function Copilot:list_agents() + local agents = self:fetch_agents() + + local result = vim.tbl_keys(agents) + table.sort(result) return result end diff --git a/lua/CopilotChat/init.lua b/lua/CopilotChat/init.lua index 7c3c4cf2..ba381b39 100644 --- a/lua/CopilotChat/init.lua +++ b/lua/CopilotChat/init.lua @@ -60,30 +60,13 @@ local function blend_color_with_neovim_bg(color_name, blend) return string.format('#%02x%02x%02x', r, g, b) end -local function dedupe_strings(str) - if not str then - return str - end - local seen = {} - local result = {} - for s in str:gmatch('[^%s,]+') do - if not seen[s] then - seen[s] = true - table.insert(result, s) - end - end - return table.concat(result, ' ') -end - local function get_error_message(err) if type(err) == 'string' then - -- Match first occurrence of :something: and capture rest local message = err:match('^[^:]+:[^:]+:(.+)') or err - -- Trim whitespace - message = message:match('^%s*(.-)%s*$') - return dedupe_strings(message) + message = message:gsub('^%s*', '') + return message end - return dedupe_strings(vim.inspect(err)) + return vim.inspect(err) end local function find_lines_between_separator( @@ -182,58 +165,6 @@ local function append(str, config) end end -local function complete() - local line = vim.api.nvim_get_current_line() - local col = vim.api.nvim_win_get_cursor(0)[2] - if col == 0 or #line == 0 then - return - end - - local prefix, cmp_start = unpack(vim.fn.matchstrpos(line:sub(1, col), [[\(/\|@\)\k*$]])) - if not prefix then - return - end - - local items = {} - local prompts_to_use = M.prompts() - - for name, prompt in pairs(prompts_to_use) do - items[#items + 1] = { - word = '/' .. name, - kind = prompt.kind, - info = prompt.prompt, - menu = prompt.description or '', - icase = 1, - dup = 0, - empty = 0, - } - end - - items[#items + 1] = { - word = '@buffers', - kind = 'context', - menu = 'Use all loaded buffers as context', - icase = 1, - dup = 0, - empty = 0, - } - - items[#items + 1] = { - word = '@buffer', - kind = 'context', - menu = 'Use the current buffer as context', - icase = 1, - dup = 0, - empty = 0, - } - - items = vim.tbl_filter(function(item) - return vim.startswith(item.word:lower(), prefix:lower()) - end, items) - - vim.fn.complete(cmp_start + 1, items) -end - local function get_selection() local bufnr = state.source.bufnr local winnr = state.source.winnr @@ -302,6 +233,70 @@ local function key_to_info(name, key, surround) return out end +--- Get the completion info for the chat window, for use with custom completion providers +---@return table +function M.complete_info() + return { + triggers = { '@', '/', '#' }, + pattern = [[\%(@\|/\|#\)\k*]], + } +end + +--- Get the completion items for the chat window, for use with custom completion providers +---@param callback function(table) +function M.complete_items(callback) + async.run(function() + local agents = state.copilot:list_agents() + local items = {} + local prompts_to_use = M.prompts() + + for name, prompt in pairs(prompts_to_use) do + items[#items + 1] = { + word = '/' .. name, + kind = prompt.kind, + info = prompt.prompt, + menu = prompt.description or '', + icase = 1, + dup = 0, + empty = 0, + } + end + + for _, agent in ipairs(agents) do + items[#items + 1] = { + word = '@' .. agent, + kind = 'agent', + menu = 'Use the specified agent', + icase = 1, + dup = 0, + empty = 0, + } + end + + items[#items + 1] = { + word = '#buffers', + kind = 'context', + menu = 'Include all loaded buffers in context', + icase = 1, + dup = 0, + empty = 0, + } + + items[#items + 1] = { + word = '#buffer', + kind = 'context', + menu = 'Include the specified buffer in context', + icase = 1, + dup = 0, + empty = 0, + } + + vim.schedule(function() + callback(items) + end) + end) +end + --- Get the prompts to use. ---@param skip_system boolean|nil ---@return table @@ -408,13 +403,44 @@ end function M.select_model() async.run(function() local models = state.copilot:list_models() + models = vim.tbl_map(function(model) + if model == M.config.model then + return model .. ' (selected)' + end + + return model + end, models) vim.schedule(function() vim.ui.select(models, { prompt = 'Select a model', }, function(choice) if choice then - M.config.model = choice + M.config.model = choice:gsub(' %(selected%)', '') + end + end) + end) + end) +end + +--- Select a Copilot agent. +function M.select_agent() + async.run(function() + local agents = state.copilot:list_agents() + agents = vim.tbl_map(function(agent) + if agent == M.config.agent then + return agent .. ' (selected)' + end + + return agent + end, agents) + + vim.schedule(function() + vim.ui.select(agents, { + prompt = 'Select an agent', + }, function(choice) + if choice then + M.config.agent = choice:gsub(' %(selected%)', '') end end) end) @@ -473,14 +499,24 @@ function M.ask(prompt, config, source) append('\n\n' .. config.answer_header .. config.separator .. '\n\n', config) local selected_context = config.context - if string.find(prompt, '@buffers') then + if string.find(prompt, '#buffers') then selected_context = 'buffers' - elseif string.find(prompt, '@buffer') then + elseif string.find(prompt, '#buffer') then selected_context = 'buffer' end - updated_prompt = string.gsub(updated_prompt, '@buffers?%s*', '') + updated_prompt = string.gsub(updated_prompt, '#buffers?%s*', '') async.run(function() + local agents = state.copilot:list_agents() + local current_agent = config.agent + + for agent in updated_prompt:gmatch('@([%w_-]+)') do + if vim.tbl_contains(agents, agent) then + current_agent = agent + updated_prompt = updated_prompt:gsub('@' .. agent .. '%s*', '') + end + end + local query_ok, embeddings = pcall(context.find_for_query, state.copilot, { context = selected_context, prompt = updated_prompt, @@ -503,6 +539,7 @@ function M.ask(prompt, config, source) filetype = filetype, system_prompt = system_prompt, model = config.model, + agent = current_agent, temperature = config.temperature, on_progress = function(token) vim.schedule(function() @@ -808,10 +845,32 @@ function M.setup(config) state.help:show(chat_help, 'markdown', 'markdown', state.chat.winnr) end) - map_key(M.config.mappings.complete, bufnr, complete) map_key(M.config.mappings.reset, bufnr, M.reset) map_key(M.config.mappings.close, bufnr, M.close) + map_key(M.config.mappings.complete, bufnr, function() + local info = M.complete_info() + local line = vim.api.nvim_get_current_line() + local col = vim.api.nvim_win_get_cursor(0)[2] + if col == 0 or #line == 0 then + return + end + + local prefix, cmp_start = unpack(vim.fn.matchstrpos(line:sub(1, col), info.pattern)) + if not prefix then + return + end + + M.complete_items(function(items) + vim.fn.complete( + cmp_start + 1, + vim.tbl_filter(function(item) + return vim.startswith(item.word:lower(), prefix:lower()) + end, items) + ) + end) + end) + map_key(M.config.mappings.submit_prompt, bufnr, function() local chat_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) local lines, start_line, end_line = @@ -983,14 +1042,12 @@ function M.setup(config) range = true, }) - vim.api.nvim_create_user_command('CopilotChatModel', function() - vim.notify('Using model: ' .. M.config.model, vim.log.levels.INFO) - end, { force = true }) - vim.api.nvim_create_user_command('CopilotChatModels', function() M.select_model() end, { force = true }) - + vim.api.nvim_create_user_command('CopilotChatAgents', function() + M.select_agent() + end, { force = true }) vim.api.nvim_create_user_command('CopilotChatOpen', function() M.open() end, { force = true }) diff --git a/lua/CopilotChat/integrations/cmp.lua b/lua/CopilotChat/integrations/cmp.lua index b3027673..47928762 100644 --- a/lua/CopilotChat/integrations/cmp.lua +++ b/lua/CopilotChat/integrations/cmp.lua @@ -4,34 +4,30 @@ local chat = require('CopilotChat') local Source = {} function Source:get_trigger_characters() - return { '@', '/' } + return chat.complete_info().triggers end function Source:get_keyword_pattern() - return [[\%(@\|/\)\k*]] + return chat.complete_info().pattern end function Source:complete(params, callback) - local items = {} - local prompts_to_use = chat.prompts() - - local prefix = string.lower(params.context.cursor_before_line:sub(params.offset)) - local prefix_len = #prefix - local checkAdd = function(word) - if word:lower():sub(1, prefix_len) == prefix then - items[#items + 1] = { - label = word, + chat.complete_items(function(items) + items = vim.tbl_map(function(item) + return { + label = item.word, kind = cmp.lsp.CompletionItemKind.Keyword, } - end - end - for name, _ in pairs(prompts_to_use) do - checkAdd('/' .. name) - end - checkAdd('@buffers') - checkAdd('@buffer') - - callback({ items = items }) + end, items) + + local prefix = string.lower(params.context.cursor_before_line:sub(params.offset)) + + callback({ + items = vim.tbl_filter(function(item) + return vim.startswith(item.label:lower(), prefix:lower()) + end, items), + }) + end) end ---@param completion_item lsp.CompletionItem