Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: 处理历史会话中图片的问题 #1636

Merged
merged 1 commit into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def execute(self, document, **kwargs):

self.context['document_list'] = document
content = ''
spliter = '\n-----------------------------------\n'
splitter = '\n-----------------------------------\n'
if document is None:
return NodeResult({'content': content}, {})

Expand All @@ -29,7 +29,7 @@ def execute(self, document, **kwargs):
# 回到文件头
buffer.seek(0)
file_content = split_handle.get_content(buffer)
content += spliter + '## ' + doc['name'] + '\n' + file_content
content += splitter + '## ' + doc['name'] + '\n' + file_content
break

return NodeResult({'content': content}, {})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class ImageUnderstandNodeSerializer(serializers.Serializer):
# 多轮对话数量
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))

dialogue_type = serializers.CharField(required=True, error_messages=ErrMessage.char("对话存储类型"))

is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))

image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片"))
Expand All @@ -32,7 +34,7 @@ def _run(self):
self.node_params_serializer.data.get('image_list')[1:])
return self.execute(image=res, **self.node_params_serializer.data, **self.flow_params_serializer.data)

def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id,
def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id,
chat_record_id,
image,
**kwargs) -> NodeResult:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,18 @@ def save_context(self, details, workflow_manage):
self.context['question'] = details.get('question')
self.answer_text = details.get('answer')

def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id, chat_record_id,
image,
**kwargs) -> NodeResult:
image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
history_message = self.get_history_message(history_chat_record, dialogue_number)
self.context['history_message'] = history_message
question = self.generate_prompt_question(prompt)
self.context['question'] = question.content
# todo 处理上传图片
message_list = self.generate_message_list(image_model, system, prompt, history_message, image)
self.context['message_list'] = message_list
self.context['image_list'] = image
self.context['dialogue_type'] = dialogue_type
if stream:
r = image_model.stream(message_list)
return NodeResult({'result': r, 'chat_model': image_model, 'message_list': message_list,
Expand All @@ -86,15 +86,31 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record
'history_message': history_message, 'question': question.content}, {},
_write_context=write_context)

@staticmethod
def get_history_message(history_chat_record, dialogue_number):
def get_history_message(self, history_chat_record, dialogue_number):
start_index = len(history_chat_record) - dialogue_number
history_message = reduce(lambda x, y: [*x, *y], [
[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
[self.generate_history_human_message(history_chat_record[index]), history_chat_record[index].get_ai_message()]
for index in
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
return history_message

def generate_history_human_message(self, chat_record):

for data in chat_record.details.values():
if self.node.id == data['node_id'] and 'image_list' in data:
image_list = data['image_list']
if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
return HumanMessage(content=chat_record.problem_text)
file_id = image_list[0]['file_id']
file = QuerySet(File).filter(id=file_id).first()
base64_image = base64.b64encode(file.get_byte()).decode("utf-8")
return HumanMessage(
content=[
{'type': 'text', 'text': data['question']},
{'type': 'image_url', 'image_url': {'url': f'data:image/jpeg;base64,{base64_image}'}},
])
return HumanMessage(content=chat_record.problem_text)

def generate_prompt_question(self, prompt):
return HumanMessage(self.workflow_manage.generate_prompt(prompt))

Expand Down Expand Up @@ -148,5 +164,6 @@ def get_details(self, index: int, **kwargs):
'answer_tokens': self.context.get('answer_tokens'),
'status': self.status,
'err_message': self.err_message,
'image_list': self.context.get('image_list')
'image_list': self.context.get('image_list'),
'dialogue_type': self.context.get('dialogue_type')
}
liuruibin marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 2 additions & 1 deletion apps/application/views/chat_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,8 @@ class UploadFile(APIView):
def post(self, request: Request, application_id: str, chat_id: str):
files = request.FILES.getlist('file')
file_ids = []
meta = {'application_id': application_id, 'chat_id': chat_id}
debug = request.data.get("debug", "false").lower() == "true"
meta = {'application_id': application_id, 'chat_id': chat_id, 'debug': debug}
for file in files:
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
file_ids.append({'name': file.name, 'url': file_url, 'file_id': file_url.split('/')[-1]})
Expand Down
2 changes: 2 additions & 0 deletions apps/common/job/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
"""
from .client_access_num_job import *
from .clean_chat_job import *
from .clean_debug_file_job import *


def run():
client_access_num_job.run()
clean_chat_job.run()
clean_debug_file_job.run()
36 changes: 36 additions & 0 deletions apps/common/job/clean_debug_file_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# coding=utf-8

import logging
from datetime import timedelta

from apscheduler.schedulers.background import BackgroundScheduler
from django.db.models import Q
from django.utils import timezone
from django_apscheduler.jobstores import DjangoJobStore

from common.lock.impl.file_lock import FileLock
from dataset.models import File

scheduler = BackgroundScheduler()
scheduler.add_jobstore(DjangoJobStore(), "default")
lock = FileLock()


def clean_debug_file():
logging.getLogger("max_kb").info('开始清理debug文件')
two_hours_ago = timezone.now() - timedelta(hours=2)
# 删除对应的文件
File.objects.filter(Q(create_time__lt=two_hours_ago) & Q(meta__debug=True)).delete()
logging.getLogger("max_kb").info('结束清理debug文件')


def run():
if lock.try_lock('clean_debug_file', 30 * 30):
try:
scheduler.start()
clean_debug_file_job = scheduler.get_job(job_id='clean_debug_file')
if clean_debug_file_job is not None:
clean_debug_file_job.remove()
scheduler.add_job(clean_debug_file, 'cron', hour='2', minute='0', second='0', id='clean_debug_file')
finally:
lock.un_lock('clean_debug_file')
liuruibin marked this conversation as resolved.
Show resolved Hide resolved
11 changes: 9 additions & 2 deletions ui/src/components/ai-chat/component/chat-input-operate/index.vue
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ const props = withDefaults(
appId?: string
chatId: string
sendMessage: (question: string, other_params_data?: any, chat?: chatType) => void
openChatId: () => Promise<string>
}>(),
{
applicationDetails: () => ({}),
Expand Down Expand Up @@ -165,8 +166,14 @@ const uploadFile = async (file: any, fileList: any) => {
}

if (!chatId_context.value) {
const res = await applicationApi.getChatOpen(props.applicationDetails.id as string)
chatId_context.value = res.data
const res = await props.openChatId()
chatId_context.value = res
}

if (props.type === 'debug-ai-chat') {
formData.append('debug', 'true')
} else {
formData.append('debug', 'false')
}

applicationApi
Expand Down
1 change: 1 addition & 0 deletions ui/src/components/ai-chat/index.vue
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
:is-mobile="isMobile"
:type="type"
:send-message="sendMessage"
:open-chat-id="openChatId"
v-model:chat-id="chartOpenId"
v-model:loading="loading"
v-if="type !== 'log'"
Expand Down
12 changes: 11 additions & 1 deletion ui/src/workflow/nodes/image-understand/index.vue
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,16 @@
@submitDialog="submitDialog"
/>
</el-form-item>
<el-form-item label="历史聊天记录">
<el-form-item>
<template #label>
<div class="flex-between">
<div>历史聊天记录</div>
<el-select v-model="form_data.dialogue_type" class="w-120">
<el-option label="节点" value="NODE"/>
<el-option label="工作流" value="WORKFLOW"/>
</el-select>
</div>
</template>
<el-input-number
v-model="form_data.dialogue_number"
:min="0"
Expand Down Expand Up @@ -213,6 +222,7 @@ const form = {
system: '',
prompt: defaultPrompt,
dialogue_number: 0,
dialogue_type: 'NODE',
is_result: true,
temperature: null,
max_tokens: null,
liuruibin marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
Loading