Skip to content

Commit

Permalink
Set up BirdClef
Browse files Browse the repository at this point in the history
  • Loading branch information
Giacomo Melacini committed Oct 20, 2023
1 parent 16a1af8 commit 06c3eb0
Show file tree
Hide file tree
Showing 29 changed files with 3,702 additions and 749 deletions.
9 changes: 9 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,12 @@ checklink/cookies.txt

# Quarto
.quarto

# Data folder
data/

# wandb folder
wandb/

# weight files
**.pth
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# dm-4-cat
# birdclef_2023

<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

Expand All @@ -8,7 +8,7 @@ documentation.
## Install

``` sh
pip install dm_4_cat
pip install birdclef_2023
```

## How to use
Expand Down
File renamed without changes.
39 changes: 39 additions & 0 deletions birdclef/_modidx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Autogenerated by nbdev

d = { 'settings': { 'branch': 'master',
'doc_baseurl': '/birdclef_2023',
'doc_host': 'https://Chavelanda.github.io',
'git_url': 'https://github.com/Chavelanda/birdclef_2023',
'lib_path': 'birdclef'},
'syms': { 'birdclef.dataset': { 'birdclef.dataset.BirdClef': ('dataset.html#birdclef', 'birdclef/dataset.py'),
'birdclef.dataset.BirdClef.__getitem__': ('dataset.html#birdclef.__getitem__', 'birdclef/dataset.py'),
'birdclef.dataset.BirdClef.__init__': ('dataset.html#birdclef.__init__', 'birdclef/dataset.py'),
'birdclef.dataset.BirdClef.__len__': ('dataset.html#birdclef.__len__', 'birdclef/dataset.py'),
'birdclef.dataset.MyPipeline': ('dataset.html#mypipeline', 'birdclef/dataset.py'),
'birdclef.dataset.MyPipeline.__init__': ('dataset.html#mypipeline.__init__', 'birdclef/dataset.py'),
'birdclef.dataset.MyPipeline.forward': ('dataset.html#mypipeline.forward', 'birdclef/dataset.py'),
'birdclef.dataset.get_dataloader': ('dataset.html#get_dataloader', 'birdclef/dataset.py'),
'birdclef.dataset.get_dataset': ('dataset.html#get_dataset', 'birdclef/dataset.py')},
'birdclef.experiment': {},
'birdclef.network': { 'birdclef.network.EfficientNetV2': ('network.html#efficientnetv2', 'birdclef/network.py'),
'birdclef.network.EfficientNetV2.__init__': ( 'network.html#efficientnetv2.__init__',
'birdclef/network.py'),
'birdclef.network.EfficientNetV2.forward': ('network.html#efficientnetv2.forward', 'birdclef/network.py'),
'birdclef.network.get_model': ('network.html#get_model', 'birdclef/network.py')},
'birdclef.preprocessing': {'birdclef.preprocessing.foo': ('preprocessing.html#foo', 'birdclef/preprocessing.py')},
'birdclef.trainer': { 'birdclef.trainer.log_weights': ('trainer.html#log_weights', 'birdclef/trainer.py'),
'birdclef.trainer.train': ('trainer.html#train', 'birdclef/trainer.py'),
'birdclef.trainer.train_one_epoch': ('trainer.html#train_one_epoch', 'birdclef/trainer.py'),
'birdclef.trainer.validate_model': ('trainer.html#validate_model', 'birdclef/trainer.py')},
'birdclef.training_utils': { 'birdclef.training_utils.compute_metrics': ( 'training_utils.html#compute_metrics',
'birdclef/training_utils.py'),
'birdclef.training_utils.get_loss_func': ( 'training_utils.html#get_loss_func',
'birdclef/training_utils.py'),
'birdclef.training_utils.get_optimizer': ( 'training_utils.html#get_optimizer',
'birdclef/training_utils.py')},
'birdclef.utils': { 'birdclef.utils.mel_to_wave': ('utils.html#mel_to_wave', 'birdclef/utils.py'),
'birdclef.utils.plot_audio': ('utils.html#plot_audio', 'birdclef/utils.py'),
'birdclef.utils.plot_fbank': ('utils.html#plot_fbank', 'birdclef/utils.py'),
'birdclef.utils.plot_specgram': ('utils.html#plot_specgram', 'birdclef/utils.py'),
'birdclef.utils.plot_spectrogram': ('utils.html#plot_spectrogram', 'birdclef/utils.py'),
'birdclef.utils.plot_waveform': ('utils.html#plot_waveform', 'birdclef/utils.py')}}}
176 changes: 176 additions & 0 deletions birdclef/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/02_dataset.ipynb.

# %% auto 0
__all__ = ['dir', 'simple_classes', 'train_metadata_simple', 'val_metadata_simple', 'test_metadata_simple', 'dataset_dict',
'MyPipeline', 'BirdClef', 'get_dataset', 'get_dataloader']

# %% ../nbs/02_dataset.ipynb 3
from IPython.display import Audio
import pandas as pd
from sklearn.preprocessing import LabelBinarizer

import torch
from torch.utils.data import Dataset, DataLoader
import torchaudio

from .utils import DATA_DIR, AUDIO_DATA_DIR, mel_to_wave, plot_audio, plot_spectrogram

# %% ../nbs/02_dataset.ipynb 6
# Define custom feature extraction pipeline.
#
# 1. Check for sample rate and resample
# 2. Waveform Augmenations
# 3. Convert to mel-scale
# 4. Mel Augmenations
# 5. Check for lenght and stretch shorter videos


class MyPipeline(torch.nn.Module):
def __init__(
self,
c_length = 10,
sample_rate=32000,
f_min = 40,
f_max = 15000,
n_fft=2048,
n_mels=128,
hop_length = 512,
power = 2.0
):
super().__init__()

self.c_length = c_length * 62.6 #626 sono 10 secondi
self.sample_rate = sample_rate
self.melspec = torchaudio.transforms.MelSpectrogram(sample_rate=self.sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, f_min=f_min, f_max=f_max, power=power)
self.amptodb = torchaudio.transforms.AmplitudeToDB()
self.stretch = torchaudio.transforms.TimeStretch(hop_length=hop_length, n_freq=128)

#Augmentations
# self.maskingFreq = torchaudio.transforms.FrequencyMasking(freq_mask_param=30)
# self.maskingTime = torchaudio.transforms.TimeMasking(time_mask_param=30)
# self.noiser = torchaudio.transforms.AddNoise()
# self.pitchShift = torchaudio.transforms.PitchShift(resample_freq, 4)


def forward(self, filename):
# 0 Load the File
waveform, sample_rate = torchaudio.load(filename, frame_offset=0, num_frames=320000)

# 1 Check for the sample rate and eventually resample to 32k
if sample_rate != self.sample_rate:
print("Wrong sample rate: resampling audio")
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sample_rate)
waveform = resampler(waveform)

# # 2 Waveform Augmentation
# #2.1 White noise
# if random.randint(0,1) < 0.3:
# noise = torch.rand(1, 320000)
# noise = (noise - 0.5) * 0.2
# snr_dbs = torch.tensor([random.randint(2,8)])
# waveform = self.noiser(waveform, noise, snr_dbs)

# # 2.2 Pitch Shift
# if random.randint(0,1) < 1:
# if True:
# waveform = self.pitchShift(waveform)


# 3 Convert to mel-scale
mel = self.melspec(waveform)
mel = self.amptodb(mel)

# 4 Mel Augmenations
# 4.1 Frequency Masking
# if True:
# mel = self.maskingFreq(mel)
# # 4.2 Time Masking
# if True:
# mel = self.maskingTime(mel)


# 5 Check for the length and stretch it to 10s, it is a transformation used to regularize the length of the data
if mel.shape[2] < self.c_length:
print("Audio too short: stretching it.")
replay_rate = mel.shape[2]/self.c_length
#print(f"replay rate {replay_rate}%")
mel = self.stretch(mel, replay_rate)
mel = mel[:,:,0:626]
#print(f"stretched shape {stretched.shape}")

return mel.float()

# %% ../nbs/02_dataset.ipynb 9
class BirdClef(Dataset):

def __init__(self, metadata=None, classes=None):

self.metadata = metadata
self.classes = classes

self.length = len(self.metadata)

binarizer = LabelBinarizer()
binarizer.fit(self.classes)

self.labels = binarizer.transform(metadata.primary_label)

# Initialize a pipeline
self.pipeline = MyPipeline()

def __len__(self):
return self.length

def __getitem__(self, idx):
filename = AUDIO_DATA_DIR + self.metadata['filename'][idx]
mel_spectrogram = self.pipeline(filename)

label = self.labels[idx]
label = torch.from_numpy(label).float()

return mel_spectrogram, label

# %% ../nbs/02_dataset.ipynb 13
dir = DATA_DIR
try:
train_metadata_base = pd.read_csv(dir + 'base/train_metadata.csv')
val_metadata_base = pd.read_csv(dir + 'base/val_metadata.csv')
test_metadata_base = pd.read_csv(dir + 'base/test_metadata.csv')
except FileNotFoundError:
dir = 'data/'
train_metadata_base = pd.read_csv(dir + 'base/train_metadata.csv')
val_metadata_base = pd.read_csv(dir + 'base/val_metadata.csv')
test_metadata_base = pd.read_csv(dir + 'base/test_metadata.csv')

simple_classes = ['thrnig1', 'wlwwar', 'barswa']
train_metadata_simple = train_metadata_base.loc[train_metadata_base.primary_label.isin(simple_classes)].reset_index()
val_metadata_simple = val_metadata_base.loc[val_metadata_base.primary_label.isin(simple_classes)].reset_index()
test_metadata_simple = test_metadata_base.loc[test_metadata_base.primary_label.isin(simple_classes)].reset_index()

# %% ../nbs/02_dataset.ipynb 14
dataset_dict = {
'train_base': (BirdClef, {'metadata': train_metadata_base, 'classes': train_metadata_base.primary_label}),
'val_base': (BirdClef, {'metadata': val_metadata_base, 'classes': train_metadata_base.primary_label}),
'test_base': (BirdClef, {'metadata': test_metadata_base, 'classes': train_metadata_base.primary_label}),

'train_simple': (BirdClef, {'metadata': train_metadata_simple, 'classes': train_metadata_simple.primary_label}),
'val_simple': (BirdClef, {'metadata': val_metadata_simple, 'classes': train_metadata_simple.primary_label}),
'test_simple': (BirdClef, {'metadata': test_metadata_simple, 'classes': train_metadata_simple.primary_label})
}

# %% ../nbs/02_dataset.ipynb 15
def get_dataset(dataset_key:str # A key of the dataset dictionary
)->Dataset: # Pytorch dataset
"A getter method to retrieve the wanted dataset."
assert dataset_key in dataset_dict, f'{dataset_key} is not an existing dataset, choose one from {dataset_dict.keys()}.'
ds_class, kwargs = dataset_dict[dataset_key]
return ds_class(**kwargs)

# %% ../nbs/02_dataset.ipynb 19
def get_dataloader(dataset_key:str, # The key to access the dataset
dataloader_kwargs:dict={} # The optional parameters for a pytorch dataloader
)->DataLoader: # Pytorch dataloader
"A function to get a dataloader from a specific dataset"
dataset = get_dataset(dataset_key)

return DataLoader(dataset, **dataloader_kwargs)
9 changes: 9 additions & 0 deletions birdclef/experiment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/06_experiment.ipynb.

# %% auto 0
__all__ = []

# %% ../nbs/06_experiment.ipynb 4
import wandb

from .trainer import train
56 changes: 56 additions & 0 deletions birdclef/network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/03_network.ipynb.

# %% auto 0
__all__ = ['model_dict', 'EfficientNetV2', 'get_model']

# %% ../nbs/03_network.ipynb 3
from typing import Union, BinaryIO, IO
from os import PathLike

import torch
import torchvision
from torch.nn import Module

from .dataset import get_dataloader


# %% ../nbs/03_network.ipynb 5
class EfficientNetV2(torch.nn.Module):
def __init__(self, num_classes=264, size='s'):
super().__init__()

if size=='s':
self.efficientnet_v2 = torchvision.models.efficientnet_v2_s(weights=None, progress=True, num_classes=num_classes)
elif size=='m':
self.efficientnet_v2 = torchvision.models.efficientnet_v2_m(weights=None, progress=True, num_classes=num_classes)
else:
self.efficientnet_v2 = torchvision.models.efficientnet_v2_l(weights=None, progress=True, num_classes=num_classes)

self.init_conv = torch.nn.Conv2d(1, 3, (3,3), padding="same")
#self.sigmoid = torch.nn.functional.sigmoid

def forward(self, x):
x = self.init_conv(x)
x = self.efficientnet_v2(x)

return x

# %% ../nbs/03_network.ipynb 10
model_dict = {
'simple_efficient_net_v2_s': (EfficientNetV2, {'num_classes': 3}),
'efficient_net_v2_s': (EfficientNetV2, {})
}

def get_model(model_key:str, # A key of the model dictionary
weights_path:Union[str, PathLike, BinaryIO, IO[bytes]] = None # A file-like object to the model weights
)->Module: # A pytorch model
"A getter method to retrieve the wanted (possibly pretrained) model"
assert model_key in model_dict, f'{model_key} is not an existing network, choose one from {model_dict.keys()}.'

net_class, kwargs = model_dict[model_key]
model = net_class(**kwargs)

if weights_path is not None:
model.load_state_dict(torch.load(weights_path))

return model
File renamed without changes.
Loading

0 comments on commit 06c3eb0

Please sign in to comment.