Skip to content

Commit

Permalink
Merge dev into main for 0.4.2 release (#123)
Browse files Browse the repository at this point in the history
* re-implement distributions using torch.distributions (#110)

* loosely re-implement distributions using torch.distributions

* further enhancement

* rename

* log_uniform as transformed dist

* delete uniform and log_normal since they are already implemented

* bring back LogNormal, change parameter sampler

* add delta function distribution, minor fixes to other distributions

* fix

* Initial commit of q-transform

* Added a MultiQTransform

* Reparameterized to number of t and f bins

* Re-factored SingleQTransform to allow for eventually using different q's for batch/channels

* Changed interpolation to better match gwpy and added option to specify frequency range in which to search for max energy tile

* Added QScan to __init__

* Changed from torch median to quantile to match numpy median

* Added documentation to qtransform

* Added gwpy to dev dependencies

* Added tests and corrected get_freqs bug

* Changed method of normalization and updated documentation

* Updated more documentation and changed interpolation method

* Changed how interpolation shape is parameterized

* Reverted documentation to state expectation of 3D input

* Fixed type hint for spectrogram_shape in qtransform

* Switched tuple to Tuple

* `InMemoryDataset` improvements (#119)

* in memory dataset inherits from torch iterable dataset

* pre-commit issues

* fix yielding logic

* fix in memory dataset tests

* handle deprecated transpose ops in phenomd (#120)

* Torch dependency fix into dev (#124)

* udpate poetry lock

* increment version once more

* re add gwpy dev dep

* update poetry lock

* fix poetry conflic

* fix poety conflict

* Poetry fix (#126)

* remove gwpy dep for now

* poetry lock file

---------

Co-authored-by: Deep Chatterjee <[email protected]>
Co-authored-by: wbenoit26 <[email protected]>
Co-authored-by: William Benoit <[email protected]>
Co-authored-by: William Benoit <[email protected]>
  • Loading branch information
5 people authored May 13, 2024
1 parent d9aeca9 commit d7027f9
Show file tree
Hide file tree
Showing 11 changed files with 925 additions and 279 deletions.
50 changes: 18 additions & 32 deletions ml4gw/dataloading/in_memory_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ml4gw.utils.slicing import slice_kernels


class InMemoryDataset:
class InMemoryDataset(torch.utils.data.IterableDataset):
"""Dataset for iterating through in-memory multi-channel timeseries
Dataset for arrays of timeseries data which can be stored
Expand Down Expand Up @@ -131,7 +131,6 @@ def __init__(
self.batches_per_epoch = batches_per_epoch
self.shuffle = shuffle
self.coincident = coincident
self._i = self._idx = None

@property
def num_kernels(self) -> int:
Expand All @@ -157,7 +156,7 @@ def __len__(self) -> int:
num_kernels = self.num_kernels ** len(self.X)
return (num_kernels - 1) // self.batch_size + 1

def __iter__(self):
def init_indices(self):
"""
Initialize arrays of indices we'll use to slice
through X and y at iteration time. This helps by
Expand Down Expand Up @@ -204,36 +203,23 @@ def __iter__(self):
# the simplest case: deteriminstic and coincident
idx = torch.arange(num_kernels, device=device)

self._idx = idx
self._i = 0
return self
return idx

def __next__(
def __iter__(
self,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self._i is None or self._idx is None:
raise TypeError(
"Must initialize InMemoryDataset iteration "
"before calling __next__"
)

# check if we're out of batches, and if so
# make sure to reset before stopping
if self._i >= len(self):
self._i = self._idx = None
raise StopIteration

# slice the array of _indices_ we'll be using to
# slice our timeseries, and scale them by the stride
slc = slice(self._i * self.batch_size, (self._i + 1) * self.batch_size)
idx = self._idx[slc] * self.stride

# slice our timeseries
X = slice_kernels(self.X, idx, self.kernel_size)
if self.y is not None:
y = slice_kernels(self.y, idx, self.kernel_size)

self._i += 1
if self.y is not None:
return X, y
return X
indices = self.init_indices()
for i in range(len(self)):
# slice the array of _indices_ we'll be using to
# slice our timeseries, and scale them by the stride
slc = slice(i * self.batch_size, (i + 1) * self.batch_size)
idx = indices[slc] * self.stride

# slice our timeseries
X = slice_kernels(self.X, idx, self.kernel_size)
if self.y is not None:
y = slice_kernels(self.y, idx, self.kernel_size)
yield X, y
else:
yield X
190 changes: 117 additions & 73 deletions ml4gw/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,93 +5,109 @@
from the corresponding distribution.
"""

import math
from typing import Optional

import torch
import torch.distributions as dist


class Uniform:
class Cosine(dist.Distribution):
"""
Sample uniformly between `low` and `high`.
Cosine distribution based on
``torch.distributions.TransformedDistribution``.
"""

def __init__(self, low: float = 0, high: float = 1) -> None:
self.low = low
self.high = high

def __call__(self, N: int) -> torch.Tensor:
return self.low + torch.rand(size=(N,)) * (self.high - self.low)


class Cosine:
"""
Sample from a raised Cosine distribution between
`low` and `high`. Based on the implementation from
bilby documented here:
https://lscsoft.docs.ligo.org/bilby/api/bilby.core.prior.analytical.Cosine.html # noqa
"""
arg_constraints = {}

def __init__(
self, low: float = -math.pi / 2, high: float = math.pi / 2
) -> None:
self,
low: float = torch.as_tensor(-torch.pi / 2),
high: float = torch.as_tensor(torch.pi / 2),
validate_args=None,
):
batch_shape = torch.Size()
super().__init__(batch_shape, validate_args=validate_args)
self.low = low
self.norm = 1 / (math.sin(high) - math.sin(low))
self.norm = 1 / (torch.sin(high) - torch.sin(low))

def __call__(self, N: int) -> torch.Tensor:
"""
Implementation lifted from
https://lscsoft.docs.ligo.org/bilby/_modules/bilby/core/prior/analytical.html#Cosine # noqa
"""
u = torch.rand(size=(N,))
return torch.arcsin(u / self.norm + math.sin(self.low))
def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
u = torch.rand(sample_shape, device=self.low.device)
return torch.arcsin(u / self.norm + torch.sin(self.low))

def log_prob(self, value):
value = torch.as_tensor(value)
inside_range = (value >= self.low) & (value <= self.high)
return value.cos().log() * inside_range

class LogNormal:

class Sine(dist.TransformedDistribution):
"""
Sample from a log normal distribution with the
specified `mean` and standard deviation `std`.
If a `low` value is specified, values sampled
lower than this will be clipped to `low`.
Sine distribution based on
``torch.distributions.TransformedDistribution``.
"""

def __init__(
self, mean: float, std: float, low: Optional[float] = None
) -> None:
self.sigma = math.log((std / mean) ** 2 + 1) ** 0.5
self.mu = 2 * math.log(mean / (mean**2 + std**2) ** 0.25)
self.low = low

def __call__(self, N: int) -> torch.Tensor:

u = self.mu + torch.randn(N) * self.sigma
x = torch.exp(u)

if self.low is not None:
x = torch.clip(x, self.low)
return x


class LogUniform(Uniform):
self,
low: float = torch.as_tensor(0),
high: float = torch.as_tensor(torch.pi),
validate_args=None,
):
base_dist = Cosine(
low - torch.pi / 2, high - torch.pi / 2, validate_args
)
super().__init__(
base_dist,
[
dist.AffineTransform(
loc=torch.pi / 2,
scale=1,
)
],
validate_args=validate_args,
)


class LogUniform(dist.TransformedDistribution):
"""
Sample from a log uniform distribution
"""

def __init__(self, low: float, high: float) -> None:
super().__init__(math.log(low), math.log(high))
def __init__(self, low: float, high: float, validate_args=None):
base_dist = dist.Uniform(
torch.as_tensor(low).log(),
torch.as_tensor(high).log(),
validate_args,
)
super().__init__(
base_dist,
[dist.ExpTransform()],
validate_args=validate_args,
)

def __call__(self, N: int) -> torch.Tensor:
u = super().__call__(N)
return torch.exp(u)

class LogNormal(dist.LogNormal):
def __init__(
self,
mean: float,
std: float,
low: Optional[float] = None,
validate_args=None,
):
self.low = low
super().__init__(loc=mean, scale=std, validate_args=validate_args)

class PowerLaw:
def support(self):
if self.low is not None:
return dist.constraints.greater_than(self.low)


class PowerLaw(dist.TransformedDistribution):
"""
Sample from a power law distribution,
.. math::
p(x) \approx x^{-\alpha}.
p(x) \approx x^{\alpha}.
Index alpha must be greater than 1.
Index alpha cannot be 0, since it is equivalent to a Uniform distribution.
This could be used, for example, as a universal distribution of
signal-to-noise ratios (SNRs) from uniformly volume distributed
sources
Expand All @@ -102,21 +118,49 @@ class PowerLaw:
where :math:`\rho_0` is a representative minimum SNR
considered for detection. See, for example,
`Schutz (2011) <https://arxiv.org/abs/1102.5421>`_.
Or, for example, ``index=2`` for uniform in Euclidean volume.
"""

support = dist.constraints.nonnegative

def __init__(
self, minimum: float, maximum: float, index: int, validate_args=None
):
if index == 0:
raise RuntimeError("Index of 0 is the same as Uniform")
elif index == -1:
base_min = torch.as_tensor(minimum).log()
base_max = torch.as_tensor(maximum).log()
transforms = [dist.ExpTransform()]
else:
index_plus = index + 1
base_min = minimum**index_plus / index_plus
base_max = maximum**index_plus / index_plus
transforms = [
dist.AffineTransform(loc=0, scale=index_plus),
dist.PowerTransform(1 / index_plus),
]
base_dist = dist.Uniform(base_min, base_max, validate_args=False)
super().__init__(
base_dist,
transforms,
validate_args=validate_args,
)


class DeltaFunction(dist.Distribution):
arg_constraints = {}

def __init__(
self, x_min: float, x_max: float = float("inf"), alpha: float = 2
) -> None:
self.x_min = x_min
self.x_max = x_max
self.alpha = alpha

self.normalization = x_min ** (-self.alpha + 1)
self.normalization -= x_max ** (-self.alpha + 1)

def __call__(self, N: int) -> torch.Tensor:
u = torch.rand(N)
u *= self.normalization
samples = self.x_min ** (-self.alpha + 1) - u
samples = torch.pow(samples, -1.0 / (self.alpha - 1))
return samples
self,
peak: float = torch.as_tensor(0.0),
validate_args=None,
):
batch_shape = torch.Size()
super().__init__(batch_shape, validate_args=validate_args)
self.peak = peak

def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
return self.peak * torch.ones(
sample_shape, device=self.peak.device, dtype=torch.float32
)
1 change: 1 addition & 0 deletions ml4gw/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .pearson import ShiftedPearsonCorrelation
from .qtransform import QScan, SingleQTransform
from .scaler import ChannelWiseScaler
from .snr_rescaler import SnrRescaler
from .spectral import SpectralDensity
Expand Down
Loading

0 comments on commit d7027f9

Please sign in to comment.