From d9bbc926f123b69ebd9a0772d67db0c9fbec08cb Mon Sep 17 00:00:00 2001 From: CaptainB Date: Tue, 27 Aug 2024 17:46:52 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E8=AF=AD=E9=9F=B3=E8=BD=AC=E6=96=87?= =?UTF-8?q?=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../models_provider/base_model_provider.py | 2 + apps/setting/models_provider/impl/base_stt.py | 8 + .../impl/openai_model_provider/model/stt.py | 38 ++ .../openai_model_provider.py | 6 +- .../credential/stt.py | 45 +++ .../model/stt.py | 326 ++++++++++++++++++ .../volcanic_engine_model_provider.py | 11 +- .../impl/xf_model_provider/credential/stt.py | 46 +++ .../impl/xf_model_provider/model/stt.py | 145 ++++++++ .../xf_model_provider/xf_model_provider.py | 7 +- pyproject.toml | 2 +- 11 files changed, 631 insertions(+), 5 deletions(-) create mode 100644 apps/setting/models_provider/impl/base_stt.py create mode 100644 apps/setting/models_provider/impl/openai_model_provider/model/stt.py create mode 100644 apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/stt.py create mode 100644 apps/setting/models_provider/impl/volcanic_engine_model_provider/model/stt.py create mode 100644 apps/setting/models_provider/impl/xf_model_provider/credential/stt.py create mode 100644 apps/setting/models_provider/impl/xf_model_provider/model/stt.py diff --git a/apps/setting/models_provider/base_model_provider.py b/apps/setting/models_provider/base_model_provider.py index 6b33139838f..4de2f006742 100644 --- a/apps/setting/models_provider/base_model_provider.py +++ b/apps/setting/models_provider/base_model_provider.py @@ -139,6 +139,8 @@ def encryption(message: str): class ModelTypeConst(Enum): LLM = {'code': 'LLM', 'message': '大语言模型'} EMBEDDING = {'code': 'EMBEDDING', 'message': '向量模型'} + STT = {'code': 'STT', 'message': '语音识别'} + TTS = {'code': 'TTS', 'message': '语音合成'} class ModelInfo: diff --git a/apps/setting/models_provider/impl/base_stt.py b/apps/setting/models_provider/impl/base_stt.py new file mode 100644 index 00000000000..af6221a5b96 --- /dev/null +++ b/apps/setting/models_provider/impl/base_stt.py @@ -0,0 +1,8 @@ +# coding=utf-8 +from abc import abstractmethod + + +class BaseSpeechToText: + @abstractmethod + def speech_to_text(self, audio_file): + pass diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/stt.py b/apps/setting/models_provider/impl/openai_model_provider/model/stt.py new file mode 100644 index 00000000000..629357731d6 --- /dev/null +++ b/apps/setting/models_provider/impl/openai_model_provider/model/stt.py @@ -0,0 +1,38 @@ +from typing import Dict + +from langchain_openai.chat_models import ChatOpenAI + +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_stt import BaseSpeechToText + + +def custom_get_token_ids(text: str): + tokenizer = TokenizerManage.get_tokenizer() + return tokenizer.encode(text) + + +class OpenAISpeechToText(MaxKBBaseModel, ChatOpenAI, BaseSpeechToText): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {} + if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: + optional_params['max_tokens'] = model_kwargs['max_tokens'] + if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: + optional_params['temperature'] = model_kwargs['temperature'] + return OpenAISpeechToText( + model=model_name, + openai_api_base=model_credential.get('api_base'), + openai_api_key=model_credential.get('api_key'), + **optional_params, + streaming=True, + stream_usage=True, + custom_get_token_ids=custom_get_token_ids + ) + + def speech_to_text(self, audio_file): + return self.client.audio.transcriptions.create( + model=self.model, + language="zh", + file=audio_file + ) diff --git a/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py b/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py index fb4c89d7b84..f8038ee6a8d 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py +++ b/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py @@ -15,6 +15,7 @@ from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel +from setting.models_provider.impl.openai_model_provider.model.stt import OpenAISpeechToText from smartdoc.conf import PROJECT_DIR openai_llm_model_credential = OpenAILLMModelCredential() @@ -58,7 +59,10 @@ OpenAIChatModel), ModelInfo('gpt-4-1106-preview', '2023年11月6日的gpt-4-turbo快照,支持上下文长度128,000 tokens', ModelTypeConst.LLM, openai_llm_model_credential, - OpenAIChatModel) + OpenAIChatModel), + ModelInfo('whisper-1', '', + ModelTypeConst.STT, openai_llm_model_credential, + OpenAISpeechToText) ] open_ai_embedding_credential = OpenAIEmbeddingCredential() model_info_embedding_list = [ diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/stt.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/stt.py new file mode 100644 index 00000000000..fac26f31050 --- /dev/null +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/stt.py @@ -0,0 +1,45 @@ +# coding=utf-8 + +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class VolcanicEngineSTTModelCredential(BaseForm, BaseModelCredential): + volcanic_api_url = forms.TextInputField('API 域名', required=True) + volcanic_app_id = forms.TextInputField('App ID', required=True) + volcanic_token = forms.PasswordInputField('Token', required=True) + volcanic_cluster = forms.TextInputField('Cluster', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['volcanic_api_url', 'volcanic_app_id', 'volcanic_token', 'volcanic_cluster']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + print(model) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'volcanic_token': super().encryption(model.get('volcanic_token', ''))} + + def get_model_params_setting_form(self, model_name): + pass diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/stt.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/stt.py new file mode 100644 index 00000000000..601f79daa95 --- /dev/null +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/stt.py @@ -0,0 +1,326 @@ +# coding=utf-8 + +""" +requires Python 3.6 or later + +pip install asyncio +pip install websockets +""" +import asyncio +import base64 +import gzip +import hmac +import json +import uuid +import wave +from enum import Enum +from hashlib import sha256 +from io import BytesIO +from typing import Dict +from urllib.parse import urlparse + +import websockets + +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI +from setting.models_provider.impl.base_stt import BaseSpeechToText + +audio_format = "mp3" # wav 或者 mp3,根据实际音频格式设置 + +PROTOCOL_VERSION = 0b0001 +DEFAULT_HEADER_SIZE = 0b0001 + +PROTOCOL_VERSION_BITS = 4 +HEADER_BITS = 4 +MESSAGE_TYPE_BITS = 4 +MESSAGE_TYPE_SPECIFIC_FLAGS_BITS = 4 +MESSAGE_SERIALIZATION_BITS = 4 +MESSAGE_COMPRESSION_BITS = 4 +RESERVED_BITS = 8 + +# Message Type: +CLIENT_FULL_REQUEST = 0b0001 +CLIENT_AUDIO_ONLY_REQUEST = 0b0010 +SERVER_FULL_RESPONSE = 0b1001 +SERVER_ACK = 0b1011 +SERVER_ERROR_RESPONSE = 0b1111 + +# Message Type Specific Flags +NO_SEQUENCE = 0b0000 # no check sequence +POS_SEQUENCE = 0b0001 +NEG_SEQUENCE = 0b0010 +NEG_SEQUENCE_1 = 0b0011 + +# Message Serialization +NO_SERIALIZATION = 0b0000 +JSON = 0b0001 +THRIFT = 0b0011 +CUSTOM_TYPE = 0b1111 + +# Message Compression +NO_COMPRESSION = 0b0000 +GZIP = 0b0001 +CUSTOM_COMPRESSION = 0b1111 + + +def generate_header( + version=PROTOCOL_VERSION, + message_type=CLIENT_FULL_REQUEST, + message_type_specific_flags=NO_SEQUENCE, + serial_method=JSON, + compression_type=GZIP, + reserved_data=0x00, + extension_header=bytes() +): + """ + protocol_version(4 bits), header_size(4 bits), + message_type(4 bits), message_type_specific_flags(4 bits) + serialization_method(4 bits) message_compression(4 bits) + reserved (8bits) 保留字段 + header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) ) + """ + header = bytearray() + header_size = int(len(extension_header) / 4) + 1 + header.append((version << 4) | header_size) + header.append((message_type << 4) | message_type_specific_flags) + header.append((serial_method << 4) | compression_type) + header.append(reserved_data) + header.extend(extension_header) + return header + + +def generate_full_default_header(): + return generate_header() + + +def generate_audio_default_header(): + return generate_header( + message_type=CLIENT_AUDIO_ONLY_REQUEST + ) + + +def generate_last_audio_default_header(): + return generate_header( + message_type=CLIENT_AUDIO_ONLY_REQUEST, + message_type_specific_flags=NEG_SEQUENCE + ) + + +def parse_response(res): + """ + protocol_version(4 bits), header_size(4 bits), + message_type(4 bits), message_type_specific_flags(4 bits) + serialization_method(4 bits) message_compression(4 bits) + reserved (8bits) 保留字段 + header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) ) + payload 类似与http 请求体 + """ + protocol_version = res[0] >> 4 + header_size = res[0] & 0x0f + message_type = res[1] >> 4 + message_type_specific_flags = res[1] & 0x0f + serialization_method = res[2] >> 4 + message_compression = res[2] & 0x0f + reserved = res[3] + header_extensions = res[4:header_size * 4] + payload = res[header_size * 4:] + result = {} + payload_msg = None + payload_size = 0 + if message_type == SERVER_FULL_RESPONSE: + payload_size = int.from_bytes(payload[:4], "big", signed=True) + payload_msg = payload[4:] + elif message_type == SERVER_ACK: + seq = int.from_bytes(payload[:4], "big", signed=True) + result['seq'] = seq + if len(payload) >= 8: + payload_size = int.from_bytes(payload[4:8], "big", signed=False) + payload_msg = payload[8:] + elif message_type == SERVER_ERROR_RESPONSE: + code = int.from_bytes(payload[:4], "big", signed=False) + result['code'] = code + payload_size = int.from_bytes(payload[4:8], "big", signed=False) + payload_msg = payload[8:] + if payload_msg is None: + return result + if message_compression == GZIP: + payload_msg = gzip.decompress(payload_msg) + if serialization_method == JSON: + payload_msg = json.loads(str(payload_msg, "utf-8")) + elif serialization_method != NO_SERIALIZATION: + payload_msg = str(payload_msg, "utf-8") + result['payload_msg'] = payload_msg + result['payload_size'] = payload_size + return result + + +def read_wav_info(data: bytes = None) -> (int, int, int, int, int): + with BytesIO(data) as _f: + wave_fp = wave.open(_f, 'rb') + nchannels, sampwidth, framerate, nframes = wave_fp.getparams()[:4] + wave_bytes = wave_fp.readframes(nframes) + return nchannels, sampwidth, framerate, nframes, len(wave_bytes) + + +class AudioType(Enum): + LOCAL = 1 # 使用本地音频文件 + + +class VolcanicEngineSpeechToText(MaxKBBaseModel, BaseChatOpenAI, BaseSpeechToText): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {} + if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: + optional_params['max_tokens'] = model_kwargs['max_tokens'] + if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: + optional_params['temperature'] = model_kwargs['temperature'] + return VolcanicEngineSpeechToText( + openai_api_base=model_credential.get('volcanic_api_url'), + openai_api_key=model_credential.get('volcanic_token'), + volcanic_app_id=model_credential.get('volcanic_app_id'), + volcanic_token=model_credential.get('volcanic_token'), + volcanic_api_url=model_credential.get('volcanic_api_url'), + volcanic_cluster=model_credential.get('volcanic_cluster'), + volcanic_llm_domain=model_name, + streaming=model_kwargs.get('streaming', False), + workflow=model_kwargs.get("workflow", "audio_in,resample,partition,vad,fe,decode,itn,nlu_punctuate"), + show_language=model_kwargs.get("show_language", False), + show_utterances=model_kwargs.get("show_utterances", False), + result_type=model_kwargs.get("result_type", "full"), + format=model_kwargs.get("format", "wav"), + rate=model_kwargs.get("sample_rate", 16000), + language=model_kwargs.get("language", "zh-CN"), + bits=model_kwargs.get("bits", 16), + channel=model_kwargs.get("channel", 1), + codec=model_kwargs.get("codec", "raw"), + audio_type=model_kwargs.get("audio_type", AudioType.LOCAL), + secret=model_kwargs.get("secret", "access_secret"), + auth_method=model_kwargs.get("auth_method", "token"), + mp3_seg_size=int(model_kwargs.get("mp3_seg_size", 10000)), + success_code=1000, # success code, default is 1000 + seg_duration=int(model_kwargs.get("seg_duration", 15000)), + nbest=int(model_kwargs.get("nbest", 1)), + **optional_params + ) + + def construct_request(self, reqid): + req = { + 'app': { + 'appid': self.appid, + 'cluster': self.cluster, + 'token': self.token, + }, + 'user': { + 'uid': self.uid + }, + 'request': { + 'reqid': reqid, + 'nbest': self.nbest, + 'workflow': self.workflow, + 'show_language': self.show_language, + 'show_utterances': self.show_utterances, + 'result_type': self.result_type, + "sequence": 1 + }, + 'audio': { + 'format': self.format, + 'rate': self.rate, + 'language': self.language, + 'bits': self.bits, + 'channel': self.channel, + 'codec': self.codec + } + } + return req + + @staticmethod + def slice_data(data: bytes, chunk_size: int) -> (list, bool): + """ + slice data + :param data: wav data + :param chunk_size: the segment size in one request + :return: segment data, last flag + """ + data_len = len(data) + offset = 0 + while offset + chunk_size < data_len: + yield data[offset: offset + chunk_size], False + offset += chunk_size + else: + yield data[offset: data_len], True + + def _real_processor(self, request_params: dict) -> dict: + pass + + def token_auth(self): + return {'Authorization': 'Bearer; {}'.format(self.token)} + + def signature_auth(self, data): + header_dicts = { + 'Custom': 'auth_custom', + } + + url_parse = urlparse(self.ws_url) + input_str = 'GET {} HTTP/1.1\n'.format(url_parse.path) + auth_headers = 'Custom' + for header in auth_headers.split(','): + input_str += '{}\n'.format(header_dicts[header]) + input_data = bytearray(input_str, 'utf-8') + input_data += data + mac = base64.urlsafe_b64encode( + hmac.new(self.secret.encode('utf-8'), input_data, digestmod=sha256).digest()) + header_dicts['Authorization'] = 'HMAC256; access_token="{}"; mac="{}"; h="{}"'.format(self.token, + str(mac, 'utf-8'), + auth_headers) + return header_dicts + + async def segment_data_processor(self, wav_data: bytes, segment_size: int): + reqid = str(uuid.uuid4()) + # 构建 full client request,并序列化压缩 + request_params = self.construct_request(reqid) + payload_bytes = str.encode(json.dumps(request_params)) + payload_bytes = gzip.compress(payload_bytes) + full_client_request = bytearray(generate_full_default_header()) + full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes) + full_client_request.extend(payload_bytes) # payload + header = None + if self.auth_method == "token": + header = self.token_auth() + elif self.auth_method == "signature": + header = self.signature_auth(full_client_request) + async with websockets.connect(self.ws_url, extra_headers=header, max_size=1000000000) as ws: + # 发送 full client request + await ws.send(full_client_request) + res = await ws.recv() + result = parse_response(res) + if 'payload_msg' in result and result['payload_msg']['code'] != self.success_code: + return result + for seq, (chunk, last) in enumerate(VolcanicEngineSpeechToText.slice_data(wav_data, segment_size), 1): + # if no compression, comment this line + payload_bytes = gzip.compress(chunk) + audio_only_request = bytearray(generate_audio_default_header()) + if last: + audio_only_request = bytearray(generate_last_audio_default_header()) + audio_only_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes) + audio_only_request.extend(payload_bytes) # payload + # 发送 audio-only client request + await ws.send(audio_only_request) + res = await ws.recv() + result = parse_response(res) + if 'payload_msg' in result and result['payload_msg']['code'] != self.success_code: + return result + return result + + def speech_to_text(self, file): + data = file.read() + audio_data = bytes(data) + if self.format == "mp3": + segment_size = self.mp3_seg_size + return asyncio.run(self.segment_data_processor(audio_data, segment_size)) + if self.format != "wav": + raise Exception("format should in wav or mp3") + nchannels, sampwidth, framerate, nframes, wav_len = read_wav_info( + audio_data) + size_per_sec = nchannels * sampwidth * framerate + segment_size = int(size_per_sec * self.seg_duration / 1000) + return asyncio.run(self.segment_data_processor(audio_data, segment_size)) diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py index 48802f6b83e..8b4be196322 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py @@ -15,17 +15,24 @@ from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel from setting.models_provider.impl.volcanic_engine_model_provider.model.llm import VolcanicEngineChatModel +from setting.models_provider.impl.volcanic_engine_model_provider.credential.stt import VolcanicEngineSTTModelCredential +from setting.models_provider.impl.volcanic_engine_model_provider.model.stt import VolcanicEngineSpeechToText from smartdoc.conf import PROJECT_DIR volcanic_engine_llm_model_credential = OpenAILLMModelCredential() - +volcanic_engine_stt_model_credential = VolcanicEngineSTTModelCredential() model_info_list = [ ModelInfo('ep-xxxxxxxxxx-yyyy', '用户前往火山方舟的模型推理页面创建推理接入点,这里需要输入ep-xxxxxxxxxx-yyyy进行调用', ModelTypeConst.LLM, volcanic_engine_llm_model_credential, VolcanicEngineChatModel - ) + ), + ModelInfo('asr', + '', + ModelTypeConst.STT, + volcanic_engine_stt_model_credential, VolcanicEngineSpeechToText + ), ] open_ai_embedding_credential = OpenAIEmbeddingCredential() diff --git a/apps/setting/models_provider/impl/xf_model_provider/credential/stt.py b/apps/setting/models_provider/impl/xf_model_provider/credential/stt.py new file mode 100644 index 00000000000..7e7d2003490 --- /dev/null +++ b/apps/setting/models_provider/impl/xf_model_provider/credential/stt.py @@ -0,0 +1,46 @@ +# coding=utf-8 + +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class XunFeiSTTModelCredential(BaseForm, BaseModelCredential): + spark_api_url = forms.TextInputField('API 域名', required=True) + spark_app_id = forms.TextInputField('APP ID', required=True) + spark_api_key = forms.PasswordInputField("API Key", required=True) + spark_api_secret = forms.PasswordInputField('API Secret', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + print(model.create_url()) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))} + + + def get_model_params_setting_form(self, model_name): + pass diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/stt.py b/apps/setting/models_provider/impl/xf_model_provider/model/stt.py new file mode 100644 index 00000000000..72c4fcd86fb --- /dev/null +++ b/apps/setting/models_provider/impl/xf_model_provider/model/stt.py @@ -0,0 +1,145 @@ +# -*- coding:utf-8 -*- +# +# 错误码链接:https://www.xfyun.cn/document/error-code (code返回错误码时必看) +# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # +import asyncio +import base64 +import datetime +import hashlib +import hmac +import json +from datetime import datetime +from typing import Dict +from urllib.parse import urlencode, urlparse + +import websockets +from langchain_community.chat_models.sparkllm import ChatSparkLLM + +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_stt import BaseSpeechToText + +STATUS_FIRST_FRAME = 0 # 第一帧的标识 +STATUS_CONTINUE_FRAME = 1 # 中间帧标识 +STATUS_LAST_FRAME = 2 # 最后一帧的标识 + + +class XFChatSparkSpeechToText(MaxKBBaseModel, ChatSparkLLM, BaseSpeechToText): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {} + if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: + optional_params['max_tokens'] = model_kwargs['max_tokens'] + if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: + optional_params['temperature'] = model_kwargs['temperature'] + return XFChatSparkSpeechToText( + spark_app_id=model_credential.get('spark_app_id'), + spark_api_key=model_credential.get('spark_api_key'), + spark_api_secret=model_credential.get('spark_api_secret'), + spark_api_url=model_credential.get('spark_api_url'), + spark_llm_domain=model_name, + streaming=model_kwargs.get('streaming', False), + **optional_params + ) + + # 生成url + def create_url(self): + url = self.spark_api_url + host = urlparse(url).hostname + # 生成RFC1123格式的时间戳 + gmt_format = '%a, %d %b %Y %H:%M:%S GMT' + date = datetime.utcnow().strftime(gmt_format) + + # 拼接字符串 + signature_origin = "host: " + host + "\n" + signature_origin += "date: " + date + "\n" + signature_origin += "GET " + "/v2/iat " + "HTTP/1.1" + # 进行hmac-sha256进行加密 + signature_sha = hmac.new(self.spark_api_secret.encode('utf-8'), signature_origin.encode('utf-8'), + digestmod=hashlib.sha256).digest() + signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8') + + authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % ( + self.spark_api_key, "hmac-sha256", "host date request-line", signature_sha) + authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') + # 将请求的鉴权参数组合为字典 + v = { + "authorization": authorization, + "date": date, + "host": host + } + # 拼接鉴权参数,生成url + url = url + '?' + urlencode(v) + # print("date: ",date) + # print("v: ",v) + # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致 + # print('websocket url :', url) + return url + + def speech_to_text(self, file): + with websockets.connect(self.create_url(), max_size=1000000000) as ws: + # 发送 full client request + asyncio.run(self.send(ws, file)) + asyncio.run(self.handle_message(ws)) + + @staticmethod + async def handle_message(ws): + res = await ws.recv() + message = json.loads(res) + code = message["code"] + sid = message["sid"] + if code != 0: + errMsg = message["message"] + print("sid:%s call error:%s code is:%s" % (sid, errMsg, code)) + else: + data = message["data"]["result"]["ws"] + result = "" + for i in data: + for w in i["cw"]: + result += w["w"] + print("sid:%s call success!,data is:%s" % (sid, json.dumps(data, ensure_ascii=False))) + + # 收到websocket连接建立的处理 + async def send(self, ws, file): + frameSize = 8000 # 每一帧的音频大小 + status = STATUS_FIRST_FRAME # 音频的状态信息,标识音频是第一帧,还是中间帧、最后一帧 + + while True: + buf = file.read(frameSize) + # 文件结束 + if not buf: + status = STATUS_LAST_FRAME + # 第一帧处理 + # 发送第一帧音频,带business 参数 + # appid 必须带上,只需第一帧发送 + if status == STATUS_FIRST_FRAME: + d = { + "common": {"app_id": self.spark_app_id}, + "business": { + "domain": "iat", + "language": "zh_cn", + "accent": "mandarin", + "vinfo": 1, + "vad_eos": 10000 + }, + "data": { + "status": 0, "format": "audio/L16;rate=16000", + "audio": str(base64.b64encode(buf), 'utf-8'), + "encoding": "lame"} + } + d = json.dumps(d) + await ws.send(d) + status = STATUS_CONTINUE_FRAME + # 中间帧处理 + elif status == STATUS_CONTINUE_FRAME: + d = {"data": {"status": 1, "format": "audio/L16;rate=16000", + "audio": str(base64.b64encode(buf), 'utf-8'), + "encoding": "lame"}} + await ws.send(json.dumps(d)) + # 最后一帧处理 + elif status == STATUS_LAST_FRAME: + d = {"data": {"status": 2, "format": "audio/L16;rate=16000", + "audio": str(base64.b64encode(buf), 'utf-8'), + "encoding": "lame"}} + await ws.send(json.dumps(d)) + break + diff --git a/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py b/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py index d33b944e375..26b158286aa 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py +++ b/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py @@ -13,15 +13,20 @@ from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \ ModelInfoManage from setting.models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential +from setting.models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential from setting.models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM +from setting.models_provider.impl.xf_model_provider.model.stt import XFChatSparkSpeechToText from smartdoc.conf import PROJECT_DIR ssl._create_default_https_context = ssl.create_default_context() qwen_model_credential = XunFeiLLMModelCredential() +stt_model_credential = XunFeiSTTModelCredential() model_info_list = [ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM), ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM), - ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM) + ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM), + ModelInfo('iat', '中英文识别', ModelTypeConst.STT, stt_model_credential, XFChatSparkSpeechToText), + ModelInfo('iat-niche', '小语种识别', ModelTypeConst.STT, stt_model_credential, XFChatSparkSpeechToText), ] model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( diff --git a/pyproject.toml b/pyproject.toml index d37963caaf0..37d1f61dfdd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dashscope = "^1.17.0" zhipuai = "^2.0.1" httpx = "^0.27.0" httpx-sse = "^0.4.0" -websocket-client = "^1.7.0" +websockets = "^13.0" langchain-google-genai = "^1.0.3" openpyxl = "^3.1.2" xlrd = "^2.0.1"