Skip to content

Commit

Permalink
feat: 语音转文本
Browse files Browse the repository at this point in the history
  • Loading branch information
liuruibin committed Sep 2, 2024
1 parent 614f17b commit d9bbc92
Show file tree
Hide file tree
Showing 11 changed files with 631 additions and 5 deletions.
2 changes: 2 additions & 0 deletions apps/setting/models_provider/base_model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def encryption(message: str):
class ModelTypeConst(Enum):
LLM = {'code': 'LLM', 'message': '大语言模型'}
EMBEDDING = {'code': 'EMBEDDING', 'message': '向量模型'}
STT = {'code': 'STT', 'message': '语音识别'}
TTS = {'code': 'TTS', 'message': '语音合成'}


class ModelInfo:
Expand Down
8 changes: 8 additions & 0 deletions apps/setting/models_provider/impl/base_stt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# coding=utf-8
from abc import abstractmethod


class BaseSpeechToText:
@abstractmethod
def speech_to_text(self, audio_file):
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import Dict

from langchain_openai.chat_models import ChatOpenAI

from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_stt import BaseSpeechToText


def custom_get_token_ids(text: str):
tokenizer = TokenizerManage.get_tokenizer()
return tokenizer.encode(text)


class OpenAISpeechToText(MaxKBBaseModel, ChatOpenAI, BaseSpeechToText):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
return OpenAISpeechToText(
model=model_name,
openai_api_base=model_credential.get('api_base'),
openai_api_key=model_credential.get('api_key'),
**optional_params,
streaming=True,
stream_usage=True,
custom_get_token_ids=custom_get_token_ids
)

def speech_to_text(self, audio_file):
return self.client.audio.transcriptions.create(
model=self.model,
language="zh",
file=audio_file
)
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel
from setting.models_provider.impl.openai_model_provider.model.stt import OpenAISpeechToText
from smartdoc.conf import PROJECT_DIR

openai_llm_model_credential = OpenAILLMModelCredential()
Expand Down Expand Up @@ -58,7 +59,10 @@
OpenAIChatModel),
ModelInfo('gpt-4-1106-preview', '2023年11月6日的gpt-4-turbo快照,支持上下文长度128,000 tokens',
ModelTypeConst.LLM, openai_llm_model_credential,
OpenAIChatModel)
OpenAIChatModel),
ModelInfo('whisper-1', '',
ModelTypeConst.STT, openai_llm_model_credential,
OpenAISpeechToText)
]
open_ai_embedding_credential = OpenAIEmbeddingCredential()
model_info_embedding_list = [
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# coding=utf-8

from typing import Dict

from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode


class VolcanicEngineSTTModelCredential(BaseForm, BaseModelCredential):
volcanic_api_url = forms.TextInputField('API 域名', required=True)
volcanic_app_id = forms.TextInputField('App ID', required=True)
volcanic_token = forms.PasswordInputField('Token', required=True)
volcanic_cluster = forms.TextInputField('Cluster', required=True)

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')

for key in ['volcanic_api_url', 'volcanic_app_id', 'volcanic_token', 'volcanic_cluster']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
print(model)
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True

def encryption_dict(self, model: Dict[str, object]):
return {**model, 'volcanic_token': super().encryption(model.get('volcanic_token', ''))}

def get_model_params_setting_form(self, model_name):
pass
Loading

0 comments on commit d9bbc92

Please sign in to comment.