From 50149fbba780b9580dddedead669c1e3a6d2211f Mon Sep 17 00:00:00 2001 From: michelia Date: Wed, 27 Nov 2024 14:33:45 +0800 Subject: [PATCH] refactor: add voice mapping for cosyvoice and update sample rate --- vox_box/backends/tts/cosyvoice.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/vox_box/backends/tts/cosyvoice.py b/vox_box/backends/tts/cosyvoice.py index 86146b7..ae1173d 100644 --- a/vox_box/backends/tts/cosyvoice.py +++ b/vox_box/backends/tts/cosyvoice.py @@ -25,6 +25,16 @@ def __init__( cfg: Config, ): self.model_load = False + self.language_map = { + "中文女": "Chinese Female", + "中文男": "Chinese Male", + "日语男": "Japanese Male", + "粤语女": "Cantonese Female", + "英文女": "English Female", + "英文男": "English Male", + "韩语女": "Korean Female", + } + self.reverse_language_map = {v: k for k, v in self.language_map.items()} self._cfg = cfg self._voices = None self._model = None @@ -70,13 +80,16 @@ def speech( if voice not in self._voices: raise ValueError(f"Voice {voice} not supported") - model_output = self._model.inference_sft(input, voice, False, speed) + original_voice = self._get_original_voice(voice) + model_output = self._model.inference_sft( + input, original_voice, stream=False, speed=speed + ) with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_file: wav_file_path = temp_file.name with wave.open(wav_file_path, "wb") as wf: wf.setnchannels(1) # single track wf.setsampwidth(2) # 16-bit - wf.setframerate(16000) # Sample rate + wf.setframerate(22050) # Sample rate for i in model_output: tts_audio = ( (i["tts_speech"].numpy() * (2**15)).astype(np.int16).tobytes() @@ -86,10 +99,9 @@ def speech( output_file_path = convert(wav_file_path, reponse_format, speed) return output_file_path - def _get_required_resource(self) -> Dict: - # TODO: not accurate - Gib = 1024 * 1024 * 1024 - return {"cuda": {"vram": 16 * Gib}, "cpu": {"ram": 16 * Gib}} - def _get_voices(self) -> List[str]: - return self._model.list_avaliable_spks() + voices = self._model.list_avaliable_spks() + return [self.language_map.get(voice, voice) for voice in voices] + + def _get_original_voice(self, voice: str) -> str: + return self.reverse_language_map.get(voice, voice)