Skip to content

Commit

Permalink
refactor: 应用设置中配置语音输入和播放
Browse files Browse the repository at this point in the history
  • Loading branch information
liuruibin committed Sep 5, 2024
1 parent 3faba75 commit 5b00c16
Show file tree
Hide file tree
Showing 7 changed files with 285 additions and 8 deletions.
4 changes: 4 additions & 0 deletions apps/application/models/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ class Application(AppModelMixin):
work_flow = models.JSONField(verbose_name="工作流数据", default=dict)
type = models.CharField(verbose_name="应用类型", choices=ApplicationTypeChoices.choices,
default=ApplicationTypeChoices.SIMPLE, max_length=256)
tts_model = models.ForeignKey(Model, related_name='tts_model_id', on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True)
stt_model = models.ForeignKey(Model, related_name='stt_model_id', on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True)
tts_model_enable = models.BooleanField(verbose_name="语音合成模型是否启用", default=False)
stt_model_enable = models.BooleanField(verbose_name="语音识别模型是否启用", default=False)

@staticmethod
def get_default_model_prompt():
Expand Down
38 changes: 32 additions & 6 deletions apps/application/serializers/application_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def list(self, with_valid=True):
@staticmethod
def reset_application(application: Dict):
application['multiple_rounds_dialogue'] = True if application.get('dialogue_number') > 0 else False
del application['dialogue_number']

if 'dataset_setting' in application:
application['dataset_setting'] = {'search_mode': 'embedding', 'no_references_setting': {
'status': 'ai_questioning',
Expand Down Expand Up @@ -710,21 +710,37 @@ def edit(self, instance: Dict, with_valid=True):
raise AppApiException(500, "模型不存在")
if not model.is_permission(application.user_id):
raise AppApiException(500, f"沒有权限使用该模型:{model.name}")
if instance.get('stt_model_id') is None or len(instance.get('stt_model_id')) == 0:
application.stt_model_id = None
else:
model = QuerySet(Model).filter(
id=instance.get('stt_model_id')).first()
if model is None:
raise AppApiException(500, "模型不存在")
if not model.is_permission(application.user_id):
raise AppApiException(500, f"沒有权限使用该模型:{model.name}")
if instance.get('tts_model_id') is None or len(instance.get('tts_model_id')) == 0:
application.tts_model_id = None
else:
model = QuerySet(Model).filter(
id=instance.get('tts_model_id')).first()
if model is None:
raise AppApiException(500, "模型不存在")
if not model.is_permission(application.user_id):
raise AppApiException(500, f"沒有权限使用该模型:{model.name}")
if 'work_flow' in instance:
# 当前用户可修改关联的知识库列表
application_dataset_id_list = [str(dataset_dict.get('id')) for dataset_dict in
self.list_dataset(with_valid=False)]
self.update_reverse_search_node(instance.get('work_flow'), application_dataset_id_list)

update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status',
'dataset_setting', 'model_setting', 'problem_optimization',
'dataset_setting', 'model_setting', 'problem_optimization', 'dialogue_number',
'stt_model_id', 'tts_model_id', 'tts_model_enable', 'stt_model_enable',
'api_key_is_active', 'icon', 'work_flow', 'model_params_setting']
for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None:
if update_key == 'multiple_rounds_dialogue':
application.__setattr__('dialogue_number', 0 if not instance.get(update_key) else 3)
else:
application.__setattr__(update_key, instance.get(update_key))
application.__setattr__(update_key, instance.get(update_key))
application.save()

if 'dataset_id_list' in instance:
Expand Down Expand Up @@ -823,6 +839,16 @@ def save_other_config(self, data):

application.save()

def speech_to_text(self, filelist):
# todo 找到模型 mp3转text
print(self.application_id)
print(filelist)

def text_to_speech(self, text):
# todo 找到模型 text转bytes
print(self.application_id)
print(text)

class ApplicationKeySerializerModel(serializers.ModelSerializer):
class Meta:
model = ApplicationApiKey
Expand Down
5 changes: 4 additions & 1 deletion apps/application/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,8 @@
path(
'application/<str:application_id>/chat/<chat_id>/chat_record/<str:chat_record_id>/dataset/<str:dataset_id>/document_id/<str:document_id>/improve/<str:paragraph_id>',
views.ChatView.ChatRecord.Improve.Operate.as_view(),
name='')
name=''),
path('application/<str:application_id>/<str:model_id>/speech_to_text', views.Application.SpeechToText.as_view(), name='application/audio'),
path('application/<str:application_id>/<str:model_id>/text_to_speech', views.Application.TextToSpeech.as_view(), name='application/audio'),

]
32 changes: 32 additions & 0 deletions apps/application/views/application_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,3 +534,35 @@ def get(self, request: Request, current_page: int, page_size: int):
ApplicationSerializer.Query(
data={**query_params_to_single_dict(request.query_params), 'user_id': request.user.id}).page(
current_page, page_size))

class SpeechToText(APIView):
authentication_classes = [TokenAuth]

@action(methods=['POST'], detail=False)
@has_permissions(ViewPermission([RoleConstants.ADMIN, RoleConstants.USER],
[lambda r, keywords: Permission(group=Group.APPLICATION,
operate=Operate.USE,
dynamic_tag=keywords.get(
'application_id'))],
compare=CompareConstants.AND))
def post(self, request: Request, application_id: str, model_id: str):
return result.success(
ApplicationSerializer.Operate(
data={'application_id': application_id, 'user_id': request.user.id, 'model_id': model_id})
.speech_to_text(request.FILES.getlist('file')[0]))

class TextToSpeech(APIView):
authentication_classes = [TokenAuth]

@action(methods=['POST'], detail=False)
@has_permissions(ViewPermission([RoleConstants.ADMIN, RoleConstants.USER],
[lambda r, keywords: Permission(group=Group.APPLICATION,
operate=Operate.USE,
dynamic_tag=keywords.get(
'application_id'))],
compare=CompareConstants.AND))
def post(self, request: Request, application_id: str, model_id: str):
return result.success(
ApplicationSerializer.Operate(
data={'application_id': application_id, 'user_id': request.user.id, 'model_id': model_id})
.text_to_speech(request.data.get('text')))
33 changes: 32 additions & 1 deletion ui/src/api/application.ts
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,35 @@ const getApplicationRerankerModel: (
) => Promise<Result<Array<any>>> = (application_id, loading) => {
return get(`${prefix}/${application_id}/model`, { model_type: 'RERANKER' }, loading)
}

/**
* 获取当前用户可使用的模型列表
* @param application_id
* @param loading
* @query { query_text: string, top_number: number, similarity: number }
* @returns
*/
const getApplicationSTTModel: (
application_id: string,
loading?: Ref<boolean>
) => Promise<Result<Array<any>>> = (application_id, loading) => {
return get(`${prefix}/${application_id}/model`, { model_type: 'STT' }, loading)
}

/**
* 获取当前用户可使用的模型列表
* @param application_id
* @param loading
* @query { query_text: string, top_number: number, similarity: number }
* @returns
*/
const getApplicationTTSModel: (
application_id: string,
loading?: Ref<boolean>
) => Promise<Result<Array<any>>> = (application_id, loading) => {
return get(`${prefix}/${application_id}/model`, { model_type: 'TTS' }, loading)
}

/**
* 发布应用
* @param 参数
Expand Down Expand Up @@ -324,5 +353,7 @@ export default {
listFunctionLib,
getFunctionLib,
getModelParamsForm,
getApplicationRerankerModel
getApplicationRerankerModel,
getApplicationSTTModel,
getApplicationTTSModel,
}
4 changes: 4 additions & 0 deletions ui/src/api/type/application.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ interface ApplicationFormType {
type?: string
work_flow?: any
model_params_setting?: any
stt_model_id?: string
tts_model_id?: string
stt_model_enable?: boolean
tts_model_enable?: boolean
}
interface chatType {
id: string
Expand Down
Loading

0 comments on commit 5b00c16

Please sign in to comment.