diff --git a/changelog/unreleased/kong/ai-proxy-aws-bedrock.yml b/changelog/unreleased/kong/ai-proxy-aws-bedrock.yml new file mode 100644 index 000000000000..adc608b92b04 --- /dev/null +++ b/changelog/unreleased/kong/ai-proxy-aws-bedrock.yml @@ -0,0 +1,5 @@ +message: | + Kong AI Gateway (AI Proxy and associated plugin family) now supports + all AWS Bedrock "Converse API" models. +type: feature +scope: Plugin diff --git a/kong-3.8.0-0.rockspec b/kong-3.8.0-0.rockspec index 8581f9cf0f92..f7ead8c8957b 100644 --- a/kong-3.8.0-0.rockspec +++ b/kong-3.8.0-0.rockspec @@ -203,6 +203,7 @@ build = { ["kong.tools.cjson"] = "kong/tools/cjson.lua", ["kong.tools.emmy_debugger"] = "kong/tools/emmy_debugger.lua", ["kong.tools.redis.schema"] = "kong/tools/redis/schema.lua", + ["kong.tools.aws_stream"] = "kong/tools/aws_stream.lua", ["kong.runloop.handler"] = "kong/runloop/handler.lua", ["kong.runloop.events"] = "kong/runloop/events.lua", @@ -612,8 +613,8 @@ build = { ["kong.llm.drivers.anthropic"] = "kong/llm/drivers/anthropic.lua", ["kong.llm.drivers.mistral"] = "kong/llm/drivers/mistral.lua", ["kong.llm.drivers.llama2"] = "kong/llm/drivers/llama2.lua", - ["kong.llm.drivers.gemini"] = "kong/llm/drivers/gemini.lua", + ["kong.llm.drivers.bedrock"] = "kong/llm/drivers/bedrock.lua", ["kong.plugins.ai-prompt-decorator.handler"] = "kong/plugins/ai-prompt-decorator/handler.lua", ["kong.plugins.ai-prompt-decorator.schema"] = "kong/plugins/ai-prompt-decorator/schema.lua", diff --git a/kong/clustering/compat/checkers.lua b/kong/clustering/compat/checkers.lua index 7128b0f79078..55dcbbc2bd4a 100644 --- a/kong/clustering/compat/checkers.lua +++ b/kong/clustering/compat/checkers.lua @@ -2,7 +2,7 @@ local ipairs = ipairs local type = type -local log_warn_message +local log_warn_message, _AI_PROVIDER_INCOMPATIBLE do local ngx_log = ngx.log local ngx_WARN = ngx.WARN @@ -19,8 +19,24 @@ do KONG_VERSION, hint, dp_version, action) ngx_log(ngx_WARN, _log_prefix, msg, log_suffix) end -end + local _AI_PROVIDERS_ADDED = { + [3008000000] = { + "gemini", + "bedrock", + }, + } + + _AI_PROVIDER_INCOMPATIBLE = function(provider, ver) + for _, v in ipairs(_AI_PROVIDERS_ADDED[ver]) do + if v == provider then + return true + end + end + + return false + end +end local compatible_checkers = { { 3008000000, --[[ 3.8.0.0 ]] @@ -40,37 +56,43 @@ local compatible_checkers = { if plugin.name == 'ai-proxy' then local config = plugin.config - if config.model.provider == "gemini" then + if _AI_PROVIDER_INCOMPATIBLE(config.model.provider, 3008000000) then + log_warn_message('configures ' .. plugin.name .. ' plugin with' .. + ' "openai preserve mode", because ' .. config.model.provider .. ' provider ' .. + ' is not supported in this release', + dp_version, log_suffix) + config.model.provider = "openai" config.route_type = "preserve" - log_warn_message('configures ' .. plugin.name .. ' plugin with' .. - ' "openai preserve mode", because gemini' .. - ' provider is not supported in this release', - dp_version, log_suffix) + has_update = true end end if plugin.name == 'ai-request-transformer' then local config = plugin.config - if config.llm.model.provider == "gemini" then - config.llm.model.provider = "openai" + if _AI_PROVIDER_INCOMPATIBLE(config.llm.model.provider, 3008000000) then log_warn_message('configures ' .. plugin.name .. ' plugin with' .. - ' "openai preserve mode", because gemini' .. - ' provider is not supported in this release', - dp_version, log_suffix) + ' "openai preserve mode", because ' .. config.llm.model.provider .. ' provider ' .. + ' is not supported in this release', + dp_version, log_suffix) + + config.llm.model.provider = "openai" + has_update = true end end if plugin.name == 'ai-response-transformer' then local config = plugin.config - if config.llm.model.provider == "gemini" then - config.llm.model.provider = "openai" + if _AI_PROVIDER_INCOMPATIBLE(config.llm.model.provider, 3008000000) then log_warn_message('configures ' .. plugin.name .. ' plugin with' .. - ' "openai preserve mode", because gemini' .. - ' provider is not supported in this release', - dp_version, log_suffix) + ' "openai preserve mode", because ' .. config.llm.model.provider .. ' provider ' .. + ' is not supported in this release', + dp_version, log_suffix) + + config.llm.model.provider = "openai" + has_update = true end end diff --git a/kong/clustering/compat/removed_fields.lua b/kong/clustering/compat/removed_fields.lua index f98965036f5d..ade547ae02d6 100644 --- a/kong/clustering/compat/removed_fields.lua +++ b/kong/clustering/compat/removed_fields.lua @@ -172,6 +172,9 @@ return { "model.options.gemini", "auth.gcp_use_service_account", "auth.gcp_service_account_json", + "model.options.bedrock", + "auth.aws_access_key_id", + "auth.aws_secret_access_key", }, ai_prompt_decorator = { "max_request_body_size", @@ -188,12 +191,18 @@ return { "llm.model.options.gemini", "llm.auth.gcp_use_service_account", "llm.auth.gcp_service_account_json", + "llm.model.options.bedrock", + "llm.auth.aws_access_key_id", + "llm.auth.aws_secret_access_key", }, ai_response_transformer = { "max_request_body_size", "llm.model.options.gemini", "llm.auth.gcp_use_service_account", "llm.auth.gcp_service_account_json", + "llm.model.options.bedrock", + "llm.auth.aws_access_key_id", + "llm.auth.aws_secret_access_key", }, prometheus = { "ai_metrics", diff --git a/kong/llm/drivers/anthropic.lua b/kong/llm/drivers/anthropic.lua index fcc6419d33b8..77c9f363f9b6 100644 --- a/kong/llm/drivers/anthropic.lua +++ b/kong/llm/drivers/anthropic.lua @@ -225,7 +225,7 @@ local function handle_stream_event(event_t, model_info, route_type) return delta_to_event(event_data, model_info) elseif event_id == "message_stop" then - return "[DONE]", nil, nil + return ai_shared._CONST.SSE_TERMINATOR, nil, nil elseif event_id == "ping" then return nil, nil, nil diff --git a/kong/llm/drivers/bedrock.lua b/kong/llm/drivers/bedrock.lua new file mode 100644 index 000000000000..372a57fa8276 --- /dev/null +++ b/kong/llm/drivers/bedrock.lua @@ -0,0 +1,442 @@ +local _M = {} + +-- imports +local cjson = require("cjson.safe") +local fmt = string.format +local ai_shared = require("kong.llm.drivers.shared") +local socket_url = require("socket.url") +local string_gsub = string.gsub +local table_insert = table.insert +local string_lower = string.lower +local signer = require("resty.aws.request.sign") +-- + +-- globals +local DRIVER_NAME = "bedrock" +-- + +local _OPENAI_ROLE_MAPPING = { + ["system"] = "assistant", + ["user"] = "user", + ["assistant"] = "assistant", +} + +_M.bedrock_unsupported_system_role_patterns = { + "amazon.titan.-.*", + "cohere.command.-text.-.*", + "cohere.command.-light.-text.-.*", + "mistral.mistral.-7b.-instruct.-.*", + "mistral.mixtral.-8x7b.-instruct.-.*", +} + +local function to_bedrock_generation_config(request_table) + return { + ["maxTokens"] = request_table.max_tokens, + ["stopSequences"] = request_table.stop, + ["temperature"] = request_table.temperature, + ["topP"] = request_table.top_p, + } +end + +local function to_additional_request_fields(request_table) + return { + request_table.bedrock.additionalModelRequestFields + } +end + +local function to_tool_config(request_table) + return { + request_table.bedrock.toolConfig + } +end + +local function handle_stream_event(event_t, model_info, route_type) + local new_event, metadata + + if (not event_t) or (not event_t.data) then + return "", nil, nil + end + + -- decode and determine the event type + local event = cjson.decode(event_t.data) + local event_type = event and event.headers and event.headers[":event-type"] + + if not event_type then + return "", nil, nil + end + + local body = event.body and cjson.decode(event.body) + + if not body then + return "", nil, nil + end + + if event_type == "messageStart" then + new_event = { + choices = { + [1] = { + delta = { + content = "", + role = body.role, + }, + index = 0, + logprobs = cjson.null, + }, + }, + model = model_info.name, + object = "chat.completion.chunk", + system_fingerprint = cjson.null, + } + + elseif event_type == "contentBlockDelta" then + new_event = { + choices = { + [1] = { + delta = { + content = (body.delta + and body.delta.text) + or "", + }, + index = 0, + finish_reason = cjson.null, + logprobs = cjson.null, + }, + }, + model = model_info.name, + object = "chat.completion.chunk", + } + + elseif event_type == "messageStop" then + new_event = { + choices = { + [1] = { + delta = {}, + index = 0, + finish_reason = body.stopReason, + logprobs = cjson.null, + }, + }, + model = model_info.name, + object = "chat.completion.chunk", + } + + elseif event_type == "metadata" then + metadata = { + prompt_tokens = body.usage and body.usage.inputTokens or 0, + completion_tokens = body.usage and body.usage.outputTokens or 0, + } + + new_event = ai_shared._CONST.SSE_TERMINATOR + + -- "contentBlockStop" is absent because it is not used for anything here + end + + if new_event then + if new_event ~= ai_shared._CONST.SSE_TERMINATOR then + new_event = cjson.encode(new_event) + end + + return new_event, nil, metadata + else + return nil, nil, metadata -- caller code will handle "unrecognised" event types + end +end + +local function to_bedrock_chat_openai(request_table, model_info, route_type) + if not request_table then -- try-catch type mechanism + local err = "empty request table received for transformation" + ngx.log(ngx.ERR, "[bedrock] ", err) + return nil, nil, err + end + + local new_r = {} + + -- anthropic models support variable versions, just like self-hosted + new_r.anthropic_version = model_info.options and model_info.options.anthropic_version + or "bedrock-2023-05-31" + + if request_table.messages and #request_table.messages > 0 then + local system_prompts = {} + + for i, v in ipairs(request_table.messages) do + -- for 'system', we just concat them all into one Bedrock instruction + if v.role and v.role == "system" then + system_prompts[#system_prompts+1] = { text = v.content } + + else + -- for any other role, just construct the chat history as 'parts.text' type + new_r.messages = new_r.messages or {} + table_insert(new_r.messages, { + role = _OPENAI_ROLE_MAPPING[v.role or "user"], -- default to 'user' + content = { + { + text = v.content or "" + }, + }, + }) + end + end + + -- only works for some models + if #system_prompts > 0 then + for _, p in ipairs(_M.bedrock_unsupported_system_role_patterns) do + if model_info.name:find(p) then + return nil, nil, "system prompts are unsupported for model '" .. model_info.name + end + end + + new_r.system = system_prompts + end + end + + new_r.inferenceConfig = to_bedrock_generation_config(request_table) + + new_r.toolConfig = request_table.bedrock + and request_table.bedrock.toolConfig + and to_tool_config(request_table) + + new_r.additionalModelRequestFields = request_table.bedrock + and request_table.bedrock.additionalModelRequestFields + and to_additional_request_fields(request_table) + + return new_r, "application/json", nil +end + +local function from_bedrock_chat_openai(response, model_info, route_type) + local response, err = cjson.decode(response) + + if err then + local err_client = "failed to decode response from Bedrock" + ngx.log(ngx.ERR, fmt("[bedrock] %s: %s", err_client, err)) + return nil, err_client + end + + -- messages/choices table is only 1 size, so don't need to static allocate + local client_response = {} + client_response.choices = {} + + if response.output + and response.output.message + and response.output.message.content + and #response.output.message.content > 0 + and response.output.message.content[1].text then + + client_response.choices[1] = { + index = 0, + message = { + role = "assistant", + content = response.output.message.content[1].text, + }, + finish_reason = string_lower(response.stopReason), + } + client_response.object = "chat.completion" + client_response.model = model_info.name + + else -- probably a server fault or other unexpected response + local err = "no generation candidates received from Bedrock, or max_tokens too short" + ngx.log(ngx.ERR, "[bedrock] ", err) + return nil, err + end + + -- process analytics + if response.usage then + client_response.usage = { + prompt_tokens = response.usage.inputTokens, + completion_tokens = response.usage.outputTokens, + total_tokens = response.usage.totalTokens, + } + end + + return cjson.encode(client_response) +end + +local transformers_to = { + ["llm/v1/chat"] = to_bedrock_chat_openai, +} + +local transformers_from = { + ["llm/v1/chat"] = from_bedrock_chat_openai, + ["stream/llm/v1/chat"] = handle_stream_event, +} + +function _M.from_format(response_string, model_info, route_type) + ngx.log(ngx.DEBUG, "converting from ", model_info.provider, "://", route_type, " type to kong") + + -- MUST return a string, to set as the response body + if not transformers_from[route_type] then + return nil, fmt("no transformer available from format %s://%s", model_info.provider, route_type) + end + + local ok, response_string, err, metadata = pcall(transformers_from[route_type], response_string, model_info, route_type) + if not ok or err then + return nil, fmt("transformation failed from type %s://%s: %s", + model_info.provider, + route_type, + err or "unexpected_error" + ) + end + + return response_string, nil, metadata +end + +function _M.to_format(request_table, model_info, route_type) + ngx.log(ngx.DEBUG, "converting from kong type to ", model_info.provider, "/", route_type) + + if route_type == "preserve" then + -- do nothing + return request_table, nil, nil + end + + if not transformers_to[route_type] then + return nil, nil, fmt("no transformer for %s://%s", model_info.provider, route_type) + end + + request_table = ai_shared.merge_config_defaults(request_table, model_info.options, model_info.route_type) + + local ok, response_object, content_type, err = pcall( + transformers_to[route_type], + request_table, + model_info + ) + if err or (not ok) then + return nil, nil, fmt("error transforming to %s://%s: %s", model_info.provider, route_type, err) + end + + return response_object, content_type, nil +end + +function _M.subrequest(body, conf, http_opts, return_res_table) + -- use shared/standard subrequest routine + local body_string, err + + if type(body) == "table" then + body_string, err = cjson.encode(body) + if err then + return nil, nil, "failed to parse body to json: " .. err + end + elseif type(body) == "string" then + body_string = body + else + return nil, nil, "body must be table or string" + end + + -- may be overridden + local url = (conf.model.options and conf.model.options.upstream_url) + or fmt( + "%s%s", + ai_shared.upstream_url_format[DRIVER_NAME], + ai_shared.operation_map[DRIVER_NAME][conf.route_type].path + ) + + local method = ai_shared.operation_map[DRIVER_NAME][conf.route_type].method + + local headers = { + ["Accept"] = "application/json", + ["Content-Type"] = "application/json", + } + + if conf.auth and conf.auth.header_name then + headers[conf.auth.header_name] = conf.auth.header_value + end + + local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table) + if err then + return nil, nil, "request to ai service failed: " .. err + end + + if return_res_table then + return res, res.status, nil, httpc + else + -- At this point, the entire request / response is complete and the connection + -- will be closed or back on the connection pool. + local status = res.status + local body = res.body + + if status > 299 then + return body, res.status, "status code " .. status + end + + return body, res.status, nil + end +end + +function _M.header_filter_hooks(body) + -- nothing to parse in header_filter phase +end + +function _M.post_request(conf) + if ai_shared.clear_response_headers[DRIVER_NAME] then + for i, v in ipairs(ai_shared.clear_response_headers[DRIVER_NAME]) do + kong.response.clear_header(v) + end + end +end + +function _M.pre_request(conf, body) + -- force gzip for bedrock because brotli and others break streaming + kong.service.request.set_header("Accept-Encoding", "gzip, identity") + + return true, nil +end + +-- returns err or nil +function _M.configure_request(conf, aws_sdk) + local operation = kong.ctx.shared.ai_proxy_streaming_mode and "converse-stream" + or "converse" + + local f_url = conf.model.options and conf.model.options.upstream_url + + if not f_url then -- upstream_url override is not set + local uri = fmt(ai_shared.upstream_url_format[DRIVER_NAME], aws_sdk.config.region) + local path = fmt( + ai_shared.operation_map[DRIVER_NAME][conf.route_type].path, + conf.model.name, + operation) + + f_url = fmt("%s%s", uri, path) + end + + local parsed_url = socket_url.parse(f_url) + + if conf.model.options and conf.model.options.upstream_path then + -- upstream path override is set (or templated from request params) + parsed_url.path = conf.model.options.upstream_path + end + + -- if the path is read from a URL capture, ensure that it is valid + parsed_url.path = string_gsub(parsed_url.path, "^/*", "/") + + kong.service.request.set_path(parsed_url.path) + kong.service.request.set_scheme(parsed_url.scheme) + kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443)) + + -- do the IAM auth and signature headers + aws_sdk.config.signatureVersion = "v4" + aws_sdk.config.endpointPrefix = "bedrock" + + local r = { + headers = {}, + method = ai_shared.operation_map[DRIVER_NAME][conf.route_type].method, + path = parsed_url.path, + host = parsed_url.host, + port = tonumber(parsed_url.port) or 443, + body = kong.request.get_raw_body() + } + + local signature, err = signer(aws_sdk.config, r) + if not signature then + return nil, "failed to sign AWS request: " .. (err or "NONE") + end + + kong.service.request.set_header("Authorization", signature.headers["Authorization"]) + if signature.headers["X-Amz-Security-Token"] then + kong.service.request.set_header("X-Amz-Security-Token", signature.headers["X-Amz-Security-Token"]) + end + if signature.headers["X-Amz-Date"] then + kong.service.request.set_header("X-Amz-Date", signature.headers["X-Amz-Date"]) + end + + return true +end + +return _M diff --git a/kong/llm/drivers/cohere.lua b/kong/llm/drivers/cohere.lua index b96cbbbc2d46..1aafc9405b0c 100644 --- a/kong/llm/drivers/cohere.lua +++ b/kong/llm/drivers/cohere.lua @@ -97,7 +97,7 @@ local function handle_stream_event(event_t, model_info, route_type) elseif event.event_type == "stream-end" then -- return a metadata object, with the OpenAI termination event - new_event = "[DONE]" + new_event = ai_shared._CONST.SSE_TERMINATOR metadata = { completion_tokens = event.response @@ -123,7 +123,7 @@ local function handle_stream_event(event_t, model_info, route_type) end if new_event then - if new_event ~= "[DONE]" then + if new_event ~= ai_shared._CONST.SSE_TERMINATOR then new_event = cjson.encode(new_event) end diff --git a/kong/llm/drivers/gemini.lua b/kong/llm/drivers/gemini.lua index 59296ee9160b..57ca7127ef29 100644 --- a/kong/llm/drivers/gemini.lua +++ b/kong/llm/drivers/gemini.lua @@ -41,30 +41,32 @@ local function is_response_content(content) and content.candidates[1].content.parts[1].text end -local function is_response_finished(content) - return content - and content.candidates - and #content.candidates > 0 - and content.candidates[1].finishReason -end - -local function handle_stream_event(event_t, model_info, route_type) +local function handle_stream_event(event_t, model_info, route_type) -- discard empty frames, it should either be a random new line, or comment if (not event_t.data) or (#event_t.data < 1) then return end - + + if event_t.data == ai_shared._CONST.SSE_TERMINATOR then + return ai_shared._CONST.SSE_TERMINATOR, nil, nil + end + local event, err = cjson.decode(event_t.data) if err then ngx.log(ngx.WARN, "failed to decode stream event frame from gemini: " .. err) return nil, "failed to decode stream event frame from gemini", nil end - local new_event - local metadata = nil - if is_response_content(event) then - new_event = { + local metadata = {} + metadata.finished_reason = event.candidates + and #event.candidates > 0 + and event.candidates[1].finishReason + or "STOP" + metadata.completion_tokens = event.usageMetadata and event.usageMetadata.candidatesTokenCount or 0 + metadata.prompt_tokens = event.usageMetadata and event.usageMetadata.promptTokenCount or 0 + + local new_event = { choices = { [1] = { delta = { @@ -75,28 +77,8 @@ local function handle_stream_event(event_t, model_info, route_type) }, }, } - end - if is_response_finished(event) then - metadata = metadata or {} - metadata.finished_reason = event.candidates[1].finishReason - new_event = "[DONE]" - end - - if event.usageMetadata then - metadata = metadata or {} - metadata.completion_tokens = event.usageMetadata.candidatesTokenCount or 0 - metadata.prompt_tokens = event.usageMetadata.promptTokenCount or 0 - end - - if new_event then - if new_event ~= "[DONE]" then - new_event = cjson.encode(new_event) - end - - return new_event, nil, metadata - else - return nil, nil, metadata -- caller code will handle "unrecognised" event types + return cjson.encode(new_event), nil, metadata end end @@ -206,6 +188,15 @@ local function from_gemini_chat_openai(response, model_info, route_type) messages.object = "chat.completion" messages.model = model_info.name + -- process analytics + if response.usageMetadata then + messages.usage = { + prompt_tokens = response.usageMetadata.promptTokenCount, + completion_tokens = response.usageMetadata.candidatesTokenCount, + total_tokens = response.usageMetadata.totalTokenCount, + } + end + else -- probably a server fault or other unexpected response local err = "no generation candidates received from Gemini, or max_tokens too short" ngx.log(ngx.ERR, err) diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index 0e1d0d18a962..6f9341884f25 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -1,11 +1,12 @@ local _M = {} -- imports -local cjson = require("cjson.safe") -local http = require("resty.http") -local fmt = string.format -local os = os -local parse_url = require("socket.url").parse +local cjson = require("cjson.safe") +local http = require("resty.http") +local fmt = string.format +local os = os +local parse_url = require("socket.url").parse +local aws_stream = require("kong.tools.aws_stream") -- -- static @@ -18,6 +19,10 @@ local cycle_aware_deep_copy = require("kong.tools.table").cycle_aware_deep_copy local function str_ltrim(s) -- remove leading whitespace from string. return type(s) == "string" and s:gsub("^%s*", "") end + +local function str_rtrim(s) -- remove trailing whitespace from string. + return type(s) == "string" and s:match('^(.*%S)%s*$') +end -- local log_entry_keys = { @@ -51,20 +56,26 @@ local log_entry_keys = { local openai_override = os.getenv("OPENAI_TEST_PORT") +_M._CONST = { + ["SSE_TERMINATOR"] = "[DONE]", +} + _M.streaming_has_token_counts = { ["cohere"] = true, ["llama2"] = true, ["anthropic"] = true, ["gemini"] = true, + ["bedrock"] = true, } _M.upstream_url_format = { - openai = fmt("%s://api.openai.com:%s", (openai_override and "http") or "https", (openai_override) or "443"), - anthropic = "https://api.anthropic.com:443", - cohere = "https://api.cohere.com:443", - azure = "https://%s.openai.azure.com:443/openai/deployments/%s", - gemini = "https://generativelanguage.googleapis.com", + openai = fmt("%s://api.openai.com:%s", (openai_override and "http") or "https", (openai_override) or "443"), + anthropic = "https://api.anthropic.com:443", + cohere = "https://api.cohere.com:443", + azure = "https://%s.openai.azure.com:443/openai/deployments/%s", + gemini = "https://generativelanguage.googleapis.com", gemini_vertex = "https://%s", + bedrock = "https://bedrock-runtime.%s.amazonaws.com", } _M.operation_map = { @@ -120,6 +131,12 @@ _M.operation_map = { method = "POST", }, }, + bedrock = { + ["llm/v1/chat"] = { + path = "/model/%s/%s", + method = "POST", + }, + }, } _M.clear_response_headers = { @@ -138,6 +155,9 @@ _M.clear_response_headers = { gemini = { "Set-Cookie", }, + bedrock = { + "Set-Cookie", + }, } --- @@ -219,7 +239,7 @@ end -- @param {string} frame input string to format into SSE events -- @param {boolean} raw_json sets application/json byte-parser mode -- @return {table} n number of split SSE messages, or empty table -function _M.frame_to_events(frame, raw_json_mode) +function _M.frame_to_events(frame, provider) local events = {} if (not frame) or (#frame < 1) or (type(frame)) ~= "string" then @@ -228,21 +248,44 @@ function _M.frame_to_events(frame, raw_json_mode) -- some new LLMs return the JSON object-by-object, -- because that totally makes sense to parse?! - if raw_json_mode then + if provider == "gemini" then + local done = false + -- if this is the first frame, it will begin with array opener '[' frame = (string.sub(str_ltrim(frame), 1, 1) == "[" and string.sub(str_ltrim(frame), 2)) or frame -- it may start with ',' which is the start of the new frame frame = (string.sub(str_ltrim(frame), 1, 1) == "," and string.sub(str_ltrim(frame), 2)) or frame - -- finally, it may end with the array terminator ']' indicating the finished stream - frame = (string.sub(str_ltrim(frame), -1) == "]" and string.sub(str_ltrim(frame), 1, -2)) or frame + -- it may end with the array terminator ']' indicating the finished stream + if string.sub(str_rtrim(frame), -1) == "]" then + frame = string.sub(str_rtrim(frame), 1, -2) + done = true + end -- for multiple events that arrive in the same frame, split by top-level comma for _, v in ipairs(split(frame, "\n,")) do events[#events+1] = { data = v } end + if done then + -- add the done signal here + -- but we have to retrieve the metadata from a previous filter run + events[#events+1] = { data = _M._CONST.SSE_TERMINATOR } + end + + elseif provider == "bedrock" then + local parser = aws_stream:new(frame) + while true do + local msg = parser:next_message() + + if not msg then + break + end + + events[#events+1] = { data = cjson.encode(msg) } + end + -- check if it's raw json and just return the split up data frame -- Cohere / Other flat-JSON format parser -- just return the split up data frame @@ -401,7 +444,7 @@ function _M.from_ollama(response_string, model_info, route_type) end end - if output and output ~= "[DONE]" then + if output and output ~= _M._CONST.SSE_TERMINATOR then output, err = cjson.encode(output) end @@ -510,6 +553,10 @@ end function _M.post_request(conf, response_object) local body_string, err + if not response_object then + return + end + if type(response_object) == "string" then -- set raw string body first, then decode body_string = response_object @@ -573,7 +620,7 @@ function _M.post_request(conf, response_object) end if response_object.usage.prompt_tokens and response_object.usage.completion_tokens - and conf.model.options.input_cost and conf.model.options.output_cost then + and conf.model.options and conf.model.options.input_cost and conf.model.options.output_cost then request_analytics_plugin[log_entry_keys.USAGE_CONTAINER][log_entry_keys.COST] = (response_object.usage.prompt_tokens * conf.model.options.input_cost + response_object.usage.completion_tokens * conf.model.options.output_cost) / 1000000 -- 1 million diff --git a/kong/llm/init.lua b/kong/llm/init.lua index 85802e54b9c7..b4b7bba5ae7a 100644 --- a/kong/llm/init.lua +++ b/kong/llm/init.lua @@ -91,20 +91,51 @@ do function LLM:ai_introspect_body(request, system_prompt, http_opts, response_regex_match) local err, _ - -- set up the request - local ai_request = { - messages = { - [1] = { - role = "system", - content = system_prompt, + -- set up the LLM request for transformation instructions + local ai_request + + -- mistral, cohere, titan (via Bedrock) don't support system commands + if self.driver == "bedrock" then + for _, p in ipairs(self.driver.bedrock_unsupported_system_role_patterns) do + if request.model:find(p) then + ai_request = { + messages = { + [1] = { + role = "user", + content = system_prompt, + }, + [2] = { + role = "assistant", + content = "What is the message?", + }, + [3] = { + role = "user", + content = request, + } + }, + stream = false, + } + break + end + end + end + + -- not Bedrock, or didn't match banned pattern - continue as normal + if not ai_request then + ai_request = { + messages = { + [1] = { + role = "system", + content = system_prompt, + }, + [2] = { + role = "user", + content = request, + } }, - [2] = { - role = "user", - content = request, - } - }, - stream = false, - } + stream = false, + } + end -- convert it to the specified driver format ai_request, _, err = self.driver.to_format(ai_request, self.conf.model, "llm/v1/chat") @@ -204,8 +235,9 @@ do } setmetatable(self, LLM) - local provider = (self.conf.model or {}).provider or "NONE_SET" - local driver_module = "kong.llm.drivers." .. provider + self.provider = (self.conf.model or {}).provider or "NONE_SET" + local driver_module = "kong.llm.drivers." .. self.provider + local ok ok, self.driver = pcall(require, driver_module) if not ok then diff --git a/kong/llm/schemas/init.lua b/kong/llm/schemas/init.lua index 9dc68f16db8a..c975c49c26f0 100644 --- a/kong/llm/schemas/init.lua +++ b/kong/llm/schemas/init.lua @@ -2,6 +2,19 @@ local typedefs = require("kong.db.schema.typedefs") local fmt = string.format +local bedrock_options_schema = { + type = "record", + required = false, + fields = { + { aws_region = { + description = "If using AWS providers (Bedrock) you can override the `AWS_REGION` " .. + "environment variable by setting this option.", + type = "string", + required = false }}, + }, +} + + local gemini_options_schema = { type = "record", required = false, @@ -68,6 +81,22 @@ local auth_schema = { "environment variable `GCP_SERVICE_ACCOUNT`.", required = false, referenceable = true }}, + { aws_access_key_id = { + type = "string", + description = "Set this if you are using an AWS provider (Bedrock) and you are authenticating " .. + "using static IAM User credentials. Setting this will override the AWS_ACCESS_KEY_ID " .. + "environment variable for this plugin instance.", + required = false, + encrypted = true, + referenceable = true }}, + { aws_secret_access_key = { + type = "string", + description = "Set this if you are using an AWS provider (Bedrock) and you are authenticating " .. + "using static IAM User credentials. Setting this will override the AWS_SECRET_ACCESS_KEY " .. + "environment variable for this plugin instance.", + required = false, + encrypted = true, + referenceable = true }}, } } @@ -144,6 +173,7 @@ local model_options_schema = { type = "string", required = false }}, { gemini = gemini_options_schema }, + { bedrock = bedrock_options_schema }, } } @@ -157,7 +187,7 @@ local model_schema = { type = "string", description = "AI provider request format - Kong translates " .. "requests to and from the specified backend compatible formats.", required = true, - one_of = { "openai", "azure", "anthropic", "cohere", "mistral", "llama2", "gemini" }}}, + one_of = { "openai", "azure", "anthropic", "cohere", "mistral", "llama2", "gemini", "bedrock" }}}, { name = { type = "string", description = "Model name to execute.", diff --git a/kong/plugins/ai-proxy/handler.lua b/kong/plugins/ai-proxy/handler.lua index 5ff894c5e054..bc7288d30075 100644 --- a/kong/plugins/ai-proxy/handler.lua +++ b/kong/plugins/ai-proxy/handler.lua @@ -12,6 +12,11 @@ local GCP_SERVICE_ACCOUNT do end local GCP = require("resty.gcp.request.credentials.accesstoken") +local aws_config = require "resty.aws.config" -- reads environment variables whilst available +local AWS = require("resty.aws") +local AWS_REGION do + AWS_REGION = os.getenv("AWS_REGION") or os.getenv("AWS_DEFAULT_REGION") +end -- @@ -48,6 +53,44 @@ local _KEYBASTION = setmetatable({}, { end return { interface = nil, error = "cloud-authentication with GCP failed" } + + elseif plugin_config.model.provider == "bedrock" then + ngx.log(ngx.NOTICE, "loading aws sdk for plugin ", kong.plugin.get_id()) + local aws + + local region = plugin_config.model.options + and plugin_config.model.options.bedrock + and plugin_config.model.options.bedrock.aws_region + or AWS_REGION + + if not region then + return { interface = nil, error = "AWS region not specified anywhere" } + end + + local access_key_set = (plugin_config.auth and plugin_config.auth.aws_access_key_id) + or aws_config.global.AWS_ACCESS_KEY_ID + local secret_key_set = plugin_config.auth and plugin_config.auth.aws_secret_access_key + or aws_config.global.AWS_SECRET_ACCESS_KEY + + aws = AWS({ + -- if any of these are nil, they either use the SDK default or + -- are deliberately null so that a different auth chain is used + region = region, + }) + + if access_key_set and secret_key_set then + -- Override credential config according to plugin config, if set + local creds = aws:Credentials { + accessKeyId = access_key_set, + secretAccessKey = secret_key_set, + } + + aws.config.credentials = creds + end + + this_cache[plugin_config] = { interface = aws, error = nil } + + return this_cache[plugin_config] end end, }) @@ -99,8 +142,7 @@ local function handle_streaming_frame(conf) chunk = kong_utils.inflate_gzip(ngx.arg[1]) end - local is_raw_json = conf.model.provider == "gemini" - local events = ai_shared.frame_to_events(chunk, is_raw_json ) + local events = ai_shared.frame_to_events(chunk, conf.model.provider) if not events then -- usually a not-supported-transformer or empty frames. @@ -142,7 +184,7 @@ local function handle_streaming_frame(conf) local err if formatted then -- only stream relevant frames back to the user - if conf.logging and conf.logging.log_payloads and (formatted ~= "[DONE]") then + if conf.logging and conf.logging.log_payloads and (formatted ~= ai_shared._CONST.SSE_TERMINATOR) then -- append the "choice" to the buffer, for logging later. this actually works! if not event_t then event_t, err = cjson.decode(formatted) @@ -160,7 +202,7 @@ local function handle_streaming_frame(conf) -- handle event telemetry if conf.logging and conf.logging.log_statistics then if not ai_shared.streaming_has_token_counts[conf.model.provider] then - if formatted ~= "[DONE]" then + if formatted ~= ai_shared._CONST.SSE_TERMINATOR then if not event_t then event_t, err = cjson.decode(formatted) end @@ -183,18 +225,25 @@ local function handle_streaming_frame(conf) framebuffer:put("data: ") framebuffer:put(formatted or "") - framebuffer:put((formatted ~= "[DONE]") and "\n\n" or "") + framebuffer:put((formatted ~= ai_shared._CONST.SSE_TERMINATOR) and "\n\n" or "") end if conf.logging and conf.logging.log_statistics and metadata then - kong_ctx_plugin.ai_stream_completion_tokens = - (kong_ctx_plugin.ai_stream_completion_tokens or 0) + - (metadata.completion_tokens or 0) - or kong_ctx_plugin.ai_stream_completion_tokens - kong_ctx_plugin.ai_stream_prompt_tokens = - (kong_ctx_plugin.ai_stream_prompt_tokens or 0) + - (metadata.prompt_tokens or 0) - or kong_ctx_plugin.ai_stream_prompt_tokens + -- gemini metadata specifically, works differently + if conf.model.provider == "gemini" then + print(metadata.completion_tokens) + kong_ctx_plugin.ai_stream_completion_tokens = metadata.completion_tokens or 0 + kong_ctx_plugin.ai_stream_prompt_tokens = metadata.prompt_tokens or 0 + else + kong_ctx_plugin.ai_stream_completion_tokens = + (kong_ctx_plugin.ai_stream_completion_tokens or 0) + + (metadata.completion_tokens or 0) + or kong_ctx_plugin.ai_stream_completion_tokens + kong_ctx_plugin.ai_stream_prompt_tokens = + (kong_ctx_plugin.ai_stream_prompt_tokens or 0) + + (metadata.prompt_tokens or 0) + or kong_ctx_plugin.ai_stream_prompt_tokens + end end end end @@ -300,8 +349,10 @@ function _M:body_filter(conf) if kong_ctx_shared.skip_response_transformer and (route_type ~= "preserve") then local response_body + if kong_ctx_shared.parsed_response then response_body = kong_ctx_shared.parsed_response + elseif kong.response.get_status() == 200 then response_body = kong.service.response.get_raw_body() if not response_body then @@ -320,6 +371,7 @@ function _M:body_filter(conf) if err then kong.log.warn("issue when transforming the response body for analytics in the body filter phase, ", err) + elseif new_response_string then ai_shared.post_request(conf, new_response_string) end diff --git a/kong/tools/aws_stream.lua b/kong/tools/aws_stream.lua new file mode 100644 index 000000000000..ebefc2c26566 --- /dev/null +++ b/kong/tools/aws_stream.lua @@ -0,0 +1,181 @@ +--- Stream class. +-- Decodes AWS response-stream types, currently application/vnd.amazon.eventstream +-- @classmod Stream + +local buf = require("string.buffer") +local to_hex = require("resty.string").to_hex + +local Stream = {} +Stream.__index = Stream + + +local _HEADER_EXTRACTORS = { + -- bool true + [0] = function(stream) + return true, 0 + end, + + -- bool false + [1] = function(stream) + return false, 0 + end, + + -- string type + [7] = function(stream) + local header_value_len = stream:next_int(16) + return stream:next_utf_8(header_value_len), header_value_len + 2 -- add the 2 bits read for the length + end, + + -- TODO ADD THE REST OF THE DATA TYPES + -- EVEN THOUGH THEY'RE NOT REALLY USED +} + +--- Constructor. +-- @function aws:Stream +-- @param chunk string complete AWS response stream chunk for decoding +-- @param is_hex boolean specify if the chunk bytes are already decoded to hex +-- @usage +-- local stream_parser = stream:new("00000120af0310f.......", true) +-- local next, err = stream_parser:next_message() +function Stream:new(chunk, is_hex) + local self = {} -- override 'self' to be the new object/class + setmetatable(self, Stream) + + if #chunk < ((is_hex and 32) or 16) then + return nil, "cannot parse a chunk less than 16 bytes long" + end + + self.read_count = 0 + self.chunk = buf.new() + self.chunk:put((is_hex and chunk) or to_hex(chunk)) + + return self +end + + +--- return the next `count` ascii bytes from the front of the chunk +--- and then trims the chunk of those bytes +-- @param count number whole utf-8 bytes to return +-- @return string resulting utf-8 string +function Stream:next_utf_8(count) + local utf_bytes = self:next_bytes(count) + + local ascii_string = "" + for i = 1, #utf_bytes, 2 do + local hex_byte = utf_bytes:sub(i, i + 1) + local ascii_byte = string.char(tonumber(hex_byte, 16)) + ascii_string = ascii_string .. ascii_byte + end + return ascii_string +end + +--- returns the next `count` bytes from the front of the chunk +--- and then trims the chunk of those bytes +-- @param count number whole integer of bytes to return +-- @return string hex-encoded next `count` bytes +function Stream:next_bytes(count) + if not self.chunk then + return nil, "function cannot be called on its own - initialise a chunk reader with :new(chunk)" + end + + local bytes = self.chunk:get(count * 2) + self.read_count = (count) + self.read_count + + return bytes +end + +--- returns the next unsigned int from the front of the chunk +--- and then trims the chunk of those bytes +-- @param size integer bit length (8, 16, 32, etc) +-- @return number whole integer of size specified +-- @return string the original bytes, for reference/checksums +function Stream:next_int(size) + if not self.chunk then + return nil, nil, "function cannot be called on its own - initialise a chunk reader with :new(chunk)" + end + + if size < 8 then + return nil, nil, "cannot work on integers smaller than 8 bits long" + end + + local int, err = self:next_bytes(size / 8) + if err then + return nil, nil, err + end + + return tonumber(int, 16), int +end + +--- returns the next message in the chunk, as a table. +--- can be used as an iterator. +-- @return table formatted next message from the given constructor chunk +function Stream:next_message() + if not self.chunk then + return nil, "function cannot be called on its own - initialise a chunk reader with :new(chunk)" + end + + if #self.chunk < 1 then + return false + end + + -- get the message length and pull that many bytes + -- + -- this is a chicken and egg problem, because we need to + -- read the message to get the length, to then re-read the + -- whole message at correct offset + local msg_len, _, err = self:next_int(32) + if err then + return err + end + + -- get the headers length + local headers_len, _, err = self:next_int(32) + if err then + return err + end + + -- get the preamble checksum + -- skip it because we're not using UDP + self:next_int(32) + + -- pull the headers from the buf + local headers = {} + local headers_bytes_read = 0 + + while headers_bytes_read < headers_len do + -- the next 8-bit int is the "header key length" + local header_key_len = self:next_int(8) + local header_key = self:next_utf_8(header_key_len) + headers_bytes_read = 1 + header_key_len + headers_bytes_read + + -- next 8-bits is the header type, which is an enum + local header_type = self:next_int(8) + headers_bytes_read = 1 + headers_bytes_read + + -- depending on the header type, depends on how long the header should max out at + local header_value, header_value_len = _HEADER_EXTRACTORS[header_type](self) + headers_bytes_read = header_value_len + headers_bytes_read + + headers[header_key] = header_value + end + + -- finally, extract the body as a string by + -- subtracting what's read so far from the + -- total length obtained right at the start + local body = self:next_utf_8(msg_len - self.read_count - 4) + + -- last 4 bytes is a body checksum + -- skip it because we're not using UDP + self:next_int(32) + + + -- rewind the tape + self.read_count = 0 + + return { + headers = headers, + body = body, + } +end + +return Stream \ No newline at end of file diff --git a/spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua b/spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua index 955f1d73681a..a6844b92e493 100644 --- a/spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua +++ b/spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua @@ -482,7 +482,7 @@ describe("CP/DP config compat transformations #" .. strategy, function() end) describe("ai plugins supported providers", function() - it("[ai-proxy] tries to use unsupported gemini on older Kong versions", function() + it("[ai-proxy] tries to use unsupported providers on older Kong versions", function() -- [[ 3.8.x ]] -- local ai_proxy = admin.plugins:insert { name = "ai-proxy", @@ -516,10 +516,20 @@ describe("CP/DP config compat transformations #" .. strategy, function() local expected = cycle_aware_deep_copy(ai_proxy) + -- max body size expected.config.max_request_body_size = nil + + -- gemini fields expected.config.auth.gcp_service_account_json = nil expected.config.auth.gcp_use_service_account = nil expected.config.model.options.gemini = nil + + -- bedrock fields + expected.config.auth.aws_access_key_id = nil + expected.config.auth.aws_secret_access_key = nil + expected.config.model.options.bedrock = nil + + -- 'ai fallback' field sets expected.config.route_type = "preserve" expected.config.model.provider = "openai" @@ -535,7 +545,7 @@ describe("CP/DP config compat transformations #" .. strategy, function() admin.plugins:remove({ id = ai_proxy.id }) end) - it("[ai-request-transformer] tries to use unsupported gemini on older Kong versions", function() + it("[ai-request-transformer] tries to use unsupported providers on older Kong versions", function() -- [[ 3.8.x ]] -- local ai_request_transformer = admin.plugins:insert { name = "ai-request-transformer", @@ -571,10 +581,20 @@ describe("CP/DP config compat transformations #" .. strategy, function() local expected = cycle_aware_deep_copy(ai_request_transformer) + -- max body size expected.config.max_request_body_size = nil + + -- gemini fields expected.config.llm.auth.gcp_service_account_json = nil expected.config.llm.auth.gcp_use_service_account = nil expected.config.llm.model.options.gemini = nil + + -- bedrock fields + expected.config.llm.auth.aws_access_key_id = nil + expected.config.llm.auth.aws_secret_access_key = nil + expected.config.llm.model.options.bedrock = nil + + -- 'ai fallback' field sets expected.config.llm.model.provider = "openai" do_assert(uuid(), "3.7.0", expected) @@ -588,7 +608,7 @@ describe("CP/DP config compat transformations #" .. strategy, function() admin.plugins:remove({ id = ai_request_transformer.id }) end) - it("[ai-response-transformer] tries to use unsupported gemini on older Kong versions", function() + it("[ai-response-transformer] tries to use unsupported providers on older Kong versions", function() -- [[ 3.8.x ]] -- local ai_response_transformer = admin.plugins:insert { name = "ai-response-transformer", @@ -624,10 +644,20 @@ describe("CP/DP config compat transformations #" .. strategy, function() local expected = cycle_aware_deep_copy(ai_response_transformer) + -- max body size expected.config.max_request_body_size = nil + + -- gemini fields expected.config.llm.auth.gcp_service_account_json = nil expected.config.llm.auth.gcp_use_service_account = nil expected.config.llm.model.options.gemini = nil + + -- bedrock fields + expected.config.llm.auth.aws_access_key_id = nil + expected.config.llm.auth.aws_secret_access_key = nil + expected.config.llm.model.options.bedrock = nil + + -- 'ai fallback' field sets expected.config.llm.model.provider = "openai" do_assert(uuid(), "3.7.0", expected) @@ -671,11 +701,19 @@ describe("CP/DP config compat transformations #" .. strategy, function() local expected = cycle_aware_deep_copy(ai_proxy) + -- max body size expected.config.max_request_body_size = nil + + -- gemini fields expected.config.auth.gcp_service_account_json = nil expected.config.auth.gcp_use_service_account = nil expected.config.model.options.gemini = nil + -- bedrock fields + expected.config.auth.aws_access_key_id = nil + expected.config.auth.aws_secret_access_key = nil + expected.config.model.options.bedrock = nil + do_assert(uuid(), "3.7.0", expected) expected.config.response_streaming = nil @@ -720,11 +758,20 @@ describe("CP/DP config compat transformations #" .. strategy, function() -- ]] local expected = cycle_aware_deep_copy(ai_request_transformer) + + -- max body size expected.config.max_request_body_size = nil + + -- gemini fields expected.config.llm.auth.gcp_service_account_json = nil expected.config.llm.auth.gcp_use_service_account = nil expected.config.llm.model.options.gemini = nil + -- bedrock fields + expected.config.llm.auth.aws_access_key_id = nil + expected.config.llm.auth.aws_secret_access_key = nil + expected.config.llm.model.options.bedrock = nil + do_assert(uuid(), "3.7.0", expected) expected.config.llm.model.options.upstream_path = nil @@ -765,11 +812,20 @@ describe("CP/DP config compat transformations #" .. strategy, function() --]] local expected = cycle_aware_deep_copy(ai_response_transformer) + + -- max body size expected.config.max_request_body_size = nil + + -- gemini fields expected.config.llm.auth.gcp_service_account_json = nil expected.config.llm.auth.gcp_use_service_account = nil expected.config.llm.model.options.gemini = nil + -- bedrock fields + expected.config.llm.auth.aws_access_key_id = nil + expected.config.llm.auth.aws_secret_access_key = nil + expected.config.llm.model.options.bedrock = nil + do_assert(uuid(), "3.7.0", expected) expected.config.llm.model.options.upstream_path = nil diff --git a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua index aeb42600d639..009f079195d0 100644 --- a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua +++ b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua @@ -237,6 +237,20 @@ local FORMATS = { }, }, }, + bedrock = { + ["llm/v1/chat"] = { + config = { + name = "bedrock", + provider = "bedrock", + options = { + max_tokens = 8192, + temperature = 0.8, + top_k = 1, + top_p = 0.6, + }, + }, + }, + }, } local STREAMS = { @@ -664,7 +678,7 @@ describe(PLUGIN_NAME .. ": (unit)", function() it("transforms truncated-json type (beginning of stream)", function() local input = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-beginning/input.bin")) - local events = ai_shared.frame_to_events(input, true) + local events = ai_shared.frame_to_events(input, "gemini") local expected = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-beginning/expected-output.json")) local expected_events = cjson.decode(expected) @@ -674,7 +688,7 @@ describe(PLUGIN_NAME .. ": (unit)", function() it("transforms truncated-json type (end of stream)", function() local input = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-end/input.bin")) - local events = ai_shared.frame_to_events(input, true) + local events = ai_shared.frame_to_events(input, "gemini") local expected = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-end/expected-output.json")) local expected_events = cjson.decode(expected) @@ -684,7 +698,7 @@ describe(PLUGIN_NAME .. ": (unit)", function() it("transforms complete-json type", function() local input = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/complete-json/input.bin")) - local events = ai_shared.frame_to_events(input, false) -- not "truncated json mode" like Gemini + local events = ai_shared.frame_to_events(input, "cohere") -- not "truncated json mode" like Gemini local expected = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/complete-json/expected-output.json")) local expected_events = cjson.decode(expected) @@ -694,7 +708,7 @@ describe(PLUGIN_NAME .. ": (unit)", function() it("transforms text/event-stream type", function() local input = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/text-event-stream/input.bin")) - local events = ai_shared.frame_to_events(input, false) -- not "truncated json mode" like Gemini + local events = ai_shared.frame_to_events(input, "openai") -- not "truncated json mode" like Gemini local expected = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/text-event-stream/expected-output.json")) local expected_events = cjson.decode(expected) @@ -702,6 +716,20 @@ describe(PLUGIN_NAME .. ": (unit)", function() assert.same(events, expected_events) end) + it("transforms application/vnd.amazon.eventstream (AWS) type", function() + local input = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/aws/input.bin")) + local events = ai_shared.frame_to_events(input, "bedrock") + + local expected = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/streaming-chunk-formats/aws/expected-output.json")) + local expected_events = cjson.decode(expected) + + assert.equal(#events, #expected_events) + for i, _ in ipairs(expected_events) do + -- tables are random ordered, so we need to compare each serialized event + assert.same(cjson.decode(events[i].data), cjson.decode(expected_events[i].data)) + end + end) + end) end) diff --git a/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua b/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua index b67d815fa07e..b1cd81295026 100644 --- a/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua @@ -902,12 +902,12 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good.json"), }) - + -- check we got internal server error local body = assert.res_status(500 , r) local json = cjson.decode(body) assert.is_truthy(json.error) - assert.equals(json.error.message, "transformation failed from type openai://llm/v1/chat: 'choices' not in llm/v1/chat response") + assert.same(json.error.message, "transformation failed from type openai://llm/v1/chat: 'choices' not in llm/v1/chat response") end) it("bad request", function() diff --git a/spec/fixtures/ai-proxy/unit/expected-requests/bedrock/llm-v1-chat.json b/spec/fixtures/ai-proxy/unit/expected-requests/bedrock/llm-v1-chat.json new file mode 100644 index 000000000000..ad68f6b28338 --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/expected-requests/bedrock/llm-v1-chat.json @@ -0,0 +1,55 @@ +{ + "system": [ + { + "text": "You are a mathematician." + } + ], + "messages": [ + { + "content": [ + { + "text": "What is 1 + 2?" + } + ], + "role": "user" + }, + { + "content": [ + { + "text": "The sum of 1 + 2 is 3. If you have any more math questions or if there's anything else I can help you with, feel free to ask!" + } + ], + "role": "assistant" + }, + { + "content": [ + { + "text": "Multiply that by 2" + } + ], + "role": "user" + }, + { + "content": [ + { + "text": "Certainly! If you multiply 3 by 2, the result is 6. If you have any more questions or if there's anything else I can help you with, feel free to ask!" + } + ], + "role": "assistant" + }, + { + "content": [ + { + "text": "Why can't you divide by zero?" + } + ], + "role": "user" + } + ], + "inferenceConfig": { + "maxTokens": 8192, + "temperature": 0.8, + "topP": 0.6 + }, + "anthropic_version": "bedrock-2023-05-31" +} \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/expected-responses/bedrock/llm-v1-chat.json b/spec/fixtures/ai-proxy/unit/expected-responses/bedrock/llm-v1-chat.json new file mode 100644 index 000000000000..948d3fb47465 --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/expected-responses/bedrock/llm-v1-chat.json @@ -0,0 +1,19 @@ +{ + "choices": [ + { + "finish_reason": "end_turn", + "index": 0, + "message": { + "content": "You cannot divide by zero because it is not a valid operation in mathematics.", + "role": "assistant" + } + } + ], + "object": "chat.completion", + "usage": { + "completion_tokens": 119, + "prompt_tokens": 19, + "total_tokens": 138 + }, + "model": "bedrock" +} \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/real-responses/bedrock/llm-v1-chat.json b/spec/fixtures/ai-proxy/unit/real-responses/bedrock/llm-v1-chat.json new file mode 100644 index 000000000000..e995bbd984d1 --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/real-responses/bedrock/llm-v1-chat.json @@ -0,0 +1,21 @@ +{ + "metrics": { + "latencyMs": 14767 + }, + "output": { + "message": { + "content": [ + { + "text": "You cannot divide by zero because it is not a valid operation in mathematics." + } + ], + "role": "assistant" + } + }, + "stopReason": "end_turn", + "usage": { + "completion_tokens": 119, + "prompt_tokens": 19, + "total_tokens": 138 + } +} \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/real-responses/gemini/llm-v1-chat.json b/spec/fixtures/ai-proxy/unit/real-responses/gemini/llm-v1-chat.json index 80781b6eb72a..96933d9835e6 100644 --- a/spec/fixtures/ai-proxy/unit/real-responses/gemini/llm-v1-chat.json +++ b/spec/fixtures/ai-proxy/unit/real-responses/gemini/llm-v1-chat.json @@ -1,34 +1,39 @@ { - "candidates": [ - { - "content": { - "parts": [ - { - "text": "Ah, vous voulez savoir le double de ce résultat ? Eh bien, le double de 2 est **4**. \n" - } - ], - "role": "model" - }, - "finishReason": "STOP", - "index": 0, - "safetyRatings": [ - { - "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", - "probability": "NEGLIGIBLE" - }, - { - "category": "HARM_CATEGORY_HATE_SPEECH", - "probability": "NEGLIGIBLE" - }, - { - "category": "HARM_CATEGORY_HARASSMENT", - "probability": "NEGLIGIBLE" - }, + "candidates": [ + { + "content": { + "parts": [ { - "category": "HARM_CATEGORY_DANGEROUS_CONTENT", - "probability": "NEGLIGIBLE" + "text": "Ah, vous voulez savoir le double de ce résultat ? Eh bien, le double de 2 est **4**. \n" } - ] - } - ] - } \ No newline at end of file + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 14, + "candidatesTokenCount": 128, + "totalTokenCount": 142 + } +} diff --git a/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/aws/expected-output.json b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/aws/expected-output.json new file mode 100644 index 000000000000..8761c5593608 --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/aws/expected-output.json @@ -0,0 +1,20 @@ +[ + { + "data": "{\"body\":\"{\\\"p\\\":\\\"abcdefghijkl\\\",\\\"role\\\":\\\"assistant\\\"}\",\"headers\":{\":event-type\":\"messageStart\",\":content-type\":\"application\/json\",\":message-type\":\"event\"}}" + }, + { + "data": "{\"body\":\"{\\\"contentBlockIndex\\\":0,\\\"delta\\\":{\\\"text\\\":\\\"Hello! Relativity is a set of physical theories that are collectively known as special relativity and general relativity, proposed by Albert Einstein. These theories revolutionized our understanding of space, time, and gravity, and have had far-reach\\\"},\\\"p\\\":\\\"abcd\\\"}\",\"headers\":{\":event-type\":\"contentBlockDelta\",\":content-type\":\"application\\/json\",\":message-type\":\"event\"}}" + }, + { + "data": "{\"headers\":{\":event-type\":\"contentBlockDelta\",\":message-type\":\"event\",\":content-type\":\"application\\/json\"},\"body\":\"{\\\"contentBlockIndex\\\":0,\\\"delta\\\":{\\\"text\\\":\\\"ing implications in various scientific and technological fields. Special relativity applies to all physical phenomena in the absence of gravity, while general relativity explains the law of gravity and its effects on the nature of space, time, and matter.\\\"},\\\"p\\\":\\\"abcdefghijk\\\"}\"}" + }, + { + "data": "{\"body\":\"{\\\"contentBlockIndex\\\":0,\\\"p\\\":\\\"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQR\\\"}\",\"headers\":{\":content-type\":\"application\\/json\",\":event-type\":\"contentBlockStop\",\":message-type\":\"event\"}}" + }, + { + "data": "{\"body\":\"{\\\"p\\\":\\\"abcdefghijklm\\\",\\\"stopReason\\\":\\\"end_turn\\\"}\",\"headers\":{\":message-type\":\"event\",\":content-type\":\"application\\/json\",\":event-type\":\"messageStop\"}}" + }, + { + "data": "{\"headers\":{\":message-type\":\"event\",\":content-type\":\"application\\/json\",\":event-type\":\"metadata\"},\"body\":\"{\\\"metrics\\\":{\\\"latencyMs\\\":2613},\\\"p\\\":\\\"abcdefghijklmnopqrstuvwxyzABCDEF\\\",\\\"usage\\\":{\\\"inputTokens\\\":9,\\\"outputTokens\\\":97,\\\"totalTokens\\\":106}}\"}" + } +] \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/aws/input.bin b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/aws/input.bin new file mode 100644 index 000000000000..8f9d03b4f7e0 Binary files /dev/null and b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/aws/input.bin differ diff --git a/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-end/expected-output.json b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-end/expected-output.json index ba6a64384d95..f35aaf6f9dba 100644 --- a/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-end/expected-output.json +++ b/spec/fixtures/ai-proxy/unit/streaming-chunk-formats/partial-json-end/expected-output.json @@ -4,5 +4,8 @@ }, { "data": "\n{\n \"candidates\": [\n {\n \"content\": {\n \"parts\": [\n {\n \"text\": \" not a limit.\\n\\nIf you're interested in learning more about relativity, I encourage you to explore further resources online or in books. There are many excellent introductory materials available. \\n\"\n }\n ],\n \"role\": \"model\"\n },\n \"finishReason\": \"STOP\",\n \"index\": 0,\n \"safetyRatings\": [\n {\n \"category\": \"HARM_CATEGORY_SEXUALLY_EXPLICIT\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_HATE_SPEECH\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_HARASSMENT\",\n \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": \"HARM_CATEGORY_DANGEROUS_CONTENT\",\n \"probability\": \"NEGLIGIBLE\"\n }\n ]\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\": 6,\n \"candidatesTokenCount\": 547,\n \"totalTokenCount\": 553\n }\n}\n" + }, + { + "data": "[DONE]" } ] \ No newline at end of file