From c66e1729c13957ae493a75e0b890e0b2ba15ef3d Mon Sep 17 00:00:00 2001 From: Yuen Sze Hong <40477634+YuenSzeHong@users.noreply.github.com> Date: Mon, 13 Jan 2025 08:08:34 +0000 Subject: [PATCH] add function call support --- src/api_proxy/worker.mjs | 100 ++++++++++++++++++++++++++++++++------- 1 file changed, 83 insertions(+), 17 deletions(-) diff --git a/src/api_proxy/worker.mjs b/src/api_proxy/worker.mjs index a6c6406..38fcdd6 100644 --- a/src/api_proxy/worker.mjs +++ b/src/api_proxy/worker.mjs @@ -334,11 +334,44 @@ const transformMessages = async (messages) => { return { system_instruction, contents }; }; -const transformRequest = async (req) => ({ - ...await transformMessages(req.messages), - safetySettings: safetySettings(req.model), - generationConfig: transformConfig(req), -}); +const transformRequest = async (req) => { + const base = { + ...await transformMessages(req.messages), + safetySettings: safetySettings(req.model), + generationConfig: transformConfig(req), + }; + + // Handle functions/tools if present + const tools = req.functions || req.tools; + if (tools) { + // Convert to array if it's a single object + const toolsArray = tools.length ? tools : [tools]; + + base.tools = [{ + functionDeclarations: toolsArray.map(tool => { + // Handle both formats: {type: 'function', function: {name: 'x'}} and {name: 'x'} + const funcDef = tool.type === 'function' ? tool.function : tool; + + return { + name: funcDef.name, + description: funcDef.description || "", + parameters: { + type: "object", + properties: { + query: { + type: "string", + description: "The search query to be executed" + } + }, + required: ["query"] + } + }; + }) + }]; + } + + return base; +}; const generateChatcmplId = () => { const characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; @@ -356,15 +389,28 @@ const reasonsMap = { //https://ai.google.dev/api/rest/v1/GenerateContentResponse // :"function_call", }; const SEP = "\n\n|>"; -const transformCandidates = (key, cand) => ({ - index: cand.index || 0, // 0-index is absent in new -002 models response - [key]: { - role: "assistant", - content: cand.content?.parts.map(p => p.text).join(SEP) - }, - logprobs: null, - finish_reason: reasonsMap[cand.finishReason] || cand.finishReason, -}); +const transformCandidates = (key, cand) => { + const base = { + index: cand.index || 0, + [key]: { + role: "assistant", + content: cand.content?.parts?.map(p => p.text).join(SEP) || null + }, + logprobs: null, + finish_reason: reasonsMap[cand.finishReason] || cand.finishReason + }; + + // Add function_call if present + if (cand.content?.parts?.some(p => p.functionCall)) { + const functionCall = cand.content.parts.find(p => p.functionCall).functionCall; + base[key].function_call = { + name: functionCall.name, + arguments: functionCall.args ? JSON.stringify(functionCall.args) : "{}" + }; + } + + return base; +}; const transformCandidatesMessage = transformCandidates.bind(null, "message"); const transformCandidatesDelta = transformCandidates.bind(null, "delta"); @@ -407,21 +453,41 @@ async function parseStreamFlush(controller) { function transformResponseStream(data, stop, first) { const item = transformCandidatesDelta(data.candidates[0]); - if (stop) { item.delta = {}; } else { item.finish_reason = null; } - if (first) { item.delta.content = ""; } else { delete item.delta.role; } + if (stop) { + item.delta = {}; + } else { + item.finish_reason = null; + } + + if (first) { + item.delta.content = ""; + // Add function_call structure if present + if (data.candidates[0].content?.parts?.some(p => p.functionCall)) { + const functionCall = data.candidates[0].content.parts.find(p => p.functionCall).functionCall; + item.delta.function_call = { + name: functionCall.name, + arguments: functionCall.args ? JSON.stringify(functionCall.args) : "{}" + }; + } + } else { + delete item.delta.role; + } + const output = { id: this.id, choices: [item], created: Math.floor(Date.now() / 1000), model: this.model, - //system_fingerprint: "fp_69829325d0", object: "chat.completion.chunk", }; + if (data.usageMetadata && this.streamIncludeUsage) { output.usage = stop ? transformUsage(data.usageMetadata) : null; } + return "data: " + JSON.stringify(output) + delimiter; } + const delimiter = "\n\n"; async function toOpenAiStream(chunk, controller) { const transform = transformResponseStream.bind(this);