From f499d3fe32f5bc82c2d2ac4a5b08c4e2e7cb536a Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Wed, 24 Jul 2024 23:52:38 +0100 Subject: [PATCH] fix(ai-proxy): fix gemini streaming; fix gemini analytics --- kong/llm/drivers/bedrock.lua | 5 +- kong/llm/drivers/gemini.lua | 59 +++++++--------- kong/llm/drivers/shared.lua | 21 +++++- kong/llm/init.lua | 3 - kong/plugins/ai-proxy/handler.lua | 24 ++++--- kong/tools/aws_stream.lua | 27 +++----- .../real-responses/gemini/llm-v1-chat.json | 67 ++++++++++--------- .../partial-json-end/expected-output.json | 3 + 8 files changed, 109 insertions(+), 100 deletions(-) diff --git a/kong/llm/drivers/bedrock.lua b/kong/llm/drivers/bedrock.lua index ae981f87442d..21690fa32f54 100644 --- a/kong/llm/drivers/bedrock.lua +++ b/kong/llm/drivers/bedrock.lua @@ -6,10 +6,8 @@ local fmt = string.format local ai_shared = require("kong.llm.drivers.shared") local socket_url = require("socket.url") local string_gsub = string.gsub -local buffer = require("string.buffer") local table_insert = table.insert local string_lower = string.lower -local string_sub = string.sub local signer = require("resty.aws.request.sign") -- @@ -122,8 +120,7 @@ local function handle_stream_event(event_t, model_info, route_type) new_event = "[DONE]" - elseif event_type == "contentBlockStop" then - -- placeholder - I don't think this does anything yet + -- "contentBlockStop" is absent because it is not used for anything here end if new_event then diff --git a/kong/llm/drivers/gemini.lua b/kong/llm/drivers/gemini.lua index 59296ee9160b..f76488dcb19c 100644 --- a/kong/llm/drivers/gemini.lua +++ b/kong/llm/drivers/gemini.lua @@ -41,30 +41,32 @@ local function is_response_content(content) and content.candidates[1].content.parts[1].text end -local function is_response_finished(content) - return content - and content.candidates - and #content.candidates > 0 - and content.candidates[1].finishReason -end - -local function handle_stream_event(event_t, model_info, route_type) +local function handle_stream_event(event_t, model_info, route_type) -- discard empty frames, it should either be a random new line, or comment if (not event_t.data) or (#event_t.data < 1) then return end - + + if event_t.data == "[DONE]" then + return "[DONE]", nil, nil + end + local event, err = cjson.decode(event_t.data) if err then ngx.log(ngx.WARN, "failed to decode stream event frame from gemini: " .. err) return nil, "failed to decode stream event frame from gemini", nil end - local new_event - local metadata = nil - if is_response_content(event) then - new_event = { + local metadata = {} + metadata.finished_reason = event.candidates + and #event.candidates > 0 + and event.candidates[1].finishReason + or "STOP" + metadata.completion_tokens = event.usageMetadata and event.usageMetadata.candidatesTokenCount or 0 + metadata.prompt_tokens = event.usageMetadata and event.usageMetadata.promptTokenCount or 0 + + local new_event = { choices = { [1] = { delta = { @@ -75,28 +77,8 @@ local function handle_stream_event(event_t, model_info, route_type) }, }, } - end - if is_response_finished(event) then - metadata = metadata or {} - metadata.finished_reason = event.candidates[1].finishReason - new_event = "[DONE]" - end - - if event.usageMetadata then - metadata = metadata or {} - metadata.completion_tokens = event.usageMetadata.candidatesTokenCount or 0 - metadata.prompt_tokens = event.usageMetadata.promptTokenCount or 0 - end - - if new_event then - if new_event ~= "[DONE]" then - new_event = cjson.encode(new_event) - end - - return new_event, nil, metadata - else - return nil, nil, metadata -- caller code will handle "unrecognised" event types + return cjson.encode(new_event), nil, metadata end end @@ -206,6 +188,15 @@ local function from_gemini_chat_openai(response, model_info, route_type) messages.object = "chat.completion" messages.model = model_info.name + -- process analytics + if response.usageMetadata then + messages.usage = { + prompt_tokens = response.usageMetadata.promptTokenCount, + completion_tokens = response.usageMetadata.candidatesTokenCount, + total_tokens = response.usageMetadata.totalTokenCount, + } + end + else -- probably a server fault or other unexpected response local err = "no generation candidates received from Gemini, or max_tokens too short" ngx.log(ngx.ERR, err) diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index 2336ffbee0c9..6b3714ca72b0 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -19,6 +19,10 @@ local cycle_aware_deep_copy = require("kong.tools.table").cycle_aware_deep_copy local function str_ltrim(s) -- remove leading whitespace from string. return type(s) == "string" and s:gsub("^%s*", "") end + +local function str_rtrim(s) -- remove trailing whitespace from string. + return type(s) == "string" and s:match('^(.*%S)%s*$') +end -- local log_entry_keys = { @@ -249,20 +253,31 @@ function _M.frame_to_events(frame, provider) -- some new LLMs return the JSON object-by-object, -- because that totally makes sense to parse?! if provider == "gemini" then + local done = false + -- if this is the first frame, it will begin with array opener '[' frame = (string.sub(str_ltrim(frame), 1, 1) == "[" and string.sub(str_ltrim(frame), 2)) or frame -- it may start with ',' which is the start of the new frame frame = (string.sub(str_ltrim(frame), 1, 1) == "," and string.sub(str_ltrim(frame), 2)) or frame - -- finally, it may end with the array terminator ']' indicating the finished stream - frame = (string.sub(str_ltrim(frame), -1) == "]" and string.sub(str_ltrim(frame), 1, -2)) or frame + -- it may end with the array terminator ']' indicating the finished stream + if string.sub(str_rtrim(frame), -1) == "]" then + frame = string.sub(str_rtrim(frame), 1, -2) + done = true + end -- for multiple events that arrive in the same frame, split by top-level comma for _, v in ipairs(split(frame, "\n,")) do events[#events+1] = { data = v } end + if done then + -- add the done signal here + -- but we have to retrieve the metadata from a previous filter run + events[#events+1] = { data = "[DONE]" } + end + elseif provider == "bedrock" then local parser = aws_stream:new(frame) while true do @@ -609,7 +624,7 @@ function _M.post_request(conf, response_object) end if response_object.usage.prompt_tokens and response_object.usage.completion_tokens - and conf.model.options.input_cost and conf.model.options.output_cost then + and conf.model.options and conf.model.options.input_cost and conf.model.options.output_cost then request_analytics_plugin[log_entry_keys.USAGE_CONTAINER][log_entry_keys.COST] = (response_object.usage.prompt_tokens * conf.model.options.input_cost + response_object.usage.completion_tokens * conf.model.options.output_cost) / 1000000 -- 1 million diff --git a/kong/llm/init.lua b/kong/llm/init.lua index 5a3454d293db..266f5e355a5c 100644 --- a/kong/llm/init.lua +++ b/kong/llm/init.lua @@ -137,9 +137,6 @@ do } end - -- needed for some drivers later - self.conf.model.source = "transformer-plugins" - -- convert it to the specified driver format ai_request, _, err = self.driver.to_format(ai_request, self.conf.model, "llm/v1/chat") if err then diff --git a/kong/plugins/ai-proxy/handler.lua b/kong/plugins/ai-proxy/handler.lua index da3cb4f77ab8..fccb66545f8a 100644 --- a/kong/plugins/ai-proxy/handler.lua +++ b/kong/plugins/ai-proxy/handler.lua @@ -5,7 +5,6 @@ local kong_utils = require("kong.tools.gzip") local kong_meta = require("kong.meta") local buffer = require "string.buffer" local strip = require("kong.tools.utils").strip -local to_hex = require("resty.string").to_hex -- cloud auth/sdk providers local GCP_SERVICE_ACCOUNT do @@ -230,14 +229,21 @@ local function handle_streaming_frame(conf) end if conf.logging and conf.logging.log_statistics and metadata then - kong_ctx_plugin.ai_stream_completion_tokens = - (kong_ctx_plugin.ai_stream_completion_tokens or 0) + - (metadata.completion_tokens or 0) - or kong_ctx_plugin.ai_stream_completion_tokens - kong_ctx_plugin.ai_stream_prompt_tokens = - (kong_ctx_plugin.ai_stream_prompt_tokens or 0) + - (metadata.prompt_tokens or 0) - or kong_ctx_plugin.ai_stream_prompt_tokens + -- gemini metadata specifically, works differently + if conf.model.provider == "gemini" then + print(metadata.completion_tokens) + kong_ctx_plugin.ai_stream_completion_tokens = metadata.completion_tokens or 0 + kong_ctx_plugin.ai_stream_prompt_tokens = metadata.prompt_tokens or 0 + else + kong_ctx_plugin.ai_stream_completion_tokens = + (kong_ctx_plugin.ai_stream_completion_tokens or 0) + + (metadata.completion_tokens or 0) + or kong_ctx_plugin.ai_stream_completion_tokens + kong_ctx_plugin.ai_stream_prompt_tokens = + (kong_ctx_plugin.ai_stream_prompt_tokens or 0) + + (metadata.prompt_tokens or 0) + or kong_ctx_plugin.ai_stream_prompt_tokens + end end end end diff --git a/kong/tools/aws_stream.lua b/kong/tools/aws_stream.lua index cee1c9ed4be0..ebefc2c26566 100644 --- a/kong/tools/aws_stream.lua +++ b/kong/tools/aws_stream.lua @@ -98,7 +98,7 @@ function Stream:next_int(size) return nil, nil, "cannot work on integers smaller than 8 bits long" end - local int, err = self:next_bytes(size / 8, trim) + local int, err = self:next_bytes(size / 8) if err then return nil, nil, err end @@ -123,22 +123,20 @@ function Stream:next_message() -- this is a chicken and egg problem, because we need to -- read the message to get the length, to then re-read the -- whole message at correct offset - local msg_len, orig_len, err = self:next_int(32) + local msg_len, _, err = self:next_int(32) if err then return err end -- get the headers length - local headers_len, orig_headers_len, err = self:next_int(32) + local headers_len, _, err = self:next_int(32) + if err then + return err + end -- get the preamble checksum - local preamble_checksum, orig_preamble_checksum, err = self:next_int(32) - - -- TODO: calculate checksum - -- local result = crc32(orig_len .. origin_headers_len, preamble_checksum) - -- if not result then - -- return nil, "preamble checksum failed - message is corrupted" - -- end + -- skip it because we're not using UDP + self:next_int(32) -- pull the headers from the buf local headers = {} @@ -167,12 +165,9 @@ function Stream:next_message() local body = self:next_utf_8(msg_len - self.read_count - 4) -- last 4 bytes is a body checksum - local msg_checksum = self:next_int(32) - -- TODO CHECK FULL MESSAGE CHECKSUM - -- local result = crc32(original_full_msg, msg_checksum) - -- if not result then - -- return nil, "preamble checksum failed - message is corrupted" - -- end + -- skip it because we're not using UDP + self:next_int(32) + -- rewind the tape self.read_count = 0 diff --git a/spec/fixtures/ai-proxy/unit/real-responses/gemini/llm-v1-chat.json b/spec/fixtures/ai-proxy/unit/real-responses/gemini/llm-v1-chat.json index 80781b6eb72a..96933d9835e6 100644 --- a/spec/fixtures/ai-proxy/unit/real-responses/gemini/llm-v1-chat.json +++ b/spec/fixtures/ai-proxy/unit/real-responses/gemini/llm-v1-chat.json @@ -1,34 +1,39 @@ { - "candidates": [ - { - "content": { - "parts": [ - { - "text": "Ah, vous voulez savoir le double de ce résultat ? Eh bien, le double de 2 est **4**. \n" - } - ], - "role": "model" - }, - "finishReason": "STOP", - "index": 0, - "safetyRatings": [ - { - "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", - "probability": "NEGLIGIBLE" - }, - { - "category": "HARM_CATEGORY_HATE_SPEECH", - "probability": "NEGLIGIBLE" - }, - { - "category": "HARM_CATEGORY_HARASSMENT", - "probability": "NEGLIGIBLE" - }, + "candidates": [ + { + "content": { + "parts": [ { - "category": "HARM_CATEGORY_DANGEROUS_CONTENT", - "probability": "NEGLIGIBLE" + "text": "Ah, vous voulez savoir le double de ce résultat ? Eh bien, le double de 2 est **4**. \n" } - ] - } - ] - } \ No newline at end of file + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 14, + "candidatesTokenCount": 128, + "totalTokenCount": 142 + } +} diff --git a/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-end/expected-output.json b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-end/expected-output.json index ba6a64384d95..f35aaf6f9dba 100644 --- a/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-end/expected-output.json +++ b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-end/expected-output.json @@ -4,5 +4,8 @@ }, { "data": "\n{\n \"candidates\": [\n {\n \"content\": {\n \"parts\": [\n {\n \"text\": \" not a limit.\\n\\nIf you're interested in learning more about relativity, I encourage you to explore further resources online or in books. There are many excellent introductory materials available. \\n\"\n }\n ],\n \"role\": \"model\"\n },\n \"finishReason\": \"STOP\",\n \"index\": 0,\n \"safetyRatings\": [\n {\n \"category\": \"HARM_CATEGORY_SEXUALLY_EXPLICIT\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_HATE_SPEECH\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_HARASSMENT\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_DANGEROUS_CONTENT\",\n \"probability\": \"NEGLIGIBLE\"\n }\n ]\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\": 6,\n \"candidatesTokenCount\": 547,\n \"totalTokenCount\": 553\n }\n}\n" + }, + { + "data": "[DONE]" } ] \ No newline at end of file