Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support baidu model #4936

Merged
merged 11 commits into from
Jul 9, 2024
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,4 @@ dev
.env

*.key
*.key.pub
*.key.pub
3 changes: 3 additions & 0 deletions app/api/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) {
case ModelProvider.Claude:
systemApiKey = serverConfig.anthropicApiKey;
break;
case ModelProvider.Ernie:
lloydzhou marked this conversation as resolved.
Show resolved Hide resolved
systemApiKey = serverConfig.baiduApiKey;
break;
lloydzhou marked this conversation as resolved.
Show resolved Hide resolved
case ModelProvider.GPT:
default:
if (req.nextUrl.pathname.includes("azure/deployments")) {
Expand Down
176 changes: 176 additions & 0 deletions app/api/baidu/[...path]/route.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import { getServerSideConfig } from "@/app/config/server";
import {
BAIDU_BASE_URL,
ApiPath,
ModelProvider,
BAIDU_OATUH_URL,
ServiceProvider,
} from "@/app/constant";
import { prettyObject } from "@/app/utils/format";
import { NextRequest, NextResponse } from "next/server";
import { auth } from "@/app/api/auth";
import { isModelAvailableInServer } from "@/app/utils/model";

const serverConfig = getServerSideConfig();

async function handle(
req: NextRequest,
{ params }: { params: { path: string[] } },
) {
console.log("[Baidu Route] params ", params);

if (req.method === "OPTIONS") {
return NextResponse.json({ body: "OK" }, { status: 200 });
}

const authResult = auth(req, ModelProvider.Ernie);
if (authResult.error) {
return NextResponse.json(authResult, {
status: 401,
});
}

try {
const response = await request(req);
return response;
} catch (e) {
console.error("[Baidu] ", e);
return NextResponse.json(prettyObject(e));
}
}

export const GET = handle;
export const POST = handle;

export const runtime = "edge";
export const preferredRegion = [
"arn1",
"bom1",
"cdg1",
"cle1",
"cpt1",
"dub1",
"fra1",
"gru1",
"hnd1",
"iad1",
"icn1",
"kix1",
"lhr1",
"pdx1",
"sfo1",
"sin1",
"syd1",
];

async function request(req: NextRequest) {
const controller = new AbortController();

let path = `${req.nextUrl.pathname}`.replaceAll(ApiPath.Baidu, "");

let baseUrl = serverConfig.baiduUrl || BAIDU_BASE_URL;

if (!baseUrl.startsWith("http")) {
baseUrl = `https://${baseUrl}`;
}

if (baseUrl.endsWith("/")) {
baseUrl = baseUrl.slice(0, -1);
}

console.log("[Proxy] ", path);
console.log("[Base Url]", baseUrl);

const timeoutId = setTimeout(
() => {
controller.abort();
},
10 * 60 * 1000,
);

const { access_token } = await getAccessToken();
const fetchUrl = `${baseUrl}${path}?access_token=${access_token}`;

const fetchOptions: RequestInit = {
headers: {
"Content-Type": "application/json",
},
method: req.method,
body: req.body,
redirect: "manual",
// @ts-ignore
duplex: "half",
signal: controller.signal,
};

// #1815 try to refuse some request to some models
if (serverConfig.customModels && req.body) {
try {
const clonedBody = await req.text();
fetchOptions.body = clonedBody;

const jsonBody = JSON.parse(clonedBody) as { model?: string };

// not undefined and is false
if (
isModelAvailableInServer(
serverConfig.customModels,
jsonBody?.model as string,
ServiceProvider.Baidu as string,
)
) {
return NextResponse.json(
{
error: true,
message: `you are not allowed to use ${jsonBody?.model} model`,
},
{
status: 403,
},
);
}
} catch (e) {
console.error(`[Baidu] filter`, e);
}
}
console.log("[Baidu request]", fetchOptions.headers, req.method);
try {
const res = await fetch(fetchUrl, fetchOptions);

console.log("[Baidu response]", res.status, " ", res.headers, res.url);
// to prevent browser prompt for credentials
const newHeaders = new Headers(res.headers);
newHeaders.delete("www-authenticate");
// to disable nginx buffering
newHeaders.set("X-Accel-Buffering", "no");

return new Response(res.body, {
status: res.status,
statusText: res.statusText,
headers: newHeaders,
});
} finally {
clearTimeout(timeoutId);
}
lloydzhou marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* 使用 AK,SK 生成鉴权签名(Access Token)
* @return 鉴权签名信息
*/
async function getAccessToken(): Promise<{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数需要放一个公共的位置:

  1. 如果是app端运行,或者用户直接在前端配置了百度的密钥,这个时候,需要在前端生成access_token,直接发送给配置的地址。
  2. 如果用户没有在前端配置自己的密钥,这个时候,发到/api/baidu/的请求是不带access_token的,这个时候,需要node server生成这个并放到请求里面。

access_token: string;
expires_in: number;
error?: number;
}> {
const AK = serverConfig.baiduApiKey;
const SK = serverConfig.baiduSecretKey;
const res = await fetch(
`${BAIDU_OATUH_URL}?grant_type=client_credentials&client_id=${AK}&client_secret=${SK}`,
{
method: "POST",
},
);
const resJson = await res.json();
return resJson;
lloydzhou marked this conversation as resolved.
Show resolved Hide resolved
}
5 changes: 5 additions & 0 deletions app/client/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import { ChatMessage, ModelType, useAccessStore, useChatStore } from "../store";
import { ChatGPTApi } from "./platforms/openai";
import { GeminiProApi } from "./platforms/google";
import { ClaudeApi } from "./platforms/anthropic";
import { ErnieApi } from "./platforms/baidu";

export const ROLES = ["system", "user", "assistant"] as const;
export type MessageRole = (typeof ROLES)[number];

Expand Down Expand Up @@ -104,6 +106,9 @@ export class ClientApi {
case ModelProvider.Claude:
this.llm = new ClaudeApi();
break;
case ModelProvider.Ernie:
this.llm = new ErnieApi();
break;
default:
this.llm = new ChatGPTApi();
}
Expand Down
Loading
Loading