Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge in dev #93

Merged
merged 24 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
d28501e
reformulating things to build batch-level dataloader (#67)
alecgunny Sep 6, 2023
2b1f3e9
Snapshotter and online averaging models (#76)
alecgunny Nov 9, 2023
4170b79
shifted pearson correlation (#78)
alecgunny Nov 26, 2023
6ae7a76
adding autoencoder lib
alecgunny Nov 30, 2023
8139271
adding init imports
alecgunny Nov 30, 2023
b883116
running pre-commit hooks
alecgunny Dec 1, 2023
49234c1
swapping num_ifos for in_channels
alecgunny Dec 1, 2023
0488f12
Merge pull request #84 from alecgunny/autoencoders
EthanMarx Dec 1, 2023
9adbcdc
add taylorF2 3.5PN in torch
deepchatterjeeligo Dec 1, 2023
cb796cb
Merge pull request #85 from deepchatterjeeligo/taylorf2
wbenoit26 Dec 7, 2023
0e227af
add spin contributions to PN coeffs
deepchatterjeeligo Dec 14, 2023
7bfabea
fix
deepchatterjeeligo Dec 15, 2023
7a2effc
avoid extra kernel launch
deepchatterjeeligo Dec 15, 2023
f62b1f8
Merge pull request #88 from deepchatterjeeligo/spin-taylor
wbenoit26 Dec 15, 2023
30c025d
add batched IMRPhenomD
deepchatterjeeligo Dec 22, 2023
ab15d6b
add phenom_d data
deepchatterjeeligo Jan 5, 2024
0eba829
Merge pull request #89 from deepchatterjeeligo/phenomd
wbenoit26 Jan 10, 2024
0929463
add signal inverter and reversal augmentations
EthanMarx Jan 25, 2024
03b38d8
add signal inverter and reversal augmentations
EthanMarx Jan 25, 2024
4b750ff
add tests
EthanMarx Jan 25, 2024
3fae734
remove prob fixture
EthanMarx Jan 26, 2024
a8e64bf
Merge pull request #92 from EthanMarx/augmentations
EthanMarx Jan 26, 2024
5f2d1a9
bump version
EthanMarx Jan 26, 2024
7c83d41
Merge pull request #94 from EthanMarx/bump-version
wbenoit26 Jan 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions ml4gw/augmentations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch


class SignalInverter(torch.nn.Module):
"""
Takes a tensor of timeseries of arbitrary dimension
and randomly inverts (i.e. h(t) -> -h(t))
each timeseries with probability `prob`.

Args:
prob:
Probability that a timeseries is inverted
"""

def __init__(self, prob: float = 0.5):
super().__init__()
self.prob = prob

def forward(self, X):
mask = torch.rand(size=X.shape[:-1]) < self.prob
X[mask] *= -1
return X


class SignalReverser(torch.nn.Module):
"""
Takes a tensor of timeseries of arbitrary dimension
and randomly reverses (i.e. h(t) -> h(-t))
each timeseries with probability `prob`.

Args:
prob:
Probability that a kernel is reversed
"""

def __init__(self, prob: float = 0.5):
super().__init__()
self.prob = prob

def forward(self, X):
mask = torch.rand(size=X.shape[:-1]) < self.prob
X[mask] = X[mask].flip(-1)
return X
3 changes: 2 additions & 1 deletion ml4gw/dataloading/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .chunked_dataset import ChunkedDataset
from .chunked_dataset import ChunkedTimeSeriesDataset
from .hdf5_dataset import Hdf5TimeSeriesDataset
from .in_memory_dataset import InMemoryDataset
278 changes: 66 additions & 212 deletions ml4gw/dataloading/chunked_dataset.py
Original file line number Diff line number Diff line change
@@ -1,262 +1,113 @@
from typing import List
from collections.abc import Iterable

import h5py
import numpy as np
import torch


class ChunkLoader(torch.utils.data.IterableDataset):
def __init__(
self,
fnames: List[str],
channels: List[str],
chunk_size: int,
reads_per_chunk: int,
chunks_per_epoch: int,
coincident: bool = True,
) -> None:
self.fnames = fnames
self.channels = channels
self.chunk_size = chunk_size
self.reads_per_chunk = reads_per_chunk
self.chunks_per_epoch = chunks_per_epoch
self.coincident = coincident

sizes = []
for f in self.fnames:
with h5py.File(f, "r") as f:
size = len(f[self.channels[0]])
sizes.append(size)
total = sum(sizes)
self.probs = np.array([i / total for i in sizes])

def sample_fnames(self):
return np.random.choice(
self.fnames,
p=self.probs,
size=(self.reads_per_chunk,),
replace=True,
)

def load_coincident(self):
fnames = self.sample_fnames()
chunks = []
for fname in fnames:
with h5py.File(fname, "r") as f:
chunk, idx = [], None
for channel in self.channels:
if idx is None:
end = len(f[channel]) - self.chunk_size
idx = np.random.randint(0, end)
x = f[channel][idx : idx + self.chunk_size]
chunk.append(x)
chunks.append(np.stack(chunk))
return np.stack(chunks)

def load_noncoincident(self):
chunks = []
for channel in self.channels:
fnames = self.sample_fnames()
chunk = []
for fname in fnames:
with h5py.File(fname, "r") as f:
end = len(f[channel]) - self.chunk_size
idx = np.random.randint(0, end)
x = f[channel][idx : idx + self.chunk_size]
chunk.append(x)
chunks.append(np.stack(chunk))
return np.stack(chunks, axis=1)

def iter_epoch(self):
for _ in range(self.chunks_per_epoch):
if self.coincident:
yield torch.Tensor(self.load_coincident())
else:
yield torch.Tensor(self.load_noncoincident())

def collate(self, xs):
return torch.cat(xs, axis=0)

def __iter__(self):
return self.iter_epoch()


class ChunkedDataset(torch.utils.data.IterableDataset):
class ChunkedTimeSeriesDataset(torch.utils.data.IterableDataset):
"""
Iterable dataset for generating batches of background data
loaded on-the-fly from multiple HDF5 files. Loads
`chunk_length`-sized randomly-sampled stretches of
background from `reads_per_chunk` randomly sampled
files up front, then samples `batches_per_chunk`
batches of kernels from this chunk before loading
in the next one. Terminates after `chunks_per_epoch`
chunks have been exhausted, which amounts to
`chunks_per_epoch * batches_per_chunk` batches.

Note that filenames are not sampled uniformly
at chunk-loading time, but are weighted according
to the amount of data each file contains. This ensures
a uniform sampling over time across the whole dataset.

To load chunks asynchronously in the background,
specify `num_workers > 0`. Note that if the
number of workers is not an even multiple of
`chunks_per_epoch`, the last chunks of an epoch
will be composed of fewer than `reads_per_chunk`
individual segments.
Wrapper dataset that will loop through chunks of timeseries
data produced by another iterable and sample windows from
these chunks.

Args:
fnames:
List of HDF5 archives containing data to read.
Each file should have all of the channels specified
in `channels` as top-level datasets.
channels:
Datasets to load from each filename in `fnames`
kernel_length:
Length of the windows returned at iteration time
in seconds
sample_rate:
Rate at which the data in the specified `fnames`
has been sampled.
chunk_it:
Iterator which will produce chunks of timeseries
data to sample windows from. Should have shape
`(N, C, T)`, where `N` is the number of chunks
to sample from, `C` is the number of channels,
and `T` is the number of samples along the
time dimension for each chunk.
kernel_size:
Size of windows to be sampled from each chunk.
Should be less than the size of each chunk
along the time dimension.
batch_size:
Number of samples to return at iteration time
reads_per_chunk:
Number of file reads to perform when generating
each chunk
chunk_length:
Amount of data to read for each segment loaded
into each chunk, in seconds
Number of windows to sample at each iteration
batches_per_chunk:
Number of batches to sample from each chunk
before loading the next one
chunks_per_epoch:
Number of chunks to generate before iteration
terminates
Number of batches of windows to sample from
each chunk before moving on to the next one.
Sampling fewer batches from each chunk means
a lower likelihood of sampling duplicate windows,
but an increase in chunk-loading overhead.
coincident:
Flag indicating whether windows returned at iteration
time should come from the same point in time for
each channel in a given batch sample.
num_workers:
Number of workers for performing chunk loading
asynchronously. If left as 0, chunk loading will
be performed in serial with batch sampling.
Whether the windows sampled from individual
channels in each batch element should be
sampled coincidentally, i.e. consisting of
the same timesteps, or whether each window
should be sample independently from the others.
device:
Device on which to host loaded chunks
Which device chunks should be moved to upon loading.
"""

def __init__(
self,
fnames: List[str],
channels: List[str],
kernel_length: float,
sample_rate: float,
chunk_it: Iterable,
kernel_size: float,
batch_size: int,
reads_per_chunk: int,
chunk_length: float,
batches_per_chunk: int,
chunks_per_epoch: int,
coincident: bool = True,
num_workers: int = 0,
device: str = "cpu",
pin_memory: bool = False,
) -> None:
if not num_workers:
reads_per_worker = reads_per_chunk
elif reads_per_chunk < num_workers:
raise ValueError(
"Too many workers {} for number of reads_per_chunk {}".format(
num_workers, reads_per_chunk
)
)
else:
reads_per_worker = int(reads_per_chunk // num_workers)

if kernel_length > chunk_length:
raise ValueError(
"Kernel length {} must be shorter than "
"chunk length {}".format(kernel_length, chunk_length)
)
self.kernel_size = int(kernel_length * sample_rate)
self.chunk_size = int(chunk_length * sample_rate)

chunk_loader = ChunkLoader(
fnames,
channels,
self.chunk_size,
reads_per_worker,
chunks_per_epoch,
coincident=coincident,
)

if not num_workers:
self.chunk_loader = chunk_loader
else:
self.chunk_loader = torch.utils.data.DataLoader(
chunk_loader,
batch_size=num_workers,
num_workers=num_workers,
pin_memory=pin_memory,
collate_fn=chunk_loader.collate,
)

self.device = device
self.num_channels = len(channels)
self.coincident = coincident

self.chunk_it = chunk_it
self.kernel_size = kernel_size
self.batch_size = batch_size
self.batches_per_chunk = batches_per_chunk
self.chunks_per_epoch = chunks_per_epoch
self.num_workers = num_workers
self.coincident = coincident
self.device = device

def __len__(self):
if not self.num_workers:
return self.chunks_per_epoch * self.batches_per_chunk
return len(self.chunk_it) * self.batches_per_chunk

num_chunks = (self.chunks_per_epoch - 1) // self.num_workers + 1
return num_chunks * self.num_workers * self.batches_per_chunk
def __iter__(self):
it = iter(self.chunk_it)
chunk = next(it)
num_chunks, num_channels, chunk_size = chunk.shape

# if we're sampling coincidentally, we only need
# to sample indices on a per-batch-element basis.
# Otherwise, we'll need indices for both each
# batch sample _and_ each channel with each sample
if self.coincident:
sample_size = (self.batch_size,)
else:
sample_size = (self.batch_size, num_channels)

def iter_epoch(self):
# slice kernels out a flattened chunk tensor
# index-for-index. We'll account for batch/
# channel indices by introducing offsets later on
idx = torch.arange(self.kernel_size, device=self.device)
idx = idx.view(1, 1, -1)
idx = idx.repeat(self.batch_size, self.num_channels, 1)
idx = idx.repeat(self.batch_size, num_channels, 1)

# this will just be a set of aranged channel indices
# repeated to offset the kernel indices in the
# flattened chunk tensor
channel_idx = torch.arange(self.num_channels, device=self.device)
channel_idx = torch.arange(num_channels, device=self.device)
channel_idx = channel_idx.view(1, -1, 1)
channel_idx = channel_idx.repeat(self.batch_size, 1, self.kernel_size)
idx += channel_idx * self.chunk_size
idx += channel_idx * chunk_size

for chunk in self.chunk_loader:
while True:
# record the number of rows in the chunk, then
# flatten it to make it easier to slice
num_chunks, _, chunk_size = chunk.shape
chunk = chunk.to(self.device).reshape(-1)
if chunk_size < self.kernel_size:
raise ValueError(
"Can't sample kernels of size {} from chunk "
"with size {}".format(self.kernel_size, chunk_size)
)
chunk = chunk.reshape(-1)

# generate batches from the current chunk
for _ in range(self.batches_per_chunk):
# if we're sampling coincidentally, we only need
# to sample indices on a per-batch-element basis.
# Otherwise, we'll need indices for both each
# batch sample _and_ each channel with each sample
if self.coincident:
size = (self.batch_size,)
else:
size = (self.batch_size, self.num_channels)

# first sample the indices of which chunk elements
# we're going to read batch elements from
chunk_idx = torch.randint(
0, num_chunks, size=size, device=self.device
0, num_chunks, size=sample_size, device=self.device
)

# account for the offset this batch element
# introduces in the flattened array
chunk_idx *= self.num_channels * self.chunk_size
chunk_idx *= num_channels * chunk_size
chunk_idx = chunk_idx.view(self.batch_size, -1, 1)
chunk_idx = chunk_idx + idx

Expand All @@ -265,7 +116,7 @@ def iter_epoch(self):
time_idx = torch.randint(
0,
chunk_size - self.kernel_size,
size=size,
size=sample_size,
device=self.device,
)
time_idx = time_idx.view(self.batch_size, -1, 1)
Expand All @@ -276,5 +127,8 @@ def iter_epoch(self):
# now slice this 3D tensor from our flattened chunk
yield chunk[chunk_idx]

def __iter__(self):
return self.iter_epoch()
try:
chunk = next(it)
except StopIteration:
break
num_chunks, num_channels, chunk_size = chunk.shape
Loading
Loading