From 3369e3be3eb3f32b867fe2ebb0e75482761fe481 Mon Sep 17 00:00:00 2001 From: Yuen Sze Hong <40477634+YuenSzeHong@users.noreply.github.com> Date: Thu, 9 Jan 2025 04:13:13 +0000 Subject: [PATCH] =?UTF-8?q?Fix:=20Flash=202-0=20doesn=E2=80=99t=20respect?= =?UTF-8?q?=20BLOCK=5FNONE=20on=20ALL=20harm=20categories?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/api_proxy/worker.mjs | 48 ++++++++++++++++++++++------------------ 1 file changed, 27 insertions(+), 21 deletions(-) diff --git a/src/api_proxy/worker.mjs b/src/api_proxy/worker.mjs index dfce6f3..f4c259f 100644 --- a/src/api_proxy/worker.mjs +++ b/src/api_proxy/worker.mjs @@ -5,7 +5,7 @@ import { Buffer } from "node:buffer"; export default { - async fetch (request) { + async fetch(request) { if (request.method === "OPTIONS") { return handleOPTIONS(); } @@ -79,7 +79,7 @@ const makeHeaders = (apiKey, more) => ({ ...more }); -async function handleModels (apiKey) { +async function handleModels(apiKey) { const response = await fetch(`${BASE_URL}/${API_VERSION}/models`, { headers: makeHeaders(apiKey), }); @@ -100,12 +100,12 @@ async function handleModels (apiKey) { } const DEFAULT_EMBEDDINGS_MODEL = "text-embedding-004"; -async function handleEmbeddings (req, apiKey) { +async function handleEmbeddings(req, apiKey) { if (typeof req.model !== "string") { throw new HttpError("model is not specified", 400); } if (!Array.isArray(req.input)) { - req.input = [ req.input ]; + req.input = [req.input]; } let model; if (req.model.startsWith("models/")) { @@ -142,9 +142,9 @@ async function handleEmbeddings (req, apiKey) { } const DEFAULT_MODEL = "gemini-1.5-pro-latest"; -async function handleCompletions (req, apiKey) { +async function handleCompletions(req, apiKey) { let model = DEFAULT_MODEL; - switch(true) { + switch (true) { case typeof req.model !== "string": break; case req.model.startsWith("models/"): @@ -196,10 +196,15 @@ const harmCategory = [ "HARM_CATEGORY_HARASSMENT", "HARM_CATEGORY_CIVIC_INTEGRITY", ]; -const safetySettings = harmCategory.map(category => ({ - category, - threshold: "BLOCK_NONE", -})); + +const safetySettings = (model) => { + const threshold = modelName?.includes('2.0') ? 'OFF' : 'BLOCK_NONE'; + return harmCategory.map(category => ({ + category, + threshold + })); +} + const fieldsMap = { stop: "stopSequences", n: "candidateCount", // not for streaming @@ -221,14 +226,14 @@ const transformConfig = (req) => { } } if (req.response_format) { - switch(req.response_format.type) { + switch (req.response_format.type) { case "json_schema": cfg.responseSchema = req.response_format.json_schema?.schema; if (cfg.responseSchema && "enum" in cfg.responseSchema) { cfg.responseMimeType = "text/x.enum"; break; } - // eslint-disable-next-line no-fallthrough + // eslint-disable-next-line no-fallthrough case "json_object": cfg.responseMimeType = "application/json"; break; @@ -330,7 +335,7 @@ const transformMessages = async (messages) => { const transformRequest = async (req) => ({ ...await transformMessages(req.messages), - safetySettings, + safetySettings: safetySettings(req.model), generationConfig: transformConfig(req), }); @@ -354,7 +359,8 @@ const transformCandidates = (key, cand) => ({ index: cand.index || 0, // 0-index is absent in new -002 models response [key]: { role: "assistant", - content: cand.content?.parts.map(p => p.text).join(SEP) }, + content: cand.content?.parts.map(p => p.text).join(SEP) + }, logprobs: null, finish_reason: reasonsMap[cand.finishReason] || cand.finishReason, }); @@ -371,7 +377,7 @@ const processCompletionsResponse = (data, model, id) => { return JSON.stringify({ id, choices: data.candidates.map(transformCandidatesMessage), - created: Math.floor(Date.now()/1000), + created: Math.floor(Date.now() / 1000), model, //system_fingerprint: "fp_69829325d0", object: "chat.completion", @@ -380,7 +386,7 @@ const processCompletionsResponse = (data, model, id) => { }; const responseLineRE = /^data: (.*)(?:\n\n|\r\r|\r\n\r\n)/; -async function parseStream (chunk, controller) { +async function parseStream(chunk, controller) { chunk = await chunk; if (!chunk) { return; } this.buffer += chunk; @@ -391,21 +397,21 @@ async function parseStream (chunk, controller) { this.buffer = this.buffer.substring(match[0].length); } while (true); // eslint-disable-line no-constant-condition } -async function parseStreamFlush (controller) { +async function parseStreamFlush(controller) { if (this.buffer) { console.error("Invalid data:", this.buffer); controller.enqueue(this.buffer); } } -function transformResponseStream (data, stop, first) { +function transformResponseStream(data, stop, first) { const item = transformCandidatesDelta(data.candidates[0]); if (stop) { item.delta = {}; } else { item.finish_reason = null; } if (first) { item.delta.content = ""; } else { delete item.delta.role; } const output = { id: this.id, choices: [item], - created: Math.floor(Date.now()/1000), + created: Math.floor(Date.now() / 1000), model: this.model, //system_fingerprint: "fp_69829325d0", object: "chat.completion.chunk", @@ -416,7 +422,7 @@ function transformResponseStream (data, stop, first) { return "data: " + JSON.stringify(output) + delimiter; } const delimiter = "\n\n"; -async function toOpenAiStream (chunk, controller) { +async function toOpenAiStream(chunk, controller) { const transform = transformResponseStream.bind(this); const line = await chunk; if (!line) { return; } @@ -445,7 +451,7 @@ async function toOpenAiStream (chunk, controller) { controller.enqueue(transform(data)); } } -async function toOpenAiStreamFlush (controller) { +async function toOpenAiStreamFlush(controller) { const transform = transformResponseStream.bind(this); if (this.last.length > 0) { for (const data of this.last) {