From 09feeea99a3fcefc5fb92eda75859289d2505f72 Mon Sep 17 00:00:00 2001 From: SPWwj Date: Fri, 12 May 2023 00:20:12 +0800 Subject: [PATCH] Add support for DALL-E #1358 --- app/components/ImageList.module.scss | 16 +++ app/components/ImageList.tsx | 20 ++++ app/components/chat.tsx | 2 + app/components/markdown.tsx | 29 +++++- app/components/markdown.tsx.module.scss | 10 ++ app/constant.ts | 2 + app/icons/image-error.svg | 33 ++++++ app/icons/image-placeholder.svg | 28 +++++ app/locales/cn.ts | 3 +- app/locales/en.ts | 3 +- app/requests.ts | 97 ++++++++++++++++- app/store/chat.ts | 132 +++++++++++++++++------- app/store/config.ts | 2 +- 13 files changed, 332 insertions(+), 45 deletions(-) create mode 100644 app/components/ImageList.module.scss create mode 100644 app/components/ImageList.tsx create mode 100644 app/components/markdown.tsx.module.scss create mode 100644 app/icons/image-error.svg create mode 100644 app/icons/image-placeholder.svg diff --git a/app/components/ImageList.module.scss b/app/components/ImageList.module.scss new file mode 100644 index 00000000000..3b7f1f93396 --- /dev/null +++ b/app/components/ImageList.module.scss @@ -0,0 +1,16 @@ +.imageGrid, +.imageGridSingle { + display: inline-flex; + flex-wrap: wrap; + justify-content: flex-start; +} + +.imageGrid img { + max-width: calc(50% - 10px); + margin: 5px; +} + +.imageGridSingle img { + max-width: 100%; + margin: 5px; +} diff --git a/app/components/ImageList.tsx b/app/components/ImageList.tsx new file mode 100644 index 00000000000..01e30013eb8 --- /dev/null +++ b/app/components/ImageList.tsx @@ -0,0 +1,20 @@ +import { ImagesResponseDataInner } from "openai"; +import React, { FC } from "react"; +import styles from "./ImageList.module.scss"; + +interface ImageListProps { + images?: ImagesResponseDataInner[]; +} +const ImageList: FC = ({ images }) => { + const singleImage = images && images.length === 1; + + return ( +
+ {images && + images.map((image, index) => ( + {`Image + ))} +
+ ); +}; +export default ImageList; diff --git a/app/components/chat.tsx b/app/components/chat.tsx index d38990372be..1b4f52b509e 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -760,6 +760,8 @@ export function Chat() { )} void }) { const ref = useRef(null); @@ -119,6 +125,8 @@ export const MarkdownContent = React.memo(_MarkDownContent); export function Markdown( props: { content: string; + images?: ImagesResponseDataInner[]; + image_alt?: string; loading?: boolean; fontSize?: number; parentRef: RefObject; @@ -171,9 +179,24 @@ export function Markdown( > {inView.current && (props.loading ? ( - +
+ {props.image_alt && } + {props.image_alt && + (props.image_alt === IMAGE_PLACEHOLDER ? ( + + ) : ( + + ))} + +
) : ( - +
+ +
+ {props.images && } + {props.image_alt && } +
+
))} ); diff --git a/app/components/markdown.tsx.module.scss b/app/components/markdown.tsx.module.scss new file mode 100644 index 00000000000..1e19384b121 --- /dev/null +++ b/app/components/markdown.tsx.module.scss @@ -0,0 +1,10 @@ +.loader { + display: flex; + flex-direction: column; + justify-content: center; + align-items: center; + height: 100%; // Adjust as necessary +} +.content_image { + text-align: center; +} diff --git a/app/constant.ts b/app/constant.ts index d0f9fc743d1..560d57a1b5e 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -40,3 +40,5 @@ export const NARROW_SIDEBAR_WIDTH = 100; export const ACCESS_CODE_PREFIX = "ak-"; export const LAST_INPUT_KEY = "last-input"; +export const IMAGE_PLACEHOLDER = "Loading your image..."; +export const IMAGE_ERROR = "IMAGE_ERROR"; diff --git a/app/icons/image-error.svg b/app/icons/image-error.svg new file mode 100644 index 00000000000..389fc5fad2b --- /dev/null +++ b/app/icons/image-error.svg @@ -0,0 +1,33 @@ + + + + + design-and-ux/error-handling + Created with Sketch. + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/app/icons/image-placeholder.svg b/app/icons/image-placeholder.svg new file mode 100644 index 00000000000..fc5e94fc374 --- /dev/null +++ b/app/icons/image-placeholder.svg @@ -0,0 +1,28 @@ + + + + + + + + + + \ No newline at end of file diff --git a/app/locales/cn.ts b/app/locales/cn.ts index 0cf3b55ec88..9d8f5c50443 100644 --- a/app/locales/cn.ts +++ b/app/locales/cn.ts @@ -170,7 +170,8 @@ const cn = { }, Store: { DefaultTopic: "新的聊天", - BotHello: "有什么可以帮你的吗", + BotHello: + "您好!今天我能为您做些什么呢?\n 要生成图片,请使用 `/Image {关键词}。`", Error: "出错了,稍后重试吧", Prompt: { History: (content: string) => diff --git a/app/locales/en.ts b/app/locales/en.ts index 420f169054b..1860cd0777a 100644 --- a/app/locales/en.ts +++ b/app/locales/en.ts @@ -172,7 +172,8 @@ const en: LocaleType = { }, Store: { DefaultTopic: "New Conversation", - BotHello: "Hello! How can I assist you today?", + BotHello: + "Hello! How can I assist you today?\n To generate images, use `/Image {keyword}.`", Error: "Something went wrong, please try again later.", Prompt: { History: (content: string) => diff --git a/app/requests.ts b/app/requests.ts index d38a91fd4f5..d576de0502b 100644 --- a/app/requests.ts +++ b/app/requests.ts @@ -8,8 +8,14 @@ import { useChatStore, } from "./store"; import { showToast } from "./components/ui-lib"; -import { ACCESS_CODE_PREFIX } from "./constant"; - +import { ACCESS_CODE_PREFIX, IMAGE_ERROR, IMAGE_PLACEHOLDER } from "./constant"; +import { + CreateImageRequest, + CreateImageRequestResponseFormatEnum, + CreateImageRequestSizeEnum, + ImagesResponse, + ImagesResponseDataInner, +} from "openai"; const TIME_OUT_MS = 60000; const makeRequestParam = ( @@ -144,6 +150,93 @@ export async function requestUsage() { subscription: total.hard_limit_usd, }; } +const makeImageRequestParam = ( + prompt: string, + options?: Omit, +): CreateImageRequest => { + // Set default values + const defaultOptions: Omit = { + n: 4, + size: CreateImageRequestSizeEnum._512x512, + response_format: CreateImageRequestResponseFormatEnum.Url, + user: "default_user", + }; + + // Override default values with provided options + const finalOptions = { ...defaultOptions, ...options }; + + const request: CreateImageRequest = { + prompt, + ...finalOptions, + }; + + return request; +}; +export async function requestImage( + keyword: string, + options?: { + onMessage: ( + message: string | null, + image: ImagesResponseDataInner[] | null, + image_alt: string | null, + done: boolean, + ) => void; + onError: (error: Error, statusCode?: number) => void; + onController?: (controller: AbortController) => void; + }, +) { + if (keyword.length < 1) { + options?.onMessage( + "Please enter a keyword after `/image`", + null, + null, + true, + ); + } else { + const controller = new AbortController(); + const reqTimeoutId = setTimeout(() => controller.abort(), TIME_OUT_MS); + options?.onController?.(controller); + + async function fetchImageAndUpdateMessage() { + try { + options?.onMessage(null, null, IMAGE_PLACEHOLDER, false); + + const sanitizedMessage = keyword.replace(/[\n\r]+/g, " "); + const req = makeImageRequestParam(sanitizedMessage); + + const res = await requestOpenaiClient("v1/images/generations")(req); + + clearTimeout(reqTimeoutId); + + const finish = (images: ImagesResponseDataInner[]) => { + options?.onMessage("Here is your images", images, null, true); + controller.abort(); + }; + + if (res.ok) { + const responseData = (await res.json()) as ImagesResponse; + finish(responseData.data); + } else if (res.status === 401) { + console.error("Unauthorized"); + options?.onError(new Error("Unauthorized"), res.status); + } else { + console.error("Stream Error", res.body); + options?.onError(new Error("Stream Error"), res.status); + } + } catch (err) { + console.error("NetWork Error", err); + options?.onError(err as Error); + options?.onMessage( + "Image generation has been cancelled.", + null, + IMAGE_ERROR, + true, + ); + } + } + fetchImageAndUpdateMessage(); + } +} export async function requestChatStream( messages: Message[], diff --git a/app/store/chat.ts b/app/store/chat.ts index cb11087d4ef..bea33d9ae94 100644 --- a/app/store/chat.ts +++ b/app/store/chat.ts @@ -1,10 +1,14 @@ import { create } from "zustand"; import { persist } from "zustand/middleware"; -import { type ChatCompletionResponseMessage } from "openai"; +import { + ImagesResponseDataInner, + type ChatCompletionResponseMessage, +} from "openai"; import { ControllerPool, requestChatStream, + requestImage, requestWithPrompt, } from "../requests"; import { trimTopic } from "../utils"; @@ -17,6 +21,8 @@ import { StoreKey } from "../constant"; export type Message = ChatCompletionResponseMessage & { date: string; + images?: ImagesResponseDataInner[]; + image_alt?: string; streaming?: boolean; isError?: boolean; id?: number; @@ -275,48 +281,100 @@ export const useChatStore = create()( session.messages.push(botMessage); }); - // make request - console.log("[User Input] ", sendMessages); - requestChatStream(sendMessages, { - onMessage(content, done) { - // stream response - if (done) { + if (userMessage.content.startsWith("/image")) { + const keyword = userMessage.content.substring("/image".length); + console.log("keyword", keyword); + requestImage(keyword, { + onMessage(content, images, image_alt, done) { + // stream response + if (done) { + botMessage.streaming = false; + botMessage.content = content!; + botMessage.images = images!; + botMessage.image_alt = image_alt!; + get().onNewMessage(botMessage); + ControllerPool.remove( + sessionIndex, + botMessage.id ?? messageIndex, + ); + } else { + botMessage.image_alt = image_alt!; + set(() => ({})); + } + }, + onError(error, statusCode) { + const isAborted = error.message.includes("aborted"); + if (statusCode === 401) { + botMessage.content = Locale.Error.Unauthorized; + } else if (!isAborted) { + botMessage.content += "\n\n" + Locale.Store.Error; + } botMessage.streaming = false; - botMessage.content = content; - get().onNewMessage(botMessage); + userMessage.isError = !isAborted; + botMessage.isError = !isAborted; + + set(() => ({})); ControllerPool.remove( sessionIndex, botMessage.id ?? messageIndex, ); - } else { - botMessage.content = content; + }, + onController(controller) { + // collect controller for stop/retry + ControllerPool.addController( + sessionIndex, + botMessage.id ?? messageIndex, + controller, + ); + }, + }); + } else { + // make request + console.log("[User Input] ", sendMessages); + requestChatStream(sendMessages, { + onMessage(content, done) { + // stream response + if (done) { + botMessage.streaming = false; + botMessage.content = content; + get().onNewMessage(botMessage); + ControllerPool.remove( + sessionIndex, + botMessage.id ?? messageIndex, + ); + } else { + botMessage.content = content; + set(() => ({})); + } + }, + onError(error, statusCode) { + const isAborted = error.message.includes("aborted"); + if (statusCode === 401) { + botMessage.content = Locale.Error.Unauthorized; + } else if (!isAborted) { + botMessage.content += "\n\n" + Locale.Store.Error; + } + botMessage.streaming = false; + userMessage.isError = !isAborted; + botMessage.isError = !isAborted; + set(() => ({})); - } - }, - onError(error, statusCode) { - const isAborted = error.message.includes("aborted"); - if (statusCode === 401) { - botMessage.content = Locale.Error.Unauthorized; - } else if (!isAborted) { - botMessage.content += "\n\n" + Locale.Store.Error; - } - botMessage.streaming = false; - userMessage.isError = !isAborted; - botMessage.isError = !isAborted; - - set(() => ({})); - ControllerPool.remove(sessionIndex, botMessage.id ?? messageIndex); - }, - onController(controller) { - // collect controller for stop/retry - ControllerPool.addController( - sessionIndex, - botMessage.id ?? messageIndex, - controller, - ); - }, - modelConfig: { ...modelConfig }, - }); + ControllerPool.remove( + sessionIndex, + botMessage.id ?? messageIndex, + ); + }, + onController(controller) { + // collect controller for stop/retry + ControllerPool.addController( + sessionIndex, + botMessage.id ?? messageIndex, + controller, + ); + }, + modelConfig: { ...modelConfig }, + }); + } }, getMemoryPrompt() { diff --git a/app/store/config.ts b/app/store/config.ts index 1e960456ff4..ddc8ef1c7e0 100644 --- a/app/store/config.ts +++ b/app/store/config.ts @@ -17,7 +17,7 @@ export enum Theme { } export const DEFAULT_CONFIG = { - submitKey: SubmitKey.CtrlEnter as SubmitKey, + submitKey: SubmitKey.Enter as SubmitKey, avatar: "1f603", fontSize: 14, theme: Theme.Auto as Theme,