Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[backport -> release/3.9.x] fix: [AG-178] AI Gateway bugs, 3.9.0 rollup #13949

Merged
merged 10 commits into from
Nov 29, 2024
Merged
27 changes: 14 additions & 13 deletions kong/llm/drivers/gemini.lua
Original file line number Diff line number Diff line change
Expand Up @@ -268,22 +268,23 @@ local function from_gemini_chat_openai(response, model_info, route_type)
messages.model = model_info.name

elseif is_tool_content(response) then
messages.choices[1] = {
index = 0,
message = {
role = "assistant",
tool_calls = {},
},
}

local function_call_responses = response.candidates[1].content.parts
for i, v in ipairs(function_call_responses) do
messages.choices[i] = {
index = 0,
message = {
role = "assistant",
tool_calls = {
{
['function'] = {
name = v.functionCall.name,
arguments = cjson.encode(v.functionCall.args),
},
},
messages.choices[1].message.tool_calls[i] =
{
['function'] = {
name = v.functionCall.name,
arguments = cjson.encode(v.functionCall.args),
},
},
}
}
end
end

Expand Down
11 changes: 10 additions & 1 deletion kong/llm/drivers/shared.lua
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ _M._CONST = {
["AWS_STREAM_CONTENT_TYPE"] = "application/vnd.amazon.eventstream",
}

_M._SUPPORTED_STREAMING_CONTENT_TYPES = {
["text/event-stream"] = true,
["application/vnd.amazon.eventstream"] = true,
["application/json"] = true,
}

_M.streaming_has_token_counts = {
["cohere"] = true,
["llama2"] = true,
Expand Down Expand Up @@ -719,7 +725,10 @@ function _M.post_request(conf, response_object)
meta_container[log_entry_keys.LLM_LATENCY] = llm_latency

if response_object.usage and response_object.usage.completion_tokens then
local time_per_token = math.floor(llm_latency / response_object.usage.completion_tokens)
local time_per_token = 0
if response_object.usage.completion_tokens > 0 then
time_per_token = math.floor(llm_latency / response_object.usage.completion_tokens)
end
request_analytics_plugin[log_entry_keys.USAGE_CONTAINER][log_entry_keys.TIME_PER_TOKEN] = time_per_token
end
end
Expand Down
9 changes: 8 additions & 1 deletion kong/llm/plugin/base.lua
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ local STAGES = {
RES_POST_PROCESSING = 7,
}

-- Filters in those stages are allowed to execute more than one time in a request
-- TODO: implement singleton support, that in one iteration of of body_filter only one filter
-- only ran one times. This is not an issue today as they are only used in one plugin.
local REPEATED_PHASES = {
[STAGES.STREAMING] = true,
}

local MetaPlugin = {}

local all_filters = {}
Expand All @@ -38,7 +45,7 @@ local function run_stage(stage, sub_plugin, conf)
if not f then
kong.log.err("no filter named '" .. name .. "' registered")

elseif not ai_executed_filters[name] then
elseif not ai_executed_filters[name] or REPEATED_PHASES[stage] then
ai_executed_filters[name] = true

kong.log.debug("executing filter ", name)
Expand Down
8 changes: 6 additions & 2 deletions kong/llm/plugin/observability.lua
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@ function _M.metrics_get(key)
-- process automatic calculation
if not metrics[key] then
if key == "llm_tpot_latency" then
return math.floor(_M.metrics_get("llm_e2e_latency") / _M.metrics_get("llm_completion_tokens_count"))
local llm_completion_tokens_count = _M.metrics_get("llm_completion_tokens_count")
if llm_completion_tokens_count > 0 then
return _M.metrics_get("llm_e2e_latency") / llm_completion_tokens_count
end
return 0
elseif key == "llm_total_tokens_count" then
return _M.metrics_get("llm_prompt_tokens_count") + _M.metrics_get("llm_completion_tokens_count")
end
Expand Down Expand Up @@ -102,4 +106,4 @@ function _M.record_request_end()
return latency
end

return _M
return _M
18 changes: 18 additions & 0 deletions kong/llm/plugin/shared-filters/normalize-json-response.lua
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,24 @@ local function transform_body(conf)
response_body = cjson.encode({ error = { message = err }})
end

-- TODO: avoid json encode and decode when transforming
-- deduplicate body usage parsing from parse-json-response
local t, err
if response_body then
t, err = cjson.decode(response_body)
if err then
kong.log.warn("failed to decode response body for usage introspection: ", err)
end

if t and t.usage and t.usage.prompt_tokens then
ai_plugin_o11y.metrics_set("llm_prompt_tokens_count", t.usage.prompt_tokens)
end

if t and t.usage and t.usage.completion_tokens then
ai_plugin_o11y.metrics_set("llm_completion_tokens_count", t.usage.completion_tokens)
end
end

set_global_ctx("response_body", response_body) -- to be sent out later or consumed by other plugins
end

Expand Down
23 changes: 16 additions & 7 deletions kong/llm/plugin/shared-filters/normalize-sse-chunk.lua
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ local function handle_streaming_frame(conf, chunk, finished)
-- how do we know if this is false but some other filter will need the body?
if conf.logging and conf.logging.log_payloads and not body_buffer then
body_buffer = buffer.new()
set_global_ctx("sse_body_buffer", buffer)
set_global_ctx("sse_body_buffer", body_buffer)
else
kong.log.debug("using existing body buffer created by: ", source)
end
Expand All @@ -83,6 +83,8 @@ local function handle_streaming_frame(conf, chunk, finished)


for _, event in ipairs(events) do
-- TODO: currently only subset of driver follow the body, err, metadata pattern
-- unify this so that it was always extracted from the body
local formatted, _, metadata = ai_driver.from_format(event, conf.model, "stream/" .. conf.route_type)

if formatted then
Expand All @@ -106,12 +108,6 @@ local function handle_streaming_frame(conf, chunk, finished)
if body_buffer then
body_buffer:put(token_t)
end

-- incredibly loose estimate based on https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them
-- but this is all we can do until OpenAI fixes this...
--
-- essentially, every 4 characters is a token, with minimum of 1*4 per event
ai_plugin_o11y.metrics_add("llm_completion_tokens_count", math.ceil(#strip(token_t) / 4))
end
end

Expand Down Expand Up @@ -141,6 +137,19 @@ local function handle_streaming_frame(conf, chunk, finished)

local prompt_tokens_count = ai_plugin_o11y.metrics_get("llm_prompt_tokens_count")
local completion_tokens_count = ai_plugin_o11y.metrics_get("llm_completion_tokens_count")

if conf.logging and conf.logging.log_statistics then
-- no metadata populated in the event streams, do our estimation
if completion_tokens_count == 0 then
-- incredibly loose estimate based on https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them
-- but this is all we can do until OpenAI fixes this...
--
-- essentially, every 4 characters is a token, with minimum of 1*4 per event
completion_tokens_count = math.ceil(#strip(response) / 4)
ai_plugin_o11y.metrics_set("llm_completion_tokens_count", completion_tokens_count)
end
end

-- populate cost
if conf.model.options and conf.model.options.input_cost and conf.model.options.output_cost then
local cost = (prompt_tokens_count * conf.model.options.input_cost +
Expand Down
2 changes: 1 addition & 1 deletion kong/llm/plugin/shared-filters/parse-json-response.lua
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ function _M:run(_)
return true
end

return _M
return _M
2 changes: 1 addition & 1 deletion kong/llm/plugin/shared-filters/parse-sse-chunk.lua
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ local function handle_streaming_frame(conf, chunk, finished)

local content_type = kong.service.response.get_header("Content-Type")
local normalized_content_type = content_type and content_type:sub(1, (content_type:find(";") or 0) - 1)
if normalized_content_type and normalized_content_type ~= "text/event-stream" and normalized_content_type ~= ai_shared._CONST.AWS_STREAM_CONTENT_TYPE then
if normalized_content_type and (not ai_shared._SUPPORTED_STREAMING_CONTENT_TYPES[normalized_content_type]) then
return true
end

Expand Down
2 changes: 2 additions & 0 deletions kong/llm/plugin/shared-filters/serialize-analytics.lua
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ function _M:run(conf)
cost = ai_plugin_o11y.metrics_get("llm_usage_cost"),
}

kong.log.inspect(usage)

kong.log.set_serialize_value(string.format("ai.%s.usage", ai_plugin_o11y.NAMESPACE), usage)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ function _M:run(conf)
local identity_interface = _KEYBASTION[conf.llm]

if identity_interface and identity_interface.error then
kong.log.err("error authenticating with ", conf.model.provider, " using native provider auth, ", identity_interface.error)
kong.log.err("error authenticating with ", conf.llm.model.provider, " using native provider auth, ", identity_interface.error)
return kong.response.exit(500, "LLM request failed before proxying")
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ function _M:run(conf)
local identity_interface = _KEYBASTION[conf.llm]

if identity_interface and identity_interface.error then
kong.log.err("error authenticating with ", conf.model.provider, " using native provider auth, ", identity_interface.error)
kong.log.err("error authenticating with ", conf.llm.model.provider, " using native provider auth, ", identity_interface.error)
return kong.response.exit(500, "LLM request failed before proxying")
end

Expand Down
35 changes: 35 additions & 0 deletions request.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
{
"contents": [
{
"role": "user",
"parts": [
{
"text": "Hi Gemini"
}
]
}
]
, "generationConfig": {
"temperature": 1
,"maxOutputTokens": 8192
,"topP": 0.95
},
"safetySettings": [
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "OFF"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "OFF"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "OFF"
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "OFF"
}
]
}
8 changes: 4 additions & 4 deletions spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -871,14 +871,14 @@ for _, strategy in helpers.all_strategies() do
local _, first_got = next(log_message.ai)
local actual_llm_latency = first_got.meta.llm_latency
local actual_time_per_token = first_got.usage.time_per_token
local time_per_token = math.floor(actual_llm_latency / first_got.usage.completion_tokens)
local time_per_token = actual_llm_latency / first_got.usage.completion_tokens

first_got.meta.llm_latency = 1
first_got.usage.time_per_token = 1

assert.same(first_expected, first_got)
assert.is_true(actual_llm_latency >= 0)
assert.same(actual_time_per_token, time_per_token)
assert.same(tonumber(string.format("%.3f", actual_time_per_token)), tonumber(string.format("%.3f", time_per_token)))
assert.same(first_got.meta.request_model, "gpt-3.5-turbo")
end)

Expand Down Expand Up @@ -1529,14 +1529,14 @@ for _, strategy in helpers.all_strategies() do

local actual_llm_latency = first_got.meta.llm_latency
local actual_time_per_token = first_got.usage.time_per_token
local time_per_token = math.floor(actual_llm_latency / first_got.usage.completion_tokens)
local time_per_token = actual_llm_latency / first_got.usage.completion_tokens

first_got.meta.llm_latency = 1
first_got.usage.time_per_token = 1

assert.same(first_expected, first_got)
assert.is_true(actual_llm_latency >= 0)
assert.same(actual_time_per_token, time_per_token)
assert.same(tonumber(string.format("%.3f", actual_time_per_token)), tonumber(string.format("%.3f", time_per_token)))
end)

it("logs payloads", function()
Expand Down
Loading
Loading