Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: google #5045

Merged
merged 1 commit into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 66 additions & 50 deletions app/api/google/[...path]/route.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import { NextRequest, NextResponse } from "next/server";
import { auth } from "../../auth";
import { getServerSideConfig } from "@/app/config/server";
import { GEMINI_BASE_URL, Google, ModelProvider } from "@/app/constant";
import {
ApiPath,
GEMINI_BASE_URL,
Google,
ModelProvider,
} from "@/app/constant";
import { prettyObject } from "@/app/utils/format";

const serverConfig = getServerSideConfig();

async function handle(
req: NextRequest,
Expand All @@ -13,32 +21,6 @@ async function handle(
return NextResponse.json({ body: "OK" }, { status: 200 });
}

const controller = new AbortController();

const serverConfig = getServerSideConfig();

let baseUrl = serverConfig.googleUrl || GEMINI_BASE_URL;

if (!baseUrl.startsWith("http")) {
baseUrl = `https://${baseUrl}`;
}

if (baseUrl.endsWith("/")) {
baseUrl = baseUrl.slice(0, -1);
}

let path = `${req.nextUrl.pathname}`.replaceAll("/api/google/", "");

console.log("[Proxy] ", path);
console.log("[Base Url]", baseUrl);

const timeoutId = setTimeout(
() => {
controller.abort();
},
10 * 60 * 1000,
);

const authResult = auth(req, ModelProvider.GeminiPro);
if (authResult.error) {
return NextResponse.json(authResult, {
Expand All @@ -49,9 +31,9 @@ async function handle(
const bearToken = req.headers.get("Authorization") ?? "";
const token = bearToken.trim().replaceAll("Bearer ", "").trim();

const key = token ? token : serverConfig.googleApiKey;
const apiKey = token ? token : serverConfig.googleApiKey;

if (!key) {
if (!apiKey) {
return NextResponse.json(
{
error: true,
Expand All @@ -62,10 +44,63 @@ async function handle(
},
);
}
try {
const response = await request(req, apiKey);
return response;
} catch (e) {
console.error("[Google] ", e);
return NextResponse.json(prettyObject(e));
}
}

const fetchUrl = `${baseUrl}/${path}?key=${key}${
req?.nextUrl?.searchParams?.get("alt") == "sse" ? "&alt=sse" : ""
export const GET = handle;
export const POST = handle;

Comment on lines +56 to +58
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separation of HTTP methods into distinct exports.

Exporting GET and POST as separate functions, although both point to handle, could be confusing since they perform the same operation. Consider renaming or merging them if their functionality does not differ, or provide additional context if they are intended to diverge in the future.

export const runtime = "edge";
export const preferredRegion = [
"bom1",
"cle1",
"cpt1",
"gru1",
"hnd1",
"iad1",
"icn1",
"kix1",
"pdx1",
"sfo1",
"sin1",
"syd1",
];

async function request(req: NextRequest, apiKey: string) {
const controller = new AbortController();

let baseUrl = serverConfig.googleUrl || GEMINI_BASE_URL;

let path = `${req.nextUrl.pathname}`.replaceAll(ApiPath.Google, "");

if (!baseUrl.startsWith("http")) {
baseUrl = `https://${baseUrl}`;
}

if (baseUrl.endsWith("/")) {
baseUrl = baseUrl.slice(0, -1);
}

console.log("[Proxy] ", path);
console.log("[Base Url]", baseUrl);

const timeoutId = setTimeout(
() => {
controller.abort();
},
10 * 60 * 1000,
);
const fetchUrl = `${baseUrl}${path}?key=${apiKey}${
req?.nextUrl?.searchParams?.get("alt") === "sse" ? "&alt=sse" : ""
}`;

console.log("[Fetch Url] ", fetchUrl);
const fetchOptions: RequestInit = {
headers: {
"Content-Type": "application/json",
Expand Down Expand Up @@ -97,22 +132,3 @@ async function handle(
clearTimeout(timeoutId);
}
}

export const GET = handle;
export const POST = handle;

export const runtime = "edge";
export const preferredRegion = [
"bom1",
"cle1",
"cpt1",
"gru1",
"hnd1",
"iad1",
"icn1",
"kix1",
"pdx1",
"sfo1",
"sin1",
"syd1",
];
69 changes: 33 additions & 36 deletions app/client/platforms/google.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Google, REQUEST_TIMEOUT_MS } from "@/app/constant";
import { ApiPath, Google, REQUEST_TIMEOUT_MS } from "@/app/constant";
import { ChatOptions, getHeaders, LLMApi, LLMModel, LLMUsage } from "../api";
import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
import { getClientConfig } from "@/app/config/client";
Expand All @@ -16,6 +16,34 @@ import {
} from "@/app/utils";

export class GeminiProApi implements LLMApi {
path(path: string): string {
const accessStore = useAccessStore.getState();

let baseUrl = "";
if (accessStore.useCustomConfig) {
baseUrl = accessStore.googleUrl;
}

if (baseUrl.length === 0) {
const isApp = !!getClientConfig()?.isApp;
baseUrl = isApp
? DEFAULT_API_HOST + `/api/proxy/google?key=${accessStore.googleApiKey}`
: ApiPath.Google;
}
if (baseUrl.endsWith("/")) {
baseUrl = baseUrl.slice(0, baseUrl.length - 1);
}
if (!baseUrl.startsWith("http") && !baseUrl.startsWith(ApiPath.Google)) {
baseUrl = "https://" + baseUrl;
}

console.log("[Proxy Endpoint] ", baseUrl, path);

let chatPath = [baseUrl, path].join("/");

chatPath += chatPath.includes("?") ? "&alt=sse" : "?alt=sse";
return chatPath;
}
extractMessage(res: any) {
console.log("[Response] gemini-pro response: ", res);

Expand Down Expand Up @@ -108,30 +136,13 @@ export class GeminiProApi implements LLMApi {
],
};

const accessStore = useAccessStore.getState();

let baseUrl = "";

if (accessStore.useCustomConfig) {
baseUrl = accessStore.googleUrl;
}

const isApp = !!getClientConfig()?.isApp;

let shouldStream = !!options.config.stream;
const controller = new AbortController();
options.onController?.(controller);
try {
if (!baseUrl && isApp) {
baseUrl = DEFAULT_API_HOST + "/api/proxy/google/";
}
baseUrl = `${baseUrl}/${Google.ChatPath(modelConfig.model)}`.replaceAll(
"//",
"/",
);
if (isApp) {
baseUrl += `?key=${accessStore.googleApiKey}`;
}
// https://github.com/google-gemini/cookbook/blob/main/quickstarts/rest/Streaming_REST.ipynb
const chatPath = this.path(Google.ChatPath(modelConfig.model));

const chatPayload = {
method: "POST",
body: JSON.stringify(requestPayload),
Expand Down Expand Up @@ -181,10 +192,6 @@ export class GeminiProApi implements LLMApi {

controller.signal.onabort = finish;

// https://github.com/google-gemini/cookbook/blob/main/quickstarts/rest/Streaming_REST.ipynb
const chatPath =
baseUrl.replace("generateContent", "streamGenerateContent") +
(baseUrl.indexOf("?") > -1 ? "&alt=sse" : "?alt=sse");
fetchEventSource(chatPath, {
...chatPayload,
async onopen(res) {
Expand Down Expand Up @@ -259,7 +266,7 @@ export class GeminiProApi implements LLMApi {
openWhenHidden: true,
});
} else {
const res = await fetch(baseUrl, chatPayload);
const res = await fetch(chatPath, chatPayload);
clearTimeout(requestTimeoutId);
const resJson = await res.json();
if (resJson?.promptFeedback?.blockReason) {
Expand All @@ -285,14 +292,4 @@ export class GeminiProApi implements LLMApi {
async models(): Promise<LLMModel[]> {
return [];
}
path(path: string): string {
return "/api/google/" + path;
}
}

function ensureProperEnding(str: string) {
if (str.startsWith("[") && !str.endsWith("]")) {
return str + "]";
}
return str;
}
3 changes: 2 additions & 1 deletion app/constant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ export const Azure = {

export const Google = {
ExampleEndpoint: "https://generativelanguage.googleapis.com/",
ChatPath: (modelName: string) => `v1beta/models/${modelName}:generateContent`,
ChatPath: (modelName: string) =>
`v1beta/models/${modelName}:streamGenerateContent`,
};

export const Baidu = {
Expand Down
Loading