Skip to content

Commit

Permalink
Merge pull request #4983 from OpenAI-Next/dev-sd
Browse files Browse the repository at this point in the history
[Feature] Stable Diffusion
  • Loading branch information
lloydzhou authored Jul 25, 2024
2 parents cf63619 + 6cc0a5a commit 941a03e
Show file tree
Hide file tree
Showing 35 changed files with 1,714 additions and 140 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,5 @@ dev

*.key
*.key.pub

masks.json
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,14 @@ You can use this option if you want to increase the number of webdav service add

Customize the default template used to initialize the User Input Preprocessing configuration item in Settings.

### `STABILITY_API_KEY` (optional)

Stability API key.

### `STABILITY_URL` (optional)

Customize Stability API url.

## Requirements

NodeJS >= 18, Docker >= 20
Expand Down
9 changes: 9 additions & 0 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,15 @@ ByteDance Api Url.

自定义默认的 template,用于初始化『设置』中的『用户输入预处理』配置项

### `STABILITY_API_KEY` (optional)

Stability API密钥

### `STABILITY_URL` (optional)

自定义的Stability API请求地址


## 开发

点击下方按钮,开始二次开发:
Expand Down
3 changes: 3 additions & 0 deletions app/api/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) {
let systemApiKey: string | undefined;

switch (modelProvider) {
case ModelProvider.Stability:
systemApiKey = serverConfig.stabilityApiKey;
break;
case ModelProvider.GeminiPro:
systemApiKey = serverConfig.googleApiKey;
break;
Expand Down
104 changes: 104 additions & 0 deletions app/api/stability/[...path]/route.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import { NextRequest, NextResponse } from "next/server";
import { getServerSideConfig } from "@/app/config/server";
import { ModelProvider, STABILITY_BASE_URL } from "@/app/constant";
import { auth } from "@/app/api/auth";

async function handle(
req: NextRequest,
{ params }: { params: { path: string[] } },
) {
console.log("[Stability] params ", params);

if (req.method === "OPTIONS") {
return NextResponse.json({ body: "OK" }, { status: 200 });
}

const controller = new AbortController();

const serverConfig = getServerSideConfig();

let baseUrl = serverConfig.stabilityUrl || STABILITY_BASE_URL;

if (!baseUrl.startsWith("http")) {
baseUrl = `https://${baseUrl}`;
}

if (baseUrl.endsWith("/")) {
baseUrl = baseUrl.slice(0, -1);
}

let path = `${req.nextUrl.pathname}`.replaceAll("/api/stability/", "");

console.log("[Stability Proxy] ", path);
console.log("[Stability Base Url]", baseUrl);

const timeoutId = setTimeout(
() => {
controller.abort();
},
10 * 60 * 1000,
);

const authResult = auth(req, ModelProvider.Stability);

if (authResult.error) {
return NextResponse.json(authResult, {
status: 401,
});
}

const bearToken = req.headers.get("Authorization") ?? "";
const token = bearToken.trim().replaceAll("Bearer ", "").trim();

const key = token ? token : serverConfig.stabilityApiKey;

if (!key) {
return NextResponse.json(
{
error: true,
message: `missing STABILITY_API_KEY in server env vars`,
},
{
status: 401,
},
);
}

const fetchUrl = `${baseUrl}/${path}`;
console.log("[Stability Url] ", fetchUrl);
const fetchOptions: RequestInit = {
headers: {
"Content-Type": req.headers.get("Content-Type") || "multipart/form-data",
Accept: req.headers.get("Accept") || "application/json",
Authorization: `Bearer ${key}`,
},
method: req.method,
body: req.body,
// to fix #2485: https://stackoverflow.com/questions/55920957/cloudflare-worker-typeerror-one-time-use-body
redirect: "manual",
// @ts-ignore
duplex: "half",
signal: controller.signal,
};

try {
const res = await fetch(fetchUrl, fetchOptions);
// to prevent browser prompt for credentials
const newHeaders = new Headers(res.headers);
newHeaders.delete("www-authenticate");
// to disable nginx buffering
newHeaders.set("X-Accel-Buffering", "no");
return new Response(res.body, {
status: res.status,
statusText: res.statusText,
headers: newHeaders,
});
} finally {
clearTimeout(timeoutId);
}
}

export const GET = handle;
export const POST = handle;

export const runtime = "edge";
8 changes: 6 additions & 2 deletions app/api/webdav/[...path]/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,13 @@ async function handle(
const normalizedAllowedEndpoint = normalizeUrl(allowedEndpoint);
const normalizedEndpoint = normalizeUrl(endpoint as string);

return normalizedEndpoint &&
return (
normalizedEndpoint &&
normalizedEndpoint.hostname === normalizedAllowedEndpoint?.hostname &&
normalizedEndpoint.pathname.startsWith(normalizedAllowedEndpoint.pathname);
normalizedEndpoint.pathname.startsWith(
normalizedAllowedEndpoint.pathname,
)
);
})
) {
return NextResponse.json(
Expand Down
22 changes: 13 additions & 9 deletions app/client/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,19 @@ export class ClientApi {
}
}

export function getBearerToken(
apiKey: string,
noBearer: boolean = false,
): string {
return validString(apiKey)
? `${noBearer ? "" : "Bearer "}${apiKey.trim()}`
: "";
}

export function validString(x: string): boolean {
return x?.length > 0;
}

export function getHeaders() {
const accessStore = useAccessStore.getState();
const chatStore = useChatStore.getState();
Expand Down Expand Up @@ -214,15 +227,6 @@ export function getHeaders() {
return isAzure ? "api-key" : isAnthropic ? "x-api-key" : "Authorization";
}

function getBearerToken(apiKey: string, noBearer: boolean = false): string {
return validString(apiKey)
? `${noBearer ? "" : "Bearer "}${apiKey.trim()}`
: "";
}

function validString(x: string): boolean {
return x?.length > 0;
}
const {
isGoogle,
isAzure,
Expand Down
3 changes: 3 additions & 0 deletions app/components/button.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import * as React from "react";

import styles from "./button.module.scss";
import { CSSProperties } from "react";

export type ButtonType = "primary" | "danger" | null;

Expand All @@ -16,6 +17,7 @@ export function IconButton(props: {
disabled?: boolean;
tabIndex?: number;
autoFocus?: boolean;
style?: CSSProperties;
}) {
return (
<button
Expand All @@ -31,6 +33,7 @@ export function IconButton(props: {
role="button"
tabIndex={props.tabIndex}
autoFocus={props.autoFocus}
style={props.style}
>
{props.icon && (
<div
Expand Down
9 changes: 8 additions & 1 deletion app/components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import AutoIcon from "../icons/auto.svg";
import BottomIcon from "../icons/bottom.svg";
import StopIcon from "../icons/pause.svg";
import RobotIcon from "../icons/robot.svg";
import PluginIcon from "../icons/plugin.svg";

import {
ChatMessage,
Expand Down Expand Up @@ -338,7 +339,7 @@ function ClearContextDivider() {
);
}

function ChatAction(props: {
export function ChatAction(props: {
text: string;
icon: JSX.Element;
onClick: () => void;
Expand Down Expand Up @@ -587,6 +588,12 @@ export function ChatActions(props: {
icon={<RobotIcon />}
/>

<ChatAction
onClick={() => showToast(Locale.WIP)}
text={Locale.Plugin.Name}
icon={<PluginIcon />}
/>

{showModelSelector && (
<Selector
defaultSelectedValue={`${currentModel}@${currentProviderName}`}
Expand Down
2 changes: 2 additions & 0 deletions app/components/error.tsx
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"use client";

import React from "react";
import { IconButton } from "./button";
import GithubIcon from "../icons/github.svg";
Expand Down
64 changes: 39 additions & 25 deletions app/components/home.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ const MaskPage = dynamic(async () => (await import("./mask")).MaskPage, {
loading: () => <Loading noLogo />,
});

const Sd = dynamic(async () => (await import("./sd")).Sd, {
loading: () => <Loading noLogo />,
});

export function useSwitchTheme() {
const config = useAppConfig();

Expand Down Expand Up @@ -122,11 +126,22 @@ const loadAsyncGoogleFont = () => {
document.head.appendChild(linkEl);
};

export function WindowContent(props: { children: React.ReactNode }) {
return (
<div className={styles["window-content"]} id={SlotID.AppBody}>
{props?.children}
</div>
);
}

function Screen() {
const config = useAppConfig();
const location = useLocation();
const isHome = location.pathname === Path.Home;
const isAuth = location.pathname === Path.Auth;
const isSd = location.pathname === Path.Sd;
const isSdNew = location.pathname === Path.SdNew;

const isMobileScreen = useMobileScreen();
const shouldTightBorder =
getClientConfig()?.isApp || (config.tightBorder && !isMobileScreen);
Expand All @@ -135,34 +150,33 @@ function Screen() {
loadAsyncGoogleFont();
}, []);

const renderContent = () => {
if (isAuth) return <AuthPage />;
if (isSd) return <Sd />;
if (isSdNew) return <Sd />;
return (
<>
<SideBar className={isHome ? styles["sidebar-show"] : ""} />
<WindowContent>
<Routes>
<Route path={Path.Home} element={<Chat />} />
<Route path={Path.NewChat} element={<NewChat />} />
<Route path={Path.Masks} element={<MaskPage />} />
<Route path={Path.Chat} element={<Chat />} />
<Route path={Path.Settings} element={<Settings />} />
</Routes>
</WindowContent>
</>
);
};

return (
<div
className={
styles.container +
` ${shouldTightBorder ? styles["tight-container"] : styles.container} ${
getLang() === "ar" ? styles["rtl-screen"] : ""
}`
}
className={`${styles.container} ${
shouldTightBorder ? styles["tight-container"] : styles.container
} ${getLang() === "ar" ? styles["rtl-screen"] : ""}`}
>
{isAuth ? (
<>
<AuthPage />
</>
) : (
<>
<SideBar className={isHome ? styles["sidebar-show"] : ""} />

<div className={styles["window-content"]} id={SlotID.AppBody}>
<Routes>
<Route path={Path.Home} element={<Chat />} />
<Route path={Path.NewChat} element={<NewChat />} />
<Route path={Path.Masks} element={<MaskPage />} />
<Route path={Path.Chat} element={<Chat />} />
<Route path={Path.Settings} element={<Settings />} />
</Routes>
</div>
</>
)}
{renderContent()}
</div>
);
}
Expand Down
2 changes: 2 additions & 0 deletions app/components/sd/index.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
export * from "./sd";
export * from "./sd-panel";
Loading

0 comments on commit 941a03e

Please sign in to comment.