Skip to content

Commit

Permalink
feat: 支持用户输入变量
Browse files Browse the repository at this point in the history
--story=1016155 --user=刘瑞斌 【应用编排】-支持设置用户输入变量 https://www.tapd.cn/57709429/s/1576480
  • Loading branch information
liuruibin committed Sep 10, 2024
1 parent 689e74a commit bbe280f
Show file tree
Hide file tree
Showing 9 changed files with 438 additions and 45 deletions.
4 changes: 2 additions & 2 deletions apps/application/flow/workflow_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,10 @@ def is_valid_base_node(self):

class WorkflowManage:
def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler,
base_to_response: BaseToResponse = SystemToResponse()):
base_to_response: BaseToResponse = SystemToResponse(), form_data = {}):
self.params = params
self.flow = flow
self.context = {}
self.context = form_data
self.node_context = []
self.work_flow_post_handler = work_flow_post_handler
self.current_node = None
Expand Down
14 changes: 10 additions & 4 deletions apps/application/serializers/application_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,7 @@ def profile(self, with_valid=True):
'tts_model_id': application.tts_model_id,
'stt_model_enable': application.stt_model_enable,
'tts_model_enable': application.tts_model_enable,
'work_flow': application.work_flow,
'show_source': application_access_token.show_source})

@transaction.atomic
Expand Down Expand Up @@ -855,10 +856,15 @@ def get_work_flow_model(instance):
nodes = instance.get('work_flow')['nodes']
for node in nodes:
if node['id'] == 'base-node':
instance['stt_model_id'] = node['properties']['node_data']['stt_model_id']
instance['tts_model_id'] = node['properties']['node_data']['tts_model_id']
instance['stt_model_enable'] = node['properties']['node_data']['stt_model_enable']
instance['tts_model_enable'] = node['properties']['node_data']['tts_model_enable']
node_data = node['properties']['node_data']
if 'stt_model_id' in node_data:
instance['stt_model_id'] = node_data['stt_model_id']
if 'tts_model_id' in node_data:
instance['tts_model_id'] = node_data['tts_model_id']
if 'stt_model_enable' in node_data:
instance['stt_model_enable'] = node_data['stt_model_enable']
if 'tts_model_enable' in node_data:
instance['tts_model_enable'] = node_data['tts_model_enable']
break

def speech_to_text(self, file, with_valid=True):
Expand Down
4 changes: 3 additions & 1 deletion apps/application/serializers/chat_message_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ class ChatMessageSerializer(serializers.Serializer):
application_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("应用id"))
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型"))
form_data = serializers.DictField(required=False, error_messages=ErrMessage.char("全局变量"))

def is_valid_application_workflow(self, *, raise_exception=False):
self.is_valid_intraday_access_num()
Expand Down Expand Up @@ -284,14 +285,15 @@ def chat_work_flow(self, chat_info: ChatInfo, base_to_response):
stream = self.data.get('stream')
client_id = self.data.get('client_id')
client_type = self.data.get('client_type')
form_data = self.data.get('form_data')
user_id = chat_info.application.user_id
work_flow_manage = WorkflowManage(Flow.new_instance(chat_info.work_flow_version.work_flow),
{'history_chat_record': chat_info.chat_record_list, 'question': message,
'chat_id': chat_info.chat_id, 'chat_record_id': str(uuid.uuid1()),
'stream': stream,
're_chat': re_chat,
'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type),
base_to_response)
base_to_response, form_data)
r = work_flow_manage.run()
return r

Expand Down
1 change: 1 addition & 0 deletions apps/application/views/chat_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def post(self, request: Request, chat_id: str):
'application_id': (request.auth.keywords.get(
'application_id') if request.auth.client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value else None),
'client_id': request.auth.client_id,
'form_data': (request.data.get('form_data') if 'form_data' in request.data else []),
'client_type': request.auth.client_type}).chat()

@action(methods=['GET'], detail=False)
Expand Down
Loading

0 comments on commit bbe280f

Please sign in to comment.