diff --git a/src/components/seal-form/components/wrapper.less b/src/components/seal-form/components/wrapper.less index 3c8dc148..53a986d5 100644 --- a/src/components/seal-form/components/wrapper.less +++ b/src/components/seal-form/components/wrapper.less @@ -163,6 +163,7 @@ } :global(.ant-select-selector) { + flex: 1; border: none !important; padding-block: 5px; padding-inline: @input-inner-padding !important; diff --git a/src/locales/en-US/common.ts b/src/locales/en-US/common.ts index edcbf12d..5213c45a 100644 --- a/src/locales/en-US/common.ts +++ b/src/locales/en-US/common.ts @@ -219,5 +219,6 @@ export default { 'common.button.recreate': 'Recreate', 'common.button.delrecreate': 'Delete (Recreate)', 'common.options.all': 'All', - 'common.options.none': 'None' + 'common.options.none': 'None', + 'common.options.auto': 'Auto' }; diff --git a/src/locales/en-US/models.ts b/src/locales/en-US/models.ts index 71b29e93..97f253db 100644 --- a/src/locales/en-US/models.ts +++ b/src/locales/en-US/models.ts @@ -44,6 +44,7 @@ export default { 'models.search.unsupport': 'This model is not supported and may be unusable after deployment.', 'models.form.scheduletype': 'Schedule Type', + 'models.form.categories': 'Model Category', 'models.form.scheduletype.auto': 'Auto', 'models.form.scheduletype.manual': 'Manual', 'models.form.scheduletype.auto.tips': diff --git a/src/locales/zh-CN/common.ts b/src/locales/zh-CN/common.ts index 585b8259..c3c969d2 100644 --- a/src/locales/zh-CN/common.ts +++ b/src/locales/zh-CN/common.ts @@ -212,5 +212,6 @@ export default { 'common.button.recreate': '重新创建', 'common.button.delrecreate': '删除(重建)', 'common.options.all': '全部', - 'common.options.none': '无' + 'common.options.none': '无', + 'common.options.auto': '自动' }; diff --git a/src/locales/zh-CN/models.ts b/src/locales/zh-CN/models.ts index e32bf7cd..c9861c1d 100644 --- a/src/locales/zh-CN/models.ts +++ b/src/locales/zh-CN/models.ts @@ -42,6 +42,7 @@ export default { 'models.search.networkerror': '网络连接异常!', 'models.search.hfvisit': '请确保您可以访问', 'models.search.unsupport': '暂不支持该模型,部署后可能无法使用', + 'models.form.categories': '模型类别', 'models.form.scheduletype': '调度方式', 'models.form.scheduletype.auto': '自动', 'models.form.scheduletype.manual': '手动', diff --git a/src/pages/llmodels/components/advance-config.tsx b/src/pages/llmodels/components/advance-config.tsx index 147e1ba7..bc77c3c9 100644 --- a/src/pages/llmodels/components/advance-config.tsx +++ b/src/pages/llmodels/components/advance-config.tsx @@ -20,6 +20,7 @@ import React, { useCallback, useMemo } from 'react'; import { backendOptionsMap, backendParamsHolderTips, + modelCategories, placementStrategyOptions } from '../config'; import llamaConfig from '../config/llama-config'; @@ -153,21 +154,16 @@ const AdvanceConfig: React.FC = (props) => { const collapseItems = useMemo(() => { const children = ( <> - {/* name="labels"> - - */} + name="categories"> + + = (props) => { )} {scheduleType === 'manual' && ( - name="gpu_selector" + name={['gpu_selector', 'gpu_ids']} rules={[ { required: true, @@ -265,6 +261,8 @@ const AdvanceConfig: React.FC = (props) => { {gpuOptions.map((item) => ( diff --git a/src/pages/llmodels/components/data-form.tsx b/src/pages/llmodels/components/data-form.tsx index 4b9fdfa1..516c7ea9 100644 --- a/src/pages/llmodels/components/data-form.tsx +++ b/src/pages/llmodels/components/data-form.tsx @@ -86,7 +86,7 @@ const DataForm: React.FC = forwardRef((props, ref) => { return { ...item, label: item.name, - value: `${item.worker_name}-${item.name}-${item.index}` + value: item.id }; }); console.log('queryGPUList========', list); @@ -381,23 +381,15 @@ const DataForm: React.FC = forwardRef((props, ref) => { }, []); const handleOk = (formdata: FormData) => { - const gpu = _.find(gpuOptions, (item: any) => { - return item.value === formdata.gpu_selector; - }); - if (gpu) { - onOk({ - ..._.omit(formdata, ['scheduleType']), - gpu_selector: { - gpu_name: gpu.name, - gpu_index: gpu.index, - worker_name: gpu.worker_name - } - }); + let data = _.cloneDeep(formdata); + if (data.categories) { + data.categories = [data.categories]; } else { - onOk({ - ..._.omit(formdata, ['scheduleType']) - }); + data.categories = []; } + onOk({ + ..._.omit(data, ['scheduleType']) + }); }; useEffect(() => { @@ -430,6 +422,7 @@ const DataForm: React.FC = forwardRef((props, ref) => { placement_strategy: 'spread', cpu_offloading: true, scheduleType: 'auto', + categories: null, distributed_inference_across_workers: true }} > diff --git a/src/pages/llmodels/components/table-list.tsx b/src/pages/llmodels/components/table-list.tsx index 3f27bbe1..d164c4f8 100644 --- a/src/pages/llmodels/components/table-list.tsx +++ b/src/pages/llmodels/components/table-list.tsx @@ -44,6 +44,7 @@ import { import { InstanceRealLogStatus, getSourceRepoConfigValue, + modelCategoriesMap, modelSourceMap } from '../config'; import { FormData, ListItem, ModelInstanceListItem } from '../config/types'; @@ -556,7 +557,7 @@ const Models: React.FC = ({ const renderModelTags = useCallback( (record: ListItem) => { - if (record.reranker) { + if (record.categories?.includes(modelCategoriesMap.reranker)) { return ( } @@ -574,7 +575,7 @@ const Models: React.FC = ({ ); } - if (record.embedding_only && !record.reranker) { + if (record.categories?.includes(modelCategoriesMap.embedding)) { return ( } @@ -591,7 +592,7 @@ const Models: React.FC = ({ ); } - if (record.text_to_speech) { + if (record.categories?.includes(modelCategoriesMap.text_to_speech)) { return ( } @@ -608,7 +609,7 @@ const Models: React.FC = ({ ); } - if (record.speech_to_text) { + if (record.categories?.includes(modelCategoriesMap.speech_to_text)) { return ( } @@ -625,7 +626,7 @@ const Models: React.FC = ({ ); } - if (record.image_only) { + if (record.categories?.includes(modelCategoriesMap.image)) { return ( } diff --git a/src/pages/llmodels/components/update-modal.tsx b/src/pages/llmodels/components/update-modal.tsx index 68bb7249..bdf9a004 100644 --- a/src/pages/llmodels/components/update-modal.tsx +++ b/src/pages/llmodels/components/update-modal.tsx @@ -88,6 +88,9 @@ const UpdateModal: React.FC = (props) => { const formData = { ...result.values, ..._.omit(props.data, result.omits), + categories: props.data?.categories?.length + ? props.data.categories[0] + : null, scheduleType: props.data?.gpu_selector ? 'manual' : 'auto', gpu_selector: props.data?.gpu_selector ? `${props.data?.gpu_selector.worker_name}-${props.data?.gpu_selector.gpu_name}-${props.data?.gpu_selector.gpu_index}` @@ -303,6 +306,7 @@ const UpdateModal: React.FC = (props) => { onOk({ ..._.omit(formdata, ['scheduleType']), + categories: formdata.categories ? [formdata.categories] : [], worker_selector: null, gpu_selector: gpu ? { @@ -316,6 +320,7 @@ const UpdateModal: React.FC = (props) => { } else { onOk({ ..._.omit(formdata, ['scheduleType']), + categories: formdata.categories ? [formdata.categories] : [], gpu_selector: null, ...obj }); diff --git a/src/pages/llmodels/config/index.ts b/src/pages/llmodels/config/index.ts index 46c8c444..0dad917c 100644 --- a/src/pages/llmodels/config/index.ts +++ b/src/pages/llmodels/config/index.ts @@ -240,6 +240,23 @@ export const placementStrategyOptions = [ } ]; +export const modelCategoriesMap = { + image: 'image', + text_to_speech: 'text_to_speech', + speech_to_text: 'speech_to_text', + embedding: 'embedding', + reranker: 'reranker' +}; + +export const modelCategories = [ + { label: 'common.options.auto', value: null, locale: true }, + { label: 'Image', value: 'image' }, + { label: 'Text-to-speech', value: 'text_to_speech' }, + { label: 'Speech-to-text', value: 'speech_to_text' }, + { label: 'Embedding', value: 'embedding' }, + { label: 'Reranker', value: 'reranker' } +]; + export const sourceRepoConfig = { [modelSourceMap.huggingface_value]: { repo_id: 'huggingface_repo_id', diff --git a/src/pages/llmodels/config/types.ts b/src/pages/llmodels/config/types.ts index 46a74f20..31ef9c9e 100644 --- a/src/pages/llmodels/config/types.ts +++ b/src/pages/llmodels/config/types.ts @@ -1,6 +1,7 @@ export interface ListItem { source: string; backend: string; + categories?: string[]; reranker: boolean; image_only?: boolean; huggingface_repo_id: string; @@ -23,15 +24,14 @@ export interface ListItem { created_at: string; updated_at: string; gpu_selector?: { - worker_name: string; - gpu_index: number; - gpu_name: string; + gpu_ids: string[]; }; worker_selector?: object; } export interface FormData { backend?: string; + categories?: string; backend_parameters?: string[]; backend_version?: string; source: string; @@ -46,9 +46,7 @@ export interface FormData { model_scope_model_id?: string; model_scope_file_path?: string; gpu_selector?: { - worker_name: string; - gpu_index: number; - gpu_name: string; + gpu_ids: string[]; }; placement_strategy?: string; cpu_offloading?: boolean; diff --git a/src/pages/playground/embedding.tsx b/src/pages/playground/embedding.tsx index 4f1c1116..124388b9 100644 --- a/src/pages/playground/embedding.tsx +++ b/src/pages/playground/embedding.tsx @@ -30,8 +30,7 @@ const PlaygroundEmbedding: React.FC = () => { const getModelListByEmbedding = async () => { try { const params = { - embedding_only: true, - reranker: false + categories: 'embedding' }; const res = await queryModelsList(params); const list = _.map(res.data || [], (item: any) => { diff --git a/src/pages/playground/images.tsx b/src/pages/playground/images.tsx index e2c85c9f..fac2d840 100644 --- a/src/pages/playground/images.tsx +++ b/src/pages/playground/images.tsx @@ -85,7 +85,7 @@ const TextToImages: React.FC = () => { const getModelList = async () => { try { const params = { - image_only: true + categories: 'image' }; const res = await queryModelsList(params); const list = _.map(res.data || [], (item: any) => { diff --git a/src/pages/playground/index.tsx b/src/pages/playground/index.tsx index 819b8d71..6911de86 100644 --- a/src/pages/playground/index.tsx +++ b/src/pages/playground/index.tsx @@ -79,11 +79,7 @@ const Playground: React.FC = () => { const getModelList = async () => { try { const params = { - embedding_only: false, - image_only: false, - reranker: false, - text_to_speech: false, - speech_to_text: false + categories: '' }; const res = await queryModelsList(params); const list = _.map(res.data || [], (item: any) => { diff --git a/src/pages/playground/rerank.tsx b/src/pages/playground/rerank.tsx index 1fb8cf6a..77b3effe 100644 --- a/src/pages/playground/rerank.tsx +++ b/src/pages/playground/rerank.tsx @@ -32,7 +32,7 @@ const PlaygroundRerank: React.FC = () => { const getModelListByReranker = async () => { try { const params = { - reranker: true + categories: 'reranker' }; const res = await queryModelsList(params); const list = _.map(res.data || [], (item: any) => { diff --git a/src/pages/playground/speech.tsx b/src/pages/playground/speech.tsx index e4c6488e..7c5efc7e 100644 --- a/src/pages/playground/speech.tsx +++ b/src/pages/playground/speech.tsx @@ -102,7 +102,7 @@ const Playground: React.FC = () => { const getTextToSpeechModels = async () => { try { const params = { - text_to_speech: true + categories: 'text_to_speech' }; const res = await queryModelsList(params); const list = _.map(res.data || [], (item: any) => { @@ -120,7 +120,7 @@ const Playground: React.FC = () => { const getSpeechToText = async () => { try { const params = { - speech_to_text: true + categories: 'speech_to_text' }; const res = await queryModelsList(params); const list = _.map(res.data || [], (item: any) => {