-
Notifications
You must be signed in to change notification settings - Fork 60k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Implement Vertex AI support for Anthropic and Google models #5794
base: main
Are you sure you want to change the base?
Changes from all commits
5222c2b
6856061
bc2f36e
c30fd63
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ import { NextRequest, NextResponse } from "next/server"; | |
import { auth } from "./auth"; | ||
import { isModelAvailableInServer } from "@/app/utils/model"; | ||
import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare"; | ||
import { getGCloudToken } from "./common"; | ||
|
||
const ALLOWD_PATH = new Set([Anthropic.ChatPath, Anthropic.ChatPath1]); | ||
|
||
|
@@ -67,10 +68,20 @@ async function request(req: NextRequest) { | |
serverConfig.anthropicApiKey || | ||
""; | ||
|
||
let path = `${req.nextUrl.pathname}`.replaceAll(ApiPath.Anthropic, ""); | ||
// adjust header and url when using vertex ai | ||
if (serverConfig.isVertexAI) { | ||
authHeaderName = "Authorization"; | ||
const gCloudToken = await getGCloudToken(); | ||
authValue = `Bearer ${gCloudToken}`; | ||
} | ||
|
||
let path = serverConfig.vertexAIUrl | ||
? "" | ||
: `${req.nextUrl.pathname}`.replaceAll(ApiPath.Anthropic, ""); | ||
|
||
let baseUrl = | ||
serverConfig.anthropicUrl || serverConfig.baseUrl || ANTHROPIC_BASE_URL; | ||
let baseUrl = serverConfig.vertexAIUrl | ||
? serverConfig.vertexAIUrl | ||
: serverConfig.anthropicUrl || serverConfig.baseUrl || ANTHROPIC_BASE_URL; | ||
|
||
if (!baseUrl.startsWith("http")) { | ||
baseUrl = `https://${baseUrl}`; | ||
|
@@ -112,13 +123,16 @@ async function request(req: NextRequest) { | |
signal: controller.signal, | ||
}; | ||
|
||
// #1815 try to refuse some request to some models | ||
// #1815 try to refuse some request to some models or tick json body for vertex ai | ||
if (serverConfig.customModels && req.body) { | ||
try { | ||
const clonedBody = await req.text(); | ||
fetchOptions.body = clonedBody; | ||
|
||
const jsonBody = JSON.parse(clonedBody) as { model?: string }; | ||
const jsonBody = JSON.parse(clonedBody) as { | ||
model?: string; | ||
anthropic_version?: string; | ||
}; | ||
Comment on lines
+132
to
+135
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Handle JSON parsing errors appropriately If the request body cannot be parsed as JSON, the code logs the error but continues processing with the original body. This could lead to unexpected behavior downstream. Consider returning an error response when JSON parsing fails to ensure the request is handled correctly. Apply this diff to return an error response when JSON parsing fails: } catch (e) {
console.error(`[Anthropic] filter`, e);
+ return NextResponse.json(
+ { error: true, message: "Invalid JSON in request body" },
+ { status: 400 },
+ );
}
|
||
|
||
// not undefined and is false | ||
if ( | ||
|
@@ -138,6 +152,14 @@ async function request(req: NextRequest) { | |
}, | ||
); | ||
} | ||
|
||
// tick json body for vertex ai and update fetch options | ||
if (serverConfig.isVertexAI) { | ||
delete jsonBody.model; | ||
jsonBody.anthropic_version = | ||
serverConfig.anthropicApiVersion || "vertex-2023-10-16"; | ||
fetchOptions.body = JSON.stringify(jsonBody); | ||
} | ||
} catch (e) { | ||
console.error(`[Anthropic] filter`, e); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ import { NextRequest, NextResponse } from "next/server"; | |
import { getServerSideConfig } from "../config/server"; | ||
import { OPENAI_BASE_URL, ServiceProvider } from "../constant"; | ||
import { cloudflareAIGatewayUrl } from "../utils/cloudflare"; | ||
import { GoogleToken } from "../utils/gtoken"; | ||
import { getModelProvider, isModelAvailableInServer } from "../utils/model"; | ||
|
||
const serverConfig = getServerSideConfig(); | ||
|
@@ -185,3 +186,25 @@ export async function requestOpenai(req: NextRequest) { | |
clearTimeout(timeoutId); | ||
} | ||
} | ||
|
||
let gTokenClient: GoogleToken | undefined; | ||
|
||
/** | ||
* Get access token for google cloud, | ||
* requires GOOGLE_CLOUD_JSON_KEY to be set | ||
* @returns access token for google cloud | ||
*/ | ||
export async function getGCloudToken() { | ||
if (!gTokenClient) { | ||
if (!serverConfig.googleCloudJsonKey) | ||
throw new Error("GOOGLE_CLOUD_JSON_KEY is not set"); | ||
const keys = JSON.parse(serverConfig.googleCloudJsonKey); | ||
gTokenClient = new GoogleToken({ | ||
email: keys.client_email, | ||
key: keys.private_key, | ||
scope: ["https://www.googleapis.com/auth/cloud-platform"], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Review the requested OAuth scope. The current scope For Vertex AI, you might only need:
|
||
}); | ||
} | ||
const credentials = await gTokenClient?.getToken(); | ||
return credentials?.access_token; | ||
} | ||
Comment on lines
+197
to
+210
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add robust error handling and type safety. The current implementation has several potential issues:
Consider applying these improvements: export async function getGCloudToken() {
if (!gTokenClient) {
if (!serverConfig.googleCloudJsonKey)
throw new Error("GOOGLE_CLOUD_JSON_KEY is not set");
- const keys = JSON.parse(serverConfig.googleCloudJsonKey);
+ let keys;
+ try {
+ keys = JSON.parse(serverConfig.googleCloudJsonKey);
+ if (!keys.client_email || !keys.private_key) {
+ throw new Error("Invalid key structure");
+ }
+ } catch (error) {
+ throw new Error("Invalid GOOGLE_CLOUD_JSON_KEY format");
+ }
gTokenClient = new GoogleToken({
email: keys.client_email,
key: keys.private_key,
scope: ["https://www.googleapis.com/auth/cloud-platform"],
});
}
- const credentials = await gTokenClient?.getToken();
- return credentials?.access_token;
+ if (!gTokenClient) {
+ throw new Error("Failed to initialize Google Token client");
+ }
+ const credentials = await gTokenClient.getToken();
+ if (!credentials?.access_token) {
+ throw new Error("Failed to obtain access token");
+ }
+ return credentials.access_token;
} Also, consider adding JSDoc type definitions for the expected JSON structure: interface GoogleCloudKey {
client_email: string;
private_key: string;
// add other required fields
} |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -3,6 +3,7 @@ import { auth } from "./auth"; | |||||||||||||||||||||||||||||||||||||||||||||||||
import { getServerSideConfig } from "@/app/config/server"; | ||||||||||||||||||||||||||||||||||||||||||||||||||
import { ApiPath, GEMINI_BASE_URL, ModelProvider } from "@/app/constant"; | ||||||||||||||||||||||||||||||||||||||||||||||||||
import { prettyObject } from "@/app/utils/format"; | ||||||||||||||||||||||||||||||||||||||||||||||||||
import { getGCloudToken } from "./common"; | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
const serverConfig = getServerSideConfig(); | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -29,7 +30,9 @@ export async function handle( | |||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
const apiKey = token ? token : serverConfig.googleApiKey; | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
if (!apiKey) { | ||||||||||||||||||||||||||||||||||||||||||||||||||
// When using Vertex AI, the API key is not required. | ||||||||||||||||||||||||||||||||||||||||||||||||||
// Instead, a GCloud token will be used later in the request. | ||||||||||||||||||||||||||||||||||||||||||||||||||
if (!apiKey && !serverConfig.isVertexAI) { | ||||||||||||||||||||||||||||||||||||||||||||||||||
return NextResponse.json( | ||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||
error: true, | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -73,7 +76,9 @@ async function request(req: NextRequest, apiKey: string) { | |||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
let baseUrl = serverConfig.googleUrl || GEMINI_BASE_URL; | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
let path = `${req.nextUrl.pathname}`.replaceAll(ApiPath.Google, ""); | ||||||||||||||||||||||||||||||||||||||||||||||||||
let path = serverConfig.vertexAIUrl | ||||||||||||||||||||||||||||||||||||||||||||||||||
? "" | ||||||||||||||||||||||||||||||||||||||||||||||||||
: `${req.nextUrl.pathname}`.replaceAll(ApiPath.Google, ""); | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
if (!baseUrl.startsWith("http")) { | ||||||||||||||||||||||||||||||||||||||||||||||||||
baseUrl = `https://${baseUrl}`; | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -92,18 +97,30 @@ async function request(req: NextRequest, apiKey: string) { | |||||||||||||||||||||||||||||||||||||||||||||||||
}, | ||||||||||||||||||||||||||||||||||||||||||||||||||
10 * 60 * 1000, | ||||||||||||||||||||||||||||||||||||||||||||||||||
); | ||||||||||||||||||||||||||||||||||||||||||||||||||
const fetchUrl = `${baseUrl}${path}${ | ||||||||||||||||||||||||||||||||||||||||||||||||||
req?.nextUrl?.searchParams?.get("alt") === "sse" ? "?alt=sse" : "" | ||||||||||||||||||||||||||||||||||||||||||||||||||
}`; | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
let authHeaderName = "x-goog-api-key"; | ||||||||||||||||||||||||||||||||||||||||||||||||||
let authValue = | ||||||||||||||||||||||||||||||||||||||||||||||||||
req.headers.get(authHeaderName) || | ||||||||||||||||||||||||||||||||||||||||||||||||||
(req.headers.get("Authorization") ?? "").replace("Bearer ", ""); | ||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+101
to
+104
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Robust extraction of the token from the 'Authorization' header The current token extraction might incorrectly handle 'Authorization' headers that do not start with 'Bearer '. To ensure accurate token retrieval, consider refining the extraction logic. Apply this diff to improve token extraction: let authValue =
req.headers.get(authHeaderName) ||
- (req.headers.get("Authorization") ?? "").replace("Bearer ", "");
+ (req.headers.get("Authorization") ?? "").replace(/^Bearer\s+/i, ""); This change uses a regular expression to remove the 'Bearer ' prefix only if it appears at the start of the string, regardless of case and handling any extra whitespace. This prevents unintended replacements elsewhere in the header value. 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
// adjust header and url when use with vertex ai | ||||||||||||||||||||||||||||||||||||||||||||||||||
if (serverConfig.vertexAIUrl) { | ||||||||||||||||||||||||||||||||||||||||||||||||||
authHeaderName = "Authorization"; | ||||||||||||||||||||||||||||||||||||||||||||||||||
const gCloudToken = await getGCloudToken(); | ||||||||||||||||||||||||||||||||||||||||||||||||||
authValue = `Bearer ${gCloudToken}`; | ||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+106
to
+111
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add error handling for 'getGCloudToken()' to prevent unhandled exceptions If Apply this diff to handle potential errors: authHeaderName = "Authorization";
- const gCloudToken = await getGCloudToken();
+ let gCloudToken;
+ try {
+ gCloudToken = await getGCloudToken();
+ } catch (error) {
+ console.error("[Google] Failed to get GCloud token", error);
+ return NextResponse.json(
+ {
+ error: true,
+ message: "Failed to obtain GCloud token",
+ },
+ { status: 500 },
+ );
+ }
authValue = `Bearer ${gCloudToken}`; This ensures that any exceptions thrown by 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||
const fetchUrl = serverConfig.vertexAIUrl | ||||||||||||||||||||||||||||||||||||||||||||||||||
? serverConfig.vertexAIUrl | ||||||||||||||||||||||||||||||||||||||||||||||||||
: `${baseUrl}${path}${ | ||||||||||||||||||||||||||||||||||||||||||||||||||
req?.nextUrl?.searchParams?.get("alt") === "sse" ? "?alt=sse" : "" | ||||||||||||||||||||||||||||||||||||||||||||||||||
}`; | ||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+112
to
+116
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Append query parameters to 'vertexAIUrl' when necessary When using Apply this diff to include the query parameter: const fetchUrl = serverConfig.vertexAIUrl
? serverConfig.vertexAIUrl +
+ (req?.nextUrl?.searchParams?.get("alt") === "sse" ? "?alt=sse" : "")
: `${baseUrl}${path}${
req?.nextUrl?.searchParams?.get("alt") === "sse" ? "?alt=sse" : ""
}`; This modification appends 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
console.log("[Fetch Url] ", fetchUrl); | ||||||||||||||||||||||||||||||||||||||||||||||||||
const fetchOptions: RequestInit = { | ||||||||||||||||||||||||||||||||||||||||||||||||||
headers: { | ||||||||||||||||||||||||||||||||||||||||||||||||||
"Content-Type": "application/json", | ||||||||||||||||||||||||||||||||||||||||||||||||||
"Cache-Control": "no-store", | ||||||||||||||||||||||||||||||||||||||||||||||||||
"x-goog-api-key": | ||||||||||||||||||||||||||||||||||||||||||||||||||
req.headers.get("x-goog-api-key") || | ||||||||||||||||||||||||||||||||||||||||||||||||||
(req.headers.get("Authorization") ?? "").replace("Bearer ", ""), | ||||||||||||||||||||||||||||||||||||||||||||||||||
[authHeaderName]: authValue, | ||||||||||||||||||||||||||||||||||||||||||||||||||
}, | ||||||||||||||||||||||||||||||||||||||||||||||||||
method: req.method, | ||||||||||||||||||||||||||||||||||||||||||||||||||
body: req.body, | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add security warning and format guidance for GOOGLE_CLOUD_JSON_KEY.
The service account key requires specific formatting and contains sensitive information. This should be clearly documented.
Apply this diff to add security guidance:
📝 Committable suggestion