Skip to content

Commit

Permalink
refactor: add voice mapping for cosyvoice and update sample rate
Browse files Browse the repository at this point in the history
  • Loading branch information
aiwantaozi committed Nov 27, 2024
1 parent ac748d3 commit 50149fb
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions vox_box/backends/tts/cosyvoice.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ def __init__(
cfg: Config,
):
self.model_load = False
self.language_map = {
"中文女": "Chinese Female",
"中文男": "Chinese Male",
"日语男": "Japanese Male",
"粤语女": "Cantonese Female",
"英文女": "English Female",
"英文男": "English Male",
"韩语女": "Korean Female",
}
self.reverse_language_map = {v: k for k, v in self.language_map.items()}
self._cfg = cfg
self._voices = None
self._model = None
Expand Down Expand Up @@ -70,13 +80,16 @@ def speech(
if voice not in self._voices:
raise ValueError(f"Voice {voice} not supported")

model_output = self._model.inference_sft(input, voice, False, speed)
original_voice = self._get_original_voice(voice)
model_output = self._model.inference_sft(
input, original_voice, stream=False, speed=speed
)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_file:
wav_file_path = temp_file.name
with wave.open(wav_file_path, "wb") as wf:
wf.setnchannels(1) # single track
wf.setsampwidth(2) # 16-bit
wf.setframerate(16000) # Sample rate
wf.setframerate(22050) # Sample rate
for i in model_output:
tts_audio = (
(i["tts_speech"].numpy() * (2**15)).astype(np.int16).tobytes()
Expand All @@ -86,10 +99,9 @@ def speech(
output_file_path = convert(wav_file_path, reponse_format, speed)
return output_file_path

def _get_required_resource(self) -> Dict:
# TODO: not accurate
Gib = 1024 * 1024 * 1024
return {"cuda": {"vram": 16 * Gib}, "cpu": {"ram": 16 * Gib}}

def _get_voices(self) -> List[str]:
return self._model.list_avaliable_spks()
voices = self._model.list_avaliable_spks()
return [self.language_map.get(voice, voice) for voice in voices]

def _get_original_voice(self, voice: str) -> str:
return self.reverse_language_map.get(voice, voice)

0 comments on commit 50149fb

Please sign in to comment.