Skip to content

Commit

Permalink
Adding spectral gate + transform (#54)
Browse files Browse the repository at this point in the history
* Adding spectral gate v1.

* Bump version.

* Adding cache and other stuff.

* Typo

* Adding spectral gate + transform.

* Running linter.

* Removing warning.

Co-authored-by: pseeth <[email protected]>
  • Loading branch information
pseeth and pseeth authored Oct 12, 2022
1 parent 39cdf47 commit 0658f2d
Show file tree
Hide file tree
Showing 13 changed files with 204 additions and 29 deletions.
2 changes: 1 addition & 1 deletion audiotools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.4.3"
__version__ = "0.4.4"
from .core import AudioSignal, STFTParams, Meter, util
from . import metrics
from . import data
Expand Down
17 changes: 15 additions & 2 deletions audiotools/core/audio_signal.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import functools
import hashlib
import math
import pathlib
Expand Down Expand Up @@ -392,6 +393,7 @@ def num_channels(self):

# STFT
@staticmethod
@functools.lru_cache(None)
def get_window(window_type, window_length, device):
"""
Wrapper around scipy.signal.get_window so one can also get the
Expand All @@ -402,7 +404,7 @@ def get_window(window_type, window_length, device):
window_length (int): Length of the window
Returns:
np.ndarray: Window returned by scipy.signa.get_window
np.ndarray: Window returned by scipy.signal.get_window
"""
if window_type == "average":
window = np.ones(window_length) / window_length
Expand Down Expand Up @@ -564,12 +566,23 @@ def istft(

return self

@staticmethod
@functools.lru_cache(None)
def get_mel_filters(sr, n_fft, n_mels, fmin=0.0, fmax=None):
return librosa_mel_fn(
sr=sr,
n_fft=n_fft,
n_mels=n_mels,
fmin=fmin,
fmax=fmax,
)

def mel_spectrogram(self, n_mels=80, mel_fmin=0.0, mel_fmax=None, **kwargs):
stft = self.stft(**kwargs)
magnitude = torch.abs(stft)

nf = magnitude.shape[2]
mel_basis = librosa_mel_fn(
mel_basis = self.get_mel_filters(
sr=self.sample_rate,
n_fft=2 * (nf - 1),
n_mels=n_mels,
Expand Down
18 changes: 10 additions & 8 deletions audiotools/core/dsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ def _preprocess_signal_for_windowing(self, window_duration, hop_duration):

return window_length, hop_length

def windows(self, window_duration, hop_duration):
window_length, hop_length = self._preprocess_signal_for_windowing(
window_duration, hop_duration
)
def windows(self, window_duration, hop_duration, preprocess: bool = True):
if preprocess:
window_length, hop_length = self._preprocess_signal_for_windowing(
window_duration, hop_duration
)

self.audio_data = self.audio_data.reshape(-1, 1, self.signal_length)

Expand All @@ -45,7 +46,7 @@ def windows(self, window_duration, hop_duration):
break
yield self[b, ..., start_idx:end_idx]

def collect_windows(self, window_duration, hop_duration):
def collect_windows(self, window_duration, hop_duration, preprocess: bool = True):
"""Function which collects overlapping windows from
an AudioSignal.
Expand All @@ -58,9 +59,10 @@ def collect_windows(self, window_duration, hop_duration):
Returns:
AudioSignal: Signal of shape (nb * num_windows, nc, window_length).
"""
window_length, hop_length = self._preprocess_signal_for_windowing(
window_duration, hop_duration
)
if preprocess:
window_length, hop_length = self._preprocess_signal_for_windowing(
window_duration, hop_duration
)

# self.audio_data: (nb, nch, nt).
unfolded = torch.nn.functional.unfold(
Expand Down
32 changes: 32 additions & 0 deletions audiotools/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from numpy.random import RandomState
from yaml import load

from .. import ml
from ..core import AudioSignal
from ..core import util
from .datasets import AudioLoader
Expand Down Expand Up @@ -811,3 +812,34 @@ def _transform(self, signal, fmin_hz: float, fmax_hz: float):
signal.magnitude = mag
signal.phase = phase
return signal


class SpectralDenoising(Equalizer):
def __init__(
self,
eq_amount: tuple = ("const", 1.0),
denoise_amount: tuple = ("uniform", 0.8, 1.0),
nz_volume: float = -40,
n_bands: int = 6,
n_freq: int = 3,
n_time: int = 5,
name: str = None,
prob: float = 1,
):
super().__init__(eq_amount=eq_amount, n_bands=n_bands, name=name, prob=prob)

self.nz_volume = nz_volume
self.denoise_amount = denoise_amount
self.spectral_gate = ml.layers.SpectralGate(n_freq, n_time)

def _transform(self, signal, nz, eq, denoise_amount):
nz = nz.normalize(self.nz_volume).equalizer(eq)
self.spectral_gate = self.spectral_gate.to(signal.device)
signal = self.spectral_gate(signal, nz, denoise_amount)
return signal

def _instantiate(self, state: RandomState):
kwargs = super()._instantiate(state)
kwargs["denoise_amount"] = util.sample_from_dist(self.denoise_amount, state)
kwargs["nz"] = AudioSignal(state.randn(22050), 44100)
return kwargs
3 changes: 2 additions & 1 deletion audiotools/ml/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from . import layers
from . import tricks
from .accelerator import Accelerator
from .experiment import Experiment
from .layers.base import BaseModel
from .layers import BaseModel
from .trainer import BaseTrainer
7 changes: 2 additions & 5 deletions audiotools/ml/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,8 @@ def prepare_model(self, model, **kwargs):
return model

# Automatic mixed-precision utilities
def autocast(self):
if self.amp:
return torch.cuda.amp.autocast()
else:
return contextlib.nullcontext()
def autocast(self, *args, **kwargs):
return torch.cuda.amp.autocast(self.amp, *args, **kwargs)

def backward(self, loss):
self.scaler.scale(loss).backward()
Expand Down
2 changes: 2 additions & 0 deletions audiotools/ml/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .base import BaseModel
from .spectral_gate import SpectralGate
126 changes: 126 additions & 0 deletions audiotools/ml/layers/spectral_gate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import torch
import torch.nn.functional as F
from torch import nn

from audiotools import AudioSignal
from audiotools import STFTParams
from audiotools.core import util


class SpectralGate(nn.Module):
def __init__(self, n_freq: int = 3, n_time: int = 5):
"""Spectral gating algorithm for noise reduction,
as in Audacity/Ocenaudio. The steps are as follows:
1. An FFT is calculated over the noise audio clip
2. Statistics are calculated over FFT of the the noise
(in frequency)
3. A threshold is calculated based upon the statistics
of the noise (and the desired sensitivity of the algorithm)
4. An FFT is calculated over the signal
5. A mask is determined by comparing the signal FFT to the
threshold
6. The mask is smoothed with a filter over frequency and time
7. The mask is appled to the FFT of the signal, and is inverted
Implementation inspired by Tim Sainburg's noisereduce:
https://timsainburg.com/noise-reduction-python.html
Parameters
----------
model : wav2wav.modules.BaseModel
The model to generate line noise from.
n_freq : int, optional
Number of frequency bins to smooth by, by default 3
n_time : int, optional
Number of time bins to smooth by, by default 5
"""
super().__init__()

smoothing_filter = torch.outer(
torch.cat(
[
torch.linspace(0, 1, n_freq + 2)[:-1],
torch.linspace(1, 0, n_freq + 2),
]
)[..., 1:-1],
torch.cat(
[
torch.linspace(0, 1, n_time + 2)[:-1],
torch.linspace(1, 0, n_time + 2),
]
)[..., 1:-1],
)
smoothing_filter = smoothing_filter / smoothing_filter.sum()
smoothing_filter = smoothing_filter.unsqueeze(0).unsqueeze(0)
self.register_buffer("smoothing_filter", smoothing_filter)

def forward(
self,
audio_signal: AudioSignal,
nz_signal: AudioSignal,
denoise_amount: float = 1.0,
n_std: float = 3.0,
win_length: int = 2048,
hop_length: int = 512,
):
"""Perform noise reduction.
Parameters
----------
audio_signal : AudioSignal
Audio signal that noise will be removed from.
nz_signal : AudioSignal, optional
Noise signal to compute noise statistics from.
denoise_amount : float, optional
Amount to denoise by, by default 1.0
n_std : float, optional
Number of standard deviations above which to consider
noise, by default 3.0
win_length : int, optional
Length of window for STFT, by default 2048
hop_length : int, optional
Hop length for STFT, by default 512
Returns
-------
AudioSignal
Denoised audio signal.
"""
stft_params = STFTParams(win_length, hop_length, "sqrt_hann")

audio_signal = audio_signal.clone()
audio_signal.stft_data = None
audio_signal.stft_params = stft_params

nz_signal = nz_signal.clone()
nz_signal.stft_params = stft_params

nz_stft_db = 20 * nz_signal.magnitude.clamp(1e-4).log10()
nz_freq_mean = nz_stft_db.mean(keepdim=True, dim=-1)
nz_freq_std = nz_stft_db.std(keepdim=True, dim=-1)

nz_thresh = nz_freq_mean + nz_freq_std * n_std

stft_db = 20 * audio_signal.magnitude.clamp(1e-4).log10()
nb, nac, nf, nt = stft_db.shape
db_thresh = nz_thresh.expand(nb, nac, -1, nt)

stft_mask = (stft_db < db_thresh).float()
shape = stft_mask.shape

stft_mask = stft_mask.reshape(nb * nac, 1, nf, nt)
pad_tuple = (
self.smoothing_filter.shape[-2] // 2,
self.smoothing_filter.shape[-1] // 2,
)
stft_mask = F.conv2d(stft_mask, self.smoothing_filter, padding=pad_tuple)
stft_mask = stft_mask.reshape(*shape)
stft_mask *= util.ensure_tensor(denoise_amount, ndim=stft_mask.ndim)
stft_mask = 1 - stft_mask

audio_signal.stft_data *= stft_mask
audio_signal.istft()

return audio_signal
9 changes: 4 additions & 5 deletions audiotools/ml/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
rank=0,
quiet: bool = False,
record_memory: bool = False,
log_file: str = "log.txt",
**kwargs,
):
"""
Expand Down Expand Up @@ -123,7 +124,7 @@ def __init__(
>>> if self.is_best(engine, 'mse/val'):
>>> model.save('checkpoints/best.model.pth')
>>> if self.top_k(engine, 'mse/val', 5):
>>> model.save(f'checkpoints/top{k}.{epoch}.model.pth')
>>> model.save(f'checkpoints/top5.{epoch}.model.pth')
>>>
>>> trainer = Trainer(writer=tb)
>>> trainer.run(train_data, val_data, num_epochs=3)
Expand Down Expand Up @@ -161,7 +162,7 @@ def __init__(
)
self.live = self.pbar
self.epoch_summary = None
self.log_file = None
self.log_file = log_file

# Set up trainer engine
self.trainer = ignite.engine.Engine(self._train_loop)
Expand All @@ -187,8 +188,6 @@ def __init__(
)

for k, v in kwargs.items():
if hasattr(self, k):
raise ValueError(f"{k} set in kwargs but overwrites self.{k}.")
setattr(self, k, v)

f = lambda: {
Expand Down Expand Up @@ -226,7 +225,7 @@ def __init__(
)

self.stdout_console = Console()
self.file_console = Console(file=open("log.txt", "w"))
self.file_console = Console(file=open(self.log_file, "w"))

@property
def state(self):
Expand Down
10 changes: 6 additions & 4 deletions audiotools/post.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,20 +97,22 @@ def upload_figure_to_discourse(
return formatted, info


def audio_table(audio_dict, first_column=None, format_fn=None): # pragma: no cover
def audio_table(
audio_dict, first_column=None, format_fn=None, **kwargs
): # pragma: no cover
from audiotools import AudioSignal

output = []
columns = None

def _default_format_fn(label, x):
def _default_format_fn(label, x, **kwargs):
if torch.is_tensor(x):
x = x.tolist()

if x is None:
return "."
elif isinstance(x, AudioSignal):
return x.embed(display=False, return_html=True)
return x.embed(display=False, return_html=True, **kwargs)
else:
return str(x)

Expand All @@ -135,7 +137,7 @@ def _default_format_fn(label, x):

formatted_audio = []
for col in columns[1:]:
formatted_audio.append(format_fn(col, v[col]))
formatted_audio.append(format_fn(col, v[col], **kwargs))

row = f"| {k} | "
row += " | ".join(formatted_audio)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name="audiotools",
version="0.4.3",
version="0.4.4",
classifiers=[
"Intended Audience :: Developers",
"Intended Audience :: Education",
Expand Down
2 changes: 0 additions & 2 deletions tests/ml/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ def checkpoint(self, engine):

assert trainer.state == trainer.trainer.state

pytest.raises(ValueError, Trainer, log_file="something")

trainer = Trainer(quiet=True)
trainer.run(train_data, val_data, num_epochs=5)

Expand Down
3 changes: 3 additions & 0 deletions tests/regression/transforms/SpectralDenoising.wav
Git LFS file not shown

0 comments on commit 0658f2d

Please sign in to comment.