-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Giacomo Melacini
committed
Oct 20, 2023
1 parent
16a1af8
commit 06c3eb0
Showing
29 changed files
with
3,702 additions
and
749 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -149,3 +149,12 @@ checklink/cookies.txt | |
|
||
# Quarto | ||
.quarto | ||
|
||
# Data folder | ||
data/ | ||
|
||
# wandb folder | ||
wandb/ | ||
|
||
# weight files | ||
**.pth |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Autogenerated by nbdev | ||
|
||
d = { 'settings': { 'branch': 'master', | ||
'doc_baseurl': '/birdclef_2023', | ||
'doc_host': 'https://Chavelanda.github.io', | ||
'git_url': 'https://github.com/Chavelanda/birdclef_2023', | ||
'lib_path': 'birdclef'}, | ||
'syms': { 'birdclef.dataset': { 'birdclef.dataset.BirdClef': ('dataset.html#birdclef', 'birdclef/dataset.py'), | ||
'birdclef.dataset.BirdClef.__getitem__': ('dataset.html#birdclef.__getitem__', 'birdclef/dataset.py'), | ||
'birdclef.dataset.BirdClef.__init__': ('dataset.html#birdclef.__init__', 'birdclef/dataset.py'), | ||
'birdclef.dataset.BirdClef.__len__': ('dataset.html#birdclef.__len__', 'birdclef/dataset.py'), | ||
'birdclef.dataset.MyPipeline': ('dataset.html#mypipeline', 'birdclef/dataset.py'), | ||
'birdclef.dataset.MyPipeline.__init__': ('dataset.html#mypipeline.__init__', 'birdclef/dataset.py'), | ||
'birdclef.dataset.MyPipeline.forward': ('dataset.html#mypipeline.forward', 'birdclef/dataset.py'), | ||
'birdclef.dataset.get_dataloader': ('dataset.html#get_dataloader', 'birdclef/dataset.py'), | ||
'birdclef.dataset.get_dataset': ('dataset.html#get_dataset', 'birdclef/dataset.py')}, | ||
'birdclef.experiment': {}, | ||
'birdclef.network': { 'birdclef.network.EfficientNetV2': ('network.html#efficientnetv2', 'birdclef/network.py'), | ||
'birdclef.network.EfficientNetV2.__init__': ( 'network.html#efficientnetv2.__init__', | ||
'birdclef/network.py'), | ||
'birdclef.network.EfficientNetV2.forward': ('network.html#efficientnetv2.forward', 'birdclef/network.py'), | ||
'birdclef.network.get_model': ('network.html#get_model', 'birdclef/network.py')}, | ||
'birdclef.preprocessing': {'birdclef.preprocessing.foo': ('preprocessing.html#foo', 'birdclef/preprocessing.py')}, | ||
'birdclef.trainer': { 'birdclef.trainer.log_weights': ('trainer.html#log_weights', 'birdclef/trainer.py'), | ||
'birdclef.trainer.train': ('trainer.html#train', 'birdclef/trainer.py'), | ||
'birdclef.trainer.train_one_epoch': ('trainer.html#train_one_epoch', 'birdclef/trainer.py'), | ||
'birdclef.trainer.validate_model': ('trainer.html#validate_model', 'birdclef/trainer.py')}, | ||
'birdclef.training_utils': { 'birdclef.training_utils.compute_metrics': ( 'training_utils.html#compute_metrics', | ||
'birdclef/training_utils.py'), | ||
'birdclef.training_utils.get_loss_func': ( 'training_utils.html#get_loss_func', | ||
'birdclef/training_utils.py'), | ||
'birdclef.training_utils.get_optimizer': ( 'training_utils.html#get_optimizer', | ||
'birdclef/training_utils.py')}, | ||
'birdclef.utils': { 'birdclef.utils.mel_to_wave': ('utils.html#mel_to_wave', 'birdclef/utils.py'), | ||
'birdclef.utils.plot_audio': ('utils.html#plot_audio', 'birdclef/utils.py'), | ||
'birdclef.utils.plot_fbank': ('utils.html#plot_fbank', 'birdclef/utils.py'), | ||
'birdclef.utils.plot_specgram': ('utils.html#plot_specgram', 'birdclef/utils.py'), | ||
'birdclef.utils.plot_spectrogram': ('utils.html#plot_spectrogram', 'birdclef/utils.py'), | ||
'birdclef.utils.plot_waveform': ('utils.html#plot_waveform', 'birdclef/utils.py')}}} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/02_dataset.ipynb. | ||
|
||
# %% auto 0 | ||
__all__ = ['dir', 'simple_classes', 'train_metadata_simple', 'val_metadata_simple', 'test_metadata_simple', 'dataset_dict', | ||
'MyPipeline', 'BirdClef', 'get_dataset', 'get_dataloader'] | ||
|
||
# %% ../nbs/02_dataset.ipynb 3 | ||
from IPython.display import Audio | ||
import pandas as pd | ||
from sklearn.preprocessing import LabelBinarizer | ||
|
||
import torch | ||
from torch.utils.data import Dataset, DataLoader | ||
import torchaudio | ||
|
||
from .utils import DATA_DIR, AUDIO_DATA_DIR, mel_to_wave, plot_audio, plot_spectrogram | ||
|
||
# %% ../nbs/02_dataset.ipynb 6 | ||
# Define custom feature extraction pipeline. | ||
# | ||
# 1. Check for sample rate and resample | ||
# 2. Waveform Augmenations | ||
# 3. Convert to mel-scale | ||
# 4. Mel Augmenations | ||
# 5. Check for lenght and stretch shorter videos | ||
|
||
|
||
class MyPipeline(torch.nn.Module): | ||
def __init__( | ||
self, | ||
c_length = 10, | ||
sample_rate=32000, | ||
f_min = 40, | ||
f_max = 15000, | ||
n_fft=2048, | ||
n_mels=128, | ||
hop_length = 512, | ||
power = 2.0 | ||
): | ||
super().__init__() | ||
|
||
self.c_length = c_length * 62.6 #626 sono 10 secondi | ||
self.sample_rate = sample_rate | ||
self.melspec = torchaudio.transforms.MelSpectrogram(sample_rate=self.sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, f_min=f_min, f_max=f_max, power=power) | ||
self.amptodb = torchaudio.transforms.AmplitudeToDB() | ||
self.stretch = torchaudio.transforms.TimeStretch(hop_length=hop_length, n_freq=128) | ||
|
||
#Augmentations | ||
# self.maskingFreq = torchaudio.transforms.FrequencyMasking(freq_mask_param=30) | ||
# self.maskingTime = torchaudio.transforms.TimeMasking(time_mask_param=30) | ||
# self.noiser = torchaudio.transforms.AddNoise() | ||
# self.pitchShift = torchaudio.transforms.PitchShift(resample_freq, 4) | ||
|
||
|
||
def forward(self, filename): | ||
# 0 Load the File | ||
waveform, sample_rate = torchaudio.load(filename, frame_offset=0, num_frames=320000) | ||
|
||
# 1 Check for the sample rate and eventually resample to 32k | ||
if sample_rate != self.sample_rate: | ||
print("Wrong sample rate: resampling audio") | ||
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sample_rate) | ||
waveform = resampler(waveform) | ||
|
||
# # 2 Waveform Augmentation | ||
# #2.1 White noise | ||
# if random.randint(0,1) < 0.3: | ||
# noise = torch.rand(1, 320000) | ||
# noise = (noise - 0.5) * 0.2 | ||
# snr_dbs = torch.tensor([random.randint(2,8)]) | ||
# waveform = self.noiser(waveform, noise, snr_dbs) | ||
|
||
# # 2.2 Pitch Shift | ||
# if random.randint(0,1) < 1: | ||
# if True: | ||
# waveform = self.pitchShift(waveform) | ||
|
||
|
||
# 3 Convert to mel-scale | ||
mel = self.melspec(waveform) | ||
mel = self.amptodb(mel) | ||
|
||
# 4 Mel Augmenations | ||
# 4.1 Frequency Masking | ||
# if True: | ||
# mel = self.maskingFreq(mel) | ||
# # 4.2 Time Masking | ||
# if True: | ||
# mel = self.maskingTime(mel) | ||
|
||
|
||
# 5 Check for the length and stretch it to 10s, it is a transformation used to regularize the length of the data | ||
if mel.shape[2] < self.c_length: | ||
print("Audio too short: stretching it.") | ||
replay_rate = mel.shape[2]/self.c_length | ||
#print(f"replay rate {replay_rate}%") | ||
mel = self.stretch(mel, replay_rate) | ||
mel = mel[:,:,0:626] | ||
#print(f"stretched shape {stretched.shape}") | ||
|
||
return mel.float() | ||
|
||
# %% ../nbs/02_dataset.ipynb 9 | ||
class BirdClef(Dataset): | ||
|
||
def __init__(self, metadata=None, classes=None): | ||
|
||
self.metadata = metadata | ||
self.classes = classes | ||
|
||
self.length = len(self.metadata) | ||
|
||
binarizer = LabelBinarizer() | ||
binarizer.fit(self.classes) | ||
|
||
self.labels = binarizer.transform(metadata.primary_label) | ||
|
||
# Initialize a pipeline | ||
self.pipeline = MyPipeline() | ||
|
||
def __len__(self): | ||
return self.length | ||
|
||
def __getitem__(self, idx): | ||
filename = AUDIO_DATA_DIR + self.metadata['filename'][idx] | ||
mel_spectrogram = self.pipeline(filename) | ||
|
||
label = self.labels[idx] | ||
label = torch.from_numpy(label).float() | ||
|
||
return mel_spectrogram, label | ||
|
||
# %% ../nbs/02_dataset.ipynb 13 | ||
dir = DATA_DIR | ||
try: | ||
train_metadata_base = pd.read_csv(dir + 'base/train_metadata.csv') | ||
val_metadata_base = pd.read_csv(dir + 'base/val_metadata.csv') | ||
test_metadata_base = pd.read_csv(dir + 'base/test_metadata.csv') | ||
except FileNotFoundError: | ||
dir = 'data/' | ||
train_metadata_base = pd.read_csv(dir + 'base/train_metadata.csv') | ||
val_metadata_base = pd.read_csv(dir + 'base/val_metadata.csv') | ||
test_metadata_base = pd.read_csv(dir + 'base/test_metadata.csv') | ||
|
||
simple_classes = ['thrnig1', 'wlwwar', 'barswa'] | ||
train_metadata_simple = train_metadata_base.loc[train_metadata_base.primary_label.isin(simple_classes)].reset_index() | ||
val_metadata_simple = val_metadata_base.loc[val_metadata_base.primary_label.isin(simple_classes)].reset_index() | ||
test_metadata_simple = test_metadata_base.loc[test_metadata_base.primary_label.isin(simple_classes)].reset_index() | ||
|
||
# %% ../nbs/02_dataset.ipynb 14 | ||
dataset_dict = { | ||
'train_base': (BirdClef, {'metadata': train_metadata_base, 'classes': train_metadata_base.primary_label}), | ||
'val_base': (BirdClef, {'metadata': val_metadata_base, 'classes': train_metadata_base.primary_label}), | ||
'test_base': (BirdClef, {'metadata': test_metadata_base, 'classes': train_metadata_base.primary_label}), | ||
|
||
'train_simple': (BirdClef, {'metadata': train_metadata_simple, 'classes': train_metadata_simple.primary_label}), | ||
'val_simple': (BirdClef, {'metadata': val_metadata_simple, 'classes': train_metadata_simple.primary_label}), | ||
'test_simple': (BirdClef, {'metadata': test_metadata_simple, 'classes': train_metadata_simple.primary_label}) | ||
} | ||
|
||
# %% ../nbs/02_dataset.ipynb 15 | ||
def get_dataset(dataset_key:str # A key of the dataset dictionary | ||
)->Dataset: # Pytorch dataset | ||
"A getter method to retrieve the wanted dataset." | ||
assert dataset_key in dataset_dict, f'{dataset_key} is not an existing dataset, choose one from {dataset_dict.keys()}.' | ||
ds_class, kwargs = dataset_dict[dataset_key] | ||
return ds_class(**kwargs) | ||
|
||
# %% ../nbs/02_dataset.ipynb 19 | ||
def get_dataloader(dataset_key:str, # The key to access the dataset | ||
dataloader_kwargs:dict={} # The optional parameters for a pytorch dataloader | ||
)->DataLoader: # Pytorch dataloader | ||
"A function to get a dataloader from a specific dataset" | ||
dataset = get_dataset(dataset_key) | ||
|
||
return DataLoader(dataset, **dataloader_kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/06_experiment.ipynb. | ||
|
||
# %% auto 0 | ||
__all__ = [] | ||
|
||
# %% ../nbs/06_experiment.ipynb 4 | ||
import wandb | ||
|
||
from .trainer import train |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/03_network.ipynb. | ||
|
||
# %% auto 0 | ||
__all__ = ['model_dict', 'EfficientNetV2', 'get_model'] | ||
|
||
# %% ../nbs/03_network.ipynb 3 | ||
from typing import Union, BinaryIO, IO | ||
from os import PathLike | ||
|
||
import torch | ||
import torchvision | ||
from torch.nn import Module | ||
|
||
from .dataset import get_dataloader | ||
|
||
|
||
# %% ../nbs/03_network.ipynb 5 | ||
class EfficientNetV2(torch.nn.Module): | ||
def __init__(self, num_classes=264, size='s'): | ||
super().__init__() | ||
|
||
if size=='s': | ||
self.efficientnet_v2 = torchvision.models.efficientnet_v2_s(weights=None, progress=True, num_classes=num_classes) | ||
elif size=='m': | ||
self.efficientnet_v2 = torchvision.models.efficientnet_v2_m(weights=None, progress=True, num_classes=num_classes) | ||
else: | ||
self.efficientnet_v2 = torchvision.models.efficientnet_v2_l(weights=None, progress=True, num_classes=num_classes) | ||
|
||
self.init_conv = torch.nn.Conv2d(1, 3, (3,3), padding="same") | ||
#self.sigmoid = torch.nn.functional.sigmoid | ||
|
||
def forward(self, x): | ||
x = self.init_conv(x) | ||
x = self.efficientnet_v2(x) | ||
|
||
return x | ||
|
||
# %% ../nbs/03_network.ipynb 10 | ||
model_dict = { | ||
'simple_efficient_net_v2_s': (EfficientNetV2, {'num_classes': 3}), | ||
'efficient_net_v2_s': (EfficientNetV2, {}) | ||
} | ||
|
||
def get_model(model_key:str, # A key of the model dictionary | ||
weights_path:Union[str, PathLike, BinaryIO, IO[bytes]] = None # A file-like object to the model weights | ||
)->Module: # A pytorch model | ||
"A getter method to retrieve the wanted (possibly pretrained) model" | ||
assert model_key in model_dict, f'{model_key} is not an existing network, choose one from {model_dict.keys()}.' | ||
|
||
net_class, kwargs = model_dict[model_key] | ||
model = net_class(**kwargs) | ||
|
||
if weights_path is not None: | ||
model.load_state_dict(torch.load(weights_path)) | ||
|
||
return model |
File renamed without changes.
Oops, something went wrong.