Skip to content

Commit

Permalink
Merge pull request #92 from EthanMarx/augmentations
Browse files Browse the repository at this point in the history
Add `SignalInverter` and `SignalReverser` Augmentations
  • Loading branch information
EthanMarx authored Jan 26, 2024
2 parents 0eba829 + 3fae734 commit a8e64bf
Show file tree
Hide file tree
Showing 3 changed files with 1,419 additions and 1,102 deletions.
43 changes: 43 additions & 0 deletions ml4gw/augmentations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch


class SignalInverter(torch.nn.Module):
"""
Takes a tensor of timeseries of arbitrary dimension
and randomly inverts (i.e. h(t) -> -h(t))
each timeseries with probability `prob`.
Args:
prob:
Probability that a timeseries is inverted
"""

def __init__(self, prob: float = 0.5):
super().__init__()
self.prob = prob

def forward(self, X):
mask = torch.rand(size=X.shape[:-1]) < self.prob
X[mask] *= -1
return X


class SignalReverser(torch.nn.Module):
"""
Takes a tensor of timeseries of arbitrary dimension
and randomly reverses (i.e. h(t) -> h(-t))
each timeseries with probability `prob`.
Args:
prob:
Probability that a kernel is reversed
"""

def __init__(self, prob: float = 0.5):
super().__init__()
self.prob = prob

def forward(self, X):
mask = torch.rand(size=X.shape[:-1]) < self.prob
X[mask] = X[mask].flip(-1)
return X
Loading

0 comments on commit a8e64bf

Please sign in to comment.