Skip to content

Commit

Permalink
refactor: 调试语音输入和语音播放
Browse files Browse the repository at this point in the history
  • Loading branch information
liuruibin committed Sep 6, 2024
1 parent 408a57a commit 5ba7296
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 20 deletions.
26 changes: 18 additions & 8 deletions apps/application/serializers/application_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from setting.models.model_management import Model
from setting.models_provider import get_model_credential
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
from setting.models_provider.tools import get_model_instance_by_model_user_id
from setting.serializers.provider_serializers import ModelSerializer
from smartdoc.conf import PROJECT_DIR

Expand Down Expand Up @@ -856,15 +857,24 @@ def get_work_flow_model(instance):
instance['tts_model_enable'] = node['properties']['node_data']['tts_model_enable']
break

def speech_to_text(self, filelist):
# todo 找到模型 mp3转text
print(self.application_id)
print(filelist)
def speech_to_text(self, file, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
application_id = self.data.get('application_id')
application = QuerySet(Application).filter(id=application_id).first()
if application.stt_model_enable:
model = get_model_instance_by_model_user_id(application.stt_model_id, application.user_id)
text = model.speech_to_text(file)
return text

def text_to_speech(self, text):
# todo 找到模型 text转bytes
print(self.application_id)
print(text)
def text_to_speech(self, text, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
application_id = self.data.get('application_id')
application = QuerySet(Application).filter(id=application_id).first()
if application.tts_model_enable:
model = get_model_instance_by_model_user_id(application.tts_model_id, application.user_id)
return model.text_to_speech(text)

class ApplicationKeySerializerModel(serializers.ModelSerializer):
class Meta:
Expand Down
4 changes: 2 additions & 2 deletions apps/application/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
'application/<str:application_id>/chat/<chat_id>/chat_record/<str:chat_record_id>/dataset/<str:dataset_id>/document_id/<str:document_id>/improve/<str:paragraph_id>',
views.ChatView.ChatRecord.Improve.Operate.as_view(),
name=''),
path('application/<str:application_id>/<str:model_id>/speech_to_text', views.Application.SpeechToText.as_view(), name='application/audio'),
path('application/<str:application_id>/<str:model_id>/text_to_speech', views.Application.TextToSpeech.as_view(), name='application/audio'),
path('application/<str:application_id>/speech_to_text', views.Application.SpeechToText.as_view(), name='application/audio'),
path('application/<str:application_id>/text_to_speech', views.Application.TextToSpeech.as_view(), name='application/audio'),

]
16 changes: 8 additions & 8 deletions apps/application/views/application_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,10 +545,9 @@ class SpeechToText(APIView):
dynamic_tag=keywords.get(
'application_id'))],
compare=CompareConstants.AND))
def post(self, request: Request, application_id: str, model_id: str):
def post(self, request: Request, application_id: str):
return result.success(
ApplicationSerializer.Operate(
data={'application_id': application_id, 'user_id': request.user.id, 'model_id': model_id})
ApplicationSerializer.Operate(data={'application_id': application_id, 'user_id': request.user.id})
.speech_to_text(request.FILES.getlist('file')[0]))

class TextToSpeech(APIView):
Expand All @@ -561,8 +560,9 @@ class TextToSpeech(APIView):
dynamic_tag=keywords.get(
'application_id'))],
compare=CompareConstants.AND))
def post(self, request: Request, application_id: str, model_id: str):
return result.success(
ApplicationSerializer.Operate(
data={'application_id': application_id, 'user_id': request.user.id, 'model_id': model_id})
.text_to_speech(request.data.get('text')))
def post(self, request: Request, application_id: str):
byte_data = ApplicationSerializer.Operate(
data={'application_id': application_id, 'user_id': request.user.id}).text_to_speech(
request.data.get('text'))
return HttpResponse(byte_data, status=200, headers={'Content-Type': 'audio/mp3',
'Content-Disposition': 'attachment; filename="abc.mp3"'})
3 changes: 2 additions & 1 deletion ui/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@
"vue-clipboard3": "^2.0.0",
"vue-codemirror": "^6.1.1",
"vue-i18n": "^9.13.1",
"vue-router": "^4.2.4"
"vue-router": "^4.2.4",
"recorder-core": "^1.3.24040900"
},
"devDependencies": {
"@rushstack/eslint-patch": "^1.3.2",
Expand Down
27 changes: 26 additions & 1 deletion ui/src/api/application.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Result } from '@/request/Result'
import { get, post, postStream, del, put } from '@/request/index'
import { get, post, postStream, del, put, request, download } from '@/request/index'
import type { pageRequest } from '@/api/type/common'
import type { ApplicationFormType } from '@/api/type/application'
import { type Ref } from 'vue'
Expand Down Expand Up @@ -330,6 +330,29 @@ const getModelParamsForm: (
) => Promise<Result<Array<FormField>>> = (application_id, model_id, loading) => {
return get(`${prefix}/${application_id}/model_params_form/${model_id}`, undefined, loading)
}

/**
* 语音转文本
*/
const postSpeechToText: (
application_id: String,
data: any,
loading?: Ref<boolean>
) => Promise<Result<any>> = (application_id, data, loading) => {
return post(`${prefix}/${application_id}/speech_to_text`, data, undefined, loading)
}

/**
* 语音转文本
*/
const postTextToSpeech: (
application_id: String,
data: any,
loading?: Ref<boolean>
) => Promise<Result<any>> = (application_id, data, loading) => {
return download(`${prefix}/${application_id}/text_to_speech`, 'post', data, undefined, loading)
}

export default {
getAllAppilcation,
getApplication,
Expand All @@ -356,4 +379,6 @@ export default {
getApplicationRerankerModel,
getApplicationSTTModel,
getApplicationTTSModel,
postSpeechToText,
postTextToSpeech,
}
123 changes: 123 additions & 0 deletions ui/src/components/ai-chat/index.vue
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@
@regeneration="regenerationChart(item)"
/>
</div>
<div style="float: right;" v-if="props.data.tts_model_enable">
<el-button :disabled="!item.write_ed" @click="playAnswerText(item.answer_text)">
<el-icon><VideoPlay /></el-icon>
</el-button>
</div>
</div>
</div>
</template>
Expand All @@ -131,6 +136,20 @@
:maxlength="100000"
@keydown.enter="sendChatHandle($event)"
/>
<div class="operate" v-if="props.data.stt_model_enable">
<el-button
v-if="mediaRecorderStatus"
@click="startRecording"
>
<el-icon><Microphone /></el-icon>
</el-button>
<el-button
v-else
@click="stopRecording"
>
<el-icon><VideoPause /></el-icon>
</el-button>
</div>
<div class="operate">
<el-button
text
Expand All @@ -149,6 +168,8 @@
</div>
</div>
</div>
<!-- 先渲染,不然不能播放 -->
<audio ref="audioPlayer" controls hidden="hidden"></audio>
</div>
</template>
<script setup lang="ts">
Expand All @@ -165,6 +186,10 @@ import useStore from '@/stores'
import MdRenderer from '@/components/markdown/MdRenderer.vue'
import { isWorkFlow } from '@/utils/application'
import { debounce } from 'lodash'
import Recorder from 'recorder-core'
import 'recorder-core/src/engine/mp3'
import 'recorder-core/src/engine/mp3-engine'
defineOptions({ name: 'AiChat' })
const route = useRoute()
const {
Expand Down Expand Up @@ -592,6 +617,104 @@ const handleScroll = () => {
}
}
// 定义响应式引用
const mediaRecorder= ref<any>(null)
const audioPlayer= ref<HTMLAudioElement | null>(null)
const mediaRecorderStatus = ref(true)
// 开始录音
const startRecording = async () => {
try {
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
mediaRecorderStatus.value = false
mediaRecorder.value = new Recorder({
type: 'mp3',
bitRate: 128,
sampleRate: 44100,
})
mediaRecorder.value.open(() => {
mediaRecorder.value.start()
}, (err: any) => {
console.error(err)
})
} catch (error) {
console.error('无法获取音频权限:', error)
}
}
// 停止录音
const stopRecording = () => {
if (mediaRecorder.value) {
mediaRecorderStatus.value = true
mediaRecorder.value.stop((blob: Blob, duration: number) => {
// 测试blob是否能正常播放
// const link = document.createElement('a')
// link.href = window.URL.createObjectURL(blob)
// link.download = 'abc.mp3'
// link.click()
uploadRecording(blob) // 上传录音文件
}, (err: any) => {
console.error('录音失败:', err)
})
}
}
// 上传录音文件
const uploadRecording = async (audioBlob: Blob) => {
try {
const formData = new FormData()
formData.append('file', audioBlob, 'recording.mp3')
if (id) {
applicationApi.postSpeechToText(id as string, formData, loading)
.then((response) => {
console.log('上传成功:', response.data)
inputValue.value = response.data
// chatMessage(null, res.data)
})
}
} catch (error) {
console.error('上传失败:', error)
}
}
const playAnswerText = (text: string) => {
if (id) {
console.log(text)
applicationApi.postTextToSpeech(id as string, { 'text': text }, loading)
.then((res: any) => {
// 假设我们有一个 MP3 文件的字节数组
// 创建 Blob 对象
const blob = new Blob([res], { type: 'audio/mp3' })
// 创建对象 URL
const url = URL.createObjectURL(blob)
// 测试blob是否能正常播放
// const link = document.createElement('a')
// link.href = window.URL.createObjectURL(blob)
// link.download = "abc.mp3"
// link.click()
// 检查 audioPlayer 是否已经引用了 DOM 元素
if (audioPlayer.value instanceof HTMLAudioElement) {
audioPlayer.value.src = url;
audioPlayer.value.play(); // 自动播放音频
} else {
console.error("audioPlayer.value is not an instance of HTMLAudioElement");
}
})
.catch((err) => {
console.log('err: ', err)
})
}
}
function setScrollBottom() {
// 将滚动条滚动到最下面
scrollDiv.value.setScrollTop(getMaxHeight())
Expand Down
17 changes: 17 additions & 0 deletions ui/src/request/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,23 @@ export const exportExcel: (
.catch((e) => {})
}


export const download: (
url: string,
method: string,
data?: any,
params?: any,
loading?: NProgress | Ref<boolean>
) => Promise<any> = (
url: string,
method: string,
data?: any,
params?: any,
loading?: NProgress | Ref<boolean>
) => {
return promise(request({ url: url, method: method, data, params, responseType: 'blob' }), loading)
}

/**
* 与服务器建立ws链接
* @param url websocket路径
Expand Down

0 comments on commit 5ba7296

Please sign in to comment.