From 59b5dd14076a87a795d39abe1e325ced58295b7b Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Thu, 16 Jan 2025 08:31:55 +0100 Subject: [PATCH 1/2] fix(fairseq): handle change of model file name --- TTS/tts/models/vits.py | 20 ++++++++++++-------- tests/zoo_tests/test_models.py | 1 + 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 3d66b50598..135b8e5016 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -3,7 +3,8 @@ import os from dataclasses import dataclass, field, replace from itertools import chain -from typing import Dict, List, Tuple, Union +from pathlib import Path +from typing import Any, Dict, List, Tuple, Union import numpy as np import torch @@ -1581,13 +1582,16 @@ def load_fairseq_checkpoint( self.disc = None # set paths - config_file = os.path.join(checkpoint_dir, "config.json") - checkpoint_file = os.path.join(checkpoint_dir, "G_100000.pth") - vocab_file = os.path.join(checkpoint_dir, "vocab.txt") + checkpoint_dir = Path(checkpoint_dir) + config_file = checkpoint_dir / "config.json" + checkpoint_file = checkpoint_dir / "model.pth" + if not checkpoint_file.is_file(): + checkpoint_file = checkpoint_dir / "G_100000.pth" + vocab_file = checkpoint_dir / "vocab.txt" # set config params - with open(config_file, "r", encoding="utf-8") as file: + with open(config_file, "r", encoding="utf-8") as f: # Load the JSON data as a dictionary - config_org = json.load(file) + config_org = json.load(f) self.config.audio.sample_rate = config_org["data"]["sampling_rate"] # self.config.add_blank = config['add_blank'] # set tokenizer @@ -1821,7 +1825,7 @@ def to_config(self) -> "CharactersConfig": class FairseqVocab(BaseVocabulary): - def __init__(self, vocab: str): + def __init__(self, vocab: Union[str, os.PathLike[Any]]): super(FairseqVocab).__init__() self.vocab = vocab @@ -1831,7 +1835,7 @@ def vocab(self): return self._vocab @vocab.setter - def vocab(self, vocab_file): + def vocab(self, vocab_file: Union[str, os.PathLike[Any]]): with open(vocab_file, encoding="utf-8") as f: self._vocab = [x.replace("\n", "") for x in f.readlines()] self.blank = self._vocab[0] diff --git a/tests/zoo_tests/test_models.py b/tests/zoo_tests/test_models.py index b7c88e0730..9f02672ef1 100644 --- a/tests/zoo_tests/test_models.py +++ b/tests/zoo_tests/test_models.py @@ -37,6 +37,7 @@ def manager(tmp_path): num_partitions = int(os.getenv("NUM_PARTITIONS", "1")) partition = int(os.getenv("TEST_PARTITION", "0")) model_names = [name for name in TTS.list_models() if name not in MODELS_WITH_SEP_TESTS] +model_names.extend(["tts_models/deu/fairseq/vits", "tts_models/sqi/fairseq/vits"]) model_names = [name for i, name in enumerate(model_names) if i % num_partitions == partition] From 8224adf6e5febe1d62d41e91486a6e343301d343 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Thu, 16 Jan 2025 08:32:44 +0100 Subject: [PATCH 2/2] chore: bump version to 0.25.3 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 44c5fb7127..4b87a10b20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ build-backend = "hatchling.build" [project] name = "coqui-tts" -version = "0.25.2" +version = "0.25.3" description = "Deep learning for Text to Speech." readme = "README.md" requires-python = ">=3.9, <3.13"