Skip to content

Commit

Permalink
refactor: improve embedding stability and token handling
Browse files Browse the repository at this point in the history
This commit improves the stability and efficiency of the embedding system:
- Add token count validation before processing embeddings
- Implement smarter batching based on token limits (8191 max)
- Reduce line threshold for big embeds from 600 to 500
- Improve error handling and logging for embedding failures
- Optimize tiktoken loading and caching
  • Loading branch information
deathbeam committed Nov 28, 2024
1 parent dee0090 commit 8584c53
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 75 deletions.
150 changes: 81 additions & 69 deletions lua/CopilotChat/copilot.lua
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ local temp_file = utils.temp_file
--- Constants
local CONTEXT_FORMAT = '[#file:%s](#file:%s-context)'
local BIG_FILE_THRESHOLD = 2000
local BIG_EMBED_THRESHOLD = 600
local BIG_EMBED_THRESHOLD = 500
local EMBED_MODEL = 'text-embedding-3-small'
local EMBED_MAX_TOKENS = 8191
local EMBED_TOKENIZER = 'cl100k_base'
local TRUNCATED = '... (truncated)'
local TIMEOUT = 30000
local VERSION_HEADERS = {
Expand Down Expand Up @@ -266,23 +269,30 @@ local function generate_ask_request(
return out
end

local function generate_embeddings_request_message(embedding)
local lines = vim.split(embedding.content, '\n')
if #lines > BIG_EMBED_THRESHOLD then
lines = vim.list_slice(lines, 1, BIG_EMBED_THRESHOLD)
table.insert(lines, TRUNCATED)
end
local content = table.concat(lines, '\n')

if embedding.filetype == 'raw' then
return content
else
return string.format(
'File: `%s`\n```%s\n%s\n```',
embedding.filename,
embedding.filetype,
content
)
end
end

local function generate_embedding_request(inputs, model)
return {
dimensions = 512,
input = vim.tbl_map(function(input)
local lines = vim.split(input.content, '\n')
if #lines > BIG_EMBED_THRESHOLD then
lines = vim.list_slice(lines, 1, BIG_EMBED_THRESHOLD)
table.insert(lines, TRUNCATED)
end
local content = table.concat(lines, '\n')

if input.filetype == 'raw' then
return content
else
return string.format('File: `%s`\n```%s\n%s\n```', input.filename, input.filetype, content)
end
end, inputs),
input = vim.tbl_map(generate_embeddings_request_message, inputs),
model = model,
}
end
Expand Down Expand Up @@ -856,41 +866,65 @@ end

--- Generate embeddings for the given inputs
---@param inputs table<CopilotChat.copilot.embed>: The inputs to embed
---@param opts CopilotChat.copilot.embed.opts?: Options for the request
---@return table<CopilotChat.copilot.embed>
function Copilot:embed(inputs, opts)
if not inputs or #inputs == 0 then
return {}
end

-- Check which embeddings need to be fetched
local cached_embeddings = {}
local uncached_embeddings = {}
for _, embed in ipairs(inputs) do
embed.filename = embed.filename or 'unknown'
embed.filetype = embed.filetype or 'text'

if embed.content then
local key = embed.filename .. utils.quick_hash(embed.content)
if self.embedding_cache[key] then
table.insert(cached_embeddings, self.embedding_cache[key])
-- Initialize essentials
local model = EMBED_MODEL
tiktoken.load(EMBED_TOKENIZER)
local to_process = {}
local results = {}

-- Process each input, using cache when possible
for _, input in ipairs(inputs) do
input.filename = input.filename or 'unknown'
input.filetype = input.filetype or 'text'

if input.content then
local cache_key = input.filename .. utils.quick_hash(input.content)
if self.embedding_cache[cache_key] then
table.insert(results, self.embedding_cache[cache_key])
else
table.insert(uncached_embeddings, embed)
local message = generate_embeddings_request_message(input)
local tokens = tiktoken.count(message)

if tokens <= EMBED_MAX_TOKENS then
input.tokens = tokens
table.insert(to_process, input)
else
log.warn(
string.format(
'Embedding for %s exceeds token limit (%d > %d), skipping',
input.filename,
tokens,
EMBED_MAX_TOKENS
)
)
end
end
else
table.insert(uncached_embeddings, embed)
end
end

opts = opts or {}
local model = opts.model or 'text-embedding-3-small'
local chunk_size = opts.chunk_size or 15
-- Process inputs in batches
while #to_process > 0 do
local batch = {}
local batch_tokens = 0

local out = {}
-- Build batch within token limit
while #to_process > 0 do
local next_input = to_process[1]
if batch_tokens + next_input.tokens > EMBED_MAX_TOKENS then
break
end
table.insert(batch, table.remove(to_process, 1))
batch_tokens = batch_tokens + next_input.tokens
end

for i = 1, #uncached_embeddings, chunk_size do
local chunk = vim.list_slice(uncached_embeddings, i, i + chunk_size - 1)
local body = vim.json.encode(generate_embedding_request(chunk, model))
-- Get embeddings for batch
local body = vim.json.encode(generate_embedding_request(batch, model))
local response, err = utils.curl_post(
'https://api.githubcopilot.com/embeddings',
vim.tbl_extend('force', self.request_args, {
Expand All @@ -899,48 +933,26 @@ function Copilot:embed(inputs, opts)
})
)

if err then
error(err)
return {}
end

if not response then
error('Failed to get response')
return {}
end

if response.status ~= 200 then
error('Failed to get response: ' .. tostring(response.status) .. '\n' .. response.body)
return {}
if err or not response or response.status ~= 200 then
error(err or ('Failed to get embeddings: ' .. (response and response.body or 'no response')))
end

local ok, content = pcall(vim.json.decode, response.body, {
luanil = {
object = true,
array = true,
},
})

local ok, content = pcall(vim.json.decode, response.body)
if not ok then
error('Failed to parse response: ' .. vim.inspect(content) .. '\n' .. response.body)
return {}
error('Failed to parse embedding response: ' .. response.body)
end

-- Process and cache results
for _, embedding in ipairs(content.data) do
table.insert(out, vim.tbl_extend('keep', chunk[embedding.index + 1], embedding))
end
end
local result = vim.tbl_extend('keep', batch[embedding.index + 1], embedding)
table.insert(results, result)

-- Cache embeddings
for _, embedding in ipairs(out) do
if embedding.content then
local key = embedding.filename .. utils.quick_hash(embedding.content)
self.embedding_cache[key] = embedding
local cache_key = result.filename .. utils.quick_hash(result.content)
self.embedding_cache[cache_key] = result
end
end

-- Merge cached embeddings and newly fetched embeddings and return
return vim.list_extend(out, cached_embeddings)
return results
end

--- Stop the running job
Expand Down
11 changes: 5 additions & 6 deletions lua/CopilotChat/tiktoken.lua
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
local async = require('plenary.async')
local log = require('plenary.log')
local utils = require('CopilotChat.utils')
local tiktoken_core = nil
local _, tiktoken_core = pcall(require, 'tiktoken_core')
local current_tokenizer = nil
local cache_dir = vim.fn.stdpath('cache')
vim.fn.mkdir(tostring(cache_dir), 'p')
Expand Down Expand Up @@ -31,16 +31,16 @@ local M = {}
--- Load the tiktoken module
---@param tokenizer string The tokenizer to load
M.load = function(tokenizer)
if tokenizer == current_tokenizer then
if not tiktoken_core then
return
end

local ok, core = pcall(require, 'tiktoken_core')
if not ok then
if tokenizer == current_tokenizer then
return
end

local path = load_tiktoken_data(tokenizer)
async.util.scheduler()
local special_tokens = {}
special_tokens['<|endoftext|>'] = 100257
special_tokens['<|fim_prefix|>'] = 100258
Expand All @@ -49,8 +49,7 @@ M.load = function(tokenizer)
special_tokens['<|endofprompt|>'] = 100276
local pat_str =
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
core.new(path, special_tokens, pat_str)
tiktoken_core = core
tiktoken_core.new(path, special_tokens, pat_str)
current_tokenizer = tokenizer
end

Expand Down

0 comments on commit 8584c53

Please sign in to comment.