-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding spectral gate + transform (#54)
* Adding spectral gate v1. * Bump version. * Adding cache and other stuff. * Typo * Adding spectral gate + transform. * Running linter. * Removing warning. Co-authored-by: pseeth <[email protected]>
- Loading branch information
Showing
13 changed files
with
204 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from . import layers | ||
from . import tricks | ||
from .accelerator import Accelerator | ||
from .experiment import Experiment | ||
from .layers.base import BaseModel | ||
from .layers import BaseModel | ||
from .trainer import BaseTrainer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .base import BaseModel | ||
from .spectral_gate import SpectralGate |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
from torch import nn | ||
|
||
from audiotools import AudioSignal | ||
from audiotools import STFTParams | ||
from audiotools.core import util | ||
|
||
|
||
class SpectralGate(nn.Module): | ||
def __init__(self, n_freq: int = 3, n_time: int = 5): | ||
"""Spectral gating algorithm for noise reduction, | ||
as in Audacity/Ocenaudio. The steps are as follows: | ||
1. An FFT is calculated over the noise audio clip | ||
2. Statistics are calculated over FFT of the the noise | ||
(in frequency) | ||
3. A threshold is calculated based upon the statistics | ||
of the noise (and the desired sensitivity of the algorithm) | ||
4. An FFT is calculated over the signal | ||
5. A mask is determined by comparing the signal FFT to the | ||
threshold | ||
6. The mask is smoothed with a filter over frequency and time | ||
7. The mask is appled to the FFT of the signal, and is inverted | ||
Implementation inspired by Tim Sainburg's noisereduce: | ||
https://timsainburg.com/noise-reduction-python.html | ||
Parameters | ||
---------- | ||
model : wav2wav.modules.BaseModel | ||
The model to generate line noise from. | ||
n_freq : int, optional | ||
Number of frequency bins to smooth by, by default 3 | ||
n_time : int, optional | ||
Number of time bins to smooth by, by default 5 | ||
""" | ||
super().__init__() | ||
|
||
smoothing_filter = torch.outer( | ||
torch.cat( | ||
[ | ||
torch.linspace(0, 1, n_freq + 2)[:-1], | ||
torch.linspace(1, 0, n_freq + 2), | ||
] | ||
)[..., 1:-1], | ||
torch.cat( | ||
[ | ||
torch.linspace(0, 1, n_time + 2)[:-1], | ||
torch.linspace(1, 0, n_time + 2), | ||
] | ||
)[..., 1:-1], | ||
) | ||
smoothing_filter = smoothing_filter / smoothing_filter.sum() | ||
smoothing_filter = smoothing_filter.unsqueeze(0).unsqueeze(0) | ||
self.register_buffer("smoothing_filter", smoothing_filter) | ||
|
||
def forward( | ||
self, | ||
audio_signal: AudioSignal, | ||
nz_signal: AudioSignal, | ||
denoise_amount: float = 1.0, | ||
n_std: float = 3.0, | ||
win_length: int = 2048, | ||
hop_length: int = 512, | ||
): | ||
"""Perform noise reduction. | ||
Parameters | ||
---------- | ||
audio_signal : AudioSignal | ||
Audio signal that noise will be removed from. | ||
nz_signal : AudioSignal, optional | ||
Noise signal to compute noise statistics from. | ||
denoise_amount : float, optional | ||
Amount to denoise by, by default 1.0 | ||
n_std : float, optional | ||
Number of standard deviations above which to consider | ||
noise, by default 3.0 | ||
win_length : int, optional | ||
Length of window for STFT, by default 2048 | ||
hop_length : int, optional | ||
Hop length for STFT, by default 512 | ||
Returns | ||
------- | ||
AudioSignal | ||
Denoised audio signal. | ||
""" | ||
stft_params = STFTParams(win_length, hop_length, "sqrt_hann") | ||
|
||
audio_signal = audio_signal.clone() | ||
audio_signal.stft_data = None | ||
audio_signal.stft_params = stft_params | ||
|
||
nz_signal = nz_signal.clone() | ||
nz_signal.stft_params = stft_params | ||
|
||
nz_stft_db = 20 * nz_signal.magnitude.clamp(1e-4).log10() | ||
nz_freq_mean = nz_stft_db.mean(keepdim=True, dim=-1) | ||
nz_freq_std = nz_stft_db.std(keepdim=True, dim=-1) | ||
|
||
nz_thresh = nz_freq_mean + nz_freq_std * n_std | ||
|
||
stft_db = 20 * audio_signal.magnitude.clamp(1e-4).log10() | ||
nb, nac, nf, nt = stft_db.shape | ||
db_thresh = nz_thresh.expand(nb, nac, -1, nt) | ||
|
||
stft_mask = (stft_db < db_thresh).float() | ||
shape = stft_mask.shape | ||
|
||
stft_mask = stft_mask.reshape(nb * nac, 1, nf, nt) | ||
pad_tuple = ( | ||
self.smoothing_filter.shape[-2] // 2, | ||
self.smoothing_filter.shape[-1] // 2, | ||
) | ||
stft_mask = F.conv2d(stft_mask, self.smoothing_filter, padding=pad_tuple) | ||
stft_mask = stft_mask.reshape(*shape) | ||
stft_mask *= util.ensure_tensor(denoise_amount, ndim=stft_mask.ndim) | ||
stft_mask = 1 - stft_mask | ||
|
||
audio_signal.stft_data *= stft_mask | ||
audio_signal.istft() | ||
|
||
return audio_signal |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Git LFS file not shown