Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
liqiang-fit2cloud committed Sep 18, 2024
2 parents 5eebf74 + f07f5ae commit 47e24d8
Show file tree
Hide file tree
Showing 68 changed files with 873 additions and 453 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class InstanceSerializer(serializers.Serializer):
"最大携带知识库段落长度"))
# 模板
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词"))
system = serializers.CharField(required=False, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("系统提示词(角色)"))
# 补齐问题
padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.char("补齐问题"))
# 未查询到引用分段
Expand All @@ -59,6 +61,7 @@ def execute(self,
prompt: str,
padding_problem_text: str = None,
no_references_setting=None,
system=None,
**kwargs) -> List[BaseMessage]:
"""
Expand All @@ -71,6 +74,7 @@ def execute(self,
:param padding_problem_text 用户修改文本
:param kwargs: 其他参数
:param no_references_setting: 无引用分段设置
:param system 系统提示称
:return:
"""
pass
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import List, Dict

from langchain.schema import BaseMessage, HumanMessage
from langchain_core.messages import SystemMessage

from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
from application.chat_pipeline.step.generate_human_message_step.i_generate_human_message_step import \
Expand All @@ -27,6 +28,7 @@ def execute(self, problem_text: str,
prompt: str,
padding_problem_text: str = None,
no_references_setting=None,
system=None,
**kwargs) -> List[BaseMessage]:
prompt = prompt if (paragraph_list is not None and len(paragraph_list) > 0) else no_references_setting.get(
'value')
Expand All @@ -35,6 +37,11 @@ def execute(self, problem_text: str,
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
for index in
range(start_index if start_index > 0 else 0, len(history_chat_record))]
if system is not None and len(system) > 0:
return [SystemMessage(system), *flat_map(history_message),
self.to_human_message(prompt, exec_problem_text, max_paragraph_char_number, paragraph_list,
no_references_setting)]

return [*flat_map(history_message),
self.to_human_message(prompt, exec_problem_text, max_paragraph_char_number, paragraph_list,
no_references_setting)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class InstanceSerializer(serializers.Serializer):
error_messages=ErrMessage.list("历史对答"))
# 大语言模型
chat_model = ModelField(required=False, allow_null=True, error_messages=ErrMessage.base("大语言模型"))
problem_optimization_prompt = serializers.CharField(required=False, max_length=102400,
error_messages=ErrMessage.char("问题补全提示词"))

def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
return self.InstanceSerializer
Expand All @@ -47,5 +49,6 @@ def _run(self, manage: PipelineManage):

@abstractmethod
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None,
problem_optimization_prompt=None,
**kwargs):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

class BaseResetProblemStep(IResetProblemStep):
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None,
problem_optimization_prompt=None,
**kwargs) -> str:
if chat_model is None:
self.context['message_tokens'] = 0
Expand All @@ -30,15 +31,19 @@ def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = Non
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
for index in
range(start_index if start_index > 0 else 0, len(history_chat_record))]
reset_prompt = problem_optimization_prompt if problem_optimization_prompt else prompt
message_list = [*flat_map(history_message),
HumanMessage(content=prompt.format(**{'question': problem_text}))]
HumanMessage(content=reset_prompt.replace('{question}', problem_text))]
response = chat_model.invoke(message_list)
padding_problem = problem_text
if response.content.__contains__("<data>") and response.content.__contains__('</data>'):
padding_problem_data = response.content[
response.content.index('<data>') + 6:response.content.index('</data>')]
if padding_problem_data is not None and len(padding_problem_data.strip()) > 0:
padding_problem = padding_problem_data
elif len(response.content) > 0:
padding_problem = response.content

try:
request_token = chat_model.get_num_tokens_from_messages(message_list)
response_token = chat_model.get_num_tokens(padding_problem)
Expand Down
3 changes: 0 additions & 3 deletions apps/application/flow/step_node/start_node/i_start_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
class IStarNode(INode):
type = 'start-node'

def get_node_params_serializer_class(self) -> Type[serializers.Serializer] | None:
return None

def _run(self):
return self.execute(**self.flow_params_serializer.data)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,16 @@

class BaseStartStepNode(IStarNode):
def execute(self, question, **kwargs) -> NodeResult:
history_chat_record = self.flow_params_serializer.data.get('history_chat_record', [])
history_context = [{'question': chat_record.problem_text, 'answer': chat_record.answer_text} for chat_record in
history_chat_record]
chat_id = self.flow_params_serializer.data.get('chat_id')
"""
开始节点 初始化全局变量
"""
return NodeResult({'question': question},
{'time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'start_time': time.time()})
{'time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'start_time': time.time(),
'history_context': history_context, 'chat_id': str(chat_id)})

def get_details(self, index: int, **kwargs):
global_fields = []
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Generated by Django 4.2.15 on 2024-09-13 18:57

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('application', '0013_application_tts_type'),
]

operations = [
migrations.AddField(
model_name='application',
name='problem_optimization_prompt',
field=models.CharField(blank=True, default='()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在<data></data>标签中', max_length=102400, null=True, verbose_name='问题优化提示词'),
),
]
11 changes: 8 additions & 3 deletions apps/application/models/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_dataset_setting_dict():


def get_model_setting_dict():
return {'prompt': Application.get_default_model_prompt()}
return {'prompt': Application.get_default_model_prompt(), 'no_references_prompt': '{question}'}


class Application(AppModelMixin):
Expand All @@ -54,8 +54,13 @@ class Application(AppModelMixin):
work_flow = models.JSONField(verbose_name="工作流数据", default=dict)
type = models.CharField(verbose_name="应用类型", choices=ApplicationTypeChoices.choices,
default=ApplicationTypeChoices.SIMPLE, max_length=256)
tts_model = models.ForeignKey(Model, related_name='tts_model_id', on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True)
stt_model = models.ForeignKey(Model, related_name='stt_model_id', on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True)
problem_optimization_prompt = models.CharField(verbose_name="问题优化提示词", max_length=102400, blank=True,
null=True,
default="()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在<data></data>标签中")
tts_model = models.ForeignKey(Model, related_name='tts_model_id', on_delete=models.SET_NULL, db_constraint=False,
blank=True, null=True)
stt_model = models.ForeignKey(Model, related_name='stt_model_id', on_delete=models.SET_NULL, db_constraint=False,
blank=True, null=True)
tts_model_enable = models.BooleanField(verbose_name="语音合成模型是否启用", default=False)
stt_model_enable = models.BooleanField(verbose_name="语音识别模型是否启用", default=False)
tts_type = models.CharField(verbose_name="语音播放类型", max_length=20, default="BROWSER")
Expand Down
21 changes: 15 additions & 6 deletions apps/application/serializers/application_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,12 @@ class DatasetSettingSerializer(serializers.Serializer):


class ModelSettingSerializer(serializers.Serializer):
prompt = serializers.CharField(required=True, max_length=2048, error_messages=ErrMessage.char("提示词"))
prompt = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=102400,
error_messages=ErrMessage.char("提示词"))
system = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=102400,
error_messages=ErrMessage.char("角色提示词"))
no_references_prompt = serializers.CharField(required=True, max_length=102400, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("无引用分段提示词"))


class ApplicationWorkflowSerializer(serializers.Serializer):
Expand Down Expand Up @@ -174,7 +179,7 @@ class ApplicationSerializer(serializers.Serializer):
error_messages=ErrMessage.char("应用描述"))
model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("模型"))
multiple_rounds_dialogue = serializers.BooleanField(required=True, error_messages=ErrMessage.char("多轮对话"))
dialogue_number = serializers.BooleanField(required=True, error_messages=ErrMessage.char("会话次数"))
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=4096,
error_messages=ErrMessage.char("开场白"))
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True),
Expand All @@ -185,6 +190,8 @@ class ApplicationSerializer(serializers.Serializer):
model_setting = ModelSettingSerializer(required=True)
# 问题补全
problem_optimization = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("问题补全"))
problem_optimization_prompt = serializers.CharField(required=False, max_length=102400,
error_messages=ErrMessage.char("问题补全提示词"))
# 应用类型
type = serializers.CharField(required=True, error_messages=ErrMessage.char("应用类型"),
validators=[
Expand Down Expand Up @@ -364,8 +371,8 @@ class Edit(serializers.Serializer):
error_messages=ErrMessage.char("应用描述"))
model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True,
error_messages=ErrMessage.char("模型"))
multiple_rounds_dialogue = serializers.BooleanField(required=False,
error_messages=ErrMessage.boolean("多轮会话"))
dialogue_number = serializers.IntegerField(required=False,
error_messages=ErrMessage.boolean("多轮会话"))
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=4096,
error_messages=ErrMessage.char("开场白"))
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True),
Expand Down Expand Up @@ -430,13 +437,14 @@ def insert_simple(self, application: Dict):
def to_application_model(user_id: str, application: Dict):
return Application(id=uuid.uuid1(), name=application.get('name'), desc=application.get('desc'),
prologue=application.get('prologue'),
dialogue_number=3 if application.get('multiple_rounds_dialogue') else 0,
dialogue_number=application.get('dialogue_number', 0),
user_id=user_id, model_id=application.get('model_id'),
dataset_setting=application.get('dataset_setting'),
model_setting=application.get('model_setting'),
problem_optimization=application.get('problem_optimization'),
type=ApplicationTypeChoices.SIMPLE,
model_params_setting=application.get('model_params_setting', {}),
problem_optimization_prompt=application.get('problem_optimization_prompt', None),
work_flow={}
)

Expand Down Expand Up @@ -601,7 +609,8 @@ def list_function_lib(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
application = QuerySet(Application).filter(id=self.data.get("application_id")).first()
return FunctionLibSerializer.Query(data={'user_id': application.user_id}).list(with_valid=True)
return FunctionLibSerializer.Query(data={'user_id': application.user_id, 'is_active': True}).list(
with_valid=True)

def get_function_lib(self, function_lib_id, with_valid=True):
if with_valid:
Expand Down
24 changes: 18 additions & 6 deletions apps/application/serializers/chat_message_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,17 @@ def __init__(self,
self.chat_record_list: List[ChatRecord] = []
self.work_flow_version = work_flow_version

@staticmethod
def get_no_references_setting(dataset_setting, model_setting):
no_references_setting = dataset_setting.get(
'no_references_setting', {
'status': 'ai_questioning',
'value': '{question}'})
if no_references_setting.get('status') == 'ai_questioning':
no_references_prompt = model_setting.get('no_references_prompt', '{question}')
no_references_setting['value'] = no_references_prompt if len(no_references_prompt) > 0 else "{question}"
return no_references_setting

def to_base_pipeline_manage_params(self):
dataset_setting = self.application.dataset_setting
model_setting = self.application.model_setting
Expand All @@ -80,20 +91,21 @@ def to_base_pipeline_manage_params(self):
'history_chat_record': self.chat_record_list,
'chat_id': self.chat_id,
'dialogue_number': self.application.dialogue_number,
'problem_optimization_prompt': self.application.problem_optimization_prompt if self.application.problem_optimization_prompt is not None and len(
self.application.problem_optimization_prompt) > 0 else '()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在<data></data>标签中',
'prompt': model_setting.get(
'prompt') if 'prompt' in model_setting else Application.get_default_model_prompt(),
'prompt') if 'prompt' in model_setting and len(model_setting.get(
'prompt')) > 0 else Application.get_default_model_prompt(),
'system': model_setting.get(
'system', None),
'model_id': model_id,
'problem_optimization': self.application.problem_optimization,
'stream': True,
'model_params_setting': model_params_setting if self.application.model_params_setting is None or len(
self.application.model_params_setting.keys()) == 0 else self.application.model_params_setting,
'search_mode': self.application.dataset_setting.get(
'search_mode') if 'search_mode' in self.application.dataset_setting else 'embedding',
'no_references_setting': self.application.dataset_setting.get(
'no_references_setting') if 'no_references_setting' in self.application.dataset_setting else {
'status': 'ai_questioning',
'value': '{question}',
},
'no_references_setting': self.get_no_references_setting(self.application.dataset_setting, model_setting),
'user_id': self.application.user_id
}

Expand Down
Loading

0 comments on commit 47e24d8

Please sign in to comment.