forked from ML4GW/ml4gw
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
IIRFilter
transform for building and applying scipy filters (ML…
…4GW#189) * rough high/low pass filter implementation * rename import * internalize functions * move filter to transforms and add docstrings * move to filters.py * create torch module * add tests * switch to scipy for filter coeff generation * use union type for p3.9 compatibility * add support for other filters * hardcode output since torchaudio takes (b, a) as input * export constants to be accessed as ml4gw.constants * add phenom tests * fix unassigned variable bug * fix tolerance on phenom test for filters * link scipy function * tests for other filters (cheby1, cheby2, ellip, bessel) * increase tolerance for phenom tests * update Return documentation
- Loading branch information
1 parent
b0548be
commit 6f0514e
Showing
3 changed files
with
450 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .constants import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
from typing import Union | ||
|
||
import torch | ||
from scipy.signal import iirfilter | ||
from torchaudio.functional import filtfilt | ||
|
||
|
||
class IIRFilter(torch.nn.Module): | ||
r""" | ||
IIR digital and analog filter design given order and critical points. | ||
Design an Nth-order digital or analog filter and apply it to a signal. | ||
Uses SciPy's `iirfilter` function to create the filter coefficients. | ||
https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirfilter.html # noqa E501 | ||
The forward call of this module accepts a batch tensor of shape | ||
(n_waveforms, n_samples) and returns the filtered waveforms. | ||
Args: | ||
N: | ||
The order of the filter. | ||
Wn: | ||
A scalar or length-2 sequence giving the critical frequencies. | ||
For digital filters, Wn are in the same units as fs. By | ||
default, fs is 2 half-cycles/sample, so these are normalized | ||
from 0 to 1, where 1 is the Nyquist frequency. (Wn is thus in | ||
half-cycles / sample). For analog filters, Wn is an angular | ||
frequency (e.g., rad/s). When Wn is a length-2 sequence,`Wn[0]` | ||
must be less than `Wn[1]`. | ||
rp: | ||
For Chebyshev and elliptic filters, provides the maximum ripple in | ||
the passband. (dB) | ||
rs: | ||
For Chebyshev and elliptic filters, provides the minimum | ||
attenuation in the stop band. (dB) | ||
btype: | ||
The type of filter. Default is 'bandpass'. | ||
analog: | ||
When True, return an analog filter, otherwise a digital filter | ||
is returned. | ||
ftype: | ||
The type of IIR filter to design: | ||
- Butterworth : 'butter' | ||
- Chebyshev I : 'cheby1' | ||
- Chebyshev II : 'cheby2' | ||
- Cauer/elliptic: 'ellip' | ||
- Bessel/Thomson: 'bessel's | ||
fs: | ||
The sampling frequency of the digital system. | ||
Returns: | ||
Filtered signal on the forward pass. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
N: int, | ||
Wn: Union[float, torch.Tensor], | ||
rs: Union[None, float, torch.Tensor] = None, | ||
rp: Union[None, float, torch.Tensor] = None, | ||
btype="band", | ||
analog=False, | ||
ftype="butter", | ||
fs=None, | ||
) -> None: | ||
super().__init__() | ||
|
||
if isinstance(Wn, torch.Tensor): | ||
Wn = Wn.numpy() | ||
if isinstance(rs, torch.Tensor): | ||
rs = rs.numpy() | ||
if isinstance(rp, torch.Tensor): | ||
rp = rp.numpy() | ||
|
||
b, a = iirfilter( | ||
N, | ||
Wn, | ||
rs=rs, | ||
rp=rp, | ||
btype=btype, | ||
analog=analog, | ||
ftype=ftype, | ||
output="ba", | ||
fs=fs, | ||
) | ||
self.register_buffer("b", torch.tensor(b)) | ||
self.register_buffer("a", torch.tensor(a)) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
r""" | ||
Apply the filter to the input signal. | ||
Args: | ||
x: | ||
The input signal to be filtered. | ||
Returns: | ||
The filtered signal. | ||
""" | ||
return filtfilt(x, self.a, self.b, clamp=False) |
Oops, something went wrong.