diff --git a/README.md b/README.md index 467bfbbe0f6..0815b723f62 100644 --- a/README.md +++ b/README.md @@ -225,6 +225,14 @@ Baidu Secret Key. Baidu Api Url. +### `BYTEDANCE_API_KEY` (optional) + +ByteDance Api Key. + +### `BYTEDANCE_URL` (optional) + +ByteDance Api Url. + ### `HIDE_USER_API_KEY` (optional) > Default: Empty @@ -261,6 +269,9 @@ User `-all` to disable all default models, `+all` to enable all default models. For Azure: use `modelName@azure=deploymentName` to customize model name and deployment name. > Example: `+gpt-3.5-turbo@azure=gpt35` will show option `gpt35(Azure)` in model list. +For ByteDance: use `modelName@bytedance=deploymentName` to customize model name and deployment name. +> Example: `+Doubao-lite-4k@bytedance=ep-xxxxx-xxx` will show option `Doubao-lite-4k(ByteDance)` in model list. + ### `DEFAULT_MODEL` (optional) Change default model diff --git a/README_CN.md b/README_CN.md index e6c4d2011d8..321efe441dd 100644 --- a/README_CN.md +++ b/README_CN.md @@ -139,6 +139,14 @@ Baidu Secret Key. Baidu Api Url. +### `BYTEDANCE_API_KEY` (可选) + +ByteDance Api Key. + +### `BYTEDANCE_URL` (可选) + +ByteDance Api Url. + ### `HIDE_USER_API_KEY` (可选) 如果你不想让用户自行填入 API Key,将此环境变量设置为 1 即可。 @@ -172,6 +180,9 @@ Baidu Api Url. 在Azure的模式下,支持使用`modelName@azure=deploymentName`的方式配置模型名称和部署名称(deploy-name) > 示例:`+gpt-3.5-turbo@azure=gpt35`这个配置会在模型列表显示一个`gpt35(Azure)`的选项 +在ByteDance的模式下,支持使用`modelName@bytedance=deploymentName`的方式配置模型名称和部署名称(deploy-name) +> 示例: `+Doubao-lite-4k@bytedance=ep-xxxxx-xxx`这个配置会在模型列表显示一个`Doubao-lite-4k(ByteDance)`的选项 + ### `DEFAULT_MODEL` (可选) diff --git a/app/api/auth.ts b/app/api/auth.ts index cce8847f4dd..f54b33cf197 100644 --- a/app/api/auth.ts +++ b/app/api/auth.ts @@ -73,6 +73,9 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) { case ModelProvider.Claude: systemApiKey = serverConfig.anthropicApiKey; break; + case ModelProvider.Doubao: + systemApiKey = serverConfig.bytedanceApiKey; + break; case ModelProvider.Ernie: systemApiKey = serverConfig.baiduApiKey; break; diff --git a/app/api/bytedance/[...path]/route.ts b/app/api/bytedance/[...path]/route.ts new file mode 100644 index 00000000000..336c837f037 --- /dev/null +++ b/app/api/bytedance/[...path]/route.ts @@ -0,0 +1,153 @@ +import { getServerSideConfig } from "@/app/config/server"; +import { + BYTEDANCE_BASE_URL, + ApiPath, + ModelProvider, + ServiceProvider, +} from "@/app/constant"; +import { prettyObject } from "@/app/utils/format"; +import { NextRequest, NextResponse } from "next/server"; +import { auth } from "@/app/api/auth"; +import { isModelAvailableInServer } from "@/app/utils/model"; + +const serverConfig = getServerSideConfig(); + +async function handle( + req: NextRequest, + { params }: { params: { path: string[] } }, +) { + console.log("[ByteDance Route] params ", params); + + if (req.method === "OPTIONS") { + return NextResponse.json({ body: "OK" }, { status: 200 }); + } + + const authResult = auth(req, ModelProvider.Doubao); + if (authResult.error) { + return NextResponse.json(authResult, { + status: 401, + }); + } + + try { + const response = await request(req); + return response; + } catch (e) { + console.error("[ByteDance] ", e); + return NextResponse.json(prettyObject(e)); + } +} + +export const GET = handle; +export const POST = handle; + +export const runtime = "edge"; +export const preferredRegion = [ + "arn1", + "bom1", + "cdg1", + "cle1", + "cpt1", + "dub1", + "fra1", + "gru1", + "hnd1", + "iad1", + "icn1", + "kix1", + "lhr1", + "pdx1", + "sfo1", + "sin1", + "syd1", +]; + +async function request(req: NextRequest) { + const controller = new AbortController(); + + let path = `${req.nextUrl.pathname}`.replaceAll(ApiPath.ByteDance, ""); + + let baseUrl = serverConfig.bytedanceUrl || BYTEDANCE_BASE_URL; + + if (!baseUrl.startsWith("http")) { + baseUrl = `https://${baseUrl}`; + } + + if (baseUrl.endsWith("/")) { + baseUrl = baseUrl.slice(0, -1); + } + + console.log("[Proxy] ", path); + console.log("[Base Url]", baseUrl); + + const timeoutId = setTimeout( + () => { + controller.abort(); + }, + 10 * 60 * 1000, + ); + + const fetchUrl = `${baseUrl}${path}`; + + const fetchOptions: RequestInit = { + headers: { + "Content-Type": "application/json", + Authorization: req.headers.get("Authorization") ?? "", + }, + method: req.method, + body: req.body, + redirect: "manual", + // @ts-ignore + duplex: "half", + signal: controller.signal, + }; + + // #1815 try to refuse some request to some models + if (serverConfig.customModels && req.body) { + try { + const clonedBody = await req.text(); + fetchOptions.body = clonedBody; + + const jsonBody = JSON.parse(clonedBody) as { model?: string }; + + // not undefined and is false + if ( + isModelAvailableInServer( + serverConfig.customModels, + jsonBody?.model as string, + ServiceProvider.ByteDance as string, + ) + ) { + return NextResponse.json( + { + error: true, + message: `you are not allowed to use ${jsonBody?.model} model`, + }, + { + status: 403, + }, + ); + } + } catch (e) { + console.error(`[ByteDance] filter`, e); + } + } + + try { + const res = await fetch(fetchUrl, fetchOptions); + + // to prevent browser prompt for credentials + const newHeaders = new Headers(res.headers); + newHeaders.delete("www-authenticate"); + // to disable nginx buffering + newHeaders.set("X-Accel-Buffering", "no"); + + return new Response(res.body, { + status: res.status, + statusText: res.statusText, + headers: newHeaders, + }); + } finally { + clearTimeout(timeoutId); + } +} diff --git a/app/client/api.ts b/app/client/api.ts index f650139f906..147b11ad211 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -10,6 +10,7 @@ import { ChatGPTApi } from "./platforms/openai"; import { GeminiProApi } from "./platforms/google"; import { ClaudeApi } from "./platforms/anthropic"; import { ErnieApi } from "./platforms/baidu"; +import { DoubaoApi } from "./platforms/bytedance"; export const ROLES = ["system", "user", "assistant"] as const; export type MessageRole = (typeof ROLES)[number]; @@ -109,6 +110,9 @@ export class ClientApi { case ModelProvider.Ernie: this.llm = new ErnieApi(); break; + case ModelProvider.Doubao: + this.llm = new DoubaoApi(); + break; default: this.llm = new ChatGPTApi(); } @@ -175,6 +179,8 @@ export function getHeaders() { const isGoogle = modelConfig.providerName == ServiceProvider.Google; const isAzure = modelConfig.providerName === ServiceProvider.Azure; const isAnthropic = modelConfig.providerName === ServiceProvider.Anthropic; + const isBaidu = modelConfig.providerName == ServiceProvider.Baidu; + const isByteDance = modelConfig.providerName === ServiceProvider.ByteDance; const isEnabledAccessControl = accessStore.enabledAccessControl(); const apiKey = isGoogle ? accessStore.googleApiKey @@ -182,8 +188,18 @@ export function getHeaders() { ? accessStore.azureApiKey : isAnthropic ? accessStore.anthropicApiKey + : isByteDance + ? accessStore.bytedanceApiKey : accessStore.openaiApiKey; - return { isGoogle, isAzure, isAnthropic, apiKey, isEnabledAccessControl }; + return { + isGoogle, + isAzure, + isAnthropic, + isBaidu, + isByteDance, + apiKey, + isEnabledAccessControl, + }; } function getAuthHeader(): string { @@ -199,10 +215,18 @@ export function getHeaders() { function validString(x: string): boolean { return x?.length > 0; } - const { isGoogle, isAzure, isAnthropic, apiKey, isEnabledAccessControl } = - getConfig(); + const { + isGoogle, + isAzure, + isAnthropic, + isBaidu, + apiKey, + isEnabledAccessControl, + } = getConfig(); // when using google api in app, not set auth header if (isGoogle && clientConfig?.isApp) return headers; + // when using baidu api in app, not set auth header + if (isBaidu && clientConfig?.isApp) return headers; const authHeader = getAuthHeader(); @@ -227,6 +251,8 @@ export function getClientApi(provider: ServiceProvider): ClientApi { return new ClientApi(ModelProvider.Claude); case ServiceProvider.Baidu: return new ClientApi(ModelProvider.Ernie); + case ServiceProvider.ByteDance: + return new ClientApi(ModelProvider.Doubao); default: return new ClientApi(ModelProvider.GPT); } diff --git a/app/client/platforms/bytedance.ts b/app/client/platforms/bytedance.ts new file mode 100644 index 00000000000..7677cafe12b --- /dev/null +++ b/app/client/platforms/bytedance.ts @@ -0,0 +1,255 @@ +"use client"; +import { + ApiPath, + ByteDance, + BYTEDANCE_BASE_URL, + REQUEST_TIMEOUT_MS, +} from "@/app/constant"; +import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; + +import { + ChatOptions, + getHeaders, + LLMApi, + LLMModel, + MultimodalContent, +} from "../api"; +import Locale from "../../locales"; +import { + EventStreamContentType, + fetchEventSource, +} from "@fortaine/fetch-event-source"; +import { prettyObject } from "@/app/utils/format"; +import { getClientConfig } from "@/app/config/client"; +import { getMessageTextContent } from "@/app/utils"; + +export interface OpenAIListModelResponse { + object: string; + data: Array<{ + id: string; + object: string; + root: string; + }>; +} + +interface RequestPayload { + messages: { + role: "system" | "user" | "assistant"; + content: string | MultimodalContent[]; + }[]; + stream?: boolean; + model: string; + temperature: number; + presence_penalty: number; + frequency_penalty: number; + top_p: number; + max_tokens?: number; +} + +export class DoubaoApi implements LLMApi { + path(path: string): string { + const accessStore = useAccessStore.getState(); + + let baseUrl = ""; + + if (accessStore.useCustomConfig) { + baseUrl = accessStore.bytedanceUrl; + } + + if (baseUrl.length === 0) { + const isApp = !!getClientConfig()?.isApp; + baseUrl = isApp ? BYTEDANCE_BASE_URL : ApiPath.ByteDance; + } + + if (baseUrl.endsWith("/")) { + baseUrl = baseUrl.slice(0, baseUrl.length - 1); + } + if (!baseUrl.startsWith("http") && !baseUrl.startsWith(ApiPath.ByteDance)) { + baseUrl = "https://" + baseUrl; + } + + console.log("[Proxy Endpoint] ", baseUrl, path); + + return [baseUrl, path].join("/"); + } + + extractMessage(res: any) { + return res.choices?.at(0)?.message?.content ?? ""; + } + + async chat(options: ChatOptions) { + const messages = options.messages.map((v) => ({ + role: v.role, + content: getMessageTextContent(v), + })); + + const modelConfig = { + ...useAppConfig.getState().modelConfig, + ...useChatStore.getState().currentSession().mask.modelConfig, + ...{ + model: options.config.model, + }, + }; + + const shouldStream = !!options.config.stream; + const requestPayload: RequestPayload = { + messages, + stream: shouldStream, + model: modelConfig.model, + temperature: modelConfig.temperature, + presence_penalty: modelConfig.presence_penalty, + frequency_penalty: modelConfig.frequency_penalty, + top_p: modelConfig.top_p, + }; + + const controller = new AbortController(); + options.onController?.(controller); + + try { + const chatPath = this.path(ByteDance.ChatPath); + const chatPayload = { + method: "POST", + body: JSON.stringify(requestPayload), + signal: controller.signal, + headers: getHeaders(), + }; + + // make a fetch request + const requestTimeoutId = setTimeout( + () => controller.abort(), + REQUEST_TIMEOUT_MS, + ); + + if (shouldStream) { + let responseText = ""; + let remainText = ""; + let finished = false; + + // animate response to make it looks smooth + function animateResponseText() { + if (finished || controller.signal.aborted) { + responseText += remainText; + console.log("[Response Animation] finished"); + if (responseText?.length === 0) { + options.onError?.(new Error("empty response from server")); + } + return; + } + + if (remainText.length > 0) { + const fetchCount = Math.max(1, Math.round(remainText.length / 60)); + const fetchText = remainText.slice(0, fetchCount); + responseText += fetchText; + remainText = remainText.slice(fetchCount); + options.onUpdate?.(responseText, fetchText); + } + + requestAnimationFrame(animateResponseText); + } + + // start animaion + animateResponseText(); + + const finish = () => { + if (!finished) { + finished = true; + options.onFinish(responseText + remainText); + } + }; + + controller.signal.onabort = finish; + + fetchEventSource(chatPath, { + ...chatPayload, + async onopen(res) { + clearTimeout(requestTimeoutId); + const contentType = res.headers.get("content-type"); + console.log( + "[ByteDance] request response content type: ", + contentType, + ); + + if (contentType?.startsWith("text/plain")) { + responseText = await res.clone().text(); + return finish(); + } + + if ( + !res.ok || + !res.headers + .get("content-type") + ?.startsWith(EventStreamContentType) || + res.status !== 200 + ) { + const responseTexts = [responseText]; + let extraInfo = await res.clone().text(); + try { + const resJson = await res.clone().json(); + extraInfo = prettyObject(resJson); + } catch {} + + if (res.status === 401) { + responseTexts.push(Locale.Error.Unauthorized); + } + + if (extraInfo) { + responseTexts.push(extraInfo); + } + + responseText = responseTexts.join("\n\n"); + + return finish(); + } + }, + onmessage(msg) { + if (msg.data === "[DONE]" || finished) { + return finish(); + } + const text = msg.data; + try { + const json = JSON.parse(text); + const choices = json.choices as Array<{ + delta: { content: string }; + }>; + const delta = choices[0]?.delta?.content; + if (delta) { + remainText += delta; + } + } catch (e) { + console.error("[Request] parse error", text, msg); + } + }, + onclose() { + finish(); + }, + onerror(e) { + options.onError?.(e); + throw e; + }, + openWhenHidden: true, + }); + } else { + const res = await fetch(chatPath, chatPayload); + clearTimeout(requestTimeoutId); + + const resJson = await res.json(); + const message = this.extractMessage(resJson); + options.onFinish(message); + } + } catch (e) { + console.log("[Request] failed to make a chat request", e); + options.onError?.(e as Error); + } + } + async usage() { + return { + used: 0, + total: 0, + }; + } + + async models(): Promise { + return []; + } +} +export { ByteDance }; diff --git a/app/components/chat.tsx b/app/components/chat.tsx index b1bdf757f44..40e02cb57ac 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -467,6 +467,14 @@ export function ChatActions(props: { return filteredModels; } }, [allModels]); + const currentModelName = useMemo(() => { + const model = models.find( + (m) => + m.name == currentModel && + m?.provider?.providerName == currentProviderName, + ); + return model?.displayName ?? ""; + }, [models, currentModel, currentProviderName]); const [showModelSelector, setShowModelSelector] = useState(false); const [showUploadImage, setShowUploadImage] = useState(false); @@ -489,7 +497,11 @@ export function ChatActions(props: { session.mask.modelConfig.providerName = nextModel?.provider ?.providerName as ServiceProvider; }); - showToast(nextModel.name); + showToast( + nextModel?.provider?.providerName == "ByteDance" + ? nextModel.displayName + : nextModel.name, + ); } }, [chatStore, currentModel, models]); @@ -571,7 +583,7 @@ export function ChatActions(props: { setShowModelSelector(true)} - text={currentModel} + text={currentModelName} icon={} /> @@ -596,7 +608,15 @@ export function ChatActions(props: { providerName as ServiceProvider; session.mask.syncGlobalConfig = false; }); - showToast(model); + if (providerName == "ByteDance") { + const selectedModel = models.find( + (m) => + m.name == model && m?.provider?.providerName == providerName, + ); + showToast(selectedModel?.displayName ?? ""); + } else { + showToast(model); + } }} /> )} diff --git a/app/components/settings.tsx b/app/components/settings.tsx index 7db09940d4b..4d19fa76ed6 100644 --- a/app/components/settings.tsx +++ b/app/components/settings.tsx @@ -54,6 +54,7 @@ import { Anthropic, Azure, Baidu, + ByteDance, Google, OPENAI_BASE_URL, Path, @@ -1249,6 +1250,51 @@ export function Settings() { )} + + {accessStore.provider === ServiceProvider.ByteDance && ( + <> + + + accessStore.update( + (access) => + (access.bytedanceUrl = e.currentTarget.value), + ) + } + > + + + { + accessStore.update( + (access) => + (access.bytedanceApiKey = + e.currentTarget.value), + ); + }} + /> + + + )} )} diff --git a/app/config/server.ts b/app/config/server.ts index 2d09c547961..866b733c365 100644 --- a/app/config/server.ts +++ b/app/config/server.ts @@ -45,6 +45,10 @@ declare global { BAIDU_API_KEY?: string; BAIDU_SECRET_KEY?: string; + // bytedance only + BYTEDANCE_URL?: string; + BYTEDANCE_API_KEY?: string; + // custom template for preprocessing user input DEFAULT_INPUT_TEMPLATE?: string; } @@ -103,6 +107,7 @@ export const getServerSideConfig = () => { const isGoogle = !!process.env.GOOGLE_API_KEY; const isAnthropic = !!process.env.ANTHROPIC_API_KEY; const isBaidu = !!process.env.BAIDU_API_KEY; + const isBytedance = !!process.env.BYTEDANCE_API_KEY; // const apiKeyEnvVar = process.env.OPENAI_API_KEY ?? ""; // const apiKeys = apiKeyEnvVar.split(",").map((v) => v.trim()); // const randomIndex = Math.floor(Math.random() * apiKeys.length); @@ -139,6 +144,10 @@ export const getServerSideConfig = () => { baiduApiKey: getApiKey(process.env.BAIDU_API_KEY), baiduSecretKey: process.env.BAIDU_SECRET_KEY, + isBytedance, + bytedanceApiKey: getApiKey(process.env.BYTEDANCE_API_KEY), + bytedanceUrl: process.env.BYTEDANCE_URL, + gtmId: process.env.GTM_ID, needCode: ACCESS_CODES.size > 0, diff --git a/app/constant.ts b/app/constant.ts index 3d48dbb6235..8a98599edc0 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -15,9 +15,10 @@ export const ANTHROPIC_BASE_URL = "https://api.anthropic.com"; export const GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/"; export const BAIDU_BASE_URL = "https://aip.baidubce.com"; - export const BAIDU_OATUH_URL = `${BAIDU_BASE_URL}/oauth/2.0/token`; +export const BYTEDANCE_BASE_URL = "https://ark.cn-beijing.volces.com"; + export enum Path { Home = "/", Chat = "/chat", @@ -33,6 +34,7 @@ export enum ApiPath { OpenAI = "/api/openai", Anthropic = "/api/anthropic", Baidu = "/api/baidu", + ByteDance = "/api/bytedance", } export enum SlotID { @@ -77,6 +79,7 @@ export enum ServiceProvider { Google = "Google", Anthropic = "Anthropic", Baidu = "Baidu", + ByteDance = "ByteDance", } export enum ModelProvider { @@ -84,6 +87,7 @@ export enum ModelProvider { GeminiPro = "GeminiPro", Claude = "Claude", Ernie = "Ernie", + Doubao = "Doubao", } export const Anthropic = { @@ -128,6 +132,11 @@ export const Baidu = { }, }; +export const ByteDance = { + ExampleEndpoint: "https://ark.cn-beijing.volces.com/api/", + ChatPath: "api/v3/chat/completions", +}; + export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang // export const DEFAULT_SYSTEM_TEMPLATE = ` // You are ChatGPT, a large language model trained by {{ServiceProvider}}. @@ -207,6 +216,15 @@ const baiduModels = [ "ernie-3.5-8k-0205", ]; +const bytedanceModels = [ + "Doubao-lite-4k", + "Doubao-lite-32k", + "Doubao-lite-128k", + "Doubao-pro-4k", + "Doubao-pro-32k", + "Doubao-pro-128k", +]; + export const DEFAULT_MODELS = [ ...openaiModels.map((name) => ({ name, @@ -253,6 +271,15 @@ export const DEFAULT_MODELS = [ providerType: "baidu", }, })), + ...bytedanceModels.map((name) => ({ + name, + available: true, + provider: { + id: "bytedance", + providerName: "ByteDance", + providerType: "bytedance", + }, + })), ] as const; export const CHAT_PAGE_SIZE = 15; diff --git a/app/locales/cn.ts b/app/locales/cn.ts index d7268807cb3..d605268703a 100644 --- a/app/locales/cn.ts +++ b/app/locales/cn.ts @@ -363,6 +363,17 @@ const cn = { SubTitle: "样例:", }, }, + ByteDance: { + ApiKey: { + Title: "接口密钥", + SubTitle: "使用自定义 ByteDance API Key", + Placeholder: "ByteDance API Key", + }, + Endpoint: { + Title: "接口地址", + SubTitle: "样例:", + }, + }, CustomModel: { Title: "自定义模型名", SubTitle: "增加自定义模型可选项,使用英文逗号隔开", diff --git a/app/locales/en.ts b/app/locales/en.ts index 3c0d8851fa4..136a5bbaccb 100644 --- a/app/locales/en.ts +++ b/app/locales/en.ts @@ -350,6 +350,17 @@ const en: LocaleType = { SubTitle: "Example:", }, }, + ByteDance: { + ApiKey: { + Title: "ByteDance API Key", + SubTitle: "Use a custom ByteDance API Key", + Placeholder: "ByteDance API Key", + }, + Endpoint: { + Title: "Endpoint Address", + SubTitle: "Example:", + }, + }, CustomModel: { Title: "Custom Models", SubTitle: "Custom model options, seperated by comma", diff --git a/app/store/access.ts b/app/store/access.ts index 7e6d01b3461..51082b33a76 100644 --- a/app/store/access.ts +++ b/app/store/access.ts @@ -52,6 +52,10 @@ const DEFAULT_ACCESS_STATE = { baiduApiKey: "", baiduSecretKey: "", + // bytedance + bytedanceApiKey: "", + bytedanceUrl: "", + // server config needCode: true, hideUserApiKey: false, @@ -92,6 +96,10 @@ export const useAccessStore = createPersistStore( return ensure(get(), ["baiduApiKey", "baiduSecretKey"]); }, + isValidByteDance() { + return ensure(get(), ["bytedanceApiKey"]); + }, + isAuthorized() { this.fetch(); @@ -102,6 +110,7 @@ export const useAccessStore = createPersistStore( this.isValidGoogle() || this.isValidAnthropic() || this.isValidBaidu() || + this.isValidByteDance() || !this.enabledAccessControl() || (this.enabledAccessControl() && ensure(get(), ["accessCode"])) ); diff --git a/app/utils/model.ts b/app/utils/model.ts index 0b160f1013b..a3a014877a9 100644 --- a/app/utils/model.ts +++ b/app/utils/model.ts @@ -39,7 +39,7 @@ export function collectModelTable( const available = !m.startsWith("-"); const nameConfig = m.startsWith("+") || m.startsWith("-") ? m.slice(1) : m; - const [name, displayName] = nameConfig.split("="); + let [name, displayName] = nameConfig.split("="); // enable or disable all models if (name === "all") { @@ -59,6 +59,11 @@ export function collectModelTable( ) { count += 1; modelTable[fullName]["available"] = available; + // swap name and displayName for bytedance + if (providerName === "bytedance") { + [name, displayName] = [displayName, name]; + modelTable[fullName]["name"] = name; + } if (displayName) { modelTable[fullName]["displayName"] = displayName; }