Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge dev into main for 0.4.2 release #123

Merged
merged 30 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
7738d5a
re-implement distributions using torch.distributions (#110)
deepchatterjeeligo Feb 21, 2024
cd22197
add delta function distribution, minor fixes to other distributions
deepchatterjeeligo Feb 22, 2024
e98fc42
fix
deepchatterjeeligo Feb 22, 2024
ea03986
Merge pull request #111 from deepchatterjeeligo/delta-dist
wbenoit26 Feb 22, 2024
9ca21cb
Initial commit of q-transform
Mar 8, 2024
10837ba
Added a MultiQTransform
Mar 9, 2024
6c9a89f
Reparameterized to number of t and f bins
Mar 9, 2024
05315a1
Re-factored SingleQTransform to allow for eventually using different …
Mar 10, 2024
be56a29
Changed interpolation to better match gwpy and added option to specif…
Mar 10, 2024
bf7592e
Added QScan to __init__
Mar 10, 2024
a763cb0
Changed from torch median to quantile to match numpy median
Mar 13, 2024
dd5a719
Added documentation to qtransform
wbenoit26 Mar 20, 2024
ac3dfab
Added gwpy to dev dependencies
wbenoit26 Mar 20, 2024
310e0fc
Added tests and corrected get_freqs bug
Mar 22, 2024
45b90c7
Merge pull request #113 from wbenoit26/qtransform
wbenoit26 Apr 16, 2024
70c1692
Changed method of normalization and updated documentation
Apr 22, 2024
9a6b2ea
Updated more documentation and changed interpolation method
Apr 22, 2024
ddc0be1
Changed how interpolation shape is parameterized
Apr 23, 2024
6b7f390
Reverted documentation to state expectation of 3D input
Apr 23, 2024
14e0314
Merge pull request #116 from wbenoit26/generalize_q_dims
wbenoit26 Apr 23, 2024
d16712d
Fixed type hint for spectrogram_shape in qtransform
Apr 24, 2024
01c1b01
Switched tuple to Tuple
Apr 24, 2024
05b38cd
Merge pull request #118 from wbenoit26/fix-type-hints
wbenoit26 Apr 24, 2024
14639ec
`InMemoryDataset` improvements (#119)
EthanMarx May 9, 2024
b33fb12
handle deprecated transpose ops in phenomd (#120)
deepchatterjeeligo May 9, 2024
8c6a706
Torch dependency fix into dev (#124)
EthanMarx May 13, 2024
f697588
update poetry lock
EthanMarx May 13, 2024
b0404c7
fix poetry conflic
EthanMarx May 13, 2024
101678c
fix poety conflict
EthanMarx May 13, 2024
df1ebdc
Poetry fix (#126)
EthanMarx May 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 18 additions & 32 deletions ml4gw/dataloading/in_memory_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ml4gw.utils.slicing import slice_kernels


class InMemoryDataset:
class InMemoryDataset(torch.utils.data.IterableDataset):
"""Dataset for iterating through in-memory multi-channel timeseries

Dataset for arrays of timeseries data which can be stored
Expand Down Expand Up @@ -131,7 +131,6 @@ def __init__(
self.batches_per_epoch = batches_per_epoch
self.shuffle = shuffle
self.coincident = coincident
self._i = self._idx = None

@property
def num_kernels(self) -> int:
Expand All @@ -157,7 +156,7 @@ def __len__(self) -> int:
num_kernels = self.num_kernels ** len(self.X)
return (num_kernels - 1) // self.batch_size + 1

def __iter__(self):
def init_indices(self):
"""
Initialize arrays of indices we'll use to slice
through X and y at iteration time. This helps by
Expand Down Expand Up @@ -204,36 +203,23 @@ def __iter__(self):
# the simplest case: deteriminstic and coincident
idx = torch.arange(num_kernels, device=device)

self._idx = idx
self._i = 0
return self
return idx

def __next__(
def __iter__(
self,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self._i is None or self._idx is None:
raise TypeError(
"Must initialize InMemoryDataset iteration "
"before calling __next__"
)

# check if we're out of batches, and if so
# make sure to reset before stopping
if self._i >= len(self):
self._i = self._idx = None
raise StopIteration

# slice the array of _indices_ we'll be using to
# slice our timeseries, and scale them by the stride
slc = slice(self._i * self.batch_size, (self._i + 1) * self.batch_size)
idx = self._idx[slc] * self.stride

# slice our timeseries
X = slice_kernels(self.X, idx, self.kernel_size)
if self.y is not None:
y = slice_kernels(self.y, idx, self.kernel_size)

self._i += 1
if self.y is not None:
return X, y
return X
indices = self.init_indices()
for i in range(len(self)):
# slice the array of _indices_ we'll be using to
# slice our timeseries, and scale them by the stride
slc = slice(i * self.batch_size, (i + 1) * self.batch_size)
idx = indices[slc] * self.stride

# slice our timeseries
X = slice_kernels(self.X, idx, self.kernel_size)
if self.y is not None:
y = slice_kernels(self.y, idx, self.kernel_size)
yield X, y
else:
yield X
190 changes: 117 additions & 73 deletions ml4gw/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,93 +5,109 @@
from the corresponding distribution.
"""

import math
from typing import Optional

import torch
import torch.distributions as dist


class Uniform:
class Cosine(dist.Distribution):
"""
Sample uniformly between `low` and `high`.
Cosine distribution based on
``torch.distributions.TransformedDistribution``.
"""

def __init__(self, low: float = 0, high: float = 1) -> None:
self.low = low
self.high = high

def __call__(self, N: int) -> torch.Tensor:
return self.low + torch.rand(size=(N,)) * (self.high - self.low)


class Cosine:
"""
Sample from a raised Cosine distribution between
`low` and `high`. Based on the implementation from
bilby documented here:
https://lscsoft.docs.ligo.org/bilby/api/bilby.core.prior.analytical.Cosine.html # noqa
"""
arg_constraints = {}

def __init__(
self, low: float = -math.pi / 2, high: float = math.pi / 2
) -> None:
self,
low: float = torch.as_tensor(-torch.pi / 2),
high: float = torch.as_tensor(torch.pi / 2),
validate_args=None,
):
batch_shape = torch.Size()
super().__init__(batch_shape, validate_args=validate_args)
self.low = low
self.norm = 1 / (math.sin(high) - math.sin(low))
self.norm = 1 / (torch.sin(high) - torch.sin(low))

def __call__(self, N: int) -> torch.Tensor:
"""
Implementation lifted from
https://lscsoft.docs.ligo.org/bilby/_modules/bilby/core/prior/analytical.html#Cosine # noqa
"""
u = torch.rand(size=(N,))
return torch.arcsin(u / self.norm + math.sin(self.low))
def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
u = torch.rand(sample_shape, device=self.low.device)
return torch.arcsin(u / self.norm + torch.sin(self.low))

def log_prob(self, value):
value = torch.as_tensor(value)
inside_range = (value >= self.low) & (value <= self.high)
return value.cos().log() * inside_range

class LogNormal:

class Sine(dist.TransformedDistribution):
"""
Sample from a log normal distribution with the
specified `mean` and standard deviation `std`.
If a `low` value is specified, values sampled
lower than this will be clipped to `low`.
Sine distribution based on
``torch.distributions.TransformedDistribution``.
"""

def __init__(
self, mean: float, std: float, low: Optional[float] = None
) -> None:
self.sigma = math.log((std / mean) ** 2 + 1) ** 0.5
self.mu = 2 * math.log(mean / (mean**2 + std**2) ** 0.25)
self.low = low

def __call__(self, N: int) -> torch.Tensor:

u = self.mu + torch.randn(N) * self.sigma
x = torch.exp(u)

if self.low is not None:
x = torch.clip(x, self.low)
return x


class LogUniform(Uniform):
self,
low: float = torch.as_tensor(0),
high: float = torch.as_tensor(torch.pi),
validate_args=None,
):
base_dist = Cosine(
low - torch.pi / 2, high - torch.pi / 2, validate_args
)
super().__init__(
base_dist,
[
dist.AffineTransform(
loc=torch.pi / 2,
scale=1,
)
],
validate_args=validate_args,
)


class LogUniform(dist.TransformedDistribution):
"""
Sample from a log uniform distribution
"""

def __init__(self, low: float, high: float) -> None:
super().__init__(math.log(low), math.log(high))
def __init__(self, low: float, high: float, validate_args=None):
base_dist = dist.Uniform(
torch.as_tensor(low).log(),
torch.as_tensor(high).log(),
validate_args,
)
super().__init__(
base_dist,
[dist.ExpTransform()],
validate_args=validate_args,
)

def __call__(self, N: int) -> torch.Tensor:
u = super().__call__(N)
return torch.exp(u)

class LogNormal(dist.LogNormal):
def __init__(
self,
mean: float,
std: float,
low: Optional[float] = None,
validate_args=None,
):
self.low = low
super().__init__(loc=mean, scale=std, validate_args=validate_args)

class PowerLaw:
def support(self):
if self.low is not None:
return dist.constraints.greater_than(self.low)


class PowerLaw(dist.TransformedDistribution):
"""
Sample from a power law distribution,
.. math::
p(x) \approx x^{-\alpha}.
p(x) \approx x^{\alpha}.

Index alpha must be greater than 1.
Index alpha cannot be 0, since it is equivalent to a Uniform distribution.
This could be used, for example, as a universal distribution of
signal-to-noise ratios (SNRs) from uniformly volume distributed
sources
Expand All @@ -102,21 +118,49 @@ class PowerLaw:
where :math:`\rho_0` is a representative minimum SNR
considered for detection. See, for example,
`Schutz (2011) <https://arxiv.org/abs/1102.5421>`_.
Or, for example, ``index=2`` for uniform in Euclidean volume.
"""

support = dist.constraints.nonnegative

def __init__(
self, minimum: float, maximum: float, index: int, validate_args=None
):
if index == 0:
raise RuntimeError("Index of 0 is the same as Uniform")
elif index == -1:
base_min = torch.as_tensor(minimum).log()
base_max = torch.as_tensor(maximum).log()
transforms = [dist.ExpTransform()]
else:
index_plus = index + 1
base_min = minimum**index_plus / index_plus
base_max = maximum**index_plus / index_plus
transforms = [
dist.AffineTransform(loc=0, scale=index_plus),
dist.PowerTransform(1 / index_plus),
]
base_dist = dist.Uniform(base_min, base_max, validate_args=False)
super().__init__(
base_dist,
transforms,
validate_args=validate_args,
)


class DeltaFunction(dist.Distribution):
arg_constraints = {}

def __init__(
self, x_min: float, x_max: float = float("inf"), alpha: float = 2
) -> None:
self.x_min = x_min
self.x_max = x_max
self.alpha = alpha

self.normalization = x_min ** (-self.alpha + 1)
self.normalization -= x_max ** (-self.alpha + 1)

def __call__(self, N: int) -> torch.Tensor:
u = torch.rand(N)
u *= self.normalization
samples = self.x_min ** (-self.alpha + 1) - u
samples = torch.pow(samples, -1.0 / (self.alpha - 1))
return samples
self,
peak: float = torch.as_tensor(0.0),
validate_args=None,
):
batch_shape = torch.Size()
super().__init__(batch_shape, validate_args=validate_args)
self.peak = peak

def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
return self.peak * torch.ones(
sample_shape, device=self.peak.device, dtype=torch.float32
)
1 change: 1 addition & 0 deletions ml4gw/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .pearson import ShiftedPearsonCorrelation
from .qtransform import QScan, SingleQTransform
from .scaler import ChannelWiseScaler
from .snr_rescaler import SnrRescaler
from .spectral import SpectralDensity
Expand Down
Loading
Loading