Skip to content

Commit

Permalink
claude support function call
Browse files Browse the repository at this point in the history
  • Loading branch information
lloydzhou committed Sep 2, 2024
1 parent 877668b commit 801b625
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 113 deletions.
1 change: 1 addition & 0 deletions app/api/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) {
console.log("[Auth] hashed access code:", hashedCode);
console.log("[User IP] ", getIP(req));
console.log("[Time] ", new Date().toLocaleString());
console.log("[ModelProvider] ", modelProvider);

if (serverConfig.needCode && !serverConfig.codes.has(hashedCode) && !apiKey) {
return {
Expand Down
228 changes: 123 additions & 105 deletions app/client/platforms/anthropic.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import { ACCESS_CODE_PREFIX, Anthropic, ApiPath } from "@/app/constant";
import { ChatOptions, getHeaders, LLMApi, MultimodalContent } from "../api";
import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
import {
useAccessStore,
useAppConfig,
useChatStore,
usePluginStore,
ChatMessageTool,
} from "@/app/store";
import { getClientConfig } from "@/app/config/client";
import { DEFAULT_API_HOST } from "@/app/constant";
import {
Expand All @@ -11,8 +17,9 @@ import {
import Locale from "../../locales";
import { prettyObject } from "@/app/utils/format";
import { getMessageTextContent, isVisionModel } from "@/app/utils";
import { preProcessImageContent } from "@/app/utils/chat";
import { preProcessImageContent, stream } from "@/app/utils/chat";
import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare";
import { RequestPayload } from "./openai";

export type MultiBlockContent = {
type: "image" | "text";
Expand Down Expand Up @@ -191,112 +198,123 @@ export class ClaudeApi implements LLMApi {
const controller = new AbortController();
options.onController?.(controller);

const payload = {
method: "POST",
body: JSON.stringify(requestBody),
signal: controller.signal,
headers: {
...getHeaders(), // get common headers
"anthropic-version": accessStore.anthropicApiVersion,
// do not send `anthropicApiKey` in browser!!!
// Authorization: getAuthKey(accessStore.anthropicApiKey),
},
};

if (shouldStream) {
try {
const context = {
text: "",
finished: false,
};

const finish = () => {
if (!context.finished) {
options.onFinish(context.text);
context.finished = true;
}
};

controller.signal.onabort = finish;
fetchEventSource(path, {
...payload,
async onopen(res) {
const contentType = res.headers.get("content-type");
console.log("response content type: ", contentType);

if (contentType?.startsWith("text/plain")) {
context.text = await res.clone().text();
return finish();
}

if (
!res.ok ||
!res.headers
.get("content-type")
?.startsWith(EventStreamContentType) ||
res.status !== 200
) {
const responseTexts = [context.text];
let extraInfo = await res.clone().text();
try {
const resJson = await res.clone().json();
extraInfo = prettyObject(resJson);
} catch {}

if (res.status === 401) {
responseTexts.push(Locale.Error.Unauthorized);
}

if (extraInfo) {
responseTexts.push(extraInfo);
}

context.text = responseTexts.join("\n\n");

return finish();
}
},
onmessage(msg) {
let chunkJson:
| undefined
| {
type: "content_block_delta" | "content_block_stop";
delta?: {
type: "text_delta";
text: string;
};
index: number;
let index = -1;
const [tools, funcs] = usePluginStore
.getState()
.getAsTools(
useChatStore.getState().currentSession().mask?.plugin as string[],
);
console.log("getAsTools", tools, funcs);
return stream(
path,
requestBody,
{
...getHeaders(),
"anthropic-version": accessStore.anthropicApiVersion,
},
// @ts-ignore
tools.map((tool) => ({
name: tool?.function?.name,
description: tool?.function?.description,
input_schema: tool?.function?.parameters,
})),
funcs,
controller,
// parseSSE
(text: string, runTools: ChatMessageTool[]) => {
// console.log("parseSSE", text, runTools);
let chunkJson:
| undefined
| {
type: "content_block_delta" | "content_block_stop";
content_block?: {
type: "tool_use";
id: string;
name: string;
};
try {
chunkJson = JSON.parse(msg.data);
} catch (e) {
console.error("[Response] parse error", msg.data);
}

if (!chunkJson || chunkJson.type === "content_block_stop") {
return finish();
}

const { delta } = chunkJson;
if (delta?.text) {
context.text += delta.text;
options.onUpdate?.(context.text, delta.text);
}
},
onclose() {
finish();
},
onerror(e) {
options.onError?.(e);
throw e;
},
openWhenHidden: true,
});
} catch (e) {
console.error("failed to chat", e);
options.onError?.(e as Error);
}
delta?: {
type: "text_delta" | "input_json_delta";
text?: string;
partial_json?: string;
};
index: number;
};
chunkJson = JSON.parse(text);

if (chunkJson?.content_block?.type == "tool_use") {
index += 1;
const id = chunkJson?.content_block.id;
const name = chunkJson?.content_block.name;
runTools.push({
id,
type: "function",
function: {
name,
arguments: "",
},
});
}
if (
chunkJson?.delta?.type == "input_json_delta" &&
chunkJson?.delta?.partial_json
) {
// @ts-ignore
runTools[index]["function"]["arguments"] +=
chunkJson?.delta?.partial_json;
}
return chunkJson?.delta?.text;
},
// processToolMessage, include tool_calls message and tool call results
(
requestPayload: RequestPayload,
toolCallMessage: any,
toolCallResult: any[],
) => {
// @ts-ignore
requestPayload?.messages?.splice(
// @ts-ignore
requestPayload?.messages?.length,
0,
{
role: "assistant",
content: toolCallMessage.tool_calls.map(
(tool: ChatMessageTool) => ({
type: "tool_use",
id: tool.id,
name: tool?.function?.name,
input: JSON.parse(tool?.function?.arguments as string),
}),
),
},
// @ts-ignore
...toolCallResult.map((result) => ({
role: "user",
content: [
{
type: "tool_result",
tool_use_id: result.tool_call_id,
content: result.content,
},
],
})),
);
},
options,
);
} else {
const payload = {
method: "POST",
body: JSON.stringify(requestBody),
signal: controller.signal,
headers: {
...getHeaders(), // get common headers
"anthropic-version": accessStore.anthropicApiVersion,
// do not send `anthropicApiKey` in browser!!!
// Authorization: getAuthKey(accessStore.anthropicApiKey),
},
};

try {
controller.signal.onabort = () => options.onFinish("");

Expand Down
2 changes: 1 addition & 1 deletion app/client/platforms/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ export class ChatGPTApi implements LLMApi {
.getAsTools(
useChatStore.getState().currentSession().mask?.plugin as string[],
);
console.log("getAsTools", tools, funcs);
// console.log("getAsTools", tools, funcs);
stream(
chatPath,
requestPayload,
Expand Down
15 changes: 9 additions & 6 deletions app/components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ import {
getMessageImages,
isVisionModel,
isDalle3,
showPlugins,
} from "../utils";

import { uploadImage as uploadImageRemote } from "@/app/utils/chat";
Expand Down Expand Up @@ -741,12 +742,14 @@ export function ChatActions(props: {
value: ArtifactsPlugin.Artifacts as string,
},
].concat(
pluginStore.getAll().map((item) => ({
// @ts-ignore
title: `${item?.title}@${item?.version}`,
// @ts-ignore
value: item?.id,
})),
showPlugins(currentProviderName, currentModel)
? pluginStore.getAll().map((item) => ({
// @ts-ignore
title: `${item?.title}@${item?.version}`,
// @ts-ignore
value: item?.id,
}))
: [],
)}
onClose={() => setShowPluginSelector(false)}
onSelection={(s) => {
Expand Down
11 changes: 11 additions & 0 deletions app/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { useEffect, useState } from "react";
import { showToast } from "./components/ui-lib";
import Locale from "./locales";
import { RequestMessage } from "./client/api";
import { ServiceProvider } from "./constant";

export function trimTopic(topic: string) {
// Fix an issue where double quotes still show in the Indonesian language
Expand Down Expand Up @@ -270,3 +271,13 @@ export function isVisionModel(model: string) {
export function isDalle3(model: string) {
return "dall-e-3" === model;
}

export function showPlugins(provider: ServiceProvider, model: string) {
if (provider == ServiceProvider.OpenAI || provider == ServiceProvider.Azure) {
return true;
}
if (provider == ServiceProvider.Anthropic && !model.includes("claude-2")) {
return true;
}
return false;
}
2 changes: 1 addition & 1 deletion app/utils/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ export function stream(
remainText += chunk;
}
} catch (e) {
console.error("[Request] parse error", text, msg);
console.error("[Request] parse error", text, msg, e);
}
},
onclose() {
Expand Down

0 comments on commit 801b625

Please sign in to comment.