diff --git a/app/api/auth.ts b/app/api/auth.ts index b750f2d17319..2b4702aedc37 100644 --- a/app/api/auth.ts +++ b/app/api/auth.ts @@ -75,7 +75,7 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) { break; case ModelProvider.GPT: default: - if (serverConfig.isAzure) { + if (req.nextUrl.pathname.includes("azure/deployments")) { systemApiKey = serverConfig.azureApiKey; } else { systemApiKey = serverConfig.apiKey; diff --git a/app/api/azure/[...path]/route.ts b/app/api/azure/[...path]/route.ts new file mode 100644 index 000000000000..4a17de0c8abc --- /dev/null +++ b/app/api/azure/[...path]/route.ts @@ -0,0 +1,57 @@ +import { getServerSideConfig } from "@/app/config/server"; +import { ModelProvider } from "@/app/constant"; +import { prettyObject } from "@/app/utils/format"; +import { NextRequest, NextResponse } from "next/server"; +import { auth } from "../../auth"; +import { requestOpenai } from "../../common"; + +async function handle( + req: NextRequest, + { params }: { params: { path: string[] } }, +) { + console.log("[Azure Route] params ", params); + + if (req.method === "OPTIONS") { + return NextResponse.json({ body: "OK" }, { status: 200 }); + } + + const subpath = params.path.join("/"); + + const authResult = auth(req, ModelProvider.GPT); + if (authResult.error) { + return NextResponse.json(authResult, { + status: 401, + }); + } + + try { + return await requestOpenai(req); + } catch (e) { + console.error("[Azure] ", e); + return NextResponse.json(prettyObject(e)); + } +} + +export const GET = handle; +export const POST = handle; + +export const runtime = "edge"; +export const preferredRegion = [ + "arn1", + "bom1", + "cdg1", + "cle1", + "cpt1", + "dub1", + "fra1", + "gru1", + "hnd1", + "iad1", + "icn1", + "kix1", + "lhr1", + "pdx1", + "sfo1", + "sin1", + "syd1", +]; diff --git a/app/api/common.ts b/app/api/common.ts index 1454fde2ed13..b2fae6df24f7 100644 --- a/app/api/common.ts +++ b/app/api/common.ts @@ -7,16 +7,17 @@ import { ServiceProvider, } from "../constant"; import { isModelAvailableInServer } from "../utils/model"; -import { makeAzurePath } from "../azure"; const serverConfig = getServerSideConfig(); export async function requestOpenai(req: NextRequest) { const controller = new AbortController(); + const isAzure = req.nextUrl.pathname.includes("azure/deployments"); + var authValue, authHeaderName = ""; - if (serverConfig.isAzure) { + if (isAzure) { authValue = req.headers .get("Authorization") @@ -56,14 +57,15 @@ export async function requestOpenai(req: NextRequest) { 10 * 60 * 1000, ); - if (serverConfig.isAzure) { - if (!serverConfig.azureApiVersion) { - return NextResponse.json({ - error: true, - message: `missing AZURE_API_VERSION in server env vars`, - }); - } - path = makeAzurePath(path, serverConfig.azureApiVersion); + if (isAzure) { + const azureApiVersion = + req?.nextUrl?.searchParams?.get("api-version") || + serverConfig.azureApiVersion; + baseUrl = baseUrl.split("/deployments").shift() as string; + path = `${req.nextUrl.pathname.replaceAll( + "/api/azure/", + "", + )}?api-version=${azureApiVersion}`; } const fetchUrl = `${baseUrl}/${path}`; diff --git a/app/azure.ts b/app/azure.ts deleted file mode 100644 index 48406c55ba5d..000000000000 --- a/app/azure.ts +++ /dev/null @@ -1,9 +0,0 @@ -export function makeAzurePath(path: string, apiVersion: string) { - // should omit /v1 prefix - path = path.replaceAll("v1/", ""); - - // should add api-key to query string - path += `${path.includes("?") ? "&" : "?"}api-version=${apiVersion}`; - - return path; -} diff --git a/app/client/api.ts b/app/client/api.ts index edee993424a1..528a5598aec9 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -30,6 +30,7 @@ export interface RequestMessage { export interface LLMConfig { model: string; + providerName?: string; temperature?: number; top_p?: number; stream?: boolean; @@ -54,6 +55,7 @@ export interface LLMUsage { export interface LLMModel { name: string; + displayName?: string; available: boolean; provider: LLMModelProvider; } @@ -155,41 +157,70 @@ export class ClientApi { export function getHeaders() { const accessStore = useAccessStore.getState(); + const chatStore = useChatStore.getState(); const headers: Record = { "Content-Type": "application/json", Accept: "application/json", }; - const modelConfig = useChatStore.getState().currentSession().mask.modelConfig; - const isGoogle = modelConfig.model.startsWith("gemini"); - const isAzure = accessStore.provider === ServiceProvider.Azure; - const isAnthropic = accessStore.provider === ServiceProvider.Anthropic; - const authHeader = isAzure ? "api-key" : isAnthropic ? 'x-api-key' : "Authorization"; - const apiKey = isGoogle - ? accessStore.googleApiKey - : isAzure - ? accessStore.azureApiKey - : isAnthropic - ? accessStore.anthropicApiKey - : accessStore.openaiApiKey; + const clientConfig = getClientConfig(); - const makeBearer = (s: string) => `${isAzure || isAnthropic ? "" : "Bearer "}${s.trim()}`; - const validString = (x: string) => x && x.length > 0; + function getConfig() { + const modelConfig = chatStore.currentSession().mask.modelConfig; + const isGoogle = modelConfig.providerName == ServiceProvider.Google; + const isAzure = modelConfig.providerName === ServiceProvider.Azure; + const isAnthropic = modelConfig.providerName === ServiceProvider.Anthropic; + const isEnabledAccessControl = accessStore.enabledAccessControl(); + const apiKey = isGoogle + ? accessStore.googleApiKey + : isAzure + ? accessStore.azureApiKey + : isAnthropic + ? accessStore.anthropicApiKey + : accessStore.openaiApiKey; + return { isGoogle, isAzure, isAnthropic, apiKey, isEnabledAccessControl }; + } + + function getAuthHeader(): string { + return isAzure ? "api-key" : isAnthropic ? "x-api-key" : "Authorization"; + } + + function getBearerToken(apiKey: string, noBearer: boolean = false): string { + return validString(apiKey) + ? `${noBearer ? "" : "Bearer "}${apiKey.trim()}` + : ""; + } + + function validString(x: string): boolean { + return x?.length > 0; + } + const { isGoogle, isAzure, isAnthropic, apiKey, isEnabledAccessControl } = + getConfig(); // when using google api in app, not set auth header - if (!(isGoogle && clientConfig?.isApp)) { - // use user's api key first - if (validString(apiKey)) { - headers[authHeader] = makeBearer(apiKey); - } else if ( - accessStore.enabledAccessControl() && - validString(accessStore.accessCode) - ) { - // access_code must send with header named `Authorization`, will using in auth middleware. - headers['Authorization'] = makeBearer( - ACCESS_CODE_PREFIX + accessStore.accessCode, - ); - } + if (isGoogle && clientConfig?.isApp) return headers; + + const authHeader = getAuthHeader(); + + const bearerToken = getBearerToken(apiKey, isAzure || isAnthropic); + + if (bearerToken) { + headers[authHeader] = bearerToken; + } else if (isEnabledAccessControl && validString(accessStore.accessCode)) { + headers["Authorization"] = getBearerToken( + ACCESS_CODE_PREFIX + accessStore.accessCode, + ); } return headers; } + +export function getClientApi(provider: ServiceProvider): ClientApi { + switch (provider) { + case ServiceProvider.Google: + return new ClientApi(ModelProvider.GeminiPro); + case ServiceProvider.Anthropic: + return new ClientApi(ModelProvider.Claude); + default: + return new ClientApi(ModelProvider.GPT); + } +} diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index f3599263023c..8615172a311d 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -1,13 +1,16 @@ "use client"; +// azure and openai, using same models. so using same LLMApi. import { ApiPath, DEFAULT_API_HOST, DEFAULT_MODELS, OpenaiPath, + Azure, REQUEST_TIMEOUT_MS, ServiceProvider, } from "@/app/constant"; import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; +import { collectModelsWithDefaultModel } from "@/app/utils/model"; import { ChatOptions, @@ -24,7 +27,6 @@ import { } from "@fortaine/fetch-event-source"; import { prettyObject } from "@/app/utils/format"; import { getClientConfig } from "@/app/config/client"; -import { makeAzurePath } from "@/app/azure"; import { getMessageTextContent, getMessageImages, @@ -62,33 +64,31 @@ export class ChatGPTApi implements LLMApi { let baseUrl = ""; + const isAzure = path.includes("deployments"); if (accessStore.useCustomConfig) { - const isAzure = accessStore.provider === ServiceProvider.Azure; - if (isAzure && !accessStore.isValidAzure()) { throw Error( "incomplete azure config, please check it in your settings page", ); } - if (isAzure) { - path = makeAzurePath(path, accessStore.azureApiVersion); - } - baseUrl = isAzure ? accessStore.azureUrl : accessStore.openaiUrl; } if (baseUrl.length === 0) { const isApp = !!getClientConfig()?.isApp; - baseUrl = isApp - ? DEFAULT_API_HOST + "/proxy" + ApiPath.OpenAI - : ApiPath.OpenAI; + const apiPath = isAzure ? ApiPath.Azure : ApiPath.OpenAI; + baseUrl = isApp ? DEFAULT_API_HOST + "/proxy" + apiPath : apiPath; } if (baseUrl.endsWith("/")) { baseUrl = baseUrl.slice(0, baseUrl.length - 1); } - if (!baseUrl.startsWith("http") && !baseUrl.startsWith(ApiPath.OpenAI)) { + if ( + !baseUrl.startsWith("http") && + !isAzure && + !baseUrl.startsWith(ApiPath.OpenAI) + ) { baseUrl = "https://" + baseUrl; } @@ -113,6 +113,7 @@ export class ChatGPTApi implements LLMApi { ...useChatStore.getState().currentSession().mask.modelConfig, ...{ model: options.config.model, + providerName: options.config.providerName, }, }; @@ -140,7 +141,35 @@ export class ChatGPTApi implements LLMApi { options.onController?.(controller); try { - const chatPath = this.path(OpenaiPath.ChatPath); + let chatPath = ""; + if (modelConfig.providerName === ServiceProvider.Azure) { + // find model, and get displayName as deployName + const { models: configModels, customModels: configCustomModels } = + useAppConfig.getState(); + const { + defaultModel, + customModels: accessCustomModels, + useCustomConfig, + } = useAccessStore.getState(); + const models = collectModelsWithDefaultModel( + configModels, + [configCustomModels, accessCustomModels].join(","), + defaultModel, + ); + const model = models.find( + (model) => + model.name === modelConfig.model && + model?.provider?.providerName === ServiceProvider.Azure, + ); + chatPath = this.path( + Azure.ChatPath( + (model?.displayName ?? model?.name) as string, + useCustomConfig ? useAccessStore.getState().azureApiVersion : "", + ), + ); + } else { + chatPath = this.path(OpenaiPath.ChatPath); + } const chatPayload = { method: "POST", body: JSON.stringify(requestPayload), diff --git a/app/components/chat.tsx b/app/components/chat.tsx index 061192504652..b1bdf757f442 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -88,6 +88,7 @@ import { Path, REQUEST_TIMEOUT_MS, UNFINISHED_INPUT, + ServiceProvider, } from "../constant"; import { Avatar } from "./emoji"; import { ContextPrompts, MaskAvatar, MaskConfig } from "./mask"; @@ -448,6 +449,9 @@ export function ChatActions(props: { // switch model const currentModel = chatStore.currentSession().mask.modelConfig.model; + const currentProviderName = + chatStore.currentSession().mask.modelConfig?.providerName || + ServiceProvider.OpenAI; const allModels = useAllModels(); const models = useMemo(() => { const filteredModels = allModels.filter((m) => m.available); @@ -479,13 +483,13 @@ export function ChatActions(props: { const isUnavaliableModel = !models.some((m) => m.name === currentModel); if (isUnavaliableModel && models.length > 0) { // show next model to default model if exist - let nextModel: ModelType = ( - models.find((model) => model.isDefault) || models[0] - ).name; - chatStore.updateCurrentSession( - (session) => (session.mask.modelConfig.model = nextModel), - ); - showToast(nextModel); + let nextModel = models.find((model) => model.isDefault) || models[0]; + chatStore.updateCurrentSession((session) => { + session.mask.modelConfig.model = nextModel.name; + session.mask.modelConfig.providerName = nextModel?.provider + ?.providerName as ServiceProvider; + }); + showToast(nextModel.name); } }, [chatStore, currentModel, models]); @@ -573,19 +577,26 @@ export function ChatActions(props: { {showModelSelector && ( ({ - title: m.displayName, - value: m.name, + title: `${m.displayName}${ + m?.provider?.providerName + ? "(" + m?.provider?.providerName + ")" + : "" + }`, + value: `${m.name}@${m?.provider?.providerName}`, }))} onClose={() => setShowModelSelector(false)} onSelection={(s) => { if (s.length === 0) return; + const [model, providerName] = s[0].split("@"); chatStore.updateCurrentSession((session) => { - session.mask.modelConfig.model = s[0] as ModelType; + session.mask.modelConfig.model = model as ModelType; + session.mask.modelConfig.providerName = + providerName as ServiceProvider; session.mask.syncGlobalConfig = false; }); - showToast(s[0]); + showToast(model); }} /> )} diff --git a/app/components/exporter.tsx b/app/components/exporter.tsx index 20e240d93b0d..948807d4ca23 100644 --- a/app/components/exporter.tsx +++ b/app/components/exporter.tsx @@ -36,11 +36,10 @@ import { toBlob, toPng } from "html-to-image"; import { DEFAULT_MASK_AVATAR } from "../store/mask"; import { prettyObject } from "../utils/format"; -import { EXPORT_MESSAGE_CLASS_NAME, ModelProvider } from "../constant"; +import { EXPORT_MESSAGE_CLASS_NAME } from "../constant"; import { getClientConfig } from "../config/client"; -import { ClientApi } from "../client/api"; +import { type ClientApi, getClientApi } from "../client/api"; import { getMessageTextContent } from "../utils"; -import { identifyDefaultClaudeModel } from "../utils/checkers"; const Markdown = dynamic(async () => (await import("./markdown")).Markdown, { loading: () => , @@ -313,14 +312,7 @@ export function PreviewActions(props: { const onRenderMsgs = (msgs: ChatMessage[]) => { setShouldExport(false); - var api: ClientApi; - if (config.modelConfig.model.startsWith("gemini")) { - api = new ClientApi(ModelProvider.GeminiPro); - } else if (identifyDefaultClaudeModel(config.modelConfig.model)) { - api = new ClientApi(ModelProvider.Claude); - } else { - api = new ClientApi(ModelProvider.GPT); - } + const api: ClientApi = getClientApi(config.modelConfig.providerName); api .share(msgs) diff --git a/app/components/home.tsx b/app/components/home.tsx index ffac64fdac0c..e127c65f8fd3 100644 --- a/app/components/home.tsx +++ b/app/components/home.tsx @@ -12,7 +12,7 @@ import LoadingIcon from "../icons/three-dots.svg"; import { getCSSVar, useMobileScreen } from "../utils"; import dynamic from "next/dynamic"; -import { ModelProvider, Path, SlotID } from "../constant"; +import { Path, SlotID } from "../constant"; import { ErrorBoundary } from "./error"; import { getISOLang, getLang } from "../locales"; @@ -27,9 +27,8 @@ import { SideBar } from "./sidebar"; import { useAppConfig } from "../store/config"; import { AuthPage } from "./auth"; import { getClientConfig } from "../config/client"; -import { ClientApi } from "../client/api"; +import { type ClientApi, getClientApi } from "../client/api"; import { useAccessStore } from "../store"; -import { identifyDefaultClaudeModel } from "../utils/checkers"; export function Loading(props: { noLogo?: boolean }) { return ( @@ -171,14 +170,8 @@ function Screen() { export function useLoadData() { const config = useAppConfig(); - var api: ClientApi; - if (config.modelConfig.model.startsWith("gemini")) { - api = new ClientApi(ModelProvider.GeminiPro); - } else if (identifyDefaultClaudeModel(config.modelConfig.model)) { - api = new ClientApi(ModelProvider.Claude); - } else { - api = new ClientApi(ModelProvider.GPT); - } + const api: ClientApi = getClientApi(config.modelConfig.providerName); + useEffect(() => { (async () => { const models = await api.llm.models(); diff --git a/app/components/model-config.tsx b/app/components/model-config.tsx index e46a018f4639..346fd3a7129d 100644 --- a/app/components/model-config.tsx +++ b/app/components/model-config.tsx @@ -1,3 +1,4 @@ +import { ServiceProvider } from "@/app/constant"; import { ModalConfigValidator, ModelConfig } from "../store"; import Locale from "../locales"; @@ -10,25 +11,25 @@ export function ModelConfigList(props: { updateConfig: (updater: (config: ModelConfig) => void) => void; }) { const allModels = useAllModels(); + const value = `${props.modelConfig.model}@${props.modelConfig?.providerName}`; return ( <>