Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement Vertex AI support for Anthropic and Google models #5794

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions app/api/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { NextRequest, NextResponse } from "next/server";
import { auth } from "./auth";
import { isModelAvailableInServer } from "@/app/utils/model";
import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare";
import { getGCloudToken } from "./common";

const ALLOWD_PATH = new Set([Anthropic.ChatPath, Anthropic.ChatPath1]);

Expand Down Expand Up @@ -67,10 +68,20 @@ async function request(req: NextRequest) {
serverConfig.anthropicApiKey ||
"";

let path = `${req.nextUrl.pathname}`.replaceAll(ApiPath.Anthropic, "");
// adjust header and url when using vertex ai
if (serverConfig.isVertexAI) {
authHeaderName = "Authorization";
const gCloudToken = await getGCloudToken();
authValue = `Bearer ${gCloudToken}`;
}

let path = serverConfig.vertexAIUrl
? ""
: `${req.nextUrl.pathname}`.replaceAll(ApiPath.Anthropic, "");

let baseUrl =
serverConfig.anthropicUrl || serverConfig.baseUrl || ANTHROPIC_BASE_URL;
let baseUrl = serverConfig.vertexAIUrl
? serverConfig.vertexAIUrl
: serverConfig.anthropicUrl || serverConfig.baseUrl || ANTHROPIC_BASE_URL;

if (!baseUrl.startsWith("http")) {
baseUrl = `https://${baseUrl}`;
Expand Down Expand Up @@ -112,13 +123,16 @@ async function request(req: NextRequest) {
signal: controller.signal,
};

// #1815 try to refuse some request to some models
// #1815 try to refuse some request to some models or tick json body for vertex ai
if (serverConfig.customModels && req.body) {
try {
const clonedBody = await req.text();
fetchOptions.body = clonedBody;

const jsonBody = JSON.parse(clonedBody) as { model?: string };
const jsonBody = JSON.parse(clonedBody) as {
model?: string;
anthropic_version?: string;
};
Comment on lines +132 to +135
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Handle JSON parsing errors appropriately

If the request body cannot be parsed as JSON, the code logs the error but continues processing with the original body. This could lead to unexpected behavior downstream. Consider returning an error response when JSON parsing fails to ensure the request is handled correctly.

Apply this diff to return an error response when JSON parsing fails:

        } catch (e) {
          console.error(`[Anthropic] filter`, e);
+         return NextResponse.json(
+           { error: true, message: "Invalid JSON in request body" },
+           { status: 400 },
+         );
        }

Committable suggestion skipped: line range outside the PR's diff.


// not undefined and is false
if (
Expand All @@ -138,6 +152,14 @@ async function request(req: NextRequest) {
},
);
}

// tick json body for vertex ai and update fetch options
if (serverConfig.isVertexAI) {
delete jsonBody.model;
jsonBody.anthropic_version =
serverConfig.anthropicApiVersion || "vertex-2023-10-16";
fetchOptions.body = JSON.stringify(jsonBody);
}
} catch (e) {
console.error(`[Anthropic] filter`, e);
}
Expand Down
23 changes: 23 additions & 0 deletions app/api/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { NextRequest, NextResponse } from "next/server";
import { getServerSideConfig } from "../config/server";
import { OPENAI_BASE_URL, ServiceProvider } from "../constant";
import { cloudflareAIGatewayUrl } from "../utils/cloudflare";
import { GoogleToken } from "../utils/gtoken";
import { getModelProvider, isModelAvailableInServer } from "../utils/model";

const serverConfig = getServerSideConfig();
Expand Down Expand Up @@ -185,3 +186,25 @@ export async function requestOpenai(req: NextRequest) {
clearTimeout(timeoutId);
}
}

let gTokenClient: GoogleToken | undefined;

/**
* Get access token for google cloud,
* requires GOOGLE_CLOUD_JSON_KEY to be set
* @returns access token for google cloud
*/
export async function getGCloudToken() {
if (!gTokenClient) {
if (!serverConfig.googleCloudJsonKey)
throw new Error("GOOGLE_CLOUD_JSON_KEY is not set");
const keys = JSON.parse(serverConfig.googleCloudJsonKey);
gTokenClient = new GoogleToken({
email: keys.client_email,
key: keys.private_key,
scope: ["https://www.googleapis.com/auth/cloud-platform"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Review the requested OAuth scope.

The current scope https://www.googleapis.com/auth/cloud-platform provides broad access to Google Cloud Platform resources. Consider using more specific scopes that align with the minimum permissions required for Vertex AI operations.

For Vertex AI, you might only need:

  • https://www.googleapis.com/auth/cloud-platform.read-only (if only calling models)
  • https://www.googleapis.com/auth/aiplatform (if using other Vertex AI features)

});
}
const credentials = await gTokenClient?.getToken();
return credentials?.access_token;
}
Comment on lines +197 to +210
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add robust error handling and type safety.

The current implementation has several potential issues:

  1. The JSON.parse call could throw on invalid JSON
  2. No validation of the parsed JSON structure
  3. Potential undefined access in gTokenClient?.getToken()
  4. No explicit handling of token expiration

Consider applying these improvements:

 export async function getGCloudToken() {
   if (!gTokenClient) {
     if (!serverConfig.googleCloudJsonKey)
       throw new Error("GOOGLE_CLOUD_JSON_KEY is not set");
-    const keys = JSON.parse(serverConfig.googleCloudJsonKey);
+    let keys;
+    try {
+      keys = JSON.parse(serverConfig.googleCloudJsonKey);
+      if (!keys.client_email || !keys.private_key) {
+        throw new Error("Invalid key structure");
+      }
+    } catch (error) {
+      throw new Error("Invalid GOOGLE_CLOUD_JSON_KEY format");
+    }
     gTokenClient = new GoogleToken({
       email: keys.client_email,
       key: keys.private_key,
       scope: ["https://www.googleapis.com/auth/cloud-platform"],
     });
   }
-  const credentials = await gTokenClient?.getToken();
-  return credentials?.access_token;
+  if (!gTokenClient) {
+    throw new Error("Failed to initialize Google Token client");
+  }
+  const credentials = await gTokenClient.getToken();
+  if (!credentials?.access_token) {
+    throw new Error("Failed to obtain access token");
+  }
+  return credentials.access_token;
 }

Also, consider adding JSDoc type definitions for the expected JSON structure:

interface GoogleCloudKey {
  client_email: string;
  private_key: string;
  // add other required fields
}

33 changes: 25 additions & 8 deletions app/api/google.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { auth } from "./auth";
import { getServerSideConfig } from "@/app/config/server";
import { ApiPath, GEMINI_BASE_URL, ModelProvider } from "@/app/constant";
import { prettyObject } from "@/app/utils/format";
import { getGCloudToken } from "./common";

const serverConfig = getServerSideConfig();

Expand All @@ -29,7 +30,9 @@ export async function handle(

const apiKey = token ? token : serverConfig.googleApiKey;

if (!apiKey) {
// When using Vertex AI, the API key is not required.
// Instead, a GCloud token will be used later in the request.
if (!apiKey && !serverConfig.isVertexAI) {
return NextResponse.json(
{
error: true,
Expand Down Expand Up @@ -73,7 +76,9 @@ async function request(req: NextRequest, apiKey: string) {

let baseUrl = serverConfig.googleUrl || GEMINI_BASE_URL;

let path = `${req.nextUrl.pathname}`.replaceAll(ApiPath.Google, "");
let path = serverConfig.vertexAIUrl
? ""
: `${req.nextUrl.pathname}`.replaceAll(ApiPath.Google, "");

if (!baseUrl.startsWith("http")) {
baseUrl = `https://${baseUrl}`;
Expand All @@ -92,18 +97,30 @@ async function request(req: NextRequest, apiKey: string) {
},
10 * 60 * 1000,
);
const fetchUrl = `${baseUrl}${path}${
req?.nextUrl?.searchParams?.get("alt") === "sse" ? "?alt=sse" : ""
}`;

let authHeaderName = "x-goog-api-key";
let authValue =
req.headers.get(authHeaderName) ||
(req.headers.get("Authorization") ?? "").replace("Bearer ", "");
Comment on lines +101 to +104
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Robust extraction of the token from the 'Authorization' header

The current token extraction might incorrectly handle 'Authorization' headers that do not start with 'Bearer '. To ensure accurate token retrieval, consider refining the extraction logic.

Apply this diff to improve token extraction:

let authValue =
  req.headers.get(authHeaderName) ||
- (req.headers.get("Authorization") ?? "").replace("Bearer ", "");
+ (req.headers.get("Authorization") ?? "").replace(/^Bearer\s+/i, "");

This change uses a regular expression to remove the 'Bearer ' prefix only if it appears at the start of the string, regardless of case and handling any extra whitespace. This prevents unintended replacements elsewhere in the header value.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
let authHeaderName = "x-goog-api-key";
let authValue =
req.headers.get(authHeaderName) ||
(req.headers.get("Authorization") ?? "").replace("Bearer ", "");
let authHeaderName = "x-goog-api-key";
let authValue =
req.headers.get(authHeaderName) ||
(req.headers.get("Authorization") ?? "").replace(/^Bearer\s+/i, "");


// adjust header and url when use with vertex ai
if (serverConfig.vertexAIUrl) {
authHeaderName = "Authorization";
const gCloudToken = await getGCloudToken();
authValue = `Bearer ${gCloudToken}`;
}
Comment on lines +106 to +111
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add error handling for 'getGCloudToken()' to prevent unhandled exceptions

If getGCloudToken() fails (e.g., due to network errors or authentication issues), the absence of error handling could cause unhandled exceptions. Implementing error handling ensures the server responds gracefully.

Apply this diff to handle potential errors:

authHeaderName = "Authorization";
- const gCloudToken = await getGCloudToken();
+ let gCloudToken;
+ try {
+   gCloudToken = await getGCloudToken();
+ } catch (error) {
+   console.error("[Google] Failed to get GCloud token", error);
+   return NextResponse.json(
+     {
+       error: true,
+       message: "Failed to obtain GCloud token",
+     },
+     { status: 500 },
+   );
+ }
authValue = `Bearer ${gCloudToken}`;

This ensures that any exceptions thrown by getGCloudToken() are caught, logged, and a meaningful error response is returned to the client.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// adjust header and url when use with vertex ai
if (serverConfig.vertexAIUrl) {
authHeaderName = "Authorization";
const gCloudToken = await getGCloudToken();
authValue = `Bearer ${gCloudToken}`;
}
// adjust header and url when use with vertex ai
if (serverConfig.vertexAIUrl) {
authHeaderName = "Authorization";
let gCloudToken;
try {
gCloudToken = await getGCloudToken();
} catch (error) {
console.error("[Google] Failed to get GCloud token", error);
return NextResponse.json(
{
error: true,
message: "Failed to obtain GCloud token",
},
{ status: 500 },
);
}
authValue = `Bearer ${gCloudToken}`;
}

const fetchUrl = serverConfig.vertexAIUrl
? serverConfig.vertexAIUrl
: `${baseUrl}${path}${
req?.nextUrl?.searchParams?.get("alt") === "sse" ? "?alt=sse" : ""
}`;
Comment on lines +112 to +116
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Append query parameters to 'vertexAIUrl' when necessary

When using serverConfig.vertexAIUrl, the fetchUrl does not include the query parameter for alt=sse, which may be required for certain functionalities like streaming responses. Adjusting the code to append this parameter when present in the request ensures consistent behavior.

Apply this diff to include the query parameter:

const fetchUrl = serverConfig.vertexAIUrl
  ? serverConfig.vertexAIUrl +
+   (req?.nextUrl?.searchParams?.get("alt") === "sse" ? "?alt=sse" : "")
  : `${baseUrl}${path}${
      req?.nextUrl?.searchParams?.get("alt") === "sse" ? "?alt=sse" : ""
    }`;

This modification appends ?alt=sse to vertexAIUrl when the request's search parameters include alt=sse, matching the behavior of the non-Vertex AI URL handling.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
const fetchUrl = serverConfig.vertexAIUrl
? serverConfig.vertexAIUrl
: `${baseUrl}${path}${
req?.nextUrl?.searchParams?.get("alt") === "sse" ? "?alt=sse" : ""
}`;
const fetchUrl = serverConfig.vertexAIUrl
? serverConfig.vertexAIUrl +
(req?.nextUrl?.searchParams?.get("alt") === "sse" ? "?alt=sse" : "")
: `${baseUrl}${path}${
req?.nextUrl?.searchParams?.get("alt") === "sse" ? "?alt=sse" : ""
}`;


console.log("[Fetch Url] ", fetchUrl);
const fetchOptions: RequestInit = {
headers: {
"Content-Type": "application/json",
"Cache-Control": "no-store",
"x-goog-api-key":
req.headers.get("x-goog-api-key") ||
(req.headers.get("Authorization") ?? "").replace("Bearer ", ""),
[authHeaderName]: authValue,
},
method: req.method,
body: req.body,
Expand Down
9 changes: 8 additions & 1 deletion app/client/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,14 @@ export function getHeaders(ignoreHeaders: boolean = false) {

if (bearerToken) {
headers[authHeader] = bearerToken;
} else if (isEnabledAccessControl && validString(accessStore.accessCode)) {
}
// ensure access code is being sent when access control is enabled,
// this will fix an issue where the access code is not being sent when provider is google, azure or anthropic
if (
isEnabledAccessControl &&
validString(accessStore.accessCode) &&
authHeader !== "Authorization"
) {
headers["Authorization"] = getBearerToken(
ACCESS_CODE_PREFIX + accessStore.accessCode,
);
Expand Down
9 changes: 9 additions & 0 deletions app/config/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ declare global {
ANTHROPIC_API_KEY?: string;
ANTHROPIC_API_VERSION?: string;

// google cloud vertex ai only
VERTEX_AI_URL?: string; // https://{loc}-aiaiplatfor.googleapis.com/v1/{project}/locations/{loc}/models/{model}/versions/{version}:predict
GOOGLE_CLOUD_JSON_KEY?: string; // service account json key content

// baidu only
BAIDU_URL?: string;
BAIDU_API_KEY?: string;
Expand Down Expand Up @@ -148,6 +152,7 @@ export const getServerSideConfig = () => {
const isGoogle = !!process.env.GOOGLE_API_KEY;
const isAnthropic = !!process.env.ANTHROPIC_API_KEY;
const isTencent = !!process.env.TENCENT_API_KEY;
const isVertexAI = !!process.env.VERTEX_AI_URL;

const isBaidu = !!process.env.BAIDU_API_KEY;
const isBytedance = !!process.env.BYTEDANCE_API_KEY;
Expand Down Expand Up @@ -191,6 +196,10 @@ export const getServerSideConfig = () => {
anthropicApiVersion: process.env.ANTHROPIC_API_VERSION,
anthropicUrl: process.env.ANTHROPIC_URL,

isVertexAI,
vertexAIUrl: process.env.VERTEX_AI_URL,
googleCloudJsonKey: process.env.GOOGLE_CLOUD_JSON_KEY,

isBaidu,
baiduUrl: process.env.BAIDU_URL,
baiduApiKey: getApiKey(process.env.BAIDU_API_KEY),
Expand Down
2 changes: 2 additions & 0 deletions app/layout.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import type { Metadata, Viewport } from "next";
import { SpeedInsights } from "@vercel/speed-insights/next";
import { getServerSideConfig } from "./config/server";
import { GoogleTagManager, GoogleAnalytics } from "@next/third-parties/google";
import { Analytics } from "@vercel/analytics/react";
const serverConfig = getServerSideConfig();

export const metadata: Metadata = {
Expand Down Expand Up @@ -65,6 +66,7 @@ export default function RootLayout({
<GoogleAnalytics gaId={serverConfig.gaId} />
</>
)}
<Analytics />
</body>
</html>
);
Expand Down
4 changes: 2 additions & 2 deletions app/utils/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import {
UPLOAD_URL,
REQUEST_TIMEOUT_MS,
} from "@/app/constant";
import { RequestMessage } from "@/app/client/api";
import { ChatOptions, RequestMessage } from "@/app/client/api";
import Locale from "@/app/locales";
import {
EventStreamContentType,
Expand Down Expand Up @@ -167,7 +167,7 @@ export function stream(
toolCallMessage: any,
toolCallResult: any[],
) => void,
options: any,
options: ChatOptions,
) {
let responseText = "";
let remainText = "";
Expand Down
Loading