Skip to content
This repository has been archived by the owner on Sep 15, 2024. It is now read-only.

Commit

Permalink
Add vision support (ChatGPTNextWeb#4076)
Browse files Browse the repository at this point in the history
  • Loading branch information
TheRamU authored and H0llyW00dzZ committed Feb 20, 2024
1 parent fc51375 commit 4424cac
Show file tree
Hide file tree
Showing 16 changed files with 635 additions and 106 deletions.
10 changes: 9 additions & 1 deletion app/client/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,17 @@ export type MessageRole = (typeof ROLES)[number];
export const Models = ["gpt-3.5-turbo", "gpt-4"] as const;
export type ChatModel = ModelType;

export interface MultimodalContent {
type: "text" | "image_url";
text?: string;
image_url?: {
url: string;
};
}

export interface RequestMessage {
role: MessageRole;
content: string;
content: string | MultimodalContent[];
}

export interface LLMConfig {
Expand Down
76 changes: 60 additions & 16 deletions app/client/platforms/google.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@ import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
import { getClientConfig } from "@/app/config/client";
import Locale from "../../locales";
import { getServerSideConfig } from "@/app/config/server";
import { getProviderFromState } from "@/app/utils";
import {
getProviderFromState,
getMessageTextContent,
getMessageImages,
isVisionModel,
} from "@/app/utils";
import { getNewStuff } from './NewStuffLLMs';


Expand Down Expand Up @@ -104,10 +109,33 @@ export class GeminiProApi implements LLMApi {
const provider = getProviderFromState();
const cfgspeed_animation = useAppConfig.getState().speed_animation; // Get the animation speed from the app config
// const apiClient = this;
const messages: Message[] = options.messages.map((v) => ({
role: v.role.replace("assistant", "model").replace("system", "user"),
parts: [{ text: v.content }],
}));
const visionModel = isVisionModel(options.config.model);
let multimodal = false;
const messages = options.messages.map((v) => {
let parts: any[] = [{ text: getMessageTextContent(v) }];
if (visionModel) {
const images = getMessageImages(v);
if (images.length > 0) {
multimodal = true;
parts = parts.concat(
images.map((image) => {
const imageType = image.split(";")[0].split(":")[1];
const imageData = image.split(",")[1];
return {
inline_data: {
mime_type: imageType,
data: imageData,
},
};
}),
);
}
}
return {
role: v.role.replace("assistant", "model").replace("system", "user"),
parts: parts,
};
});

// google requires that role in neighboring messages must not be the same
for (let i = 0; i < messages.length - 1;) {
Expand All @@ -118,8 +146,6 @@ export class GeminiProApi implements LLMApi {
i++;
}
}

const appConfig = useAppConfig.getState().modelConfig;
const chatConfig = useChatStore.getState().currentSession().mask.modelConfig;

// Call getNewStuff to determine the max_tokens and other configurations
Expand All @@ -130,11 +156,15 @@ export class GeminiProApi implements LLMApi {
chatConfig.useMaxTokens,
);

const modelConfig: ModelConfig = {
...appConfig,
...chatConfig,
// Use max_tokens from getNewStuff if defined, otherwise use the existing value
max_tokens: max_tokens !== undefined ? max_tokens : options.config.max_tokens,
// if (visionModel && messages.length > 1) {
// options.onError?.(new Error("Multiturn chat is not enabled for models/gemini-pro-vision"));
// }
const modelConfig = {
...useAppConfig.getState().modelConfig,
...useChatStore.getState().currentSession().mask.modelConfig,
...{
model: options.config.model,
},
};

const requestPayload = {
Expand Down Expand Up @@ -177,15 +207,16 @@ export class GeminiProApi implements LLMApi {
const controller = new AbortController();
options.onController?.(controller);
try {
// Note: With this refactoring, it's now possible to use `v1`, `v1beta` in the settings.
// However, this is just temporary and might need to be changed in the future.
let chatPath = this.path(accessStore.googleApiVersion + Google.ChatPath);
let googleChatPath = visionModel
? Google.VisionChatPath
: Google.ChatPath;
let chatPath = this.path(googleChatPath);

// let baseUrl = accessStore.googleUrl;

if (!baseUrl) {
baseUrl = isApp
? DEFAULT_API_HOST + "/api/proxy/google/" + accessStore.googleApiVersion + Google.ChatPath
? DEFAULT_API_HOST + "/api/proxy/google/" + googleChatPath
: chatPath;
}

Expand Down Expand Up @@ -252,6 +283,19 @@ export class GeminiProApi implements LLMApi {
value,
}): Promise<any> {
if (done) {
if (response.status !== 200) {
try {
let data = JSON.parse(ensureProperEnding(partialData));
if (data && data[0].error) {
options.onError?.(new Error(data[0].error.message));
} else {
options.onError?.(new Error("Request failed"));
}
} catch (_) {
options.onError?.(new Error("Request failed"));
}
}

console.log("[Streaming] Stream complete");
// options.onFinish(responseText + remainText);
finished = true;
Expand Down
41 changes: 15 additions & 26 deletions app/client/platforms/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@ import {
} from "@/app/constant";
import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";

import { ChatOptions, getHeaders, LLMApi, LLMModel, LLMUsage } from "../api";
import {
ChatOptions,
getHeaders,
LLMApi,
LLMModel,
LLMUsage,
MultimodalContent,
} from "../api";
import Locale from "../../locales";
import {
EventStreamContentType,
Expand All @@ -24,6 +31,11 @@ import { prettyObject } from "@/app/utils/format";
import { getClientConfig } from "@/app/config/client";
import { getProviderFromState } from "@/app/utils";
import { makeAzurePath } from "@/app/azure";
import {
getMessageTextContent,
getMessageImages,
isVisionModel,
} from "@/app/utils";

export interface OpenAIListModelResponse {
object: string;
Expand Down Expand Up @@ -88,33 +100,10 @@ export class ChatGPTApi implements LLMApi {
*
*/
async chat(options: ChatOptions) {
/**
* The text moderation configuration.
* @remarks
* This variable stores the text moderation settings obtained from the app configuration.
* @author H0llyW00dzZ
*/
const textmoderation = useAppConfig.getState().textmoderation;
const checkprovider = getProviderFromState();
const userMessageS = options.messages.filter((msg) => msg.role === "user");
const lastUserMessage = userMessageS[userMessageS.length - 1]?.content;
const moderationPath = this.path(OpenaiPath.ModerationPath);
// Check if text moderation is enabled and required
if (textmoderation !== false
&& options.whitelist !== true
// Skip text moderation for Azure provider since azure already have text-moderation, and its enabled by default on their service
&& checkprovider !== ServiceProvider.Azure) {
// Call the moderateText method and handle the result
const moderationResult = await moderateText(moderationPath, lastUserMessage, OpenaiPath.TextModerationModels.latest);
if (moderationResult) {
options.onFinish(moderationResult); // Finish early if moderationResult is not null
return;
}
}

const visionModel = isVisionModel(options.config.model);
const messages = options.messages.map((v) => ({
role: v.role,
content: v.content,
content: visionModel ? v.content : getMessageTextContent(v),
}));

const modelConfig = {
Expand Down
Loading

0 comments on commit 4424cac

Please sign in to comment.