Skip to content

Commit

Permalink
Add IIRFilter transform for building and applying scipy filters (ML…
Browse files Browse the repository at this point in the history
…4GW#189)

* rough high/low pass filter implementation

* rename import

* internalize functions

* move filter to transforms and add docstrings

* move to filters.py

* create torch module

* add tests

* switch to scipy for filter coeff generation

* use union type for p3.9 compatibility

* add support for other filters

* hardcode output since torchaudio takes (b, a) as input

* export constants to be accessed as ml4gw.constants

* add phenom tests

* fix unassigned variable bug

* fix tolerance on phenom test for filters

* link scipy function

* tests for other filters (cheby1, cheby2, ellip, bessel)

* increase tolerance for phenom tests

* update Return documentation
  • Loading branch information
ravioli1369 authored Feb 4, 2025
1 parent b0548be commit 6f0514e
Show file tree
Hide file tree
Showing 3 changed files with 450 additions and 0 deletions.
1 change: 1 addition & 0 deletions ml4gw/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .constants import *
100 changes: 100 additions & 0 deletions ml4gw/transforms/iirfilter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from typing import Union

import torch
from scipy.signal import iirfilter
from torchaudio.functional import filtfilt


class IIRFilter(torch.nn.Module):
r"""
IIR digital and analog filter design given order and critical points.
Design an Nth-order digital or analog filter and apply it to a signal.
Uses SciPy's `iirfilter` function to create the filter coefficients.
https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirfilter.html # noqa E501
The forward call of this module accepts a batch tensor of shape
(n_waveforms, n_samples) and returns the filtered waveforms.
Args:
N:
The order of the filter.
Wn:
A scalar or length-2 sequence giving the critical frequencies.
For digital filters, Wn are in the same units as fs. By
default, fs is 2 half-cycles/sample, so these are normalized
from 0 to 1, where 1 is the Nyquist frequency. (Wn is thus in
half-cycles / sample). For analog filters, Wn is an angular
frequency (e.g., rad/s). When Wn is a length-2 sequence,`Wn[0]`
must be less than `Wn[1]`.
rp:
For Chebyshev and elliptic filters, provides the maximum ripple in
the passband. (dB)
rs:
For Chebyshev and elliptic filters, provides the minimum
attenuation in the stop band. (dB)
btype:
The type of filter. Default is 'bandpass'.
analog:
When True, return an analog filter, otherwise a digital filter
is returned.
ftype:
The type of IIR filter to design:
- Butterworth : 'butter'
- Chebyshev I : 'cheby1'
- Chebyshev II : 'cheby2'
- Cauer/elliptic: 'ellip'
- Bessel/Thomson: 'bessel's
fs:
The sampling frequency of the digital system.
Returns:
Filtered signal on the forward pass.
"""

def __init__(
self,
N: int,
Wn: Union[float, torch.Tensor],
rs: Union[None, float, torch.Tensor] = None,
rp: Union[None, float, torch.Tensor] = None,
btype="band",
analog=False,
ftype="butter",
fs=None,
) -> None:
super().__init__()

if isinstance(Wn, torch.Tensor):
Wn = Wn.numpy()
if isinstance(rs, torch.Tensor):
rs = rs.numpy()
if isinstance(rp, torch.Tensor):
rp = rp.numpy()

b, a = iirfilter(
N,
Wn,
rs=rs,
rp=rp,
btype=btype,
analog=analog,
ftype=ftype,
output="ba",
fs=fs,
)
self.register_buffer("b", torch.tensor(b))
self.register_buffer("a", torch.tensor(a))

def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""
Apply the filter to the input signal.
Args:
x:
The input signal to be filtered.
Returns:
The filtered signal.
"""
return filtfilt(x, self.a, self.b, clamp=False)
Loading

0 comments on commit 6f0514e

Please sign in to comment.