From aed72672539e9d23ab33ddba2a00509a6009c569 Mon Sep 17 00:00:00 2001 From: jialin Date: Mon, 30 Dec 2024 16:06:45 +0800 Subject: [PATCH] chore: model meta data --- src/components/auto-tooltip/index.tsx | 18 ++- src/components/image-editor/index.tsx | 1 - src/config/global.d.ts | 1 + src/config/route-cachekey.ts | 2 +- src/layouts/index.tsx | 18 ++- src/locales/en-US/playground.ts | 8 +- src/locales/zh-CN/playground.ts | 5 +- .../llmodels/components/advance-config.tsx | 1 - src/pages/llmodels/components/data-form.tsx | 56 ++++--- .../llmodels/components/instance-item.tsx | 1 + .../llmodels/components/update-modal.tsx | 96 +++++++----- src/pages/llmodels/config/index.ts | 7 +- src/pages/llmodels/config/types.ts | 2 +- .../playground/components/dynamic-params.tsx | 30 ++-- .../playground/components/ground-images.tsx | 137 ++++++++++++++---- .../playground/components/image-edit.tsx | 111 ++++++++++++-- .../playground/components/params-settings.tsx | 67 +++++---- src/pages/playground/config/params-config.ts | 49 +++++-- src/pages/playground/images.tsx | 6 +- src/pages/playground/index.tsx | 6 +- 20 files changed, 445 insertions(+), 177 deletions(-) diff --git a/src/components/auto-tooltip/index.tsx b/src/components/auto-tooltip/index.tsx index c6d06637..c72c44ed 100644 --- a/src/components/auto-tooltip/index.tsx +++ b/src/components/auto-tooltip/index.tsx @@ -119,13 +119,17 @@ const AutoTooltip: React.FC = ({ borderRadius: 12 }} closeIcon={ - + tagProps.closable ? ( + + ) : ( + false + ) } > {children} diff --git a/src/components/image-editor/index.tsx b/src/components/image-editor/index.tsx index eb988c57..92da538c 100644 --- a/src/components/image-editor/index.tsx +++ b/src/components/image-editor/index.tsx @@ -199,7 +199,6 @@ const CanvasImageEditor: React.FC = ({ ctx.beginPath(); stroke.forEach((point, i) => { - console.log('Drawing Point:', point); if (i === 0) { ctx.moveTo(point.x, point.y); } else { diff --git a/src/config/global.d.ts b/src/config/global.d.ts index f9d5fef2..14086252 100644 --- a/src/config/global.d.ts +++ b/src/config/global.d.ts @@ -32,6 +32,7 @@ declare namespace Global { label: string; locale?: boolean; value: T; + meta?: Record; } interface HintOptions { diff --git a/src/config/route-cachekey.ts b/src/config/route-cachekey.ts index cb8847e1..6106839e 100644 --- a/src/config/route-cachekey.ts +++ b/src/config/route-cachekey.ts @@ -1,3 +1,3 @@ export default { - playgroundTextToImage: '/playground/text-to-image' + '/playground/text-to-image': '/playground/text-to-image' }; diff --git a/src/layouts/index.tsx b/src/layouts/index.tsx index f07a978e..d5023681 100644 --- a/src/layouts/index.tsx +++ b/src/layouts/index.tsx @@ -1,11 +1,12 @@ // @ts-nocheck -import { routeCacheAtom } from '@/atoms/route-cache'; +import { routeCacheAtom, setRouteCache } from '@/atoms/route-cache'; import { GPUStackVersionAtom, UpdateCheckAtom, userAtom } from '@/atoms/user'; import ShortCuts, { modalConfig as ShortCutsConfig } from '@/components/short-cuts'; import VersionInfo, { modalConfig } from '@/components/version-info'; +import routeCachekey from '@/config/route-cachekey'; import useOverlayScroller from '@/hooks/use-overlay-scroller'; import { logout } from '@/pages/login/apis'; import { useAccessMarkedRoutes } from '@@/plugin-access'; @@ -106,8 +107,6 @@ export default (props: any) => { const [collapsed, setCollapsed] = useState(false); const [collapseValue, setCollapseValue] = useState(false); - console.log('routeCache========', routeCache); - const initialInfo = (useModel && useModel('@@initialState')) || { initialState: undefined, loading: false, @@ -140,10 +139,15 @@ export default (props: any) => { }); }; + const initRouteCacheValue = (pathname) => { + if (routeCache.get(pathname) === undefined && routeCachekey[pathname]) { + setRouteCache(pathname, false); + } + }; + const dropRouteCache = (pathname) => { - console.log('routeCache.keys()========', routeCache.keys()); for (let key of routeCache.keys()) { - if (key !== pathname && !routeCache.get(key)) { + if (key !== pathname && !routeCache.get(key) && routeCachekey[key]) { dropByCacheKey(key); routeCache.delete(key); } @@ -302,6 +306,9 @@ export default (props: any) => { const { location } = history; const { pathname } = location; + initRouteCacheValue(pathname); + dropRouteCache(pathname); + // if user is not change password, redirect to change password page if ( location.pathname !== loginPath && @@ -321,7 +328,6 @@ export default (props: any) => { : '/playground'; history.push(pathname); } - dropRouteCache(pathname); }} formatMessage={formatMessage} menu={{ diff --git a/src/locales/en-US/playground.ts b/src/locales/en-US/playground.ts index 493382da..5b320e04 100644 --- a/src/locales/en-US/playground.ts +++ b/src/locales/en-US/playground.ts @@ -124,5 +124,11 @@ export default { 'playground.params.size.description': 'The maximum size of the generated image is controlled by the deployment parameters of the model. Refer to', 'playground.documents.verify.embedding': 'At least add two pieces of text.', - 'playground.documents.verify.rerank': 'The documents cannot be empty.' + 'playground.documents.verify.rerank': 'The documents cannot be empty.', + 'playground.image.guidance.tip': + 'The lower the value, the higher the diversity, and the lower the adherence to the prompt.', + 'playground.image.cfg_scale.tip': + 'The lower the value, the higher the diversity.', + 'playground.image.strength.tip': + 'The higher the value, the greater the modification to the original image.' }; diff --git a/src/locales/zh-CN/playground.ts b/src/locales/zh-CN/playground.ts index 05619232..9f56e418 100644 --- a/src/locales/zh-CN/playground.ts +++ b/src/locales/zh-CN/playground.ts @@ -121,5 +121,8 @@ export default { 'playground.params.size.description': '图片生成的最大尺寸受控于模型的部署参数。参考文档', 'playground.documents.verify.embedding': '至少输入两条文本', - 'playground.documents.verify.rerank': '文档不能为空' + 'playground.documents.verify.rerank': '文档不能为空', + 'playground.image.guidance.tip': '值越低,多样性越高,对提示词的贴合度越低', + 'playground.image.cfg_scale.tip': '值越低,多样性越高', + 'playground.image.strength.tip': '值越高,它对原图的修改越大' }; diff --git a/src/pages/llmodels/components/advance-config.tsx b/src/pages/llmodels/components/advance-config.tsx index 3c566dd8..3b62057b 100644 --- a/src/pages/llmodels/components/advance-config.tsx +++ b/src/pages/llmodels/components/advance-config.tsx @@ -156,7 +156,6 @@ const AdvanceConfig: React.FC = (props) => { name="categories"> = forwardRef((props, ref) => { const { action, isGGUF, onOk } = props; const [form] = Form.useForm(); @@ -57,28 +82,7 @@ const DataForm: React.FC = forwardRef((props, ref) => { speech2text: false }); - const sourceOptions = [ - { - label: 'Hugging Face', - value: modelSourceMap.huggingface_value, - key: 'huggingface' - }, - { - label: 'Ollama Library', - value: modelSourceMap.ollama_library_value, - key: 'ollama_library' - }, - { - label: 'ModelScope', - value: modelSourceMap.modelscope_value, - key: 'model_scope' - }, - { - label: intl.formatMessage({ id: 'models.form.localPath' }), - value: modelSourceMap.local_path_value, - key: 'local_path' - } - ]; + const localPathCache = useRef(''); const getGPUList = async () => { const data = await queryGPUList(); @@ -186,8 +190,15 @@ const DataForm: React.FC = forwardRef((props, ref) => { } }; + const handleOnFocus = () => { + localPathCache.current = form.getFieldValue('local_path'); + }; + const handleLocalPathBlur = (e: any) => { const value = e.target.value; + if (value === localPathCache.current && value) { + return; + } const isEndwithGGUF = _.endsWith(value, '.gguf'); let backend = backendOptionsMap.llamaBox; if (!isEndwithGGUF) { @@ -344,6 +355,7 @@ const DataForm: React.FC = forwardRef((props, ref) => { > diff --git a/src/pages/llmodels/components/instance-item.tsx b/src/pages/llmodels/components/instance-item.tsx index 851e5a29..41311393 100644 --- a/src/pages/llmodels/components/instance-item.tsx +++ b/src/pages/llmodels/components/instance-item.tsx @@ -69,6 +69,7 @@ const InstanceItem: React.FC = ({ InstanceStatusMap.Initializing, InstanceStatusMap.Running, InstanceStatusMap.Error, + InstanceStatusMap.Starting, InstanceStatusMap.Downloading ], icon: diff --git a/src/pages/llmodels/components/update-modal.tsx b/src/pages/llmodels/components/update-modal.tsx index bbe8558e..b1d44071 100644 --- a/src/pages/llmodels/components/update-modal.tsx +++ b/src/pages/llmodels/components/update-modal.tsx @@ -8,7 +8,14 @@ import { PageActionType } from '@/config/types'; import { useIntl } from '@umijs/max'; import { Form, Modal, Tooltip, Typography } from 'antd'; import _ from 'lodash'; -import React, { memo, useCallback, useEffect, useMemo, useState } from 'react'; +import React, { + memo, + useCallback, + useEffect, + useMemo, + useRef, + useState +} from 'react'; import SimpleBar from 'simplebar-react'; import 'simplebar-react/dist/simplebar.min.css'; import { queryGPUList } from '../apis'; @@ -35,6 +42,29 @@ const SEARCH_SOURCE = [ modelSourceMap.modelscope_value ]; +const sourceOptions = [ + { + label: 'Hugging Face', + value: modelSourceMap.huggingface_value, + key: 'huggingface' + }, + { + label: 'Ollama Library', + value: modelSourceMap.ollama_library_value, + key: 'ollama_library' + }, + { + label: 'ModelScope', + value: modelSourceMap.modelscope_value, + key: 'model_scope' + }, + { + label: 'models.form.localPath', + value: modelSourceMap.local_path_value, + key: 'local_path' + } +]; + const UpdateModal: React.FC = (props) => { const { title, action, open, onOk, onCancel } = props || {}; const [form] = Form.useForm(); @@ -42,6 +72,7 @@ const UpdateModal: React.FC = (props) => { const [gpuOptions, setGpuOptions] = useState([]); const [isGGUF, setIsGGUF] = useState(false); const [loading, setLoading] = useState(false); + const localPathCache = useRef(''); const getGPUList = async () => { const data = await queryGPUList(); @@ -57,29 +88,6 @@ const UpdateModal: React.FC = (props) => { setGpuOptions(list); }; - const sourceOptions = [ - { - label: 'Hugging Face', - value: modelSourceMap.huggingface_value, - key: 'huggingface' - }, - { - label: 'Ollama Library', - value: modelSourceMap.ollama_library_value, - key: 'ollama_library' - }, - { - label: 'ModelScope', - value: modelSourceMap.modelscope_value, - key: 'model_scope' - }, - { - label: intl.formatMessage({ id: 'models.form.localPath' }), - value: modelSourceMap.local_path_value, - key: 'local_path' - } - ]; - useEffect(() => { if (action === PageAction.EDIT && open) { const result = setSourceRepoConfigValue( @@ -106,6 +114,34 @@ const UpdateModal: React.FC = (props) => { setIsGGUF(props.data?.backend === backendOptionsMap.llamaBox); }, [props.data?.backend]); + const handleBackendChange = useCallback((val: string) => { + if (val === backendOptionsMap.llamaBox) { + form.setFieldsValue({ + distributed_inference_across_workers: true, + cpu_offloading: true + }); + } + form.setFieldValue('backend_version', ''); + }, []); + + const handleOnFocus = () => { + localPathCache.current = form.getFieldValue('local_path'); + }; + + const handleLocalPathBlur = (e: any) => { + const value = e.target.value; + if (value === localPathCache.current && value) { + return; + } + const isEndwithGGUF = _.endsWith(value, '.gguf'); + let backend = backendOptionsMap.llamaBox; + if (!isEndwithGGUF) { + backend = backendOptionsMap.vllm; + } + handleBackendChange?.(backend); + form.setFieldValue('backend', backend); + }; + const renderHuggingfaceFields = () => { return ( <> @@ -250,6 +286,8 @@ const UpdateModal: React.FC = (props) => { ]} > = (props) => { form.submit(); }; - const handleBackendChange = useCallback((val: string) => { - if (val === backendOptionsMap.llamaBox) { - form.setFieldsValue({ - distributed_inference_across_workers: true, - cpu_offloading: true - }); - } - form.setFieldValue('backend_version', ''); - }, []); - const handleOk = (formdata: FormData) => { let obj = {}; if (formdata.backend === backendOptionsMap.vllm) { diff --git a/src/pages/llmodels/config/index.ts b/src/pages/llmodels/config/index.ts index 0dad917c..46f6bf3f 100644 --- a/src/pages/llmodels/config/index.ts +++ b/src/pages/llmodels/config/index.ts @@ -161,6 +161,7 @@ export const modelSourceValueMap = { export const InstanceStatusMap = { Initializing: 'initializing', + Starting: 'starting', Pending: 'pending', Running: 'running', Scheduled: 'scheduled', @@ -183,7 +184,8 @@ export const InstanceStatusMapValue = { [InstanceStatusMap.Error]: 'Error', [InstanceStatusMap.Downloading]: 'Downloading', [InstanceStatusMap.Unknown]: 'Unknown', - [InstanceStatusMap.Analyzing]: 'Analyzing' + [InstanceStatusMap.Analyzing]: 'Analyzing', + [InstanceStatusMap.Starting]: 'Starting' }; export const status: any = { @@ -194,7 +196,8 @@ export const status: any = { [InstanceStatusMap.Error]: StatusMaps.error, [InstanceStatusMap.Downloading]: StatusMaps.transitioning, [InstanceStatusMap.Unknown]: StatusMaps.inactive, - [InstanceStatusMap.Analyzing]: StatusMaps.transitioning + [InstanceStatusMap.Analyzing]: StatusMaps.transitioning, + [InstanceStatusMap.Starting]: StatusMaps.transitioning }; export const ActionList = [ diff --git a/src/pages/llmodels/config/types.ts b/src/pages/llmodels/config/types.ts index 31ef9c9e..44a5b519 100644 --- a/src/pages/llmodels/config/types.ts +++ b/src/pages/llmodels/config/types.ts @@ -31,7 +31,7 @@ export interface ListItem { export interface FormData { backend?: string; - categories?: string; + categories?: string[]; backend_parameters?: string[]; backend_version?: string; source: string; diff --git a/src/pages/playground/components/dynamic-params.tsx b/src/pages/playground/components/dynamic-params.tsx index dcab7da2..f32e309a 100644 --- a/src/pages/playground/components/dynamic-params.tsx +++ b/src/pages/playground/components/dynamic-params.tsx @@ -67,25 +67,21 @@ const ParamsSettings: React.FC = forwardRef( })); useEffect(() => { + let model = selectedModel || ''; + if (showModelSelector) { - form.setFieldsValue({ - model: selectedModel || _.get(modelList, '[0].value'), - ...initialValues - }); - setParams({ - model: selectedModel || _.get(modelList, '[0].value'), - ...initialValues - }); - } else { - form.setFieldsValue({ - model: selectedModel || '', - ...initialValues - }); - setParams({ - model: selectedModel || '', - ...initialValues - }); + model = model || _.get(modelList, '[0].value'); } + + form.setFieldsValue({ + model: model, + ...initialValues + }); + setParams({ + model: model, + ...initialValues + }); + onModelChange?.(model); }, [modelList, showModelSelector, selectedModel, initialValues]); const handleModelChange = useCallback( diff --git a/src/pages/playground/components/ground-images.tsx b/src/pages/playground/components/ground-images.tsx index 713b7749..8dbf20da 100644 --- a/src/pages/playground/components/ground-images.tsx +++ b/src/pages/playground/components/ground-images.tsx @@ -13,7 +13,7 @@ import { } from '@/utils/fetch-chunk-data'; import { FileImageOutlined, SwapOutlined } from '@ant-design/icons'; import { useIntl, useSearchParams } from '@umijs/max'; -import { Button, Checkbox, Form, Tooltip } from 'antd'; +import { Button, Form, Tooltip } from 'antd'; import classNames from 'classnames'; import _ from 'lodash'; import 'overlayscrollbars/overlayscrollbars.css'; @@ -32,8 +32,9 @@ import { promptList } from '../config'; import { ImageAdvancedParamsConfig, ImageCustomSizeConfig, + ImageParamsConfig, ImageconstExtraConfig, - ImageParamsConfig as paramsConfig + imageSizeOptions } from '../config/params-config'; import { MessageItem, ParamsSchema } from '../config/types'; import '../style/ground-left.less'; @@ -48,6 +49,17 @@ interface MessageProps { loaded?: boolean; ref?: any; } + +// for advanced fields +const METAKEYS = [ + 'sample_method', + 'sampling_steps', + 'schedule_method', + 'cfg_scale', + 'guidance', + 'negative_prompt' +]; + const advancedFieldsDefaultValus = { seed: null, sample_method: 'euler_a', @@ -55,7 +67,7 @@ const advancedFieldsDefaultValus = { guidance: 3.5, sampling_steps: 10, negative_prompt: null, - schedule_method: 'discrete', + schedule_method: 'default', preview: null }; @@ -101,12 +113,10 @@ const GroundImages: React.FC = forwardRef((props, ref) => { const messageListLengthCache = useRef(0); const requestToken = useRef(null); const [currentPrompt, setCurrentPrompt] = useState(''); + const [modelMeta, setModelMeta] = useState({}); const form = useRef(null); const inputRef = useRef(null); - const previewRef = useRef({ - preview: false, - preview_faster: false - }); + const cacheFormData = useRef>({}); const size = Form.useWatch('size', form.current?.form); @@ -125,10 +135,52 @@ const GroundImages: React.FC = forwardRef((props, ref) => { }; }); + const paramsConfig = useMemo(() => { + const { max_height, max_width } = modelMeta || {}; + if ( + !max_height || + !max_width || + (max_height === 1024 && max_width === 1024) + ) { + return ImageParamsConfig; + } + const newImageSizeOptions = imageSizeOptions.filter((item) => { + return item.width <= max_width && item.height <= max_height; + }); + if ( + !newImageSizeOptions.find( + (item) => item.width === max_width && item.height === max_height + ) + ) { + newImageSizeOptions.push({ + width: max_width, + height: max_height, + label: `${max_width}x${max_height}`, + value: `${max_width}x${max_height}` + }); + } + return ImageParamsConfig.map((item) => { + if (item.name === 'size') { + return { + ...item, + options: newImageSizeOptions + }; + } + return item; + }); + }, [modelMeta]); + const generateNumber = (min: number, max: number) => { return Math.floor(Math.random() * (max - min + 1) + min); }; + const updateCacheFormData = (values: Record) => { + cacheFormData.current = { + ...cacheFormData.current, + ...values + }; + }; + const handleRandomPrompt = useCallback(() => { const randomIndex = generateNumber(0, promptList.length - 1); const randomPrompt = promptList[randomIndex]; @@ -211,8 +263,10 @@ const GroundImages: React.FC = forwardRef((props, ref) => { setMessageId(); setTokenResult(null); setCurrentPrompt(current?.content || ''); - setRouteCache(routeCachekey.playgroundTextToImage, true); - const imgSize = _.split(finalParameters.size, 'x'); + setRouteCache(routeCachekey['/playground/text-to-image'], true); + const imgSize = _.split(finalParameters.size, 'x').map((item: number) => + _.toNumber(item) + ); // preview let stream_options: Record = { @@ -320,7 +374,7 @@ const GroundImages: React.FC = forwardRef((props, ref) => { setImageList([]); } finally { setLoading(false); - setRouteCache(routeCachekey.playgroundTextToImage, false); + setRouteCache(routeCachekey['/playground/text-to-image'], false); } }; const handleClear = () => { @@ -345,12 +399,14 @@ const GroundImages: React.FC = forwardRef((props, ref) => { const handleToggleParamsStyle = () => { if (isOpenaiCompatible) { form.current?.form?.setFieldsValue({ - ...advancedFieldsDefaultValus + ...advancedFieldsDefaultValus, + ..._.pick(cacheFormData.current, _.keys(advancedFieldsDefaultValus)) }); setParams((pre: object) => { return { ..._.omit(pre, _.keys(openaiCompatibleFieldsDefaultValus)), - ...advancedFieldsDefaultValus + ...advancedFieldsDefaultValus, + ..._.pick(cacheFormData.current, _.keys(advancedFieldsDefaultValus)) }; }); } else { @@ -365,6 +421,7 @@ const GroundImages: React.FC = forwardRef((props, ref) => { }); } setIsOpenaiCompatible(!isOpenaiCompatible); + updateCacheFormData(parameters); }; const renderExtra = useMemo(() => { @@ -428,7 +485,12 @@ const GroundImages: React.FC = forwardRef((props, ref) => { : item.description?.text } onChange={item.name === 'random_seed' ? handleFieldChange : null} - {..._.omit(item, ['name', 'rules', 'disabledConfig'])} + {..._.omit(item, [ + 'name', + 'rules', + 'disabledConfig', + 'description' + ])} > ); @@ -470,34 +532,44 @@ const GroundImages: React.FC = forwardRef((props, ref) => { 'rules', 'disabledConfig' ])} + max={ + item.name === 'height' + ? modelMeta.max_height || item.attrs?.max + : modelMeta.max_width || item.attrs?.max + } > ); }); } return null; - }, [size, intl]); + }, [size, intl, modelMeta]); - const hanldeOnPreview = (e: any) => { - previewRef.current.preview = e.target.checked; - }; + const handleOnModelChange = useCallback( + (val: string) => { + if (!val) return; - const hanldeOnPreviewFaster = (e: any) => { - previewRef.current.preview_faster = e.target.checked; - }; + const model = modelList.find((item) => item.value === val); - const renderPreview = useMemo(() => { - return ( - <> - - Preview - - - Preview Faster - - - ); - }, []); + setModelMeta(model?.meta || {}); + + if (!isOpenaiCompatible) { + setParams((pre: object) => { + return { + ...pre, + ..._.pick(model?.meta, METAKEYS, {}) + }; + }); + form.current?.form?.setFieldsValue({ + ..._.pick(model?.meta, METAKEYS, {}) + }); + } + updateCacheFormData({ + ..._.pick(model?.meta, METAKEYS, {}) + }); + }, + [modelList, isOpenaiCompatible] + ); useEffect(() => { return () => { @@ -672,6 +744,7 @@ const GroundImages: React.FC = forwardRef((props, ref) => { } + onModelChange={handleOnModelChange} setParams={setParams} paramsConfig={paramsConfig} initialValues={initialValues} diff --git a/src/pages/playground/components/image-edit.tsx b/src/pages/playground/components/image-edit.tsx index 06f070bd..3d921e8f 100644 --- a/src/pages/playground/components/image-edit.tsx +++ b/src/pages/playground/components/image-edit.tsx @@ -33,8 +33,9 @@ import { EDIT_IMAGE_API } from '../apis'; import { ImageAdvancedParamsConfig, ImageCustomSizeConfig, + ImageParamsConfig, ImageconstExtraConfig, - ImageEidtParamsConfig as paramsConfig + imageSizeOptions } from '../config/params-config'; import { MessageItem, ParamsSchema } from '../config/types'; import '../style/ground-left.less'; @@ -49,6 +50,18 @@ interface MessageProps { loaded?: boolean; ref?: any; } + +// for advanced fields +const METAKEYS = [ + 'sample_method', + 'sampling_steps', + 'schedule_method', + 'cfg_scale', + 'guidance', + 'negative_prompt', + 'strength' +]; + const advancedFieldsDefaultValus = { seed: 1, sample_method: 'euler_a', @@ -58,7 +71,7 @@ const advancedFieldsDefaultValus = { sampling_steps: 10, negative_prompt: null, preview: null, - schedule_method: 'discrete' + schedule_method: 'default' }; const openaiCompatibleFieldsDefaultValus = { @@ -109,6 +122,7 @@ const GroundImages: React.FC = forwardRef((props, ref) => { const [image, setImage] = useState(''); const [mask, setMask] = useState(''); const [uploadList, setUploadList] = useState([]); + const [modelMeta, setModelMeta] = useState({}); const [imageStatus, setImageStatus] = useState<{ isOriginal: boolean; isResetNeeded: boolean; @@ -116,7 +130,7 @@ const GroundImages: React.FC = forwardRef((props, ref) => { isOriginal: false, isResetNeeded: false }); - + const cacheFormData = useRef({}); const size = Form.useWatch('size', form.current?.form); const { initialize, updateScrollerPosition } = useOverlayScroller(); @@ -134,6 +148,48 @@ const GroundImages: React.FC = forwardRef((props, ref) => { }; }); + const updateCacheFormData = (values: Record) => { + cacheFormData.current = { + ...cacheFormData.current, + ...values + }; + }; + + const paramsConfig = useMemo(() => { + const { max_height, max_width } = modelMeta || {}; + if ( + !max_height || + !max_width || + (max_height === 1024 && max_width === 1024) + ) { + return ImageParamsConfig; + } + const newImageSizeOptions = imageSizeOptions.filter((item) => { + return item.width <= max_width && item.height <= max_height; + }); + if ( + !newImageSizeOptions.find( + (item) => item.width === max_width && item.height === max_height + ) + ) { + newImageSizeOptions.push({ + width: max_width, + height: max_height, + label: `${max_width}x${max_height}`, + value: `${max_width}x${max_height}` + }); + } + return ImageParamsConfig.map((item) => { + if (item.name === 'size') { + return { + ...item, + options: newImageSizeOptions + }; + } + return item; + }); + }, [modelMeta]); + const setImageSize = useCallback(() => { let size: Record = { span: 12 @@ -154,12 +210,10 @@ const GroundImages: React.FC = forwardRef((props, ref) => { }, [parameters.n]); const imageFile = useMemo(() => { - console.log('image:', image); return base64ToFile(image, 'image'); }, [image]); const maskFile = useMemo(() => { - console.log('mask:', mask); return base64ToFile(mask, 'mask'); }, [mask]); @@ -224,7 +278,7 @@ const GroundImages: React.FC = forwardRef((props, ref) => { setMessageId(); setTokenResult(null); setCurrentPrompt(current?.content || ''); - setRouteCache(routeCachekey.playgroundTextToImage, true); + setRouteCache(routeCachekey['/playground/text-to-image'], true); const imgSize = _.split(finalParameters.size, 'x').map((item: string) => _.toNumber(item) @@ -284,7 +338,7 @@ const GroundImages: React.FC = forwardRef((props, ref) => { const result: any = await fetchChunkedData({ data: params, - // url: `http://192.168.50.174:40935/v1/images/edits?t=${Date.now()}`, + // url: `http:///v1/images/edits?t=${Date.now()}`, url: `${EDIT_IMAGE_API}?t=${Date.now()}`, signal: requestToken.current.signal }); @@ -339,7 +393,7 @@ const GroundImages: React.FC = forwardRef((props, ref) => { setImageList([]); } finally { setLoading(false); - setRouteCache(routeCachekey.playgroundTextToImage, false); + setRouteCache(routeCachekey['/playground/text-to-image'], false); } }; const handleClear = () => { @@ -447,7 +501,12 @@ const GroundImages: React.FC = forwardRef((props, ref) => { : item.description?.text || '' } onChange={item.name === 'random_seed' ? handleFieldChange : null} - {..._.omit(item, ['name', 'rules', 'disabledConfig'])} + {..._.omit(item, [ + 'name', + 'rules', + 'disabledConfig', + 'description' + ])} > ); @@ -489,13 +548,44 @@ const GroundImages: React.FC = forwardRef((props, ref) => { 'rules', 'disabledConfig' ])} + max={ + item.name === 'height' + ? modelMeta.max_height || item.attrs?.max + : modelMeta.max_width || item.attrs?.max + } > ); }); } return null; - }, [size, intl]); + }, [size, intl, modelMeta]); + + const handleOnModelChange = useCallback( + (val: string) => { + if (!val) return; + + const model = modelList.find((item) => item.value === val); + + setModelMeta(model?.meta || {}); + + if (!isOpenaiCompatible) { + setParams((pre: object) => { + return { + ...pre, + ..._.pick(model?.meta, METAKEYS, {}) + }; + }); + form.current?.form?.setFieldsValue({ + ..._.pick(model?.meta, METAKEYS, {}) + }); + } + updateCacheFormData({ + ..._.pick(model?.meta, METAKEYS, {}) + }); + }, + [modelList, isOpenaiCompatible] + ); const handleUpdateImageList = useCallback((base64List: any) => { const img = _.get(base64List, '[0].dataUrl', ''); @@ -758,6 +848,7 @@ const GroundImages: React.FC = forwardRef((props, ref) => { } + onModelChange={handleOnModelChange} setParams={setParams} paramsConfig={paramsConfig} initialValues={initialValues} diff --git a/src/pages/playground/components/params-settings.tsx b/src/pages/playground/components/params-settings.tsx index 2408b17e..7ff6f7cf 100644 --- a/src/pages/playground/components/params-settings.tsx +++ b/src/pages/playground/components/params-settings.tsx @@ -29,6 +29,14 @@ type ParamsSettingsProps = { globalParams?: ParamsSettingsFormProps; }; +const METAKEYS: Record = { + seed: 'seed', + stop: 'stop', + temperature: 'temperature', + top_p: 'top_p', + max_tokens: 'n_ctx' +}; + const ParamsSettings: React.FC = ({ selectedModel, setParams, @@ -49,28 +57,6 @@ const ParamsSettings: React.FC = ({ const [form] = Form.useForm(); const formId = useId(); - useEffect(() => { - if (showModelSelector) { - form.setFieldsValue({ - ...initialValues, - model: selectedModel || _.get(modelList, '[0].value') - }); - setParams({ - ...initialValues, - model: selectedModel || _.get(modelList, '[0].value') - }); - } else { - form.setFieldsValue({ - ...initialValues, - model: selectedModel || '' - }); - setParams({ - ...initialValues, - model: selectedModel || '' - }); - } - }, [modelList, showModelSelector, selectedModel]); - const handleOnFinish = (values: any) => { console.log('handleOnFinish', values); }; @@ -108,11 +94,42 @@ const ParamsSettings: React.FC = ({ [form, setParams, onValuesChange] ); - const handleResetParams = () => { - form.setFieldsValue(initialValues); - setParams(initialValues); + const handleModelChange = (val: string) => { + const model = _.find(modelList, { value: val }); + const modelMeta = model?.meta || {}; + const keys = Object.keys(METAKEYS).map((k: string) => { + return METAKEYS[k]; + }); + const modelMetaKeys = _.pick(modelMeta, keys); + const obj = _.reduce( + METAKEYS, + (result: any, value: any, key: string) => { + result[key] = modelMetaKeys[value]; + return result; + }, + {} + ); + return obj; }; + useEffect(() => { + let model = selectedModel || ''; + if (showModelSelector) { + model = model || _.get(modelList, '[0].value'); + } + const modelMetaData = handleModelChange(model); + form.setFieldsValue({ + ...initialValues, + ...modelMetaData, + model: model + }); + setParams({ + ...initialValues, + ...modelMetaData, + model: model + }); + }, [modelList, showModelSelector, selectedModel]); + useEffect(() => { form.setFieldsValue(globalParams); }, [globalParams]); diff --git a/src/pages/playground/config/params-config.ts b/src/pages/playground/config/params-config.ts index 680825f7..1505157a 100644 --- a/src/pages/playground/config/params-config.ts +++ b/src/pages/playground/config/params-config.ts @@ -1,5 +1,25 @@ import { ParamsSchema } from './types'; +export const imageSizeOptions: { + label: string; + value: string; + width: number; + height: number; + locale?: boolean; +}[] = [ + { + label: 'playground.params.custom', + value: 'custom', + locale: true, + width: 0, + height: 0 + }, + { label: '512x512', value: '512x512', width: 512, height: 512 }, + { label: '768x1024', value: '768x1024', width: 768, height: 1024 }, + { label: '1024x768', value: '1024x768', width: 1024, height: 768 }, + { label: '1024x1024', value: '1024x1024', width: 1024, height: 1024 } +]; + export const TTSParamsConfig: ParamsSchema[] = [ { type: 'Select', @@ -104,13 +124,7 @@ export const ImageParamsConfig: ParamsSchema[] = [ { type: 'Select', name: 'size', - options: [ - { label: 'playground.params.custom', value: 'custom', locale: true }, - { label: '512x512', value: '512x512' }, - { label: '768x1024', value: '768x1024' }, - { label: '1024x768', value: '1024x768' }, - { label: '1024x1024', value: '1024x1024' } - ], + options: imageSizeOptions, description: { text: 'playground.params.size.description', html: true, @@ -264,6 +278,7 @@ export const ImageAdvancedParamsConfig: ParamsSchema[] = [ type: 'Select', name: 'schedule_method', options: [ + { label: 'default', value: 'default' }, { label: 'discrete', value: 'discrete' }, { label: 'karras', value: 'karras' }, { label: 'exponential', value: 'exponential' }, @@ -304,6 +319,11 @@ export const ImageAdvancedParamsConfig: ParamsSchema[] = [ text: 'Guidance', isLocalized: false }, + description: { + text: 'playground.image.guidance.tip', + html: false, + isLocalized: true + }, attrs: { min: 1.0, max: 10, @@ -322,11 +342,11 @@ export const ImageAdvancedParamsConfig: ParamsSchema[] = [ text: 'Strength', isLocalized: false }, - // description: { - // text: '值越高,它对原图的修改越大,更多变化', - // html: false, - // isLocalized: false - // }, + description: { + text: 'playground.image.strength.tip', + html: false, + isLocalized: true + }, attrs: { min: 0, max: 1, @@ -345,6 +365,11 @@ export const ImageAdvancedParamsConfig: ParamsSchema[] = [ text: 'CFG Scale', isLocalized: false }, + description: { + text: 'playground.image.cfg_scale.tip', + html: false, + isLocalized: true + }, attrs: { min: 1.0, max: 10, diff --git a/src/pages/playground/images.tsx b/src/pages/playground/images.tsx index 101ba617..9927370e 100644 --- a/src/pages/playground/images.tsx +++ b/src/pages/playground/images.tsx @@ -85,13 +85,15 @@ const TextToImages: React.FC = () => { const getModelList = async () => { try { const params = { - categories: 'image' + categories: 'image', + with_meta: true }; const res = await queryModelsList(params); const list = _.map(res.data || [], (item: any) => { return { value: item.id, - label: item.id + label: item.id, + meta: item.meta }; }) as Global.BaseOption[]; return list; diff --git a/src/pages/playground/index.tsx b/src/pages/playground/index.tsx index 6911de86..227e8a20 100644 --- a/src/pages/playground/index.tsx +++ b/src/pages/playground/index.tsx @@ -79,13 +79,15 @@ const Playground: React.FC = () => { const getModelList = async () => { try { const params = { - categories: '' + categories: '', + with_meta: true }; const res = await queryModelsList(params); const list = _.map(res.data || [], (item: any) => { return { value: item.id, - label: item.id + label: item.id, + meta: item.meta }; }) as Global.BaseOption[]; return list;