-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1e28455
commit 430712c
Showing
21 changed files
with
3,607 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
__pycache__ | ||
|
||
exp/ | ||
data/ | ||
|
||
WavLM-Base+.pt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,309 @@ | ||
#! /usr/bin/python | ||
# -*- encoding: utf-8 -*- | ||
|
||
import torch | ||
import numpy | ||
import random | ||
import pdb | ||
import os | ||
import threading | ||
import time | ||
import math | ||
import glob | ||
# import soundfile | ||
from scipy import signal | ||
import soundfile | ||
from torch.utils.data import Dataset, DataLoader | ||
import torch.distributed as dist | ||
|
||
def round_down(num, divisor): | ||
return num - (num%divisor) | ||
|
||
def worker_init_fn(worker_id): | ||
numpy.random.seed(numpy.random.get_state()[1][0] + worker_id) | ||
|
||
|
||
def loadWAV(filename, max_frames, evalmode=True, num_eval=5): | ||
|
||
# Maximum audio length | ||
max_audio = max_frames * 160 + 240 | ||
|
||
# Read wav file and convert to torch tensor | ||
audio, sample_rate = soundfile.read(filename) | ||
|
||
|
||
audiosize = audio.shape[0] | ||
|
||
if audiosize <= max_audio: | ||
shortage = max_audio - audiosize + 1 | ||
audio = numpy.pad(audio, (0, shortage), 'wrap') | ||
audiosize = audio.shape[0] | ||
|
||
if evalmode: | ||
startframe = numpy.linspace(0,audiosize-max_audio,num=num_eval) | ||
else: | ||
startframe = numpy.array([numpy.int64(random.random()*(audiosize-max_audio))]) | ||
|
||
feats = [] | ||
if evalmode and max_frames == 0: | ||
feats.append(audio) | ||
else: | ||
for asf in startframe: | ||
feats.append(audio[int(asf):int(asf)+max_audio]) | ||
|
||
feat = numpy.stack(feats,axis=0).astype(float) | ||
|
||
return feat; | ||
|
||
class AugmentWAV(object): | ||
|
||
def __init__(self, musan_path, rir_path, max_frames): | ||
|
||
self.max_frames = max_frames | ||
self.max_audio = max_audio = max_frames * 160 + 240 | ||
|
||
self.noisetypes = ['noise','speech','music'] | ||
|
||
self.noisesnr = {'noise':[0,15],'speech':[13,20],'music':[5,15]} | ||
self.numnoise = {'noise':[1,1], 'speech':[3,8], 'music':[1,1] } | ||
self.noiselist = {} | ||
|
||
augment_files = glob.glob(os.path.join(musan_path,'*/*/*.wav')); | ||
|
||
for file in augment_files: | ||
if not file.split('/')[-3] in self.noiselist: | ||
self.noiselist[file.split('/')[-3]] = [] | ||
self.noiselist[file.split('/')[-3]].append(file) | ||
|
||
self.rir_files = glob.glob(os.path.join(rir_path,'*/*/*.wav')); | ||
|
||
def additive_noise(self, noisecat, audio): | ||
|
||
clean_db = 10 * numpy.log10(numpy.mean(audio ** 2)+1e-4) | ||
|
||
numnoise = self.numnoise[noisecat] | ||
noiselist = random.sample(self.noiselist[noisecat], random.randint(numnoise[0],numnoise[1])) | ||
|
||
noises = [] | ||
|
||
for noise in noiselist: | ||
|
||
noiseaudio = loadWAV(noise, self.max_frames, evalmode=False) | ||
noise_snr = random.uniform(self.noisesnr[noisecat][0],self.noisesnr[noisecat][1]) | ||
noise_db = 10 * numpy.log10(numpy.mean(noiseaudio[0] ** 2)+1e-4) | ||
noises.append(numpy.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio) | ||
|
||
return numpy.sum(numpy.concatenate(noises,axis=0),axis=0,keepdims=True) + audio | ||
|
||
def reverberate(self, audio): | ||
|
||
rir_file = random.choice(self.rir_files) | ||
|
||
rir, fs = soundfile.read(rir_file) | ||
rir = numpy.expand_dims(rir.astype(float),0) | ||
rir = rir / numpy.sqrt(numpy.sum(rir**2)) | ||
|
||
return signal.convolve(audio, rir, mode='full')[:,:self.max_audio] | ||
|
||
|
||
class train_dataset_loader(Dataset): | ||
def __init__(self, train_list, augment, musan_path, rir_path, max_frames, train_path, **kwargs): | ||
|
||
self.augment_wav = AugmentWAV(musan_path=musan_path, rir_path=rir_path, max_frames = max_frames) | ||
|
||
self.train_list = train_list | ||
self.max_frames = max_frames; | ||
self.musan_path = musan_path | ||
self.rir_path = rir_path | ||
self.augment = augment | ||
|
||
# Read training files | ||
with open(train_list) as dataset_file: | ||
lines = dataset_file.readlines(); | ||
|
||
# Make a dictionary of ID names and ID indices | ||
dictkeys = list(set([x.split()[0] for x in lines])) | ||
dictkeys.sort() | ||
dictkeys = { key : ii for ii, key in enumerate(dictkeys) } | ||
|
||
# Parse the training list into file names and ID indices | ||
self.data_list = [] | ||
self.data_label = [] | ||
|
||
for lidx, line in enumerate(lines): | ||
data = line.strip().split(); | ||
|
||
speaker_label = dictkeys[data[0]]; | ||
filename = os.path.join(train_path,data[1]); | ||
|
||
self.data_label.append(speaker_label) | ||
self.data_list.append(filename) | ||
|
||
|
||
def __getitem__(self, indices): | ||
|
||
feat_clean = [] | ||
feat = [] | ||
|
||
for index in indices: | ||
try: | ||
audio_clean = loadWAV(self.data_list[index], self.max_frames, evalmode=False) | ||
except: | ||
print(self.data_list[index]) | ||
|
||
if len(audio_clean.shape) == 3: | ||
print(self.data_list[index]) | ||
|
||
if self.augment: | ||
augtype = random.randint(0,5) | ||
if augtype == 0: | ||
audio = audio_clean | ||
elif augtype == 1: | ||
audio = self.augment_wav.reverberate(audio_clean) | ||
elif augtype == 2: | ||
audio = self.augment_wav.additive_noise('music',audio_clean) | ||
elif augtype == 3: | ||
audio = self.augment_wav.additive_noise('speech',audio_clean) | ||
elif augtype == 4: | ||
audio = self.augment_wav.additive_noise('noise',audio_clean) | ||
elif augtype == 5: | ||
audio = self.augment_wav.additive_noise('speech',audio_clean) | ||
audio = self.augment_wav.additive_noise('music',audio_clean) | ||
|
||
feat_clean.append(audio_clean) | ||
feat.append(audio) | ||
|
||
feat_clean = numpy.concatenate(feat_clean, axis=0) | ||
feat = numpy.concatenate(feat, axis=0) | ||
|
||
return torch.FloatTensor(feat_clean), torch.FloatTensor(feat), self.data_label[index], self.data_list[index] | ||
|
||
def __len__(self): | ||
return len(self.data_list) | ||
|
||
|
||
|
||
class test_dataset_loader(Dataset): | ||
def __init__(self, test_list, test_path, eval_frames, num_eval, **kwargs): | ||
self.max_frames = eval_frames; | ||
self.num_eval = num_eval | ||
self.test_path = test_path | ||
self.test_list = test_list | ||
|
||
def __getitem__(self, index): | ||
# print(self.test_list[index]) | ||
audio = loadWAV(os.path.join(self.test_path,self.test_list[index]), self.max_frames, evalmode=True, num_eval=self.num_eval) | ||
|
||
audio2 = loadWAV(os.path.join(self.test_path,self.test_list[index]), 0, evalmode=True, num_eval=self.num_eval) | ||
|
||
return torch.FloatTensor(audio), torch.FloatTensor(audio2), self.test_list[index] | ||
# return torch.FloatTensor(audio2), self.test_list[index] | ||
|
||
def __len__(self): | ||
return len(self.test_list) | ||
|
||
|
||
class train_dataset_sampler(torch.utils.data.Sampler): | ||
def __init__(self, data_source, nPerSpeaker, max_seg_per_spk, batch_size, distributed, seed, **kwargs): | ||
|
||
self.data_label = data_source.data_label; | ||
self.nPerSpeaker = nPerSpeaker; | ||
self.max_seg_per_spk = max_seg_per_spk; | ||
self.batch_size = batch_size; | ||
self.epoch = 0; | ||
self.seed = seed; | ||
self.distributed = distributed; | ||
|
||
def __iter__(self): | ||
|
||
g = torch.Generator() | ||
g.manual_seed(self.seed + self.epoch) | ||
indices = torch.randperm(len(self.data_label), generator=g).tolist() | ||
|
||
data_dict = {} | ||
|
||
# Sort into dictionary of file indices for each ID | ||
for index in indices: | ||
speaker_label = self.data_label[index] | ||
if not (speaker_label in data_dict): | ||
data_dict[speaker_label] = []; | ||
data_dict[speaker_label].append(index); | ||
|
||
|
||
## Group file indices for each class | ||
dictkeys = list(data_dict.keys()); | ||
dictkeys.sort() | ||
|
||
lol = lambda lst, sz: [lst[i:i+sz] for i in range(0, len(lst), sz)] | ||
|
||
flattened_list = [] | ||
flattened_label = [] | ||
|
||
for findex, key in enumerate(dictkeys): | ||
data = data_dict[key] | ||
numSeg = round_down(min(len(data),self.max_seg_per_spk),self.nPerSpeaker) | ||
|
||
rp = lol(numpy.arange(numSeg),self.nPerSpeaker) | ||
flattened_label.extend([findex] * (len(rp))) | ||
for indices in rp: | ||
flattened_list.append([data[i] for i in indices]) | ||
|
||
## Mix data in random order | ||
mixid = torch.randperm(len(flattened_label), generator=g).tolist() | ||
mixlabel = [] | ||
mixmap = [] | ||
|
||
## Prevent two pairs of the same speaker in the same batch | ||
for ii in mixid: | ||
startbatch = round_down(len(mixlabel), self.batch_size) | ||
if flattened_label[ii] not in mixlabel[startbatch:]: | ||
mixlabel.append(flattened_label[ii]) | ||
mixmap.append(ii) | ||
|
||
mixed_list = [flattened_list[i] for i in mixmap] | ||
|
||
## Divide data to each GPU | ||
if self.distributed: | ||
total_size = round_down(len(mixed_list), self.batch_size * dist.get_world_size()) | ||
start_index = int ( ( dist.get_rank() ) / dist.get_world_size() * total_size ) | ||
end_index = int ( ( dist.get_rank() + 1 ) / dist.get_world_size() * total_size ) | ||
self.num_samples = end_index - start_index | ||
return iter(mixed_list[start_index:end_index]) | ||
else: | ||
total_size = round_down(len(mixed_list), self.batch_size) | ||
self.num_samples = total_size | ||
return iter(mixed_list[:total_size]) | ||
|
||
|
||
def __len__(self) -> int: | ||
return self.num_samples | ||
|
||
def set_epoch(self, epoch: int) -> None: | ||
self.epoch = epoch | ||
|
||
|
||
if __name__ == '__main__': | ||
train_dataset = train_dataset_loader(train_list='/mnt/proj3/open-24-5/pengjy_new/WavLM_Adapter/CNCeleb_lst/CNCeleb_trainlist_200spk.txt', | ||
augment=False, | ||
musan_path='/mnt/proj3/open-24-5/pengjy_new/musan_split/', | ||
rir_path='/mnt/proj3/open-24-5/plchot/data_augment/16kHz/simulated_rirs/', | ||
max_frames=300, | ||
train_path='/mnt/proj3/open-24-5/pengjy_new/Data/CN-Celeb_flac/data', | ||
) | ||
|
||
train_sampler = train_dataset_sampler(train_dataset, nPerSpeaker=1, max_seg_per_spk=500, batch_size=100, distributed=False,seed=120) | ||
# train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) | ||
|
||
train_loader = torch.utils.data.DataLoader( | ||
train_dataset, | ||
batch_size=100, | ||
num_workers=10, | ||
sampler=train_sampler, | ||
pin_memory=True, | ||
drop_last=True, | ||
) | ||
for data, data_label in train_loader: | ||
print(data.shape) | ||
data = data.transpose(1,0) | ||
print(data.shape) | ||
quit() |
Oops, something went wrong.