Skip to content

Commit

Permalink
in memory dataset inherits from torch iterable dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
EthanMarx committed May 8, 2024
1 parent 05b38cd commit 7c42f0b
Showing 1 changed file with 21 additions and 34 deletions.
55 changes: 21 additions & 34 deletions ml4gw/dataloading/in_memory_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ml4gw.utils.slicing import slice_kernels


class InMemoryDataset:
class InMemoryDataset(torch.utils.data.IterableDataset):
"""Dataset for iterating through in-memory multi-channel timeseries
Dataset for arrays of timeseries data which can be stored
Expand Down Expand Up @@ -131,7 +131,6 @@ def __init__(
self.batches_per_epoch = batches_per_epoch
self.shuffle = shuffle
self.coincident = coincident
self._i = self._idx = None

@property
def num_kernels(self) -> int:
Expand All @@ -156,8 +155,8 @@ def __len__(self) -> int:
# support it
num_kernels = self.num_kernels ** len(self.X)
return (num_kernels - 1) // self.batch_size + 1

def __iter__(self):
def init_indices(self):
"""
Initialize arrays of indices we'll use to slice
through X and y at iteration time. This helps by
Expand Down Expand Up @@ -204,36 +203,24 @@ def __iter__(self):
# the simplest case: deteriminstic and coincident
idx = torch.arange(num_kernels, device=device)

self._idx = idx
self._i = 0
return self
return idx

def __next__(
def __iter__(
self,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self._i is None or self._idx is None:
raise TypeError(
"Must initialize InMemoryDataset iteration "
"before calling __next__"
)

# check if we're out of batches, and if so
# make sure to reset before stopping
if self._i >= len(self):
self._i = self._idx = None
raise StopIteration

# slice the array of _indices_ we'll be using to
# slice our timeseries, and scale them by the stride
slc = slice(self._i * self.batch_size, (self._i + 1) * self.batch_size)
idx = self._idx[slc] * self.stride

# slice our timeseries
X = slice_kernels(self.X, idx, self.kernel_size)
if self.y is not None:
y = slice_kernels(self.y, idx, self.kernel_size)

self._i += 1
if self.y is not None:
return X, y
return X

indices = self.init_indices()
for i in range(len(self)):
# slice the array of _indices_ we'll be using to
# slice our timeseries, and scale them by the stride
slc = slice(i * self.batch_size, (i + 1) * self.batch_size)
idx = indices[slc] * self.stride

# slice our timeseries
X = slice_kernels(self.X, idx, self.kernel_size)
if self.y is not None:
y = slice_kernels(self.y, idx, self.kernel_size)

if self.y is not None:
yield X, y
yield X

0 comments on commit 7c42f0b

Please sign in to comment.