Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: 支持重排模型 #1121

Merged
merged 1 commit into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion apps/application/flow/step_node/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
from .direct_reply_node import *
from .function_lib_node import *
from .function_node import *
from .reranker_node import *

node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode,
BaseFunctionNodeNode, BaseFunctionLibNodeNode]
BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode]


def get_node(node_type):
Expand Down
9 changes: 9 additions & 0 deletions apps/application/flow/step_node/reranker_node/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: MaxKB
@Author:虎
@file: __init__.py
@date:2024/9/4 11:37
@desc:
"""
from .impl import *
59 changes: 59 additions & 0 deletions apps/application/flow/step_node/reranker_node/i_reranker_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# coding=utf-8
"""
@project: MaxKB
@Author:虎
@file: i_reranker_node.py
@date:2024/9/4 10:40
@desc:
"""
from typing import Type

from rest_framework import serializers

from application.flow.i_step_node import INode, NodeResult
from common.util.field_message import ErrMessage


class RerankerSettingSerializer(serializers.Serializer):
# 需要查询的条数
top_n = serializers.IntegerField(required=True,
error_messages=ErrMessage.integer("引用分段数"))
# 相似度 0-1之间
similarity = serializers.FloatField(required=True, max_value=2, min_value=0,
error_messages=ErrMessage.float("引用分段数"))
max_paragraph_char_number = serializers.IntegerField(required=True,
error_messages=ErrMessage.float("最大引用分段字数"))


class RerankerStepNodeSerializer(serializers.Serializer):
reranker_setting = RerankerSettingSerializer(required=True)

question_reference_address = serializers.ListField(required=True)
reranker_model_id = serializers.UUIDField(required=True)
reranker_reference_list = serializers.ListField(required=True, child=serializers.ListField(required=True))

def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)


class IRerankerNode(INode):
type = 'reranker-node'

def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return RerankerStepNodeSerializer

def _run(self):
question = self.workflow_manage.get_reference_field(
self.node_params_serializer.data.get('question_reference_address')[0],
self.node_params_serializer.data.get('question_reference_address')[1:])
reranker_list = [self.workflow_manage.get_reference_field(
reference[0],
reference[1:]) for reference in
self.node_params_serializer.data.get('reranker_reference_list')]
return self.execute(**self.node_params_serializer.data, question=str(question),

reranker_list=reranker_list)

def execute(self, question, reranker_setting, reranker_list, reranker_model_id,
**kwargs) -> NodeResult:
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: MaxKB
@Author:虎
@file: __init__.py
@date:2024/9/4 11:39
@desc:
"""
from .base_reranker_node import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# coding=utf-8
"""
@project: MaxKB
@Author:虎
@file: base_reranker_node.py
@date:2024/9/4 11:41
@desc:
"""
from typing import List

from langchain_core.documents import Document

from application.flow.i_step_node import NodeResult
from application.flow.step_node.reranker_node.i_reranker_node import IRerankerNode
from setting.models_provider.tools import get_model_instance_by_model_user_id


def merge_reranker_list(reranker_list, result=None):
if result is None:
result = []
for document in reranker_list:
if isinstance(document, list):
merge_reranker_list(document, result)
elif isinstance(document, dict):
content = document.get('title', '') + document.get('content', '')
result.append(str(document) if len(content) == 0 else content)
else:
result.append(str(document))
return result


def filter_result(document_list: List[Document], max_paragraph_char_number, top_n, similarity):
use_len = 0
result = []
for index in range(len(document_list)):
document = document_list[index]
if use_len >= max_paragraph_char_number or index >= top_n or document.metadata.get(
'relevance_score') < similarity:
break
content = document.page_content[0:max_paragraph_char_number - use_len]
use_len = use_len + len(content)
result.append({'page_content': content, 'metadata': document.metadata})
return result


class BaseRerankerNode(IRerankerNode):
def execute(self, question, reranker_setting, reranker_list, reranker_model_id,
**kwargs) -> NodeResult:
documents = merge_reranker_list(reranker_list)
reranker_model = get_model_instance_by_model_user_id(reranker_model_id,
self.flow_params_serializer.data.get('user_id'))
result = reranker_model.compress_documents(
[Document(page_content=document) for document in documents if document is not None and len(document) > 0],
question)
top_n = reranker_setting.get('top_n', 3)
similarity = reranker_setting.get('similarity', 0.6)
max_paragraph_char_number = reranker_setting.get('max_paragraph_char_number', 5000)
r = filter_result(result, max_paragraph_char_number, top_n, similarity)
return NodeResult({'result_list': r, 'result': ''.join([item.get('page_content') for item in r])}, {})

def get_details(self, index: int, **kwargs):
return {
'name': self.node.properties.get('stepName'),
"index": index,
"question": self.node_params_serializer.data.get('question'),
'run_time': self.context.get('run_time'),
'type': self.node.type,
'reranker_setting': self.node_params_serializer.data.get('reranker_setting'),
'result_list': self.context.get('result_list'),
'result': self.context.get('result'),
'status': self.status,
'err_message': self.err_message
}
1 change: 1 addition & 0 deletions apps/setting/models_provider/base_model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ class ModelTypeConst(Enum):
EMBEDDING = {'code': 'EMBEDDING', 'message': '向量模型'}
STT = {'code': 'STT', 'message': '语音识别'}
TTS = {'code': 'TTS', 'message': '语音合成'}
RERANKER = {'code': 'RERANKER', 'message': '重排模型'}


class ModelInfo:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# coding=utf-8
"""
@project: MaxKB
@Author:虎
@file: reranker.py
@date:2024/9/3 14:33
@desc:
"""
from typing import Dict

from langchain_core.documents import Document

from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
from setting.models_provider.impl.local_model_provider.model.reranker import LocalBaseReranker


class LocalRerankerCredential(BaseForm, BaseModelCredential):

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
if not model_type == 'RERANKER':
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['cache_dir']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model: LocalBaseReranker = provider.get_model(model_type, model_name, model_credential)
model.compress_documents([Document(page_content='你好')], '你好')
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True

def encryption_dict(self, model: Dict[str, object]):
return model

cache_dir = forms.TextInputField('模型目录', required=True)
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,20 @@
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
ModelInfoManage
from setting.models_provider.impl.local_model_provider.credential.embedding import LocalEmbeddingCredential
from setting.models_provider.impl.local_model_provider.credential.reranker import LocalRerankerCredential
from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
from setting.models_provider.impl.local_model_provider.model.reranker import LocalReranker
from smartdoc.conf import PROJECT_DIR

embedding_text2vec_base_chinese = ModelInfo('shibing624/text2vec-base-chinese', '', ModelTypeConst.EMBEDDING,
LocalEmbeddingCredential(), LocalEmbedding)
bge_reranker_v2_m3 = ModelInfo('BAAI/bge-reranker-v2-m3', '', ModelTypeConst.RERANKER,
LocalRerankerCredential(), LocalReranker)

model_info_manage = (ModelInfoManage.builder().append_model_info(embedding_text2vec_base_chinese)
.append_default_model_info(embedding_text2vec_base_chinese)
.append_model_info(bge_reranker_v2_m3)
.append_default_model_info(bge_reranker_v2_m3)
.build())


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# coding=utf-8
"""
@project: MaxKB
@Author:虎
@file: reranker.py.py
@date:2024/9/2 16:42
@desc:
"""
from typing import Sequence, Optional, Dict, Any

import requests
import torch
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from setting.models_provider.base_model_provider import MaxKBBaseModel
from smartdoc.const import CONFIG


class LocalReranker(MaxKBBaseModel):
def __init__(self, model_name, top_n=3, cache_dir=None):
super().__init__()
self.model_name = model_name
self.cache_dir = cache_dir
self.top_n = top_n

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
if model_kwargs.get('use_local', True):
return LocalBaseReranker(model_name=model_name, cache_dir=model_credential.get('cache_dir'),
model_kwargs={'device': model_credential.get('device', 'cpu')}

)
return WebLocalBaseReranker(model_name=model_name, cache_dir=model_credential.get('cache_dir'),
model_kwargs={'device': model_credential.get('device')},
**model_kwargs)


class WebLocalBaseReranker(MaxKBBaseModel, BaseDocumentCompressor):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
pass

model_id: str = None

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.model_id = kwargs.get('model_id', None)

def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
Sequence[Document]:
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
res = requests.post(
f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}/api/model/{self.model_id}/compress_documents',
json={'documents': [{'page_content': document.page_content, 'metadata': document.metadata} for document in
documents], 'query': query}, headers={'Content-Type': 'application/json'})
result = res.json()
if result.get('code', 500) == 200:
return [Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document
in result.get('data')]
raise Exception(result.get('msg'))


class LocalBaseReranker(MaxKBBaseModel, BaseDocumentCompressor):
client: Any = None
tokenizer: Any = None
model: Optional[str] = None
cache_dir: Optional[str] = None
model_kwargs = {}

def __init__(self, model_name, cache_dir=None, **model_kwargs):
super().__init__()
self.model = model_name
self.cache_dir = cache_dir
self.model_kwargs = model_kwargs
self.client = AutoModelForSequenceClassification.from_pretrained(self.model, cache_dir=self.cache_dir)
self.tokenizer = AutoTokenizer.from_pretrained(self.model, cache_dir=self.cache_dir)
self.client = self.client.to(self.model_kwargs.get('device', 'cpu'))
self.client.eval()

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return LocalBaseReranker(model_name, cache_dir=model_credential.get('cache_dir'), **model_kwargs)

def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
Sequence[Document]:
with torch.no_grad():
inputs = self.tokenizer([[query, document.page_content] for document in documents], padding=True,
truncation=True, return_tensors='pt', max_length=512)
scores = [torch.sigmoid(s).float().item() for s in
self.client(**inputs, return_dict=True).logits.view(-1, ).float()]
result = [Document(page_content=documents[index].page_content, metadata={'relevance_score': scores[index]})
for index
in range(len(documents))]
result.sort(key=lambda row: row.metadata.get('relevance_score'), reverse=True)
return result
20 changes: 20 additions & 0 deletions apps/setting/serializers/model_apply_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
@desc:
"""
from django.db.models import QuerySet
from langchain_core.documents import Document
from rest_framework import serializers

from common.config.embedding_config import ModelManage
Expand All @@ -33,6 +34,16 @@ class EmbedQuery(serializers.Serializer):
text = serializers.CharField(required=True, error_messages=ErrMessage.char("向量文本"))


class CompressDocument(serializers.Serializer):
page_content = serializers.CharField(required=True, error_messages=ErrMessage.char("文本"))
metadata = serializers.DictField(required=False, error_messages=ErrMessage.dict("元数据"))


class CompressDocuments(serializers.Serializer):
documents = CompressDocument(required=True, many=True)
query = serializers.CharField(required=True, error_messages=ErrMessage.char("查询query"))


class ModelApplySerializers(serializers.Serializer):
model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id"))

Expand All @@ -51,3 +62,12 @@ def embed_query(self, instance, with_valid=True):

model = get_embedding_model(self.data.get('model_id'))
return model.embed_query(instance.get('text'))

def compress_documents(self, instance, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
CompressDocuments(data=instance).is_valid(raise_exception=True)
model = get_embedding_model(self.data.get('model_id'))
return [{'page_content': d.page_content, 'metadata': d.metadata} for d in model.compress_documents(
[Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document in
instance.get('documents')], instance.get('query'))]
5 changes: 4 additions & 1 deletion apps/setting/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
path('provider/model_form', views.Provide.ModelForm.as_view(),
name="provider/model_form"),
path('model', views.Model.as_view(), name='model'),
path('model/<str:model_id>/model_params_form', views.Model.ModelParamsForm.as_view(), name='model/model_params_form'),
path('model/<str:model_id>/model_params_form', views.Model.ModelParamsForm.as_view(),
name='model/model_params_form'),
path('model/<str:model_id>', views.Model.Operate.as_view(), name='model/operate'),
path('model/<str:model_id>/pause_download', views.Model.PauseDownload.as_view(), name='model/operate'),
path('model/<str:model_id>/meta', views.Model.ModelMeta.as_view(), name='model/operate/meta'),
Expand All @@ -31,4 +32,6 @@
name='model/embed_documents'),
path('model/<str:model_id>/embed_query', views.ModelApply.EmbedQuery.as_view(),
name='model/embed_query'),
path('model/<str:model_id>/compress_documents', views.ModelApply.CompressDocuments.as_view(),
name='model/embed_query'),
]
Loading
Loading