Skip to content

Commit

Permalink
Merge pull request #160 from ML4GW/dev
Browse files Browse the repository at this point in the history
Merge dev into main
  • Loading branch information
wbenoit26 authored Oct 1, 2024
2 parents c2e1937 + 2fb7e19 commit 0cdff78
Show file tree
Hide file tree
Showing 41 changed files with 1,338 additions and 1,196 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
max-line-length = 79
max-complexity = 18
select = B,C,E,F,W,T4,B9
ignore = W503, E203 # ignore for consistency with black
ignore = W503, E203, F722 # ignore for consistency with black and jaxtyping

# ignore asterisk imports and unused
# import errors in __init__ files
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/unit-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11']
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
steps:
- uses: actions/checkout@v2

Expand Down
10 changes: 8 additions & 2 deletions ml4gw/augmentations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import torch
from jaxtyping import Float
from torch import Tensor


class SignalInverter(torch.nn.Module):
Expand All @@ -16,7 +18,9 @@ def __init__(self, prob: float = 0.5):
super().__init__()
self.prob = prob

def forward(self, X):
def forward(
self, X: Float[Tensor, "*batch time"]
) -> Float[Tensor, "*batch time"]:
mask = torch.rand(size=X.shape[:-1]) < self.prob
X[mask] *= -1
return X
Expand All @@ -37,7 +41,9 @@ def __init__(self, prob: float = 0.5):
super().__init__()
self.prob = prob

def forward(self, X):
def forward(
self, X: Float[Tensor, "*batch time"]
) -> Float[Tensor, "*batch time"]:
mask = torch.rand(size=X.shape[:-1]) < self.prob
X[mask] = X[mask].flip(-1)
return X
6 changes: 4 additions & 2 deletions ml4gw/dataloading/chunked_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import torch

from ml4gw.types import WaveformTensor


class ChunkedTimeSeriesDataset(torch.utils.data.IterableDataset):
"""
Expand Down Expand Up @@ -55,10 +57,10 @@ def __init__(
self.coincident = coincident
self.device = device

def __len__(self):
def __len__(self) -> int:
return len(self.chunk_it) * self.batches_per_chunk

def __iter__(self):
def __iter__(self) -> WaveformTensor:
it = iter(self.chunk_it)
chunk = next(it)
num_chunks, num_channels, chunk_size = chunk.shape
Expand Down
2 changes: 1 addition & 1 deletion ml4gw/dataloading/hdf5_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def sample_batch(self) -> WaveformTensor:
x[b, c] = f[self.channels[c]][i : i + self.kernel_size]
return torch.Tensor(x)

def __iter__(self) -> torch.Tensor:
def __iter__(self) -> WaveformTensor:
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
num_batches = self.batches_per_epoch
Expand Down
12 changes: 8 additions & 4 deletions ml4gw/dataloading/in_memory_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from typing import Optional, Tuple, Union

import torch
from jaxtyping import Float
from torch import Tensor

from ml4gw import types
from ml4gw.utils.slicing import slice_kernels


Expand Down Expand Up @@ -76,9 +77,9 @@ class InMemoryDataset(torch.utils.data.IterableDataset):

def __init__(
self,
X: types.TimeSeriesTensor,
X: Float[Tensor, "channels time"],
kernel_size: int,
y: Optional[types.ScalarTensor] = None,
y: Optional[Float[Tensor, " time"]] = None,
batch_size: int = 32,
stride: int = 1,
batches_per_epoch: Optional[int] = None,
Expand Down Expand Up @@ -207,7 +208,10 @@ def init_indices(self):

def __iter__(
self,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
) -> Union[
Float[Tensor, "batch channel time"],
Tuple[Float[Tensor, "batch channel time"], Float[Tensor, " batch"]],
]:

indices = self.init_indices()
for i in range(len(self)):
Expand Down
8 changes: 5 additions & 3 deletions ml4gw/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import torch
import torch.distributions as dist
from jaxtyping import Float
from torch import Tensor


class Cosine(dist.Distribution):
Expand All @@ -31,11 +33,11 @@ def __init__(
self.high = torch.as_tensor(high)
self.norm = 1 / (torch.sin(self.high) - torch.sin(self.low))

def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
def rsample(self, sample_shape: torch.Size = torch.Size()) -> Tensor:
u = torch.rand(sample_shape, device=self.low.device)
return torch.arcsin(u / self.norm + torch.sin(self.low))

def log_prob(self, value):
def log_prob(self, value: float) -> Float[Tensor, ""]:
value = torch.as_tensor(value)
inside_range = (value >= self.low) & (value <= self.high)
return value.cos().log() * inside_range
Expand Down Expand Up @@ -164,7 +166,7 @@ def __init__(
super().__init__(batch_shape, validate_args=validate_args)
self.peak = torch.as_tensor(peak)

def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
def rsample(self, sample_shape: torch.Size = torch.Size()) -> Tensor:
return self.peak * torch.ones(
sample_shape, device=self.peak.device, dtype=torch.float32
)
48 changes: 21 additions & 27 deletions ml4gw/gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,21 @@
from typing import List, Tuple, Union

import torch
from torchtyping import TensorType
from jaxtyping import Float
from torch import Tensor

from ml4gw.constants import C
from ml4gw.types import (
BatchTensor,
NetworkDetectorTensors,
NetworkVertices,
PSDTensor,
ScalarTensor,
TensorGeometry,
VectorGeometry,
WaveformTensor,
)
from ml4gw.utils.interferometer import InterferometerGeometry

SPEED_OF_LIGHT = 299792458.0 # m/s


# define some tensor shapes we'll reuse a bit
# up front. Need to assign these variables so
# that static linters don't give us name errors
batch = num_ifos = polarizations = time = frequency = space = None # noqa


def outer(x: VectorGeometry, y: VectorGeometry) -> TensorGeometry:
"""
Expand Down Expand Up @@ -62,12 +56,12 @@ def breathing(m: VectorGeometry, n: VectorGeometry) -> TensorGeometry:


def compute_antenna_responses(
theta: ScalarTensor,
psi: ScalarTensor,
phi: ScalarTensor,
theta: BatchTensor,
psi: BatchTensor,
phi: BatchTensor,
detector_tensors: NetworkDetectorTensors,
modes: List[str],
) -> TensorType["batch", "polarizations", "num_ifos"]:
) -> Float[Tensor, "batch polarizations num_ifos"]:
"""
Compute the antenna pattern factors of a batch of
waveforms as a function of the sky parameters of
Expand Down Expand Up @@ -147,8 +141,8 @@ def compute_antenna_responses(

def shift_responses(
responses: WaveformTensor,
theta: ScalarTensor,
phi: ScalarTensor,
theta: BatchTensor,
phi: BatchTensor,
vertices: NetworkVertices,
sample_rate: float,
) -> WaveformTensor:
Expand All @@ -166,7 +160,7 @@ def shift_responses(
# Divide by c in the second line so that we only
# need to multiply the array by a single float
dt = -(omega * vertices).sum(axis=-1)
dt *= sample_rate / SPEED_OF_LIGHT
dt *= sample_rate / C
dt = torch.trunc(dt).type(torch.int64)

# rolling by gathering implementation based on
Expand All @@ -191,13 +185,13 @@ def shift_responses(


def compute_observed_strain(
dec: ScalarTensor,
psi: ScalarTensor,
phi: ScalarTensor,
dec: BatchTensor,
psi: BatchTensor,
phi: BatchTensor,
detector_tensors: NetworkDetectorTensors,
detector_vertices: NetworkVertices,
sample_rate: float,
**polarizations: TensorType["batch", "time"],
**polarizations: Float[Tensor, "batch time"],
) -> WaveformTensor:
"""
Compute the strain timeseries $h(t)$ observed by a network
Expand Down Expand Up @@ -289,8 +283,8 @@ def compute_ifo_snr(
responses: WaveformTensor,
psd: PSDTensor,
sample_rate: float,
highpass: Union[float, TensorType["frequency"], None] = None,
) -> TensorType["batch", "num_ifos"]:
highpass: Union[float, Float[Tensor, " frequency"], None] = None,
) -> Float[Tensor, "batch num_ifos"]:
r"""Compute the SNRs of a batch of interferometer responses
Compute the signal to noise ratio (SNR) of individual
Expand Down Expand Up @@ -390,8 +384,8 @@ def compute_network_snr(
responses: WaveformTensor,
psd: PSDTensor,
sample_rate: float,
highpass: Union[float, TensorType["frequency"], None] = None,
) -> ScalarTensor:
highpass: Union[float, Float[Tensor, " frequency"], None] = None,
) -> BatchTensor:
r"""
Compute the total SNR from a gravitational waveform
from a network of interferometers. The total SNR for
Expand Down Expand Up @@ -437,10 +431,10 @@ def compute_network_snr(

def reweight_snrs(
responses: WaveformTensor,
target_snrs: Union[float, ScalarTensor],
target_snrs: Union[float, BatchTensor],
psd: PSDTensor,
sample_rate: float,
highpass: Union[float, TensorType["frequency"], None] = None,
highpass: Union[float, Float[Tensor, " frequency"], None] = None,
) -> WaveformTensor:
"""Scale interferometer responses such that they have a desired SNR
Expand Down
17 changes: 11 additions & 6 deletions ml4gw/nn/autoencoder/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from collections.abc import Sequence
from typing import Optional
from typing import Optional, Tuple, Union

import torch
from torch import Tensor

from ml4gw.nn.autoencoder.skip_connection import SkipConnection

Expand All @@ -27,12 +28,16 @@ class Autoencoder(torch.nn.Module):
and how they operate.
"""

def __init__(self, skip_connection: Optional[SkipConnection] = None):
def __init__(
self, skip_connection: Optional[SkipConnection] = None
) -> None:
super().__init__()
self.skip_connection = skip_connection
self.blocks = torch.nn.ModuleList()

def encode(self, *X: torch.Tensor, return_states: bool = False):
def encode(
self, *X: Tensor, return_states: bool = False
) -> Union[Tensor, Tuple[Tensor, Sequence]]:
states = []
for block in self.blocks:
if isinstance(X, tuple):
Expand All @@ -48,7 +53,7 @@ def encode(self, *X: torch.Tensor, return_states: bool = False):
return X, states[:-1]
return X

def decode(self, *X, states: Optional[Sequence[torch.Tensor]] = None):
def decode(self, *X, states: Optional[Sequence[Tensor]] = None) -> Tensor:
if self.skip_connection is not None and states is None:
raise ValueError(
"Must pass intermediate states when autoencoder "
Expand Down Expand Up @@ -76,14 +81,14 @@ def decode(self, *X, states: Optional[Sequence[torch.Tensor]] = None):
X = self.skip_connection(X, state)
return X

def forward(self, *X):
def forward(self, *X: Tensor) -> Tensor:
return_states = self.skip_connection is not None
X = self.encode(*X, return_states=return_states)
if return_states:
*X, states = X
else:
states = None

if isinstance(X, torch.Tensor):
if isinstance(X, Tensor):
X = (X,)
return self.decode(*X, states=states)
11 changes: 7 additions & 4 deletions ml4gw/nn/autoencoder/convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Optional

import torch
from torch import Tensor

from ml4gw.nn.autoencoder.base import Autoencoder
from ml4gw.nn.autoencoder.skip_connection import SkipConnection
Expand Down Expand Up @@ -64,12 +65,12 @@ def __init__(
self.encode_norm = norm(out_channels)
self.decode_norm = norm(decode_channels)

def encode(self, X):
def encode(self, X: Tensor) -> Tensor:
X = self.encode_layer(X)
X = self.encode_norm(X)
return self.activation(X)

def decode(self, X):
def decode(self, X: Tensor) -> Tensor:
X = self.decode_layer(X)
X = self.decode_norm(X)
return self.output_activation(X)
Expand Down Expand Up @@ -144,13 +145,15 @@ def __init__(
self.blocks.append(block)
in_channels = channels * groups

def decode(self, *X, states=None, input_size: Optional[int] = None):
def decode(
self, *X, states=None, input_size: Optional[int] = None
) -> Tensor:
X = super().decode(*X, states=states)
if input_size is not None:
return match_size(X, input_size)
return X

def forward(self, X):
def forward(self, X: Tensor) -> Tensor:
input_size = X.size(-1)
X = super().forward(X)
return match_size(X, input_size)
Loading

0 comments on commit 0cdff78

Please sign in to comment.