diff --git a/apps/setting/models_provider/impl/xinference_model_provider/credential/reranker.py b/apps/setting/models_provider/impl/xinference_model_provider/credential/reranker.py new file mode 100644 index 0000000000..87f27971e6 --- /dev/null +++ b/apps/setting/models_provider/impl/xinference_model_provider/credential/reranker.py @@ -0,0 +1,47 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: reranker.py + @date:2024/9/10 9:46 + @desc: +""" +from typing import Dict + +from langchain_core.documents import Document + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class XInferenceRerankerModelCredential(BaseForm, BaseModelCredential): + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=True): + if not model_type == 'RERANKER': + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + for key in ['server_url']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.compress_documents([Document(page_content='你好')], '你好') + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model_info: Dict[str, object]): + return model_info + + server_url = forms.TextInputField('API 域名', required=True) + + api_key = forms.PasswordInputField('API Key', required=False) diff --git a/apps/setting/models_provider/impl/xinference_model_provider/model/reranker.py b/apps/setting/models_provider/impl/xinference_model_provider/model/reranker.py new file mode 100644 index 0000000000..f32e1ee947 --- /dev/null +++ b/apps/setting/models_provider/impl/xinference_model_provider/model/reranker.py @@ -0,0 +1,73 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: reranker.py + @date:2024/9/10 9:45 + @desc: +""" +from typing import Sequence, Optional, Any, Dict + +from langchain_core.callbacks import Callbacks +from langchain_core.documents import BaseDocumentCompressor, Document +from xinference_client.client.restful.restful_client import RESTfulRerankModelHandle + +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class XInferenceReranker(MaxKBBaseModel, BaseDocumentCompressor): + client: Any + server_url: Optional[str] + """URL of the xinference server""" + model_uid: Optional[str] + """UID of the launched model""" + api_key: Optional[str] + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return XInferenceReranker(server_url=model_credential.get('server_url'), model_uid=model_name, + api_key=model_credential.get('api_key')) + + top_n: Optional[int] = 3 + + def __init__( + self, server_url: Optional[str] = None, model_uid: Optional[str] = None, top_n=3, + api_key: Optional[str] = None + ): + try: + from xinference.client import RESTfulClient + except ImportError: + try: + from xinference_client import RESTfulClient + except ImportError as e: + raise ImportError( + "Could not import RESTfulClient from xinference. Please install it" + " with `pip install xinference` or `pip install xinference_client`." + ) from e + + super().__init__() + + if server_url is None: + raise ValueError("Please provide server URL") + + if model_uid is None: + raise ValueError("Please provide the model UID") + + self.server_url = server_url + + self.model_uid = model_uid + + self.api_key = api_key + + self.client = RESTfulClient(server_url, api_key) + + self.top_n = top_n + + def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \ + Sequence[Document]: + if documents is None or len(documents) == 0: + return [] + model: RESTfulRerankModelHandle = self.client.get_model(self.model_uid) + res = model.rerank([document.page_content for document in documents], query, self.top_n, return_documents=True) + return [Document(page_content=d.get('document', {}).get('text'), + metadata={'relevance_score': d.get('relevance_score')}) for d in res.get('results', [])] diff --git a/apps/setting/models_provider/impl/xinference_model_provider/xinference_model_provider.py b/apps/setting/models_provider/impl/xinference_model_provider/xinference_model_provider.py index d751fa7aaf..d8e8166034 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/xinference_model_provider.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/xinference_model_provider.py @@ -10,8 +10,10 @@ from setting.models_provider.impl.xinference_model_provider.credential.embedding import \ XinferenceEmbeddingModelCredential from setting.models_provider.impl.xinference_model_provider.credential.llm import XinferenceLLMModelCredential +from setting.models_provider.impl.xinference_model_provider.credential.reranker import XInferenceRerankerModelCredential from setting.models_provider.impl.xinference_model_provider.model.embedding import XinferenceEmbedding from setting.models_provider.impl.xinference_model_provider.model.llm import XinferenceChatModel +from setting.models_provider.impl.xinference_model_provider.model.reranker import XInferenceReranker from smartdoc.conf import PROJECT_DIR xinference_llm_model_credential = XinferenceLLMModelCredential() @@ -480,7 +482,9 @@ ModelInfo('text2vec-large-chinese', 'Text2Vec 的中文大型版本嵌入模型。', ModelTypeConst.EMBEDDING, xinference_embedding_model_credential, XinferenceEmbedding), ] - +rerank_list = [ModelInfo('bce-reranker-base_v1', + '发布新的重新排名器,建立在强大的 M3 和LLM (GEMMA 和 MiniCPM,实际上没那么大)骨干上,支持多语言处理和更大的输入,大幅提高 BEIR、C-MTEB/Retrieval 的排名性能、MIRACL、LlamaIndex 评估', + ModelTypeConst.RERANKER, XInferenceRerankerModelCredential(), XInferenceReranker)] model_info_manage = (ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( ModelInfo( 'phi3', @@ -492,6 +496,7 @@ '', '', ModelTypeConst.EMBEDDING, xinference_embedding_model_credential, XinferenceEmbedding)) + .append_model_info_list(rerank_list).append_default_model_info(rerank_list[0]) .build())