Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
theolepage committed Sep 19, 2024
1 parent 1e28455 commit 430712c
Show file tree
Hide file tree
Showing 21 changed files with 3,607 additions and 0 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
__pycache__

exp/
data/

WavLM-Base+.pt
309 changes: 309 additions & 0 deletions DatasetLoader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,309 @@
#! /usr/bin/python
# -*- encoding: utf-8 -*-

import torch
import numpy
import random
import pdb
import os
import threading
import time
import math
import glob
# import soundfile
from scipy import signal
import soundfile
from torch.utils.data import Dataset, DataLoader
import torch.distributed as dist

def round_down(num, divisor):
return num - (num%divisor)

def worker_init_fn(worker_id):
numpy.random.seed(numpy.random.get_state()[1][0] + worker_id)


def loadWAV(filename, max_frames, evalmode=True, num_eval=5):

# Maximum audio length
max_audio = max_frames * 160 + 240

# Read wav file and convert to torch tensor
audio, sample_rate = soundfile.read(filename)


audiosize = audio.shape[0]

if audiosize <= max_audio:
shortage = max_audio - audiosize + 1
audio = numpy.pad(audio, (0, shortage), 'wrap')
audiosize = audio.shape[0]

if evalmode:
startframe = numpy.linspace(0,audiosize-max_audio,num=num_eval)
else:
startframe = numpy.array([numpy.int64(random.random()*(audiosize-max_audio))])

feats = []
if evalmode and max_frames == 0:
feats.append(audio)
else:
for asf in startframe:
feats.append(audio[int(asf):int(asf)+max_audio])

feat = numpy.stack(feats,axis=0).astype(float)

return feat;

class AugmentWAV(object):

def __init__(self, musan_path, rir_path, max_frames):

self.max_frames = max_frames
self.max_audio = max_audio = max_frames * 160 + 240

self.noisetypes = ['noise','speech','music']

self.noisesnr = {'noise':[0,15],'speech':[13,20],'music':[5,15]}
self.numnoise = {'noise':[1,1], 'speech':[3,8], 'music':[1,1] }
self.noiselist = {}

augment_files = glob.glob(os.path.join(musan_path,'*/*/*.wav'));

for file in augment_files:
if not file.split('/')[-3] in self.noiselist:
self.noiselist[file.split('/')[-3]] = []
self.noiselist[file.split('/')[-3]].append(file)

self.rir_files = glob.glob(os.path.join(rir_path,'*/*/*.wav'));

def additive_noise(self, noisecat, audio):

clean_db = 10 * numpy.log10(numpy.mean(audio ** 2)+1e-4)

numnoise = self.numnoise[noisecat]
noiselist = random.sample(self.noiselist[noisecat], random.randint(numnoise[0],numnoise[1]))

noises = []

for noise in noiselist:

noiseaudio = loadWAV(noise, self.max_frames, evalmode=False)
noise_snr = random.uniform(self.noisesnr[noisecat][0],self.noisesnr[noisecat][1])
noise_db = 10 * numpy.log10(numpy.mean(noiseaudio[0] ** 2)+1e-4)
noises.append(numpy.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio)

return numpy.sum(numpy.concatenate(noises,axis=0),axis=0,keepdims=True) + audio

def reverberate(self, audio):

rir_file = random.choice(self.rir_files)

rir, fs = soundfile.read(rir_file)
rir = numpy.expand_dims(rir.astype(float),0)
rir = rir / numpy.sqrt(numpy.sum(rir**2))

return signal.convolve(audio, rir, mode='full')[:,:self.max_audio]


class train_dataset_loader(Dataset):
def __init__(self, train_list, augment, musan_path, rir_path, max_frames, train_path, **kwargs):

self.augment_wav = AugmentWAV(musan_path=musan_path, rir_path=rir_path, max_frames = max_frames)

self.train_list = train_list
self.max_frames = max_frames;
self.musan_path = musan_path
self.rir_path = rir_path
self.augment = augment

# Read training files
with open(train_list) as dataset_file:
lines = dataset_file.readlines();

# Make a dictionary of ID names and ID indices
dictkeys = list(set([x.split()[0] for x in lines]))
dictkeys.sort()
dictkeys = { key : ii for ii, key in enumerate(dictkeys) }

# Parse the training list into file names and ID indices
self.data_list = []
self.data_label = []

for lidx, line in enumerate(lines):
data = line.strip().split();

speaker_label = dictkeys[data[0]];
filename = os.path.join(train_path,data[1]);

self.data_label.append(speaker_label)
self.data_list.append(filename)


def __getitem__(self, indices):

feat_clean = []
feat = []

for index in indices:
try:
audio_clean = loadWAV(self.data_list[index], self.max_frames, evalmode=False)
except:
print(self.data_list[index])

if len(audio_clean.shape) == 3:
print(self.data_list[index])

if self.augment:
augtype = random.randint(0,5)
if augtype == 0:
audio = audio_clean
elif augtype == 1:
audio = self.augment_wav.reverberate(audio_clean)
elif augtype == 2:
audio = self.augment_wav.additive_noise('music',audio_clean)
elif augtype == 3:
audio = self.augment_wav.additive_noise('speech',audio_clean)
elif augtype == 4:
audio = self.augment_wav.additive_noise('noise',audio_clean)
elif augtype == 5:
audio = self.augment_wav.additive_noise('speech',audio_clean)
audio = self.augment_wav.additive_noise('music',audio_clean)

feat_clean.append(audio_clean)
feat.append(audio)

feat_clean = numpy.concatenate(feat_clean, axis=0)
feat = numpy.concatenate(feat, axis=0)

return torch.FloatTensor(feat_clean), torch.FloatTensor(feat), self.data_label[index], self.data_list[index]

def __len__(self):
return len(self.data_list)



class test_dataset_loader(Dataset):
def __init__(self, test_list, test_path, eval_frames, num_eval, **kwargs):
self.max_frames = eval_frames;
self.num_eval = num_eval
self.test_path = test_path
self.test_list = test_list

def __getitem__(self, index):
# print(self.test_list[index])
audio = loadWAV(os.path.join(self.test_path,self.test_list[index]), self.max_frames, evalmode=True, num_eval=self.num_eval)

audio2 = loadWAV(os.path.join(self.test_path,self.test_list[index]), 0, evalmode=True, num_eval=self.num_eval)

return torch.FloatTensor(audio), torch.FloatTensor(audio2), self.test_list[index]
# return torch.FloatTensor(audio2), self.test_list[index]

def __len__(self):
return len(self.test_list)


class train_dataset_sampler(torch.utils.data.Sampler):
def __init__(self, data_source, nPerSpeaker, max_seg_per_spk, batch_size, distributed, seed, **kwargs):

self.data_label = data_source.data_label;
self.nPerSpeaker = nPerSpeaker;
self.max_seg_per_spk = max_seg_per_spk;
self.batch_size = batch_size;
self.epoch = 0;
self.seed = seed;
self.distributed = distributed;

def __iter__(self):

g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.data_label), generator=g).tolist()

data_dict = {}

# Sort into dictionary of file indices for each ID
for index in indices:
speaker_label = self.data_label[index]
if not (speaker_label in data_dict):
data_dict[speaker_label] = [];
data_dict[speaker_label].append(index);


## Group file indices for each class
dictkeys = list(data_dict.keys());
dictkeys.sort()

lol = lambda lst, sz: [lst[i:i+sz] for i in range(0, len(lst), sz)]

flattened_list = []
flattened_label = []

for findex, key in enumerate(dictkeys):
data = data_dict[key]
numSeg = round_down(min(len(data),self.max_seg_per_spk),self.nPerSpeaker)

rp = lol(numpy.arange(numSeg),self.nPerSpeaker)
flattened_label.extend([findex] * (len(rp)))
for indices in rp:
flattened_list.append([data[i] for i in indices])

## Mix data in random order
mixid = torch.randperm(len(flattened_label), generator=g).tolist()
mixlabel = []
mixmap = []

## Prevent two pairs of the same speaker in the same batch
for ii in mixid:
startbatch = round_down(len(mixlabel), self.batch_size)
if flattened_label[ii] not in mixlabel[startbatch:]:
mixlabel.append(flattened_label[ii])
mixmap.append(ii)

mixed_list = [flattened_list[i] for i in mixmap]

## Divide data to each GPU
if self.distributed:
total_size = round_down(len(mixed_list), self.batch_size * dist.get_world_size())
start_index = int ( ( dist.get_rank() ) / dist.get_world_size() * total_size )
end_index = int ( ( dist.get_rank() + 1 ) / dist.get_world_size() * total_size )
self.num_samples = end_index - start_index
return iter(mixed_list[start_index:end_index])
else:
total_size = round_down(len(mixed_list), self.batch_size)
self.num_samples = total_size
return iter(mixed_list[:total_size])


def __len__(self) -> int:
return self.num_samples

def set_epoch(self, epoch: int) -> None:
self.epoch = epoch


if __name__ == '__main__':
train_dataset = train_dataset_loader(train_list='/mnt/proj3/open-24-5/pengjy_new/WavLM_Adapter/CNCeleb_lst/CNCeleb_trainlist_200spk.txt',
augment=False,
musan_path='/mnt/proj3/open-24-5/pengjy_new/musan_split/',
rir_path='/mnt/proj3/open-24-5/plchot/data_augment/16kHz/simulated_rirs/',
max_frames=300,
train_path='/mnt/proj3/open-24-5/pengjy_new/Data/CN-Celeb_flac/data',
)

train_sampler = train_dataset_sampler(train_dataset, nPerSpeaker=1, max_seg_per_spk=500, batch_size=100, distributed=False,seed=120)
# train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)

train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=100,
num_workers=10,
sampler=train_sampler,
pin_memory=True,
drop_last=True,
)
for data, data_label in train_loader:
print(data.shape)
data = data.transpose(1,0)
print(data.shape)
quit()
Loading

0 comments on commit 430712c

Please sign in to comment.