Skip to content

Commit

Permalink
Added repeated dataset amd test of models
Browse files Browse the repository at this point in the history
  • Loading branch information
rrMat committed Jan 16, 2024
1 parent cc50071 commit c3f1613
Show file tree
Hide file tree
Showing 5 changed files with 1,232 additions and 86 deletions.
46 changes: 30 additions & 16 deletions birdclef/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ def __init__(


#Augmentations
self.maskingFreq = torchaudio.transforms.FrequencyMasking(freq_mask_param=20)
self.maskingTime = torchaudio.transforms.TimeMasking(time_mask_param=20)
self.maskingFreq = torchaudio.transforms.FrequencyMasking(freq_mask_param=40)
# self.maskingTime = torchaudio.transforms.TimeMasking(time_mask_param=20)
self.noiser = torchaudio.transforms.AddNoise()
self.speed_perturb = torchaudio.transforms.SpeedPerturbation(sample_rate, [0.9, 1.1, 1.0, 1.0, 1.0, 0.8, 1.2, 1.0])

self.rnd_offset = rnd_offset


Expand Down Expand Up @@ -96,25 +96,21 @@ def forward(self, filename):
# 2 Waveform Augmenations
if self.augmentations:
# Rasdom noise
if np.random.random() > 0.8:
if np.random.random() > 0.5:
noise = torch.randn_like(waveform)
snr_dbs = torch.tensor([20])
waveform = self.noiser(waveform, noise, snr_dbs)
# Speed perturbation
if np.random.random() > 0.8:
waveform = self.speed_perturb(waveform)[0]

snr_dbs = torch.tensor([10])
waveform = self.noiser(waveform, noise, snr_dbs)


# 3 Convert to mel-scale
mel = self.melspec(waveform)

# 4 Mel Augmenations
if self.augmentations:
if np.random.random() > 0.8:
mel = self.maskingTime(mel)
# if np.random.random() > 0.8:
# mel = self.maskingTime(mel)

if np.random.random() > 0.8:
if np.random.random() > 0.5:
mel = self.maskingFreq(mel)


Expand Down Expand Up @@ -166,6 +162,8 @@ def inverse_transform(self, mel):
class BirdClef(Dataset):

def __init__(self, metadata=None, classes=None, per_channel=False, augmentations=False, rnd_offset=False):



self.metadata = metadata
sorted_classes = classes.sort_values()
Expand Down Expand Up @@ -218,7 +216,19 @@ def __getitem__(self, idx):
val_metadata_simple = val_metadata_base.loc[val_metadata_base.primary_label.isin(simple_classes)].reset_index()
test_metadata_simple = test_metadata_base.loc[test_metadata_base.primary_label.isin(simple_classes)].reset_index()

# %% ../nbs/02_dataset.ipynb 20
# %% ../nbs/02_dataset.ipynb 21
dir = DATA_DIR
try:
train_metadata_repeated = pd.read_csv(dir + 'repeated/train_metadata.csv')
val_metadata_repeated = pd.read_csv(dir + 'repeated/val_metadata.csv')
test_metadata_repeated = pd.read_csv(dir + 'repeated/test_metadata.csv')
except FileNotFoundError:
dir = 'data/'
train_metadata_repeated = pd.read_csv(dir + 'repeated/train_metadata.csv')
val_metadata_repeated = pd.read_csv(dir + 'repeated/val_metadata.csv')
test_metadata_repeated = pd.read_csv(dir + 'repeated/test_metadata.csv')

# %% ../nbs/02_dataset.ipynb 22
dataset_dict = {
'train_base': (BirdClef, {'metadata': train_metadata_base, 'classes': train_metadata_base.primary_label}),
'val_base': (BirdClef, {'metadata': val_metadata_base, 'classes': train_metadata_base.primary_label}),
Expand Down Expand Up @@ -248,17 +258,21 @@ def __getitem__(self, idx):
'val_base_pcn_aug_rnd': (BirdClef, {'metadata': val_metadata_base, 'classes': train_metadata_base.primary_label, 'per_channel': True, 'augmentations': True, 'rnd_offset': True}),
'test_base_pcn_aug_rnd': (BirdClef, {'metadata': test_metadata_base, 'classes': train_metadata_base.primary_label, 'per_channel': True, 'augmentations': True, 'rnd_offset': True}),

'train_repeated_pcn_rnd': (BirdClef, {'metadata': train_metadata_repeated, 'classes': train_metadata_repeated.primary_label, 'per_channel': True, 'rnd_offset': True}),
'val_repeated_pcn_rnd': (BirdClef, {'metadata': val_metadata_repeated, 'classes': train_metadata_repeated.primary_label, 'per_channel': True, 'rnd_offset': True}),
'test_repeated_pcn_rnd': (BirdClef, {'metadata': test_metadata_repeated, 'classes': train_metadata_repeated.primary_label, 'per_channel': True, 'rnd_offset': True}),

}

# %% ../nbs/02_dataset.ipynb 21
# %% ../nbs/02_dataset.ipynb 23
def get_dataset(dataset_key:str # A key of the dataset dictionary
)->Dataset: # Pytorch dataset
"A getter method to retrieve the wanted dataset."
assert dataset_key in dataset_dict, f'{dataset_key} is not an existing dataset, choose one from {dataset_dict.keys()}.'
ds_class, kwargs = dataset_dict[dataset_key]
return ds_class(**kwargs)

# %% ../nbs/02_dataset.ipynb 25
# %% ../nbs/02_dataset.ipynb 27
def get_dataloader(dataset_key:str, # The key to access the dataset
dataloader_kwargs:dict={} # The optional parameters for a pytorch dataloader
)->DataLoader: # Pytorch dataloader
Expand Down
Loading

0 comments on commit c3f1613

Please sign in to comment.