diff --git a/egs/ljspeech/TTS/matcha/infer.py b/egs/ljspeech/TTS/matcha/infer.py index bcf9a8b4a0..bb68a42f15 100755 --- a/egs/ljspeech/TTS/matcha/infer.py +++ b/egs/ljspeech/TTS/matcha/infer.py @@ -266,10 +266,10 @@ def main(): raise ValueError(f"{params.vocoder} does not exist") vocoder = load_vocoder(params.vocoder) - vocoder = vocoder.to(device) + vocoder.to(device) denoiser = Denoiser(vocoder, mode="zeros") - denoiser = denoiser.to(device) + denoiser.to(device) infer_dataset( dl=test_dl, diff --git a/egs/ljspeech/TTS/matcha/synth.py b/egs/ljspeech/TTS/matcha/synth.py index a4880fd3a2..f411ce4fae 100755 --- a/egs/ljspeech/TTS/matcha/synth.py +++ b/egs/ljspeech/TTS/matcha/synth.py @@ -2,21 +2,18 @@ # Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) import argparse -import datetime as dt import json import logging from pathlib import Path import soundfile as sf import torch -from matcha.hifigan.config import v1, v2, v3 +from infer import load_vocoder, synthesise, to_waveform from matcha.hifigan.denoiser import Denoiser -from matcha.hifigan.models import Generator as HiFiGAN from tokenizer import Tokenizer from train import get_model, get_params from icefall.checkpoint import load_checkpoint -from icefall.utils import AttributeDict, setup_logger def get_parser(): @@ -36,7 +33,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=Path, - default="matcha/exp-new-3", + default="matcha/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -77,58 +74,14 @@ def get_parser(): help="The filename of the wave to save the generated speech", ) - return parser - - -def load_vocoder(checkpoint_path): - checkpoint_path = str(checkpoint_path) - if checkpoint_path.endswith("v1"): - h = AttributeDict(v1) - elif checkpoint_path.endswith("v2"): - h = AttributeDict(v2) - elif checkpoint_path.endswith("v3"): - h = AttributeDict(v3) - else: - raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}") - - hifigan = HiFiGAN(h).to("cpu") - hifigan.load_state_dict( - torch.load(checkpoint_path, map_location="cpu")["generator"] - ) - _ = hifigan.eval() - hifigan.remove_weight_norm() - return hifigan - - -def to_waveform(mel, vocoder, denoiser): - audio = vocoder(mel).clamp(-1, 1) - audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze() - return audio.cpu().squeeze() - - -def process_text(text: str, tokenizer): - x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True) - x = torch.tensor(x, dtype=torch.long) - x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu") - return {"x_orig": text, "x": x, "x_lengths": x_lengths} - - -def synthesise( - model, tokenizer, n_timesteps, text, length_scale, temperature, spks=None -): - text_processed = process_text(text, tokenizer) - start_t = dt.datetime.now() - output = model.synthesise( - text_processed["x"], - text_processed["x_lengths"], - n_timesteps=n_timesteps, - temperature=temperature, - spks=spks, - length_scale=length_scale, + parser.add_argument( + "--sampling-rate", + type=int, + default=22050, + help="The sampling rate of the generated speech (default: 22050 for LJSpeech)", ) - # merge everything to one dict - output.update({"start_t": start_t, **text_processed}) - return output + + return parser @torch.inference_mode() @@ -139,6 +92,12 @@ def main(): params.update(vars(args)) + logging.info("Infer started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + tokenizer = Tokenizer(params.tokens) params.blank_id = tokenizer.pad_id params.vocab_size = tokenizer.vocab_size @@ -151,43 +110,57 @@ def main(): params.model_args.data_statistics.mel_mean = stats["fbank_mean"] params.model_args.data_statistics.mel_std = stats["fbank_std"] + + # Number of ODE Solver steps + params.n_timesteps = 2 + + # Changes to the speaking rate + params.length_scale = 1.0 + + # Sampling temperature + params.temperature = 0.667 logging.info(params) logging.info("About to create model") model = get_model(params) load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + model.to(device) model.eval() if not Path(params.vocoder).is_file(): raise ValueError(f"{params.vocoder} does not exist") vocoder = load_vocoder(params.vocoder) - denoiser = Denoiser(vocoder, mode="zeros") + vocoder.to(device) - # Number of ODE Solver steps - n_timesteps = 2 - - # Changes to the speaking rate - length_scale = 1.0 - - # Sampling temperature - temperature = 0.667 + denoiser = Denoiser(vocoder, mode="zeros") + denoiser.to(device) output = synthesise( model=model, tokenizer=tokenizer, - n_timesteps=n_timesteps, + n_timesteps=params.n_timesteps, text=params.input_text, - length_scale=length_scale, - temperature=temperature, + length_scale=params.length_scale, + temperature=params.temperature, + device=device, ) output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) - sf.write(params.output_wav, output["waveform"], 22050, "PCM_16") + sf.write( + file=params.output_wav, + data=output["waveform"], + samplerate=params.sampling_rate, + subtype="PCM_16", + ) if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + torch.set_num_threads(1) torch.set_num_interop_threads(1) main()