From 1708525d0125651f748f30652f89684036a8436f Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Thu, 5 Sep 2024 11:28:21 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E9=87=8D=E6=8E=92?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/application/flow/step_node/__init__.py | 3 +- .../flow/step_node/reranker_node/__init__.py | 9 + .../reranker_node/i_reranker_node.py | 59 ++++ .../step_node/reranker_node/impl/__init__.py | 9 + .../reranker_node/impl/base_reranker_node.py | 73 +++++ .../models_provider/base_model_provider.py | 1 + .../credential/reranker.py | 47 +++ .../local_model_provider.py | 6 + .../local_model_provider/model/reranker.py | 97 ++++++ .../serializers/model_apply_serializers.py | 20 ++ apps/setting/urls.py | 5 +- apps/setting/views/model_apply.py | 10 + ui/src/api/application.ts | 16 +- ui/src/assets/icon_reranker.svg | 16 + ui/src/enums/workflow.ts | 3 +- ui/src/workflow/common/data.ts | 28 +- ui/src/workflow/icons/reranker-node-icon.vue | 6 + .../reranker-node/ParamSettingDialog.vue | 266 +++++++++++++++ ui/src/workflow/nodes/reranker-node/index.ts | 12 + ui/src/workflow/nodes/reranker-node/index.vue | 303 ++++++++++++++++++ 20 files changed, 983 insertions(+), 6 deletions(-) create mode 100644 apps/application/flow/step_node/reranker_node/__init__.py create mode 100644 apps/application/flow/step_node/reranker_node/i_reranker_node.py create mode 100644 apps/application/flow/step_node/reranker_node/impl/__init__.py create mode 100644 apps/application/flow/step_node/reranker_node/impl/base_reranker_node.py create mode 100644 apps/setting/models_provider/impl/local_model_provider/credential/reranker.py create mode 100644 apps/setting/models_provider/impl/local_model_provider/model/reranker.py create mode 100644 ui/src/assets/icon_reranker.svg create mode 100644 ui/src/workflow/icons/reranker-node-icon.vue create mode 100644 ui/src/workflow/nodes/reranker-node/ParamSettingDialog.vue create mode 100644 ui/src/workflow/nodes/reranker-node/index.ts create mode 100644 ui/src/workflow/nodes/reranker-node/index.vue diff --git a/apps/application/flow/step_node/__init__.py b/apps/application/flow/step_node/__init__.py index 1d5af03ca1..62273818cb 100644 --- a/apps/application/flow/step_node/__init__.py +++ b/apps/application/flow/step_node/__init__.py @@ -14,9 +14,10 @@ from .direct_reply_node import * from .function_lib_node import * from .function_node import * +from .reranker_node import * node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode, - BaseFunctionNodeNode, BaseFunctionLibNodeNode] + BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode] def get_node(node_type): diff --git a/apps/application/flow/step_node/reranker_node/__init__.py b/apps/application/flow/step_node/reranker_node/__init__.py new file mode 100644 index 0000000000..881d0f8a39 --- /dev/null +++ b/apps/application/flow/step_node/reranker_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py + @date:2024/9/4 11:37 + @desc: +""" +from .impl import * diff --git a/apps/application/flow/step_node/reranker_node/i_reranker_node.py b/apps/application/flow/step_node/reranker_node/i_reranker_node.py new file mode 100644 index 0000000000..fec3ec0217 --- /dev/null +++ b/apps/application/flow/step_node/reranker_node/i_reranker_node.py @@ -0,0 +1,59 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: i_reranker_node.py + @date:2024/9/4 10:40 + @desc: +""" +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.util.field_message import ErrMessage + + +class RerankerSettingSerializer(serializers.Serializer): + # 需要查询的条数 + top_n = serializers.IntegerField(required=True, + error_messages=ErrMessage.integer("引用分段数")) + # 相似度 0-1之间 + similarity = serializers.FloatField(required=True, max_value=2, min_value=0, + error_messages=ErrMessage.float("引用分段数")) + max_paragraph_char_number = serializers.IntegerField(required=True, + error_messages=ErrMessage.float("最大引用分段字数")) + + +class RerankerStepNodeSerializer(serializers.Serializer): + reranker_setting = RerankerSettingSerializer(required=True) + + question_reference_address = serializers.ListField(required=True) + reranker_model_id = serializers.UUIDField(required=True) + reranker_reference_list = serializers.ListField(required=True, child=serializers.ListField(required=True)) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + + +class IRerankerNode(INode): + type = 'reranker-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return RerankerStepNodeSerializer + + def _run(self): + question = self.workflow_manage.get_reference_field( + self.node_params_serializer.data.get('question_reference_address')[0], + self.node_params_serializer.data.get('question_reference_address')[1:]) + reranker_list = [self.workflow_manage.get_reference_field( + reference[0], + reference[1:]) for reference in + self.node_params_serializer.data.get('reranker_reference_list')] + return self.execute(**self.node_params_serializer.data, question=str(question), + + reranker_list=reranker_list) + + def execute(self, question, reranker_setting, reranker_list, reranker_model_id, + **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/reranker_node/impl/__init__.py b/apps/application/flow/step_node/reranker_node/impl/__init__.py new file mode 100644 index 0000000000..ef5ca80585 --- /dev/null +++ b/apps/application/flow/step_node/reranker_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py + @date:2024/9/4 11:39 + @desc: +""" +from .base_reranker_node import * diff --git a/apps/application/flow/step_node/reranker_node/impl/base_reranker_node.py b/apps/application/flow/step_node/reranker_node/impl/base_reranker_node.py new file mode 100644 index 0000000000..129fe3ff95 --- /dev/null +++ b/apps/application/flow/step_node/reranker_node/impl/base_reranker_node.py @@ -0,0 +1,73 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: base_reranker_node.py + @date:2024/9/4 11:41 + @desc: +""" +from typing import List + +from langchain_core.documents import Document + +from application.flow.i_step_node import NodeResult +from application.flow.step_node.reranker_node.i_reranker_node import IRerankerNode +from setting.models_provider.tools import get_model_instance_by_model_user_id + + +def merge_reranker_list(reranker_list, result=None): + if result is None: + result = [] + for document in reranker_list: + if isinstance(document, list): + merge_reranker_list(document, result) + elif isinstance(document, dict): + content = document.get('title', '') + document.get('content', '') + result.append(str(document) if len(content) == 0 else content) + else: + result.append(str(document)) + return result + + +def filter_result(document_list: List[Document], max_paragraph_char_number, top_n, similarity): + use_len = 0 + result = [] + for index in range(len(document_list)): + document = document_list[index] + if use_len >= max_paragraph_char_number or index >= top_n or document.metadata.get( + 'relevance_score') < similarity: + break + content = document.page_content[0:max_paragraph_char_number - use_len] + use_len = use_len + len(content) + result.append({'page_content': content, 'metadata': document.metadata}) + return result + + +class BaseRerankerNode(IRerankerNode): + def execute(self, question, reranker_setting, reranker_list, reranker_model_id, + **kwargs) -> NodeResult: + documents = merge_reranker_list(reranker_list) + reranker_model = get_model_instance_by_model_user_id(reranker_model_id, + self.flow_params_serializer.data.get('user_id')) + result = reranker_model.compress_documents( + [Document(page_content=document) for document in documents if document is not None and len(document) > 0], + question) + top_n = reranker_setting.get('top_n', 3) + similarity = reranker_setting.get('similarity', 0.6) + max_paragraph_char_number = reranker_setting.get('max_paragraph_char_number', 5000) + r = filter_result(result, max_paragraph_char_number, top_n, similarity) + return NodeResult({'result_list': r, 'result': ''.join([item.get('page_content') for item in r])}, {}) + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + "question": self.node_params_serializer.data.get('question'), + 'run_time': self.context.get('run_time'), + 'type': self.node.type, + 'reranker_setting': self.node_params_serializer.data.get('reranker_setting'), + 'result_list': self.context.get('result_list'), + 'result': self.context.get('result'), + 'status': self.status, + 'err_message': self.err_message + } diff --git a/apps/setting/models_provider/base_model_provider.py b/apps/setting/models_provider/base_model_provider.py index 4de2f00674..9171e036cf 100644 --- a/apps/setting/models_provider/base_model_provider.py +++ b/apps/setting/models_provider/base_model_provider.py @@ -141,6 +141,7 @@ class ModelTypeConst(Enum): EMBEDDING = {'code': 'EMBEDDING', 'message': '向量模型'} STT = {'code': 'STT', 'message': '语音识别'} TTS = {'code': 'TTS', 'message': '语音合成'} + RERANKER = {'code': 'RERANKER', 'message': '重排模型'} class ModelInfo: diff --git a/apps/setting/models_provider/impl/local_model_provider/credential/reranker.py b/apps/setting/models_provider/impl/local_model_provider/credential/reranker.py new file mode 100644 index 0000000000..0048fcedb6 --- /dev/null +++ b/apps/setting/models_provider/impl/local_model_provider/credential/reranker.py @@ -0,0 +1,47 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: reranker.py + @date:2024/9/3 14:33 + @desc: +""" +from typing import Dict + +from langchain_core.documents import Document + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +from setting.models_provider.impl.local_model_provider.model.reranker import LocalBaseReranker + + +class LocalRerankerCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + if not model_type == 'RERANKER': + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + for key in ['cache_dir']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model: LocalBaseReranker = provider.get_model(model_type, model_name, model_credential) + model.compress_documents([Document(page_content='你好')], '你好') + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return model + + cache_dir = forms.TextInputField('模型目录', required=True) diff --git a/apps/setting/models_provider/impl/local_model_provider/local_model_provider.py b/apps/setting/models_provider/impl/local_model_provider/local_model_provider.py index 65cc57322b..2c92bbbfb2 100644 --- a/apps/setting/models_provider/impl/local_model_provider/local_model_provider.py +++ b/apps/setting/models_provider/impl/local_model_provider/local_model_provider.py @@ -16,14 +16,20 @@ from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \ ModelInfoManage from setting.models_provider.impl.local_model_provider.credential.embedding import LocalEmbeddingCredential +from setting.models_provider.impl.local_model_provider.credential.reranker import LocalRerankerCredential from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding +from setting.models_provider.impl.local_model_provider.model.reranker import LocalReranker from smartdoc.conf import PROJECT_DIR embedding_text2vec_base_chinese = ModelInfo('shibing624/text2vec-base-chinese', '', ModelTypeConst.EMBEDDING, LocalEmbeddingCredential(), LocalEmbedding) +bge_reranker_v2_m3 = ModelInfo('BAAI/bge-reranker-v2-m3', '', ModelTypeConst.RERANKER, + LocalRerankerCredential(), LocalReranker) model_info_manage = (ModelInfoManage.builder().append_model_info(embedding_text2vec_base_chinese) .append_default_model_info(embedding_text2vec_base_chinese) + .append_model_info(bge_reranker_v2_m3) + .append_default_model_info(bge_reranker_v2_m3) .build()) diff --git a/apps/setting/models_provider/impl/local_model_provider/model/reranker.py b/apps/setting/models_provider/impl/local_model_provider/model/reranker.py new file mode 100644 index 0000000000..a2f9323290 --- /dev/null +++ b/apps/setting/models_provider/impl/local_model_provider/model/reranker.py @@ -0,0 +1,97 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: reranker.py.py + @date:2024/9/2 16:42 + @desc: +""" +from typing import Sequence, Optional, Dict, Any + +import requests +import torch +from langchain_core.callbacks import Callbacks +from langchain_core.documents import BaseDocumentCompressor, Document +from transformers import AutoModelForSequenceClassification, AutoTokenizer + +from setting.models_provider.base_model_provider import MaxKBBaseModel +from smartdoc.const import CONFIG + + +class LocalReranker(MaxKBBaseModel): + def __init__(self, model_name, top_n=3, cache_dir=None): + super().__init__() + self.model_name = model_name + self.cache_dir = cache_dir + self.top_n = top_n + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + if model_kwargs.get('use_local', True): + return LocalBaseReranker(model_name=model_name, cache_dir=model_credential.get('cache_dir'), + model_kwargs={'device': model_credential.get('device', 'cpu')} + + ) + return WebLocalBaseReranker(model_name=model_name, cache_dir=model_credential.get('cache_dir'), + model_kwargs={'device': model_credential.get('device')}, + **model_kwargs) + + +class WebLocalBaseReranker(MaxKBBaseModel, BaseDocumentCompressor): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + pass + + model_id: str = None + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.model_id = kwargs.get('model_id', None) + + def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \ + Sequence[Document]: + bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' + res = requests.post( + f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}/api/model/{self.model_id}/compress_documents', + json={'documents': [{'page_content': document.page_content, 'metadata': document.metadata} for document in + documents], 'query': query}, headers={'Content-Type': 'application/json'}) + result = res.json() + if result.get('code', 500) == 200: + return [Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document + in result.get('data')] + raise Exception(result.get('msg')) + + +class LocalBaseReranker(MaxKBBaseModel, BaseDocumentCompressor): + client: Any = None + tokenizer: Any = None + model: Optional[str] = None + cache_dir: Optional[str] = None + model_kwargs = {} + + def __init__(self, model_name, cache_dir=None, **model_kwargs): + super().__init__() + self.model = model_name + self.cache_dir = cache_dir + self.model_kwargs = model_kwargs + self.client = AutoModelForSequenceClassification.from_pretrained(self.model, cache_dir=self.cache_dir) + self.tokenizer = AutoTokenizer.from_pretrained(self.model, cache_dir=self.cache_dir) + self.client = self.client.to(self.model_kwargs.get('device', 'cpu')) + self.client.eval() + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return LocalBaseReranker(model_name, cache_dir=model_credential.get('cache_dir'), **model_kwargs) + + def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \ + Sequence[Document]: + with torch.no_grad(): + inputs = self.tokenizer([[query, document.page_content] for document in documents], padding=True, + truncation=True, return_tensors='pt', max_length=512) + scores = [torch.sigmoid(s).float().item() for s in + self.client(**inputs, return_dict=True).logits.view(-1, ).float()] + result = [Document(page_content=documents[index].page_content, metadata={'relevance_score': scores[index]}) + for index + in range(len(documents))] + result.sort(key=lambda row: row.metadata.get('relevance_score'), reverse=True) + return result diff --git a/apps/setting/serializers/model_apply_serializers.py b/apps/setting/serializers/model_apply_serializers.py index 2177b5fe6d..fd41869870 100644 --- a/apps/setting/serializers/model_apply_serializers.py +++ b/apps/setting/serializers/model_apply_serializers.py @@ -7,6 +7,7 @@ @desc: """ from django.db.models import QuerySet +from langchain_core.documents import Document from rest_framework import serializers from common.config.embedding_config import ModelManage @@ -33,6 +34,16 @@ class EmbedQuery(serializers.Serializer): text = serializers.CharField(required=True, error_messages=ErrMessage.char("向量文本")) +class CompressDocument(serializers.Serializer): + page_content = serializers.CharField(required=True, error_messages=ErrMessage.char("文本")) + metadata = serializers.DictField(required=False, error_messages=ErrMessage.dict("元数据")) + + +class CompressDocuments(serializers.Serializer): + documents = CompressDocument(required=True, many=True) + query = serializers.CharField(required=True, error_messages=ErrMessage.char("查询query")) + + class ModelApplySerializers(serializers.Serializer): model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id")) @@ -51,3 +62,12 @@ def embed_query(self, instance, with_valid=True): model = get_embedding_model(self.data.get('model_id')) return model.embed_query(instance.get('text')) + + def compress_documents(self, instance, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + CompressDocuments(data=instance).is_valid(raise_exception=True) + model = get_embedding_model(self.data.get('model_id')) + return [{'page_content': d.page_content, 'metadata': d.metadata} for d in model.compress_documents( + [Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document in + instance.get('documents')], instance.get('query'))] diff --git a/apps/setting/urls.py b/apps/setting/urls.py index 557edcceb4..42e80592cc 100644 --- a/apps/setting/urls.py +++ b/apps/setting/urls.py @@ -17,7 +17,8 @@ path('provider/model_form', views.Provide.ModelForm.as_view(), name="provider/model_form"), path('model', views.Model.as_view(), name='model'), - path('model//model_params_form', views.Model.ModelParamsForm.as_view(), name='model/model_params_form'), + path('model//model_params_form', views.Model.ModelParamsForm.as_view(), + name='model/model_params_form'), path('model/', views.Model.Operate.as_view(), name='model/operate'), path('model//pause_download', views.Model.PauseDownload.as_view(), name='model/operate'), path('model//meta', views.Model.ModelMeta.as_view(), name='model/operate/meta'), @@ -31,4 +32,6 @@ name='model/embed_documents'), path('model//embed_query', views.ModelApply.EmbedQuery.as_view(), name='model/embed_query'), + path('model//compress_documents', views.ModelApply.CompressDocuments.as_view(), + name='model/embed_query'), ] diff --git a/apps/setting/views/model_apply.py b/apps/setting/views/model_apply.py index 4a4e6139cb..6bd0b548ee 100644 --- a/apps/setting/views/model_apply.py +++ b/apps/setting/views/model_apply.py @@ -36,3 +36,13 @@ class EmbedQuery(APIView): def post(self, request: Request, model_id): return result.success( ModelApplySerializers(data={'model_id': model_id}).embed_query(request.data)) + + class CompressDocuments(APIView): + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="重排序文档", + operation_id="重排序文档", + responses=result.get_default_response(), + tags=["模型"]) + def post(self, request: Request, model_id): + return result.success( + ModelApplySerializers(data={'model_id': model_id}).compress_documents(request.data)) diff --git a/ui/src/api/application.ts b/ui/src/api/application.ts index 76c20f3774..5711b8ac7e 100644 --- a/ui/src/api/application.ts +++ b/ui/src/api/application.ts @@ -237,6 +237,19 @@ const getApplicationModel: ( return get(`${prefix}/${application_id}/model`, loading) } +/** + * 获取当前用户可使用的模型列表 + * @param application_id + * @param loading + * @query { query_text: string, top_number: number, similarity: number } + * @returns + */ +const getApplicationRerankerModel: ( + application_id: string, + loading?: Ref +) => Promise>> = (application_id, loading) => { + return get(`${prefix}/${application_id}/model`, { model_type: 'RERANKER' }, loading) +} /** * 发布应用 * @param 参数 @@ -310,5 +323,6 @@ export default { postWorkflowChatOpen, listFunctionLib, getFunctionLib, - getModelParamsForm + getModelParamsForm, + getApplicationRerankerModel } diff --git a/ui/src/assets/icon_reranker.svg b/ui/src/assets/icon_reranker.svg new file mode 100644 index 0000000000..fa62f6d0eb --- /dev/null +++ b/ui/src/assets/icon_reranker.svg @@ -0,0 +1,16 @@ + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/ui/src/enums/workflow.ts b/ui/src/enums/workflow.ts index 88c854966f..be2571a5c3 100644 --- a/ui/src/enums/workflow.ts +++ b/ui/src/enums/workflow.ts @@ -7,5 +7,6 @@ export enum WorkflowType { Condition = 'condition-node', Reply = 'reply-node', FunctionLib = 'function-lib-node', - FunctionLibCustom = 'function-node' + FunctionLibCustom = 'function-node', + RrerankerNode = 'reranker-node' } diff --git a/ui/src/workflow/common/data.ts b/ui/src/workflow/common/data.ts index ca40cf1ac0..3cc284c327 100644 --- a/ui/src/workflow/common/data.ts +++ b/ui/src/workflow/common/data.ts @@ -139,7 +139,30 @@ export const replyNode = { } } } -export const menuNodes = [aiChatNode, searchDatasetNode, questionNode, conditionNode, replyNode] +export const rerankerNode = { + type: WorkflowType.RrerankerNode, + text: '使用重排模型对多个知识库的检索结果进行二次召回', + label: '多路召回', + properties: { + stepName: '多路召回', + config: { + fields: [ + { + label: '结果', + value: 'result' + } + ] + } + } +} +export const menuNodes = [ + aiChatNode, + searchDatasetNode, + questionNode, + conditionNode, + replyNode, + rerankerNode +] /** * 自定义函数配置数据 @@ -203,7 +226,8 @@ export const nodeDict: any = { [WorkflowType.Start]: startNode, [WorkflowType.Reply]: replyNode, [WorkflowType.FunctionLib]: functionLibNode, - [WorkflowType.FunctionLibCustom]: functionNode + [WorkflowType.FunctionLibCustom]: functionNode, + [WorkflowType.RrerankerNode]: rerankerNode } export function isWorkFlow(type: string | undefined) { return type === 'WORK_FLOW' diff --git a/ui/src/workflow/icons/reranker-node-icon.vue b/ui/src/workflow/icons/reranker-node-icon.vue new file mode 100644 index 0000000000..70517d1ba8 --- /dev/null +++ b/ui/src/workflow/icons/reranker-node-icon.vue @@ -0,0 +1,6 @@ + + diff --git a/ui/src/workflow/nodes/reranker-node/ParamSettingDialog.vue b/ui/src/workflow/nodes/reranker-node/ParamSettingDialog.vue new file mode 100644 index 0000000000..35a6cd150c --- /dev/null +++ b/ui/src/workflow/nodes/reranker-node/ParamSettingDialog.vue @@ -0,0 +1,266 @@ + + + diff --git a/ui/src/workflow/nodes/reranker-node/index.ts b/ui/src/workflow/nodes/reranker-node/index.ts new file mode 100644 index 0000000000..9b3afc5c69 --- /dev/null +++ b/ui/src/workflow/nodes/reranker-node/index.ts @@ -0,0 +1,12 @@ +import RerankerNodeVue from './index.vue' +import { AppNode, AppNodeModel } from '@/workflow/common/app-node' +class RerankerNode extends AppNode { + constructor(props: any) { + super(props, RerankerNodeVue) + } +} +export default { + type: 'reranker-node', + model: AppNodeModel, + view: RerankerNode +} diff --git a/ui/src/workflow/nodes/reranker-node/index.vue b/ui/src/workflow/nodes/reranker-node/index.vue new file mode 100644 index 0000000000..08d4be5f58 --- /dev/null +++ b/ui/src/workflow/nodes/reranker-node/index.vue @@ -0,0 +1,303 @@ + + +