Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: 调试语音输入和语音播放 #1127

Merged
merged 1 commit into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions apps/application/serializers/application_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from setting.models.model_management import Model
from setting.models_provider import get_model_credential
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
from setting.models_provider.tools import get_model_instance_by_model_user_id
from setting.serializers.provider_serializers import ModelSerializer
from smartdoc.conf import PROJECT_DIR

Expand Down Expand Up @@ -856,15 +857,24 @@ def get_work_flow_model(instance):
instance['tts_model_enable'] = node['properties']['node_data']['tts_model_enable']
break

def speech_to_text(self, filelist):
# todo 找到模型 mp3转text
print(self.application_id)
print(filelist)
def speech_to_text(self, file, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
application_id = self.data.get('application_id')
application = QuerySet(Application).filter(id=application_id).first()
if application.stt_model_enable:
model = get_model_instance_by_model_user_id(application.stt_model_id, application.user_id)
text = model.speech_to_text(file)
return text

def text_to_speech(self, text):
# todo 找到模型 text转bytes
print(self.application_id)
print(text)
def text_to_speech(self, text, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
application_id = self.data.get('application_id')
application = QuerySet(Application).filter(id=application_id).first()
if application.tts_model_enable:
model = get_model_instance_by_model_user_id(application.tts_model_id, application.user_id)
return model.text_to_speech(text)

class ApplicationKeySerializerModel(serializers.ModelSerializer):
class Meta:
Expand Down
4 changes: 2 additions & 2 deletions apps/application/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
'application/<str:application_id>/chat/<chat_id>/chat_record/<str:chat_record_id>/dataset/<str:dataset_id>/document_id/<str:document_id>/improve/<str:paragraph_id>',
views.ChatView.ChatRecord.Improve.Operate.as_view(),
name=''),
path('application/<str:application_id>/<str:model_id>/speech_to_text', views.Application.SpeechToText.as_view(), name='application/audio'),
path('application/<str:application_id>/<str:model_id>/text_to_speech', views.Application.TextToSpeech.as_view(), name='application/audio'),
path('application/<str:application_id>/speech_to_text', views.Application.SpeechToText.as_view(), name='application/audio'),
path('application/<str:application_id>/text_to_speech', views.Application.TextToSpeech.as_view(), name='application/audio'),

]
16 changes: 8 additions & 8 deletions apps/application/views/application_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,10 +545,9 @@ class SpeechToText(APIView):
dynamic_tag=keywords.get(
'application_id'))],
compare=CompareConstants.AND))
def post(self, request: Request, application_id: str, model_id: str):
def post(self, request: Request, application_id: str):
return result.success(
ApplicationSerializer.Operate(
data={'application_id': application_id, 'user_id': request.user.id, 'model_id': model_id})
ApplicationSerializer.Operate(data={'application_id': application_id, 'user_id': request.user.id})
.speech_to_text(request.FILES.getlist('file')[0]))

class TextToSpeech(APIView):
Expand All @@ -561,8 +560,9 @@ class TextToSpeech(APIView):
dynamic_tag=keywords.get(
'application_id'))],
compare=CompareConstants.AND))
def post(self, request: Request, application_id: str, model_id: str):
return result.success(
ApplicationSerializer.Operate(
data={'application_id': application_id, 'user_id': request.user.id, 'model_id': model_id})
.text_to_speech(request.data.get('text')))
def post(self, request: Request, application_id: str):
byte_data = ApplicationSerializer.Operate(
data={'application_id': application_id, 'user_id': request.user.id}).text_to_speech(
request.data.get('text'))
return HttpResponse(byte_data, status=200, headers={'Content-Type': 'audio/mp3',
'Content-Disposition': 'attachment; filename="abc.mp3"'})
3 changes: 2 additions & 1 deletion ui/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@
"vue-clipboard3": "^2.0.0",
"vue-codemirror": "^6.1.1",
"vue-i18n": "^9.13.1",
"vue-router": "^4.2.4"
"vue-router": "^4.2.4",
"recorder-core": "^1.3.24040900"
},
"devDependencies": {
"@rushstack/eslint-patch": "^1.3.2",
Expand Down
27 changes: 26 additions & 1 deletion ui/src/api/application.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Result } from '@/request/Result'
import { get, post, postStream, del, put } from '@/request/index'
import { get, post, postStream, del, put, request, download } from '@/request/index'
import type { pageRequest } from '@/api/type/common'
import type { ApplicationFormType } from '@/api/type/application'
import { type Ref } from 'vue'
Expand Down Expand Up @@ -330,6 +330,29 @@ const getModelParamsForm: (
) => Promise<Result<Array<FormField>>> = (application_id, model_id, loading) => {
return get(`${prefix}/${application_id}/model_params_form/${model_id}`, undefined, loading)
}

/**
* 语音转文本
*/
const postSpeechToText: (
application_id: String,
data: any,
loading?: Ref<boolean>
) => Promise<Result<any>> = (application_id, data, loading) => {
return post(`${prefix}/${application_id}/speech_to_text`, data, undefined, loading)
}

/**
* 语音转文本
*/
const postTextToSpeech: (
application_id: String,
data: any,
loading?: Ref<boolean>
) => Promise<Result<any>> = (application_id, data, loading) => {
return download(`${prefix}/${application_id}/text_to_speech`, 'post', data, undefined, loading)
}

export default {
getAllAppilcation,
getApplication,
Expand All @@ -356,4 +379,6 @@ export default {
getApplicationRerankerModel,
getApplicationSTTModel,
getApplicationTTSModel,
postSpeechToText,
postTextToSpeech,
}
122 changes: 122 additions & 0 deletions ui/src/components/ai-chat/index.vue
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@
@regeneration="regenerationChart(item)"
/>
</div>
<div style="float: right;" v-if="props.data.tts_model_enable">
<el-button :disabled="!item.write_ed" @click="playAnswerText(item.answer_text)">
<el-icon><VideoPlay /></el-icon>
</el-button>
</div>
</div>
</div>
</template>
Expand All @@ -131,6 +136,20 @@
:maxlength="100000"
@keydown.enter="sendChatHandle($event)"
/>
<div class="operate" v-if="props.data.stt_model_enable">
<el-button
v-if="mediaRecorderStatus"
@click="startRecording"
>
<el-icon><Microphone /></el-icon>
</el-button>
<el-button
v-else
@click="stopRecording"
>
<el-icon><VideoPause /></el-icon>
</el-button>
</div>
<div class="operate">
<el-button
text
Expand All @@ -149,6 +168,8 @@
</div>
</div>
</div>
<!-- 先渲染,不然不能播放 -->
<audio ref="audioPlayer" controls hidden="hidden"></audio>
</div>
</template>
<script setup lang="ts">
Expand All @@ -165,6 +186,10 @@ import useStore from '@/stores'
import MdRenderer from '@/components/markdown/MdRenderer.vue'
import { isWorkFlow } from '@/utils/application'
import { debounce } from 'lodash'
import Recorder from 'recorder-core'
import 'recorder-core/src/engine/mp3'
import 'recorder-core/src/engine/mp3-engine'

defineOptions({ name: 'AiChat' })
const route = useRoute()
const {
Expand Down Expand Up @@ -592,6 +617,103 @@ const handleScroll = () => {
}
}

// 定义响应式引用
const mediaRecorder= ref<any>(null)
const audioPlayer= ref<HTMLAudioElement | null>(null)
const mediaRecorderStatus = ref(true)


// 开始录音
const startRecording = async () => {
try {
mediaRecorderStatus.value = false
mediaRecorder.value = new Recorder({
type: 'mp3',
bitRate: 128,
sampleRate: 44100,
})

mediaRecorder.value.open(() => {
mediaRecorder.value.start()
}, (err: any) => {
console.error(err)
})
} catch (error) {
console.error('无法获取音频权限:', error)
}
}

// 停止录音
const stopRecording = () => {
if (mediaRecorder.value) {
mediaRecorderStatus.value = true
mediaRecorder.value.stop((blob: Blob, duration: number) => {
// 测试blob是否能正常播放
// const link = document.createElement('a')
// link.href = window.URL.createObjectURL(blob)
// link.download = 'abc.mp3'
// link.click()

uploadRecording(blob) // 上传录音文件
}, (err: any) => {
console.error('录音失败:', err)
})
}
}

// 上传录音文件
const uploadRecording = async (audioBlob: Blob) => {
try {
const formData = new FormData()
formData.append('file', audioBlob, 'recording.mp3')

if (id) {
applicationApi.postSpeechToText(id as string, formData, loading)
.then((response) => {
console.log('上传成功:', response.data)
inputValue.value = response.data
// chatMessage(null, res.data)
})
}

} catch (error) {
console.error('上传失败:', error)
}
}

const playAnswerText = (text: string) => {
if (id) {
console.log(text)
applicationApi.postTextToSpeech(id as string, { 'text': text }, loading)
.then((res: any) => {

// 假设我们有一个 MP3 文件的字节数组
// 创建 Blob 对象
const blob = new Blob([res], { type: 'audio/mp3' })

// 创建对象 URL
const url = URL.createObjectURL(blob)

// 测试blob是否能正常播放
// const link = document.createElement('a')
// link.href = window.URL.createObjectURL(blob)
// link.download = "abc.mp3"
// link.click()

// 检查 audioPlayer 是否已经引用了 DOM 元素
if (audioPlayer.value instanceof HTMLAudioElement) {
audioPlayer.value.src = url;
audioPlayer.value.play(); // 自动播放音频
} else {
console.error("audioPlayer.value is not an instance of HTMLAudioElement");
}
})
.catch((err) => {
console.log('err: ', err)
})
}
}

function setScrollBottom() {
// 将滚动条滚动到最下面
scrollDiv.value.setScrollTop(getMaxHeight())
Expand Down
17 changes: 17 additions & 0 deletions ui/src/request/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,23 @@ export const exportExcel: (
.catch((e) => {})
}


export const download: (
url: string,
method: string,
data?: any,
params?: any,
loading?: NProgress | Ref<boolean>
) => Promise<any> = (
url: string,
method: string,
data?: any,
params?: any,
loading?: NProgress | Ref<boolean>
) => {
return promise(request({ url: url, method: method, data, params, responseType: 'blob' }), loading)
}

/**
* 与服务器建立ws链接
* @param url websocket路径
Expand Down
Loading