Skip to content

Commit

Permalink
fix(ai-proxy): fix gemini streaming; fix gemini analytics
Browse files Browse the repository at this point in the history
  • Loading branch information
tysoekong committed Jul 25, 2024
1 parent 1e2a6f0 commit f499d3f
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 100 deletions.
5 changes: 1 addition & 4 deletions kong/llm/drivers/bedrock.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@ local fmt = string.format
local ai_shared = require("kong.llm.drivers.shared")
local socket_url = require("socket.url")
local string_gsub = string.gsub
local buffer = require("string.buffer")
local table_insert = table.insert
local string_lower = string.lower
local string_sub = string.sub
local signer = require("resty.aws.request.sign")
--

Expand Down Expand Up @@ -122,8 +120,7 @@ local function handle_stream_event(event_t, model_info, route_type)

new_event = "[DONE]"

elseif event_type == "contentBlockStop" then
-- placeholder - I don't think this does anything yet
-- "contentBlockStop" is absent because it is not used for anything here
end

if new_event then
Expand Down
59 changes: 25 additions & 34 deletions kong/llm/drivers/gemini.lua
Original file line number Diff line number Diff line change
Expand Up @@ -41,30 +41,32 @@ local function is_response_content(content)
and content.candidates[1].content.parts[1].text
end

local function is_response_finished(content)
return content
and content.candidates
and #content.candidates > 0
and content.candidates[1].finishReason
end

local function handle_stream_event(event_t, model_info, route_type)
local function handle_stream_event(event_t, model_info, route_type)
-- discard empty frames, it should either be a random new line, or comment
if (not event_t.data) or (#event_t.data < 1) then
return
end


if event_t.data == "[DONE]" then
return "[DONE]", nil, nil
end

local event, err = cjson.decode(event_t.data)
if err then
ngx.log(ngx.WARN, "failed to decode stream event frame from gemini: " .. err)
return nil, "failed to decode stream event frame from gemini", nil
end

local new_event
local metadata = nil

if is_response_content(event) then
new_event = {
local metadata = {}
metadata.finished_reason = event.candidates
and #event.candidates > 0
and event.candidates[1].finishReason
or "STOP"
metadata.completion_tokens = event.usageMetadata and event.usageMetadata.candidatesTokenCount or 0
metadata.prompt_tokens = event.usageMetadata and event.usageMetadata.promptTokenCount or 0

local new_event = {
choices = {
[1] = {
delta = {
Expand All @@ -75,28 +77,8 @@ local function handle_stream_event(event_t, model_info, route_type)
},
},
}
end

if is_response_finished(event) then
metadata = metadata or {}
metadata.finished_reason = event.candidates[1].finishReason
new_event = "[DONE]"
end

if event.usageMetadata then
metadata = metadata or {}
metadata.completion_tokens = event.usageMetadata.candidatesTokenCount or 0
metadata.prompt_tokens = event.usageMetadata.promptTokenCount or 0
end

if new_event then
if new_event ~= "[DONE]" then
new_event = cjson.encode(new_event)
end

return new_event, nil, metadata
else
return nil, nil, metadata -- caller code will handle "unrecognised" event types
return cjson.encode(new_event), nil, metadata
end
end

Expand Down Expand Up @@ -206,6 +188,15 @@ local function from_gemini_chat_openai(response, model_info, route_type)
messages.object = "chat.completion"
messages.model = model_info.name

-- process analytics
if response.usageMetadata then
messages.usage = {
prompt_tokens = response.usageMetadata.promptTokenCount,
completion_tokens = response.usageMetadata.candidatesTokenCount,
total_tokens = response.usageMetadata.totalTokenCount,
}
end

else -- probably a server fault or other unexpected response
local err = "no generation candidates received from Gemini, or max_tokens too short"
ngx.log(ngx.ERR, err)
Expand Down
21 changes: 18 additions & 3 deletions kong/llm/drivers/shared.lua
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ local cycle_aware_deep_copy = require("kong.tools.table").cycle_aware_deep_copy
local function str_ltrim(s) -- remove leading whitespace from string.
return type(s) == "string" and s:gsub("^%s*", "")
end

local function str_rtrim(s) -- remove trailing whitespace from string.
return type(s) == "string" and s:match('^(.*%S)%s*$')
end
--

local log_entry_keys = {
Expand Down Expand Up @@ -249,20 +253,31 @@ function _M.frame_to_events(frame, provider)
-- some new LLMs return the JSON object-by-object,
-- because that totally makes sense to parse?!
if provider == "gemini" then
local done = false

-- if this is the first frame, it will begin with array opener '['
frame = (string.sub(str_ltrim(frame), 1, 1) == "[" and string.sub(str_ltrim(frame), 2)) or frame

-- it may start with ',' which is the start of the new frame
frame = (string.sub(str_ltrim(frame), 1, 1) == "," and string.sub(str_ltrim(frame), 2)) or frame

-- finally, it may end with the array terminator ']' indicating the finished stream
frame = (string.sub(str_ltrim(frame), -1) == "]" and string.sub(str_ltrim(frame), 1, -2)) or frame
-- it may end with the array terminator ']' indicating the finished stream
if string.sub(str_rtrim(frame), -1) == "]" then
frame = string.sub(str_rtrim(frame), 1, -2)
done = true
end

-- for multiple events that arrive in the same frame, split by top-level comma
for _, v in ipairs(split(frame, "\n,")) do
events[#events+1] = { data = v }
end

if done then
-- add the done signal here
-- but we have to retrieve the metadata from a previous filter run
events[#events+1] = { data = "[DONE]" }
end

elseif provider == "bedrock" then
local parser = aws_stream:new(frame)
while true do
Expand Down Expand Up @@ -609,7 +624,7 @@ function _M.post_request(conf, response_object)
end

if response_object.usage.prompt_tokens and response_object.usage.completion_tokens
and conf.model.options.input_cost and conf.model.options.output_cost then
and conf.model.options and conf.model.options.input_cost and conf.model.options.output_cost then
request_analytics_plugin[log_entry_keys.USAGE_CONTAINER][log_entry_keys.COST] =
(response_object.usage.prompt_tokens * conf.model.options.input_cost
+ response_object.usage.completion_tokens * conf.model.options.output_cost) / 1000000 -- 1 million
Expand Down
3 changes: 0 additions & 3 deletions kong/llm/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,6 @@ do
}
end

-- needed for some drivers later
self.conf.model.source = "transformer-plugins"

-- convert it to the specified driver format
ai_request, _, err = self.driver.to_format(ai_request, self.conf.model, "llm/v1/chat")
if err then
Expand Down
24 changes: 15 additions & 9 deletions kong/plugins/ai-proxy/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ local kong_utils = require("kong.tools.gzip")
local kong_meta = require("kong.meta")
local buffer = require "string.buffer"
local strip = require("kong.tools.utils").strip
local to_hex = require("resty.string").to_hex

-- cloud auth/sdk providers
local GCP_SERVICE_ACCOUNT do
Expand Down Expand Up @@ -230,14 +229,21 @@ local function handle_streaming_frame(conf)
end

if conf.logging and conf.logging.log_statistics and metadata then
kong_ctx_plugin.ai_stream_completion_tokens =
(kong_ctx_plugin.ai_stream_completion_tokens or 0) +
(metadata.completion_tokens or 0)
or kong_ctx_plugin.ai_stream_completion_tokens
kong_ctx_plugin.ai_stream_prompt_tokens =
(kong_ctx_plugin.ai_stream_prompt_tokens or 0) +
(metadata.prompt_tokens or 0)
or kong_ctx_plugin.ai_stream_prompt_tokens
-- gemini metadata specifically, works differently
if conf.model.provider == "gemini" then
print(metadata.completion_tokens)
kong_ctx_plugin.ai_stream_completion_tokens = metadata.completion_tokens or 0
kong_ctx_plugin.ai_stream_prompt_tokens = metadata.prompt_tokens or 0
else
kong_ctx_plugin.ai_stream_completion_tokens =
(kong_ctx_plugin.ai_stream_completion_tokens or 0) +
(metadata.completion_tokens or 0)
or kong_ctx_plugin.ai_stream_completion_tokens
kong_ctx_plugin.ai_stream_prompt_tokens =
(kong_ctx_plugin.ai_stream_prompt_tokens or 0) +
(metadata.prompt_tokens or 0)
or kong_ctx_plugin.ai_stream_prompt_tokens
end
end
end
end
Expand Down
27 changes: 11 additions & 16 deletions kong/tools/aws_stream.lua
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ function Stream:next_int(size)
return nil, nil, "cannot work on integers smaller than 8 bits long"
end

local int, err = self:next_bytes(size / 8, trim)
local int, err = self:next_bytes(size / 8)
if err then
return nil, nil, err
end
Expand All @@ -123,22 +123,20 @@ function Stream:next_message()
-- this is a chicken and egg problem, because we need to
-- read the message to get the length, to then re-read the
-- whole message at correct offset
local msg_len, orig_len, err = self:next_int(32)
local msg_len, _, err = self:next_int(32)
if err then
return err
end

-- get the headers length
local headers_len, orig_headers_len, err = self:next_int(32)
local headers_len, _, err = self:next_int(32)
if err then
return err
end

-- get the preamble checksum
local preamble_checksum, orig_preamble_checksum, err = self:next_int(32)

-- TODO: calculate checksum
-- local result = crc32(orig_len .. origin_headers_len, preamble_checksum)
-- if not result then
-- return nil, "preamble checksum failed - message is corrupted"
-- end
-- skip it because we're not using UDP
self:next_int(32)

-- pull the headers from the buf
local headers = {}
Expand Down Expand Up @@ -167,12 +165,9 @@ function Stream:next_message()
local body = self:next_utf_8(msg_len - self.read_count - 4)

-- last 4 bytes is a body checksum
local msg_checksum = self:next_int(32)
-- TODO CHECK FULL MESSAGE CHECKSUM
-- local result = crc32(original_full_msg, msg_checksum)
-- if not result then
-- return nil, "preamble checksum failed - message is corrupted"
-- end
-- skip it because we're not using UDP
self:next_int(32)


-- rewind the tape
self.read_count = 0
Expand Down
67 changes: 36 additions & 31 deletions spec/fixtures/ai-proxy/unit/real-responses/gemini/llm-v1-chat.json
Original file line number Diff line number Diff line change
@@ -1,34 +1,39 @@
{
"candidates": [
{
"content": {
"parts": [
{
"text": "Ah, vous voulez savoir le double de ce résultat ? Eh bien, le double de 2 est **4**. \n"
}
],
"role": "model"
},
"finishReason": "STOP",
"index": 0,
"safetyRatings": [
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE"
},
"candidates": [
{
"content": {
"parts": [
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE"
"text": "Ah, vous voulez savoir le double de ce résultat ? Eh bien, le double de 2 est **4**. \n"
}
]
}
]
}
],
"role": "model"
},
"finishReason": "STOP",
"index": 0,
"safetyRatings": [
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE"
}
]
}
],
"usageMetadata": {
"promptTokenCount": 14,
"candidatesTokenCount": 128,
"totalTokenCount": 142
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,8 @@
},
{
"data": "\n{\n \"candidates\": [\n {\n \"content\": {\n \"parts\": [\n {\n \"text\": \" not a limit.\\n\\nIf you're interested in learning more about relativity, I encourage you to explore further resources online or in books. There are many excellent introductory materials available. \\n\"\n }\n ],\n \"role\": \"model\"\n },\n \"finishReason\": \"STOP\",\n \"index\": 0,\n \"safetyRatings\": [\n {\n \"category\": \"HARM_CATEGORY_SEXUALLY_EXPLICIT\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_HATE_SPEECH\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_HARASSMENT\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_DANGEROUS_CONTENT\",\n \"probability\": \"NEGLIGIBLE\"\n }\n ]\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\": 6,\n \"candidatesTokenCount\": 547,\n \"totalTokenCount\": 553\n }\n}\n"
},
{
"data": "[DONE]"
}
]

0 comments on commit f499d3f

Please sign in to comment.