Skip to content

Commit

Permalink
updated
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZr committed Nov 5, 2024
1 parent 06c2993 commit b9e3d24
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 71 deletions.
4 changes: 2 additions & 2 deletions egs/ljspeech/TTS/matcha/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,10 +266,10 @@ def main():
raise ValueError(f"{params.vocoder} does not exist")

vocoder = load_vocoder(params.vocoder)
vocoder = vocoder.to(device)
vocoder.to(device)

denoiser = Denoiser(vocoder, mode="zeros")
denoiser = denoiser.to(device)
denoiser.to(device)

infer_dataset(
dl=test_dl,
Expand Down
111 changes: 42 additions & 69 deletions egs/ljspeech/TTS/matcha/synth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,18 @@
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)

import argparse
import datetime as dt
import json
import logging
from pathlib import Path

import soundfile as sf
import torch
from matcha.hifigan.config import v1, v2, v3
from infer import load_vocoder, synthesise, to_waveform
from matcha.hifigan.denoiser import Denoiser
from matcha.hifigan.models import Generator as HiFiGAN
from tokenizer import Tokenizer
from train import get_model, get_params

from icefall.checkpoint import load_checkpoint
from icefall.utils import AttributeDict, setup_logger


def get_parser():
Expand All @@ -36,7 +33,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=Path,
default="matcha/exp-new-3",
default="matcha/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
Expand Down Expand Up @@ -77,58 +74,14 @@ def get_parser():
help="The filename of the wave to save the generated speech",
)

return parser


def load_vocoder(checkpoint_path):
checkpoint_path = str(checkpoint_path)
if checkpoint_path.endswith("v1"):
h = AttributeDict(v1)
elif checkpoint_path.endswith("v2"):
h = AttributeDict(v2)
elif checkpoint_path.endswith("v3"):
h = AttributeDict(v3)
else:
raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}")

hifigan = HiFiGAN(h).to("cpu")
hifigan.load_state_dict(
torch.load(checkpoint_path, map_location="cpu")["generator"]
)
_ = hifigan.eval()
hifigan.remove_weight_norm()
return hifigan


def to_waveform(mel, vocoder, denoiser):
audio = vocoder(mel).clamp(-1, 1)
audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
return audio.cpu().squeeze()


def process_text(text: str, tokenizer):
x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
x = torch.tensor(x, dtype=torch.long)
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu")
return {"x_orig": text, "x": x, "x_lengths": x_lengths}


def synthesise(
model, tokenizer, n_timesteps, text, length_scale, temperature, spks=None
):
text_processed = process_text(text, tokenizer)
start_t = dt.datetime.now()
output = model.synthesise(
text_processed["x"],
text_processed["x_lengths"],
n_timesteps=n_timesteps,
temperature=temperature,
spks=spks,
length_scale=length_scale,
parser.add_argument(
"--sampling-rate",
type=int,
default=22050,
help="The sampling rate of the generated speech (default: 22050 for LJSpeech)",
)
# merge everything to one dict
output.update({"start_t": start_t, **text_processed})
return output

return parser


@torch.inference_mode()
Expand All @@ -139,6 +92,12 @@ def main():

params.update(vars(args))

logging.info("Infer started")

device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)

tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
Expand All @@ -151,43 +110,57 @@ def main():

params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
params.model_args.data_statistics.mel_std = stats["fbank_std"]

# Number of ODE Solver steps
params.n_timesteps = 2

# Changes to the speaking rate
params.length_scale = 1.0

# Sampling temperature
params.temperature = 0.667
logging.info(params)

logging.info("About to create model")
model = get_model(params)

load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device)
model.eval()

if not Path(params.vocoder).is_file():
raise ValueError(f"{params.vocoder} does not exist")

vocoder = load_vocoder(params.vocoder)
denoiser = Denoiser(vocoder, mode="zeros")
vocoder.to(device)

# Number of ODE Solver steps
n_timesteps = 2

# Changes to the speaking rate
length_scale = 1.0

# Sampling temperature
temperature = 0.667
denoiser = Denoiser(vocoder, mode="zeros")
denoiser.to(device)

output = synthesise(
model=model,
tokenizer=tokenizer,
n_timesteps=n_timesteps,
n_timesteps=params.n_timesteps,
text=params.input_text,
length_scale=length_scale,
temperature=temperature,
length_scale=params.length_scale,
temperature=params.temperature,
device=device,
)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)

sf.write(params.output_wav, output["waveform"], 22050, "PCM_16")
sf.write(
file=params.output_wav,
data=output["waveform"],
samplerate=params.sampling_rate,
subtype="PCM_16",
)


if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"

logging.basicConfig(format=formatter, level=logging.INFO)

torch.set_num_threads(1)
torch.set_num_interop_threads(1)
main()

0 comments on commit b9e3d24

Please sign in to comment.