Skip to content

Commit

Permalink
refactor: 支持使用浏览器语音播放
Browse files Browse the repository at this point in the history
  • Loading branch information
liuruibin committed Sep 11, 2024
1 parent 3c461b6 commit 2270f97
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 6 deletions.
2 changes: 1 addition & 1 deletion apps/application/views/application_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class Model(APIView):
tags=["应用"],
manual_parameters=ApplicationApi.Model.get_request_params_api())
@has_permissions(ViewPermission(
[RoleConstants.ADMIN, RoleConstants.USER],
[RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_ACCESS_TOKEN],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))],
compare=CompareConstants.AND))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# coding=utf-8

from typing import Dict

from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential


class BrowserTextToSpeechCredential(BaseForm, BaseModelCredential):

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
return True

def encryption_dict(self, model: Dict[str, object]):
return model
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,25 @@
ModelInfoManage
from setting.models_provider.impl.local_model_provider.credential.embedding import LocalEmbeddingCredential
from setting.models_provider.impl.local_model_provider.credential.reranker import LocalRerankerCredential
from setting.models_provider.impl.local_model_provider.credential.tts import BrowserTextToSpeechCredential
from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
from setting.models_provider.impl.local_model_provider.model.reranker import LocalReranker
from setting.models_provider.impl.local_model_provider.model.tts import BrowserTextToSpeech
from smartdoc.conf import PROJECT_DIR

embedding_text2vec_base_chinese = ModelInfo('shibing624/text2vec-base-chinese', '', ModelTypeConst.EMBEDDING,
LocalEmbeddingCredential(), LocalEmbedding)
bge_reranker_v2_m3 = ModelInfo('BAAI/bge-reranker-v2-m3', '', ModelTypeConst.RERANKER,
LocalRerankerCredential(), LocalReranker)

browser_tts = ModelInfo('browser_tts', '', ModelTypeConst.TTS, BrowserTextToSpeechCredential(), BrowserTextToSpeech)


model_info_manage = (ModelInfoManage.builder().append_model_info(embedding_text2vec_base_chinese)
.append_default_model_info(embedding_text2vec_base_chinese)
.append_model_info(bge_reranker_v2_m3)
.append_default_model_info(bge_reranker_v2_m3)
.append_model_info(browser_tts)
.build())


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Dict

from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_tts import BaseTextToSpeech



class BrowserTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
model: str

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.model = kwargs.get('model')

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
return BrowserTextToSpeech(
model=model_name,
**optional_params,
)

def check_auth(self):
pass

def text_to_speech(self, text):
pass
20 changes: 17 additions & 3 deletions ui/src/components/ai-chat/index.vue
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@
@regeneration="regenerationChart(item)"
/>
</div>
<!-- 语音播放 -->
<div style="float: right;" v-if="props.data.tts_model_enable">
<el-button :disabled="!item.write_ed" @click="playAnswerText(item.answer_text)">
<el-icon>
Expand Down Expand Up @@ -247,7 +248,13 @@ const props = defineProps({
chatId: {
type: String,
default: ''
} // 历史记录Id
}, // 历史记录Id
ttsModelOptions: {
type: Object,
default: () => {
return {}
}
}
})
const emit = defineEmits(['refresh', 'scroll'])
Expand Down Expand Up @@ -321,8 +328,7 @@ watch(
)
function handleInputFieldList() {
props.data.work_flow?.nodes
.filter((v: any) => v.id === 'base-node')
props.data.work_flow?.nodes?.filter((v: any) => v.id === 'base-node')
.map((v: any) => {
inputFieldList.value = v.properties.input_field_list.map((v: any) => {
switch (v.type) {
Expand Down Expand Up @@ -763,6 +769,14 @@ const uploadRecording = async (audioBlob: Blob) => {
}
const playAnswerText = (text: string) => {
if (props.ttsModelOptions?.model_local_provider?.filter((v: any) => v.id === props.data.tts_model_id).length > 0) {
// 创建一个新的 SpeechSynthesisUtterance 实例
const utterance = new SpeechSynthesisUtterance(text);
// 调用浏览器的朗读功能
window.speechSynthesis.speak(utterance);
return
}
applicationApi.postTextToSpeech(props.data.id as string, { 'text': text }, loading)
.then((res: any) => {
Expand Down
19 changes: 18 additions & 1 deletion ui/src/views/application-workflow/index.vue
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@
</div>
</div>
<div class="scrollbar-height">
<AiChat :data="detail"></AiChat>
<AiChat :data="detail" :tts-model-options="ttsModelOptions"></AiChat>
</div>
</div>
</el-collapse-transition>
Expand All @@ -157,6 +157,7 @@ import { datetimeFormat } from '@/utils/time'
import useStore from '@/stores'
import { WorkFlowInstance } from '@/workflow/common/validate'
import { hasPermission } from '@/utils/permission'
import { groupBy } from 'lodash'

const { user, application } = useStore()
const router = useRouter()
Expand All @@ -181,6 +182,7 @@ const enlarge = ref(false)
const saveTime = ref<any>('')
const activeName = ref('base')
const functionLibList = ref<any[]>([])
const ttsModelOptions = ref<any>(null)

function publicHandle() {
workflowRef.value
Expand Down Expand Up @@ -310,6 +312,20 @@ function getList() {
})
}

function getTTSModel() {
loading.value = true
applicationApi
.getApplicationTTSModel(id)
.then((res: any) => {
ttsModelOptions.value = groupBy(res?.data, 'provider')
loading.value = false
})
.catch(() => {
loading.value = false
})
}


/**
* 定时保存
*/
Expand All @@ -329,6 +345,7 @@ const closeInterval = () => {
}

onMounted(() => {
getTTSModel()
getDetail()
getList()
// 初始化定时任务
Expand Down
2 changes: 1 addition & 1 deletion ui/src/views/application/ApplicationSetting.vue
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@
</h4>
</div>
<div class="scrollbar-height">
<AiChat :data="applicationForm"></AiChat>
<AiChat :data="applicationForm" :tts-model-options="ttsModelOptions"></AiChat>
</div>
</div>
</el-col>
Expand Down
20 changes: 20 additions & 0 deletions ui/src/views/chat/pc/index.vue
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
:appId="applicationDetail?.id"
:record="currentRecordList"
:chatId="currentChatId"
:tts-model-options="ttsModelOptions"
@refresh="refresh"
@scroll="handleScroll"
>
Expand All @@ -130,6 +131,8 @@ import { marked } from 'marked'
import { saveAs } from 'file-saver'
import { isAppIcon } from '@/utils/application'
import useStore from '@/stores'
import applicationApi from '@/api/application'
import { groupBy } from 'lodash'
import useResize from '@/layout/hooks/useResize'
useResize()
Expand Down Expand Up @@ -167,6 +170,8 @@ const left_loading = ref(false)
const applicationDetail = ref<any>({})
const applicationAvailable = ref<boolean>(true)
const chatLogeData = ref<any[]>([])
const ttsModelOptions = ref<any>(null)
const paginationConfig = ref({
current_page: 1,
Expand Down Expand Up @@ -228,6 +233,7 @@ function getAppProfile() {
if (res.data?.show_history || !user.isEnterprise()) {
getChatLog(applicationDetail.value.id)
}
getTTSModel()
})
.catch(() => {
applicationAvailable.value = false
Expand Down Expand Up @@ -336,6 +342,20 @@ async function exportHTML(): Promise<void> {
saveAs(blob, suggestedName)
}
function getTTSModel() {
loading.value = true
applicationApi
.getApplicationTTSModel(applicationDetail.value.id)
.then((res: any) => {
ttsModelOptions.value = groupBy(res?.data, 'provider')
loading.value = false
})
.catch(() => {
loading.value = false
})
}
onMounted(() => {
user.changeUserType(2)
getAccessToken(accessToken)
Expand Down

0 comments on commit 2270f97

Please sign in to comment.