Skip to content

Commit

Permalink
feat: 支持xinference Rerank模型
Browse files Browse the repository at this point in the history
  • Loading branch information
shaohuzhang1 committed Sep 10, 2024
1 parent 277ed17 commit 791c6b4
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# coding=utf-8
"""
@project: MaxKB
@Author:虎
@file: reranker.py
@date:2024/9/10 9:46
@desc:
"""
from typing import Dict

from langchain_core.documents import Document

from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode


class XInferenceRerankerModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=True):
if not model_type == 'RERANKER':
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['server_url']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.compress_documents([Document(page_content='你好')], '你好')
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True

def encryption_dict(self, model_info: Dict[str, object]):
return model_info

server_url = forms.TextInputField('API 域名', required=True)

api_key = forms.PasswordInputField('API Key', required=False)
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# coding=utf-8
"""
@project: MaxKB
@Author:虎
@file: reranker.py
@date:2024/9/10 9:45
@desc:
"""
from typing import Sequence, Optional, Any, Dict

from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from xinference_client.client.restful.restful_client import RESTfulRerankModelHandle

from setting.models_provider.base_model_provider import MaxKBBaseModel


class XInferenceReranker(MaxKBBaseModel, BaseDocumentCompressor):
client: Any
server_url: Optional[str]
"""URL of the xinference server"""
model_uid: Optional[str]
"""UID of the launched model"""
api_key: Optional[str]

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return XInferenceReranker(server_url=model_credential.get('server_url'), model_uid=model_name,
api_key=model_credential.get('api_key'))

top_n: Optional[int] = 3

def __init__(
self, server_url: Optional[str] = None, model_uid: Optional[str] = None, top_n=3,
api_key: Optional[str] = None
):
try:
from xinference.client import RESTfulClient
except ImportError:
try:
from xinference_client import RESTfulClient
except ImportError as e:
raise ImportError(
"Could not import RESTfulClient from xinference. Please install it"
" with `pip install xinference` or `pip install xinference_client`."
) from e

super().__init__()

if server_url is None:
raise ValueError("Please provide server URL")

if model_uid is None:
raise ValueError("Please provide the model UID")

self.server_url = server_url

self.model_uid = model_uid

self.api_key = api_key

self.client = RESTfulClient(server_url, api_key)

self.top_n = top_n

def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
Sequence[Document]:
if documents is None or len(documents) == 0:
return []
model: RESTfulRerankModelHandle = self.client.get_model(self.model_uid)
res = model.rerank([document.page_content for document in documents], query, self.top_n, return_documents=True)
return [Document(page_content=d.get('document', {}).get('text'),
metadata={'relevance_score': d.get('relevance_score')}) for d in res.get('results', [])]
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
from setting.models_provider.impl.xinference_model_provider.credential.embedding import \
XinferenceEmbeddingModelCredential
from setting.models_provider.impl.xinference_model_provider.credential.llm import XinferenceLLMModelCredential
from setting.models_provider.impl.xinference_model_provider.credential.reranker import XInferenceRerankerModelCredential
from setting.models_provider.impl.xinference_model_provider.model.embedding import XinferenceEmbedding
from setting.models_provider.impl.xinference_model_provider.model.llm import XinferenceChatModel
from setting.models_provider.impl.xinference_model_provider.model.reranker import XInferenceReranker
from smartdoc.conf import PROJECT_DIR

xinference_llm_model_credential = XinferenceLLMModelCredential()
Expand Down Expand Up @@ -480,7 +482,9 @@
ModelInfo('text2vec-large-chinese', 'Text2Vec 的中文大型版本嵌入模型。', ModelTypeConst.EMBEDDING,
xinference_embedding_model_credential, XinferenceEmbedding),
]

rerank_list = [ModelInfo('bce-reranker-base_v1',
'发布新的重新排名器,建立在强大的 M3 和LLM (GEMMA 和 MiniCPM,实际上没那么大)骨干上,支持多语言处理和更大的输入,大幅提高 BEIR、C-MTEB/Retrieval 的排名性能、MIRACL、LlamaIndex 评估',
ModelTypeConst.RERANKER, XInferenceRerankerModelCredential(), XInferenceReranker)]
model_info_manage = (ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
ModelInfo(
'phi3',
Expand All @@ -492,6 +496,7 @@
'',
'',
ModelTypeConst.EMBEDDING, xinference_embedding_model_credential, XinferenceEmbedding))
.append_model_info_list(rerank_list).append_default_model_info(rerank_list[0])
.build())


Expand Down

0 comments on commit 791c6b4

Please sign in to comment.