From ac599aa47c49bde2d557a6f7317347c940b29b17 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Fri, 2 Aug 2024 18:00:42 +0800 Subject: [PATCH 1/5] add dalle3 model --- app/client/platforms/openai.ts | 90 ++++++++++++++++++++++++---------- app/components/chat.tsx | 34 +++++++++++++ app/constant.ts | 7 ++- app/store/chat.ts | 5 ++ app/utils.ts | 4 ++ 5 files changed, 113 insertions(+), 27 deletions(-) diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index 680125fe6c4..28de30051ea 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -33,6 +33,7 @@ import { getMessageTextContent, getMessageImages, isVisionModel, + isDalle3 as _isDalle3, } from "@/app/utils"; export interface OpenAIListModelResponse { @@ -58,6 +59,13 @@ export interface RequestPayload { max_tokens?: number; } +export interface DalleRequestPayload { + model: string; + prompt: string; + n: number; + size: "1024x1024" | "1792x1024" | "1024x1792"; +} + export class ChatGPTApi implements LLMApi { private disableListModels = true; @@ -101,19 +109,25 @@ export class ChatGPTApi implements LLMApi { } extractMessage(res: any) { + if (res.error) { + return "```\n" + JSON.stringify(res, null, 4) + "\n```"; + } + // dalle3 model return url, just return + if (res.data) { + const url = res.data?.at(0)?.url ?? ""; + return [ + { + type: "image_url", + image_url: { + url, + }, + }, + ]; + } return res.choices?.at(0)?.message?.content ?? ""; } async chat(options: ChatOptions) { - const visionModel = isVisionModel(options.config.model); - const messages: ChatOptions["messages"] = []; - for (const v of options.messages) { - const content = visionModel - ? await preProcessImageContent(v.content) - : getMessageTextContent(v); - messages.push({ role: v.role, content }); - } - const modelConfig = { ...useAppConfig.getState().modelConfig, ...useChatStore.getState().currentSession().mask.modelConfig, @@ -123,26 +137,48 @@ export class ChatGPTApi implements LLMApi { }, }; - const requestPayload: RequestPayload = { - messages, - stream: options.config.stream, - model: modelConfig.model, - temperature: modelConfig.temperature, - presence_penalty: modelConfig.presence_penalty, - frequency_penalty: modelConfig.frequency_penalty, - top_p: modelConfig.top_p, - // max_tokens: Math.max(modelConfig.max_tokens, 1024), - // Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore. - }; + let requestPayload: RequestPayload | DalleRequestPayload; + + const isDalle3 = _isDalle3(options.config.model); + if (isDalle3) { + const prompt = getMessageTextContent(options.messages.slice(-1)?.pop()); + requestPayload = { + model: options.config.model, + prompt, + n: 1, + size: options.config?.size ?? "1024x1024", + }; + } else { + const visionModel = isVisionModel(options.config.model); + const messages: ChatOptions["messages"] = []; + for (const v of options.messages) { + const content = visionModel + ? await preProcessImageContent(v.content) + : getMessageTextContent(v); + messages.push({ role: v.role, content }); + } - // add max_tokens to vision model - if (visionModel && modelConfig.model.includes("preview")) { - requestPayload["max_tokens"] = Math.max(modelConfig.max_tokens, 4000); + requestPayload = { + messages, + stream: options.config.stream, + model: modelConfig.model, + temperature: modelConfig.temperature, + presence_penalty: modelConfig.presence_penalty, + frequency_penalty: modelConfig.frequency_penalty, + top_p: modelConfig.top_p, + // max_tokens: Math.max(modelConfig.max_tokens, 1024), + // Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore. + }; + + // add max_tokens to vision model + if (visionModel && modelConfig.model.includes("preview")) { + requestPayload["max_tokens"] = Math.max(modelConfig.max_tokens, 4000); + } } console.log("[Request] openai payload: ", requestPayload); - const shouldStream = !!options.config.stream; + const shouldStream = !isDalle3 && !!options.config.stream; const controller = new AbortController(); options.onController?.(controller); @@ -168,13 +204,15 @@ export class ChatGPTApi implements LLMApi { model?.provider?.providerName === ServiceProvider.Azure, ); chatPath = this.path( - Azure.ChatPath( + (isDalle3 ? Azure.ImagePath : Azure.ChatPath)( (model?.displayName ?? model?.name) as string, useCustomConfig ? useAccessStore.getState().azureApiVersion : "", ), ); } else { - chatPath = this.path(OpenaiPath.ChatPath); + chatPath = this.path( + isDalle3 ? OpenaiPath.ImagePath : OpenaiPath.ChatPath, + ); } const chatPayload = { method: "POST", diff --git a/app/components/chat.tsx b/app/components/chat.tsx index bb4b611ad79..b95e85d45df 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -37,6 +37,7 @@ import AutoIcon from "../icons/auto.svg"; import BottomIcon from "../icons/bottom.svg"; import StopIcon from "../icons/pause.svg"; import RobotIcon from "../icons/robot.svg"; +import SizeIcon from "../icons/size.svg"; import PluginIcon from "../icons/plugin.svg"; import { @@ -60,6 +61,7 @@ import { getMessageTextContent, getMessageImages, isVisionModel, + isDalle3, } from "../utils"; import { uploadImage as uploadImageRemote } from "@/app/utils/chat"; @@ -481,6 +483,11 @@ export function ChatActions(props: { const [showPluginSelector, setShowPluginSelector] = useState(false); const [showUploadImage, setShowUploadImage] = useState(false); + const [showSizeSelector, setShowSizeSelector] = useState(false); + const dalle3Sizes = ["1024x1024", "1792x1024", "1024x1792"]; + const currentSize = + chatStore.currentSession().mask.modelConfig?.size || "1024x1024"; + useEffect(() => { const show = isVisionModel(currentModel); setShowUploadImage(show); @@ -624,6 +631,33 @@ export function ChatActions(props: { /> )} + {isDalle3(currentModel) && ( + setShowSizeSelector(true)} + text={currentSize} + icon={} + /> + )} + + {showSizeSelector && ( + ({ + title: m, + value: m, + }))} + onClose={() => setShowSizeSelector(false)} + onSelection={(s) => { + if (s.length === 0) return; + const size = s[0]; + chatStore.updateCurrentSession((session) => { + session.mask.modelConfig.size = size; + }); + showToast(size); + }} + /> + )} + setShowPluginSelector(true)} text={Locale.Plugin.Name} diff --git a/app/constant.ts b/app/constant.ts index 5251b5b4fc9..b777872c8e0 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -146,6 +146,7 @@ export const Anthropic = { export const OpenaiPath = { ChatPath: "v1/chat/completions", + ImagePath: "v1/images/generations", UsagePath: "dashboard/billing/usage", SubsPath: "dashboard/billing/subscription", ListModelPath: "v1/models", @@ -154,7 +155,10 @@ export const OpenaiPath = { export const Azure = { ChatPath: (deployName: string, apiVersion: string) => `deployments/${deployName}/chat/completions?api-version=${apiVersion}`, - ExampleEndpoint: "https://{resource-url}/openai/deployments/{deploy-id}", + // https://.openai.azure.com/openai/deployments//images/generations?api-version= + ImagePath: (deployName: string, apiVersion: string) => + `deployments/${deployName}/images/generations?api-version=${apiVersion}`, + ExampleEndpoint: "https://{resource-url}/openai", }; export const Google = { @@ -256,6 +260,7 @@ const openaiModels = [ "gpt-4-vision-preview", "gpt-4-turbo-2024-04-09", "gpt-4-1106-preview", + "dall-e-3", ]; const googleModels = [ diff --git a/app/store/chat.ts b/app/store/chat.ts index 5892ef0c8c6..7b47f3ec629 100644 --- a/app/store/chat.ts +++ b/app/store/chat.ts @@ -26,6 +26,7 @@ import { nanoid } from "nanoid"; import { createPersistStore } from "../utils/store"; import { collectModelsWithDefaultModel } from "../utils/model"; import { useAccessStore } from "./access"; +import { isDalle3 } from "../utils"; export type ChatMessage = RequestMessage & { date: string; @@ -541,6 +542,10 @@ export const useChatStore = createPersistStore( const config = useAppConfig.getState(); const session = get().currentSession(); const modelConfig = session.mask.modelConfig; + // skip summarize when using dalle3? + if (isDalle3(modelConfig.model)) { + return; + } const api: ClientApi = getClientApi(modelConfig.providerName); diff --git a/app/utils.ts b/app/utils.ts index 2f2c8ae95ab..a3c329b8239 100644 --- a/app/utils.ts +++ b/app/utils.ts @@ -265,3 +265,7 @@ export function isVisionModel(model: string) { visionKeywords.some((keyword) => model.includes(keyword)) || isGpt4Turbo ); } + +export function isDalle3(model: string) { + return "dall-e-3" === model; +} From 1c24ca58c784775fb0d2cf9daa07949d329bd36a Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Fri, 2 Aug 2024 18:03:19 +0800 Subject: [PATCH 2/5] add dalle3 model --- app/icons/size.svg | 1 + 1 file changed, 1 insertion(+) create mode 100644 app/icons/size.svg diff --git a/app/icons/size.svg b/app/icons/size.svg new file mode 100644 index 00000000000..3da4fadfec6 --- /dev/null +++ b/app/icons/size.svg @@ -0,0 +1 @@ + From 46cb48023e6b2ffa52a44775b58a83a97dcffac2 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Fri, 2 Aug 2024 18:50:48 +0800 Subject: [PATCH 3/5] fix typescript error --- app/client/api.ts | 3 ++- app/client/platforms/openai.ts | 7 +++++-- app/components/chat.tsx | 5 +++-- app/store/config.ts | 2 ++ app/typing.ts | 2 ++ 5 files changed, 14 insertions(+), 5 deletions(-) diff --git a/app/client/api.ts b/app/client/api.ts index f10e4761887..88157e79cc7 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -6,7 +6,7 @@ import { ServiceProvider, } from "../constant"; import { ChatMessage, ModelType, useAccessStore, useChatStore } from "../store"; -import { ChatGPTApi } from "./platforms/openai"; +import { ChatGPTApi, DalleRequestPayload } from "./platforms/openai"; import { GeminiProApi } from "./platforms/google"; import { ClaudeApi } from "./platforms/anthropic"; import { ErnieApi } from "./platforms/baidu"; @@ -42,6 +42,7 @@ export interface LLMConfig { stream?: boolean; presence_penalty?: number; frequency_penalty?: number; + size?: DalleRequestPayload["size"]; } export interface ChatOptions { diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index 28de30051ea..54309e29f7e 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -13,6 +13,7 @@ import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; import { collectModelsWithDefaultModel } from "@/app/utils/model"; import { preProcessImageContent } from "@/app/utils/chat"; import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare"; +import { DalleSize } from "@/app/typing"; import { ChatOptions, @@ -63,7 +64,7 @@ export interface DalleRequestPayload { model: string; prompt: string; n: number; - size: "1024x1024" | "1792x1024" | "1024x1792"; + size: DalleSize; } export class ChatGPTApi implements LLMApi { @@ -141,7 +142,9 @@ export class ChatGPTApi implements LLMApi { const isDalle3 = _isDalle3(options.config.model); if (isDalle3) { - const prompt = getMessageTextContent(options.messages.slice(-1)?.pop()); + const prompt = getMessageTextContent( + options.messages.slice(-1)?.pop() as any, + ); requestPayload = { model: options.config.model, prompt, diff --git a/app/components/chat.tsx b/app/components/chat.tsx index b95e85d45df..67ea80c4a85 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -69,6 +69,7 @@ import { uploadImage as uploadImageRemote } from "@/app/utils/chat"; import dynamic from "next/dynamic"; import { ChatControllerPool } from "../client/controller"; +import { DalleSize } from "../typing"; import { Prompt, usePromptStore } from "../store/prompt"; import Locale from "../locales"; @@ -484,9 +485,9 @@ export function ChatActions(props: { const [showUploadImage, setShowUploadImage] = useState(false); const [showSizeSelector, setShowSizeSelector] = useState(false); - const dalle3Sizes = ["1024x1024", "1792x1024", "1024x1792"]; + const dalle3Sizes: DalleSize[] = ["1024x1024", "1792x1024", "1024x1792"]; const currentSize = - chatStore.currentSession().mask.modelConfig?.size || "1024x1024"; + chatStore.currentSession().mask.modelConfig?.size ?? "1024x1024"; useEffect(() => { const show = isVisionModel(currentModel); diff --git a/app/store/config.ts b/app/store/config.ts index 1eaafe12b1d..705a9d87c40 100644 --- a/app/store/config.ts +++ b/app/store/config.ts @@ -1,4 +1,5 @@ import { LLMModel } from "../client/api"; +import { DalleSize } from "../typing"; import { getClientConfig } from "../config/client"; import { DEFAULT_INPUT_TEMPLATE, @@ -60,6 +61,7 @@ export const DEFAULT_CONFIG = { compressMessageLengthThreshold: 1000, enableInjectSystemPrompts: true, template: config?.template ?? DEFAULT_INPUT_TEMPLATE, + size: "1024x1024" as DalleSize, }, }; diff --git a/app/typing.ts b/app/typing.ts index b09722ab902..86320358157 100644 --- a/app/typing.ts +++ b/app/typing.ts @@ -7,3 +7,5 @@ export interface RequestMessage { role: MessageRole; content: string; } + +export type DalleSize = "1024x1024" | "1792x1024" | "1024x1792"; From 8c83fe23a1661d37644626e8d71130d96ce413f9 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Fri, 2 Aug 2024 20:58:21 +0800 Subject: [PATCH 4/5] using b64_json for dall-e-3 --- app/client/platforms/openai.ts | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index 54309e29f7e..ee9a70913bd 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -11,7 +11,11 @@ import { } from "@/app/constant"; import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; import { collectModelsWithDefaultModel } from "@/app/utils/model"; -import { preProcessImageContent } from "@/app/utils/chat"; +import { + preProcessImageContent, + uploadImage, + base64Image2Blob, +} from "@/app/utils/chat"; import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare"; import { DalleSize } from "@/app/typing"; @@ -63,6 +67,7 @@ export interface RequestPayload { export interface DalleRequestPayload { model: string; prompt: string; + response_format: "url" | "b64_json"; n: number; size: DalleSize; } @@ -109,13 +114,18 @@ export class ChatGPTApi implements LLMApi { return cloudflareAIGatewayUrl([baseUrl, path].join("/")); } - extractMessage(res: any) { + async extractMessage(res: any) { if (res.error) { return "```\n" + JSON.stringify(res, null, 4) + "\n```"; } - // dalle3 model return url, just return + // dalle3 model return url, using url create image message if (res.data) { - const url = res.data?.at(0)?.url ?? ""; + let url = res.data?.at(0)?.url ?? ""; + const b64_json = res.data?.at(0)?.b64_json ?? ""; + if (!url && b64_json) { + // uploadImage + url = await uploadImage(base64Image2Blob(b64_json, "image/png")); + } return [ { type: "image_url", @@ -148,6 +158,8 @@ export class ChatGPTApi implements LLMApi { requestPayload = { model: options.config.model, prompt, + // URLs are only valid for 60 minutes after the image has been generated. + response_format: "b64_json", // using b64_json, and save image in CacheStorage n: 1, size: options.config?.size ?? "1024x1024", }; @@ -227,7 +239,7 @@ export class ChatGPTApi implements LLMApi { // make a fetch request const requestTimeoutId = setTimeout( () => controller.abort(), - REQUEST_TIMEOUT_MS, + isDalle3 ? REQUEST_TIMEOUT_MS * 2 : REQUEST_TIMEOUT_MS, // dalle3 using b64_json is slow. ); if (shouldStream) { @@ -358,7 +370,7 @@ export class ChatGPTApi implements LLMApi { clearTimeout(requestTimeoutId); const resJson = await res.json(); - const message = this.extractMessage(resJson); + const message = await this.extractMessage(resJson); options.onFinish(message); } } catch (e) { From 4a8e85c28a293c765ce73af6afb34aaa4840290e Mon Sep 17 00:00:00 2001 From: Dogtiti <499960698@qq.com> Date: Fri, 2 Aug 2024 22:16:08 +0800 Subject: [PATCH 5/5] fix: empty response --- app/client/platforms/openai.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index ee9a70913bd..8b03d1397e6 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -135,7 +135,7 @@ export class ChatGPTApi implements LLMApi { }, ]; } - return res.choices?.at(0)?.message?.content ?? ""; + return res.choices?.at(0)?.message?.content ?? res; } async chat(options: ChatOptions) {