From 801b62543a909e2cf6c71d635e49f839012723fa Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Mon, 2 Sep 2024 21:45:47 +0800 Subject: [PATCH] claude support function call --- app/api/auth.ts | 1 + app/client/platforms/anthropic.ts | 228 ++++++++++++++++-------------- app/client/platforms/openai.ts | 2 +- app/components/chat.tsx | 15 +- app/utils.ts | 11 ++ app/utils/chat.ts | 2 +- 6 files changed, 146 insertions(+), 113 deletions(-) diff --git a/app/api/auth.ts b/app/api/auth.ts index 95965ceec2d..c8fa6787d9a 100644 --- a/app/api/auth.ts +++ b/app/api/auth.ts @@ -38,6 +38,7 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) { console.log("[Auth] hashed access code:", hashedCode); console.log("[User IP] ", getIP(req)); console.log("[Time] ", new Date().toLocaleString()); + console.log("[ModelProvider] ", modelProvider); if (serverConfig.needCode && !serverConfig.codes.has(hashedCode) && !apiKey) { return { diff --git a/app/client/platforms/anthropic.ts b/app/client/platforms/anthropic.ts index b079ba1ada2..91434ffcc06 100644 --- a/app/client/platforms/anthropic.ts +++ b/app/client/platforms/anthropic.ts @@ -1,6 +1,12 @@ import { ACCESS_CODE_PREFIX, Anthropic, ApiPath } from "@/app/constant"; import { ChatOptions, getHeaders, LLMApi, MultimodalContent } from "../api"; -import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; +import { + useAccessStore, + useAppConfig, + useChatStore, + usePluginStore, + ChatMessageTool, +} from "@/app/store"; import { getClientConfig } from "@/app/config/client"; import { DEFAULT_API_HOST } from "@/app/constant"; import { @@ -11,8 +17,9 @@ import { import Locale from "../../locales"; import { prettyObject } from "@/app/utils/format"; import { getMessageTextContent, isVisionModel } from "@/app/utils"; -import { preProcessImageContent } from "@/app/utils/chat"; +import { preProcessImageContent, stream } from "@/app/utils/chat"; import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare"; +import { RequestPayload } from "./openai"; export type MultiBlockContent = { type: "image" | "text"; @@ -191,112 +198,123 @@ export class ClaudeApi implements LLMApi { const controller = new AbortController(); options.onController?.(controller); - const payload = { - method: "POST", - body: JSON.stringify(requestBody), - signal: controller.signal, - headers: { - ...getHeaders(), // get common headers - "anthropic-version": accessStore.anthropicApiVersion, - // do not send `anthropicApiKey` in browser!!! - // Authorization: getAuthKey(accessStore.anthropicApiKey), - }, - }; - if (shouldStream) { - try { - const context = { - text: "", - finished: false, - }; - - const finish = () => { - if (!context.finished) { - options.onFinish(context.text); - context.finished = true; - } - }; - - controller.signal.onabort = finish; - fetchEventSource(path, { - ...payload, - async onopen(res) { - const contentType = res.headers.get("content-type"); - console.log("response content type: ", contentType); - - if (contentType?.startsWith("text/plain")) { - context.text = await res.clone().text(); - return finish(); - } - - if ( - !res.ok || - !res.headers - .get("content-type") - ?.startsWith(EventStreamContentType) || - res.status !== 200 - ) { - const responseTexts = [context.text]; - let extraInfo = await res.clone().text(); - try { - const resJson = await res.clone().json(); - extraInfo = prettyObject(resJson); - } catch {} - - if (res.status === 401) { - responseTexts.push(Locale.Error.Unauthorized); - } - - if (extraInfo) { - responseTexts.push(extraInfo); - } - - context.text = responseTexts.join("\n\n"); - - return finish(); - } - }, - onmessage(msg) { - let chunkJson: - | undefined - | { - type: "content_block_delta" | "content_block_stop"; - delta?: { - type: "text_delta"; - text: string; - }; - index: number; + let index = -1; + const [tools, funcs] = usePluginStore + .getState() + .getAsTools( + useChatStore.getState().currentSession().mask?.plugin as string[], + ); + console.log("getAsTools", tools, funcs); + return stream( + path, + requestBody, + { + ...getHeaders(), + "anthropic-version": accessStore.anthropicApiVersion, + }, + // @ts-ignore + tools.map((tool) => ({ + name: tool?.function?.name, + description: tool?.function?.description, + input_schema: tool?.function?.parameters, + })), + funcs, + controller, + // parseSSE + (text: string, runTools: ChatMessageTool[]) => { + // console.log("parseSSE", text, runTools); + let chunkJson: + | undefined + | { + type: "content_block_delta" | "content_block_stop"; + content_block?: { + type: "tool_use"; + id: string; + name: string; }; - try { - chunkJson = JSON.parse(msg.data); - } catch (e) { - console.error("[Response] parse error", msg.data); - } - - if (!chunkJson || chunkJson.type === "content_block_stop") { - return finish(); - } - - const { delta } = chunkJson; - if (delta?.text) { - context.text += delta.text; - options.onUpdate?.(context.text, delta.text); - } - }, - onclose() { - finish(); - }, - onerror(e) { - options.onError?.(e); - throw e; - }, - openWhenHidden: true, - }); - } catch (e) { - console.error("failed to chat", e); - options.onError?.(e as Error); - } + delta?: { + type: "text_delta" | "input_json_delta"; + text?: string; + partial_json?: string; + }; + index: number; + }; + chunkJson = JSON.parse(text); + + if (chunkJson?.content_block?.type == "tool_use") { + index += 1; + const id = chunkJson?.content_block.id; + const name = chunkJson?.content_block.name; + runTools.push({ + id, + type: "function", + function: { + name, + arguments: "", + }, + }); + } + if ( + chunkJson?.delta?.type == "input_json_delta" && + chunkJson?.delta?.partial_json + ) { + // @ts-ignore + runTools[index]["function"]["arguments"] += + chunkJson?.delta?.partial_json; + } + return chunkJson?.delta?.text; + }, + // processToolMessage, include tool_calls message and tool call results + ( + requestPayload: RequestPayload, + toolCallMessage: any, + toolCallResult: any[], + ) => { + // @ts-ignore + requestPayload?.messages?.splice( + // @ts-ignore + requestPayload?.messages?.length, + 0, + { + role: "assistant", + content: toolCallMessage.tool_calls.map( + (tool: ChatMessageTool) => ({ + type: "tool_use", + id: tool.id, + name: tool?.function?.name, + input: JSON.parse(tool?.function?.arguments as string), + }), + ), + }, + // @ts-ignore + ...toolCallResult.map((result) => ({ + role: "user", + content: [ + { + type: "tool_result", + tool_use_id: result.tool_call_id, + content: result.content, + }, + ], + })), + ); + }, + options, + ); } else { + const payload = { + method: "POST", + body: JSON.stringify(requestBody), + signal: controller.signal, + headers: { + ...getHeaders(), // get common headers + "anthropic-version": accessStore.anthropicApiVersion, + // do not send `anthropicApiKey` in browser!!! + // Authorization: getAuthKey(accessStore.anthropicApiKey), + }, + }; + try { controller.signal.onabort = () => options.onFinish(""); diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index 4c5831fe3e9..b3b306d1d11 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -246,7 +246,7 @@ export class ChatGPTApi implements LLMApi { .getAsTools( useChatStore.getState().currentSession().mask?.plugin as string[], ); - console.log("getAsTools", tools, funcs); + // console.log("getAsTools", tools, funcs); stream( chatPath, requestPayload, diff --git a/app/components/chat.tsx b/app/components/chat.tsx index 7bac62bc4c3..7d180f0b739 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -66,6 +66,7 @@ import { getMessageImages, isVisionModel, isDalle3, + showPlugins, } from "../utils"; import { uploadImage as uploadImageRemote } from "@/app/utils/chat"; @@ -741,12 +742,14 @@ export function ChatActions(props: { value: ArtifactsPlugin.Artifacts as string, }, ].concat( - pluginStore.getAll().map((item) => ({ - // @ts-ignore - title: `${item?.title}@${item?.version}`, - // @ts-ignore - value: item?.id, - })), + showPlugins(currentProviderName, currentModel) + ? pluginStore.getAll().map((item) => ({ + // @ts-ignore + title: `${item?.title}@${item?.version}`, + // @ts-ignore + value: item?.id, + })) + : [], )} onClose={() => setShowPluginSelector(false)} onSelection={(s) => { diff --git a/app/utils.ts b/app/utils.ts index 2a292290755..b9884c70644 100644 --- a/app/utils.ts +++ b/app/utils.ts @@ -2,6 +2,7 @@ import { useEffect, useState } from "react"; import { showToast } from "./components/ui-lib"; import Locale from "./locales"; import { RequestMessage } from "./client/api"; +import { ServiceProvider } from "./constant"; export function trimTopic(topic: string) { // Fix an issue where double quotes still show in the Indonesian language @@ -270,3 +271,13 @@ export function isVisionModel(model: string) { export function isDalle3(model: string) { return "dall-e-3" === model; } + +export function showPlugins(provider: ServiceProvider, model: string) { + if (provider == ServiceProvider.OpenAI || provider == ServiceProvider.Azure) { + return true; + } + if (provider == ServiceProvider.Anthropic && !model.includes("claude-2")) { + return true; + } + return false; +} diff --git a/app/utils/chat.ts b/app/utils/chat.ts index d8ab5770cc2..49e5060d432 100644 --- a/app/utils/chat.ts +++ b/app/utils/chat.ts @@ -334,7 +334,7 @@ export function stream( remainText += chunk; } } catch (e) { - console.error("[Request] parse error", text, msg); + console.error("[Request] parse error", text, msg, e); } }, onclose() {