Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…into feat-baidu
  • Loading branch information
Dogtiti committed Jul 6, 2024
2 parents 785d374 + 7218f13 commit 9f7d137
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 82 deletions.
88 changes: 56 additions & 32 deletions app/client/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -162,46 +162,70 @@ export class ClientApi {

export function getHeaders() {
const accessStore = useAccessStore.getState();
const chatStore = useChatStore.getState();
const headers: Record<string, string> = {
"Content-Type": "application/json",
Accept: "application/json",
};
const modelConfig = useChatStore.getState().currentSession().mask.modelConfig;
const isGoogle = modelConfig.providerName == ServiceProvider.Google;
const isAzure = modelConfig.providerName === ServiceProvider.Azure;
const isAnthropic = modelConfig.providerName === ServiceProvider.Anthropic;
const authHeader = isAzure
? "api-key"
: isAnthropic
? "x-api-key"
: "Authorization";
const apiKey = isGoogle
? accessStore.googleApiKey
: isAzure
? accessStore.azureApiKey
: isAnthropic
? accessStore.anthropicApiKey
: accessStore.openaiApiKey;

const clientConfig = getClientConfig();
const makeBearer = (s: string) =>
`${isAzure || isAnthropic ? "" : "Bearer "}${s.trim()}`;
const validString = (x: string) => x && x.length > 0;

function getConfig() {
const modelConfig = chatStore.currentSession().mask.modelConfig;
const isGoogle = modelConfig.providerName == ServiceProvider.Google;
const isAzure = modelConfig.providerName === ServiceProvider.Azure;
const isAnthropic = modelConfig.providerName === ServiceProvider.Anthropic;
const isEnabledAccessControl = accessStore.enabledAccessControl();
const apiKey = isGoogle
? accessStore.googleApiKey
: isAzure
? accessStore.azureApiKey
: isAnthropic
? accessStore.anthropicApiKey
: accessStore.openaiApiKey;
return { isGoogle, isAzure, isAnthropic, apiKey, isEnabledAccessControl };
}

function getAuthHeader(): string {
return isAzure ? "api-key" : isAnthropic ? "x-api-key" : "Authorization";
}

function getBearerToken(apiKey: string, noBearer: boolean = false): string {
return validString(apiKey)
? `${noBearer ? "" : "Bearer "}${apiKey.trim()}`
: "";
}

function validString(x: string): boolean {
return x?.length > 0;
}
const { isGoogle, isAzure, isAnthropic, apiKey, isEnabledAccessControl } =
getConfig();
// when using google api in app, not set auth header
if (!(isGoogle && clientConfig?.isApp)) {
// use user's api key first
if (validString(apiKey)) {
headers[authHeader] = makeBearer(apiKey);
} else if (
accessStore.enabledAccessControl() &&
validString(accessStore.accessCode)
) {
// access_code must send with header named `Authorization`, will using in auth middleware.
headers["Authorization"] = makeBearer(
ACCESS_CODE_PREFIX + accessStore.accessCode,
);
}
if (isGoogle && clientConfig?.isApp) return headers;

const authHeader = getAuthHeader();

const bearerToken = getBearerToken(apiKey, isAzure || isAnthropic);

if (bearerToken) {
headers[authHeader] = bearerToken;
} else if (isEnabledAccessControl && validString(accessStore.accessCode)) {
headers["Authorization"] = getBearerToken(
ACCESS_CODE_PREFIX + accessStore.accessCode,
);
}

return headers;
}

export function getClientApi(provider: ServiceProvider): ClientApi {
switch (provider) {
case ServiceProvider.Google:
return new ClientApi(ModelProvider.GeminiPro);
case ServiceProvider.Anthropic:
return new ClientApi(ModelProvider.Claude);
default:
return new ClientApi(ModelProvider.GPT);
}
}
19 changes: 3 additions & 16 deletions app/components/exporter.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,9 @@ import { toBlob, toPng } from "html-to-image";
import { DEFAULT_MASK_AVATAR } from "../store/mask";

import { prettyObject } from "../utils/format";
import {
EXPORT_MESSAGE_CLASS_NAME,
ModelProvider,
ServiceProvider,
} from "../constant";
import { EXPORT_MESSAGE_CLASS_NAME } from "../constant";
import { getClientConfig } from "../config/client";
import { ClientApi } from "../client/api";
import { type ClientApi, getClientApi } from "../client/api";
import { getMessageTextContent } from "../utils";

const Markdown = dynamic(async () => (await import("./markdown")).Markdown, {
Expand Down Expand Up @@ -316,16 +312,7 @@ export function PreviewActions(props: {
const onRenderMsgs = (msgs: ChatMessage[]) => {
setShouldExport(false);

var api: ClientApi;
if (config.modelConfig.providerName == ServiceProvider.Google) {
api = new ClientApi(ModelProvider.GeminiPro);
} else if (config.modelConfig.providerName == ServiceProvider.Anthropic) {
api = new ClientApi(ModelProvider.Claude);
} else if (config.modelConfig.providerName == ServiceProvider.Baidu) {
api = new ClientApi(ModelProvider.Ernie);
} else {
api = new ClientApi(ModelProvider.GPT);
}
const api: ClientApi = getClientApi(config.modelConfig.providerName);

api
.share(msgs)
Expand Down
16 changes: 4 additions & 12 deletions app/components/home.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import LoadingIcon from "../icons/three-dots.svg";
import { getCSSVar, useMobileScreen } from "../utils";

import dynamic from "next/dynamic";
import { ServiceProvider, ModelProvider, Path, SlotID } from "../constant";
import { Path, SlotID } from "../constant";
import { ErrorBoundary } from "./error";

import { getISOLang, getLang } from "../locales";
Expand All @@ -27,7 +27,7 @@ import { SideBar } from "./sidebar";
import { useAppConfig } from "../store/config";
import { AuthPage } from "./auth";
import { getClientConfig } from "../config/client";
import { ClientApi } from "../client/api";
import { type ClientApi, getClientApi } from "../client/api";
import { useAccessStore } from "../store";

export function Loading(props: { noLogo?: boolean }) {
Expand Down Expand Up @@ -170,16 +170,8 @@ function Screen() {
export function useLoadData() {
const config = useAppConfig();

var api: ClientApi;
if (config.modelConfig.providerName == ServiceProvider.Google) {
api = new ClientApi(ModelProvider.GeminiPro);
} else if (config.modelConfig.providerName == ServiceProvider.Anthropic) {
api = new ClientApi(ModelProvider.Claude);
} else if (config.modelConfig.providerName == ServiceProvider.Baidu) {
api = new ClientApi(ModelProvider.Ernie);
} else {
api = new ClientApi(ModelProvider.GPT);
}
const api: ClientApi = getClientApi(config.modelConfig.providerName);

useEffect(() => {
(async () => {
const models = await api.llm.models();
Expand Down
30 changes: 8 additions & 22 deletions app/store/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@ import {
SUMMARIZE_MODEL,
GEMINI_SUMMARIZE_MODEL,
} from "../constant";
import { ClientApi, RequestMessage, MultimodalContent } from "../client/api";
import { getClientApi } from "../client/api";
import type {
ClientApi,
RequestMessage,
MultimodalContent,
} from "../client/api";
import { ChatControllerPool } from "../client/controller";
import { prettyObject } from "../utils/format";
import { estimateTokenLength } from "../utils/token";
Expand Down Expand Up @@ -363,17 +368,7 @@ export const useChatStore = createPersistStore(
]);
});

var api: ClientApi;
if (modelConfig.providerName == ServiceProvider.Google) {
api = new ClientApi(ModelProvider.GeminiPro);
} else if (modelConfig.providerName == ServiceProvider.Anthropic) {
api = new ClientApi(ModelProvider.Claude);
} else if (modelConfig.providerName == ServiceProvider.Baidu) {
api = new ClientApi(ModelProvider.Ernie);
} else {
api = new ClientApi(ModelProvider.GPT);
}

const api: ClientApi = getClientApi(modelConfig.providerName);
// make request
api.llm.chat({
messages: sendMessages,
Expand Down Expand Up @@ -549,16 +544,7 @@ export const useChatStore = createPersistStore(
const session = get().currentSession();
const modelConfig = session.mask.modelConfig;

var api: ClientApi;
if (modelConfig.providerName == ServiceProvider.Google) {
api = new ClientApi(ModelProvider.GeminiPro);
} else if (modelConfig.providerName == ServiceProvider.Anthropic) {
api = new ClientApi(ModelProvider.Claude);
} else if (modelConfig.providerName == ServiceProvider.Baidu) {
api = new ClientApi(ModelProvider.Ernie);
} else {
api = new ClientApi(ModelProvider.GPT);
}
const api: ClientApi = getClientApi(modelConfig.providerName);

// remove error messages if any
const messages = session.messages;
Expand Down

0 comments on commit 9f7d137

Please sign in to comment.