diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 9faabd4ef7..629d118a91 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -47,6 +47,7 @@ from setting.models.model_management import Model from setting.models_provider import get_model_credential from setting.models_provider.constants.model_provider_constants import ModelProvideConstants +from setting.models_provider.tools import get_model_instance_by_model_user_id from setting.serializers.provider_serializers import ModelSerializer from smartdoc.conf import PROJECT_DIR @@ -856,15 +857,24 @@ def get_work_flow_model(instance): instance['tts_model_enable'] = node['properties']['node_data']['tts_model_enable'] break - def speech_to_text(self, filelist): - # todo 找到模型 mp3转text - print(self.application_id) - print(filelist) + def speech_to_text(self, file, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + application_id = self.data.get('application_id') + application = QuerySet(Application).filter(id=application_id).first() + if application.stt_model_enable: + model = get_model_instance_by_model_user_id(application.stt_model_id, application.user_id) + text = model.speech_to_text(file) + return text - def text_to_speech(self, text): - # todo 找到模型 text转bytes - print(self.application_id) - print(text) + def text_to_speech(self, text, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + application_id = self.data.get('application_id') + application = QuerySet(Application).filter(id=application_id).first() + if application.tts_model_enable: + model = get_model_instance_by_model_user_id(application.tts_model_id, application.user_id) + return model.text_to_speech(text) class ApplicationKeySerializerModel(serializers.ModelSerializer): class Meta: diff --git a/apps/application/urls.py b/apps/application/urls.py index 2b7e779fae..5a41bc59ce 100644 --- a/apps/application/urls.py +++ b/apps/application/urls.py @@ -64,7 +64,7 @@ 'application//chat//chat_record//dataset//document_id//improve/', views.ChatView.ChatRecord.Improve.Operate.as_view(), name=''), - path('application///speech_to_text', views.Application.SpeechToText.as_view(), name='application/audio'), - path('application///text_to_speech', views.Application.TextToSpeech.as_view(), name='application/audio'), + path('application//speech_to_text', views.Application.SpeechToText.as_view(), name='application/audio'), + path('application//text_to_speech', views.Application.TextToSpeech.as_view(), name='application/audio'), ] diff --git a/apps/application/views/application_views.py b/apps/application/views/application_views.py index db0dae034c..a9dd2fe4dd 100644 --- a/apps/application/views/application_views.py +++ b/apps/application/views/application_views.py @@ -545,10 +545,9 @@ class SpeechToText(APIView): dynamic_tag=keywords.get( 'application_id'))], compare=CompareConstants.AND)) - def post(self, request: Request, application_id: str, model_id: str): + def post(self, request: Request, application_id: str): return result.success( - ApplicationSerializer.Operate( - data={'application_id': application_id, 'user_id': request.user.id, 'model_id': model_id}) + ApplicationSerializer.Operate(data={'application_id': application_id, 'user_id': request.user.id}) .speech_to_text(request.FILES.getlist('file')[0])) class TextToSpeech(APIView): @@ -561,8 +560,9 @@ class TextToSpeech(APIView): dynamic_tag=keywords.get( 'application_id'))], compare=CompareConstants.AND)) - def post(self, request: Request, application_id: str, model_id: str): - return result.success( - ApplicationSerializer.Operate( - data={'application_id': application_id, 'user_id': request.user.id, 'model_id': model_id}) - .text_to_speech(request.data.get('text'))) + def post(self, request: Request, application_id: str): + byte_data = ApplicationSerializer.Operate( + data={'application_id': application_id, 'user_id': request.user.id}).text_to_speech( + request.data.get('text')) + return HttpResponse(byte_data, status=200, headers={'Content-Type': 'audio/mp3', + 'Content-Disposition': 'attachment; filename="abc.mp3"'}) diff --git a/ui/package.json b/ui/package.json index 0a3635c206..3e1582bae3 100644 --- a/ui/package.json +++ b/ui/package.json @@ -43,7 +43,8 @@ "vue-clipboard3": "^2.0.0", "vue-codemirror": "^6.1.1", "vue-i18n": "^9.13.1", - "vue-router": "^4.2.4" + "vue-router": "^4.2.4", + "recorder-core": "^1.3.24040900" }, "devDependencies": { "@rushstack/eslint-patch": "^1.3.2", diff --git a/ui/src/api/application.ts b/ui/src/api/application.ts index 2cf3b413eb..f21c465b0f 100644 --- a/ui/src/api/application.ts +++ b/ui/src/api/application.ts @@ -1,5 +1,5 @@ import { Result } from '@/request/Result' -import { get, post, postStream, del, put } from '@/request/index' +import { get, post, postStream, del, put, request, download } from '@/request/index' import type { pageRequest } from '@/api/type/common' import type { ApplicationFormType } from '@/api/type/application' import { type Ref } from 'vue' @@ -330,6 +330,29 @@ const getModelParamsForm: ( ) => Promise>> = (application_id, model_id, loading) => { return get(`${prefix}/${application_id}/model_params_form/${model_id}`, undefined, loading) } + +/** + * 语音转文本 + */ +const postSpeechToText: ( + application_id: String, + data: any, + loading?: Ref +) => Promise> = (application_id, data, loading) => { + return post(`${prefix}/${application_id}/speech_to_text`, data, undefined, loading) +} + +/** + * 语音转文本 + */ +const postTextToSpeech: ( + application_id: String, + data: any, + loading?: Ref +) => Promise> = (application_id, data, loading) => { + return download(`${prefix}/${application_id}/text_to_speech`, 'post', data, undefined, loading) +} + export default { getAllAppilcation, getApplication, @@ -356,4 +379,6 @@ export default { getApplicationRerankerModel, getApplicationSTTModel, getApplicationTTSModel, + postSpeechToText, + postTextToSpeech, } diff --git a/ui/src/components/ai-chat/index.vue b/ui/src/components/ai-chat/index.vue index d72d13e793..074cb86d2b 100644 --- a/ui/src/components/ai-chat/index.vue +++ b/ui/src/components/ai-chat/index.vue @@ -114,6 +114,11 @@ @regeneration="regenerationChart(item)" /> +
+ + + +
@@ -131,6 +136,20 @@ :maxlength="100000" @keydown.enter="sendChatHandle($event)" /> +
+ + + + + + +
+ +