Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add multi speaker training #12

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 33 additions & 26 deletions VITS_with_XPhoneBERT/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
3) computes spectrograms from audio files.
"""

def __init__(self, audiopaths_sid_text, hparams):
def __init__(self, audiopaths_sid_text, hparams, tokenizer):
self.audiopaths_sid_text = load_filepaths_and_text(audiopaths_sid_text)
self.text_cleaners = hparams.text_cleaners
self.max_wav_value = hparams.max_wav_value
Expand All @@ -178,7 +178,9 @@ def __init__(self, audiopaths_sid_text, hparams):

self.add_blank = hparams.add_blank
self.min_text_len = getattr(hparams, "min_text_len", 1)
self.max_text_len = getattr(hparams, "max_text_len", 190)
self.max_text_len = getattr(hparams, "max_text_len", 500)

self.tokenizer = tokenizer

random.seed(1234)
random.shuffle(self.audiopaths_sid_text)
Expand All @@ -204,10 +206,10 @@ def _filter(self):
def get_audio_text_speaker_pair(self, audiopath_sid_text):
# separate filename, speaker_id and text
audiopath, sid, text = audiopath_sid_text[0], audiopath_sid_text[1], audiopath_sid_text[2]
text = self.get_text(text)
input_ids, attention_mask = self.get_text(text)
spec, wav = self.get_audio(audiopath)
sid = self.get_sid(sid)
return (text, spec, wav, sid)
return (input_ids, attention_mask, spec, wav, sid)

def get_audio(self, filename):
audio, sampling_rate = load_wav_to_torch(filename)
Expand All @@ -228,14 +230,13 @@ def get_audio(self, filename):
return spec, audio_norm

def get_text(self, text):
if self.cleaned_text:
text_norm = cleaned_text_to_sequence(text)
else:
text_norm = text_to_sequence(text, self.text_cleaners)
if self.add_blank:
text_norm = commons.intersperse(text_norm, 0)
text_norm = torch.LongTensor(text_norm)
return text_norm
tokenized_text = self.tokenizer(text)
input_ids = tokenized_text['input_ids']
attention_mask = tokenized_text['attention_mask']

input_ids = torch.LongTensor(input_ids)
attention_mask = torch.LongTensor(attention_mask)
return input_ids, attention_mask

def get_sid(self, sid):
sid = torch.LongTensor([int(sid)])
Expand All @@ -252,55 +253,61 @@ class TextAudioSpeakerCollate():
""" Zero-pads model inputs and targets
"""

def __init__(self, return_ids=False):
def __init__(self, return_ids=False, pad_token_id=1):
self.return_ids = return_ids
self.pad_token_id = pad_token_id

def __call__(self, batch):
"""Collate's training batch from normalized text, audio and speaker identities
PARAMS
------
batch: [text_normalized, spec_normalized, wav_normalized, sid]
batch: [input_ids, attention_mask, spec_normalized, wav_normalized, sid]
"""
# Right zero-pad all one-hot text sequences to max input length
_, ids_sorted_decreasing = torch.sort(
torch.LongTensor([x[1].size(1) for x in batch]),
torch.LongTensor([x[2].size(1) for x in batch]),
dim=0, descending=True)

max_text_len = max([len(x[0]) for x in batch])
max_spec_len = max([x[1].size(1) for x in batch])
max_wav_len = max([x[2].size(1) for x in batch])
max_spec_len = max([x[2].size(1) for x in batch])
max_wav_len = max([x[3].size(1) for x in batch])

text_lengths = torch.LongTensor(len(batch))
spec_lengths = torch.LongTensor(len(batch))
wav_lengths = torch.LongTensor(len(batch))
sid = torch.LongTensor(len(batch))

text_padded = torch.LongTensor(len(batch), max_text_len)
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
input_padded = torch.LongTensor(len(batch), max_text_len)
attention_padded = torch.LongTensor(len(batch), max_text_len)
spec_padded = torch.FloatTensor(len(batch), batch[0][2].size(0), max_spec_len)
wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
text_padded.zero_()
input_padded.fill_(self.pad_token_id)
attention_padded.zero_()
spec_padded.zero_()
wav_padded.zero_()
for i in range(len(ids_sorted_decreasing)):
row = batch[ids_sorted_decreasing[i]]

text = row[0]
text_padded[i, :text.size(0)] = text
input_padded[i, :text.size(0)] = text
text_lengths[i] = text.size(0)

spec = row[1]
attention = row[1]
attention_padded[i, :attention.size(0)] = attention

spec = row[2]
spec_padded[i, :, :spec.size(1)] = spec
spec_lengths[i] = spec.size(1)

wav = row[2]
wav = row[3]
wav_padded[i, :, :wav.size(1)] = wav
wav_lengths[i] = wav.size(1)

sid[i] = row[3]
sid[i] = row[4]

if self.return_ids:
return text_padded, text_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, sid, ids_sorted_decreasing
return text_padded, text_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, sid
return input_padded, attention_padded, text_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, sid, ids_sorted_decreasing
return input_padded, attention_padded, text_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, sid


class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
Expand Down
Loading