From 7c42f0b86bb53276f3c9d32a9382a7fde947b2ea Mon Sep 17 00:00:00 2001 From: Ethan Jacob Marx Date: Wed, 8 May 2024 05:06:07 -0700 Subject: [PATCH] in memory dataset inherits from torch iterable dataset --- ml4gw/dataloading/in_memory_dataset.py | 55 ++++++++++---------------- 1 file changed, 21 insertions(+), 34 deletions(-) diff --git a/ml4gw/dataloading/in_memory_dataset.py b/ml4gw/dataloading/in_memory_dataset.py index 0d308e2e..65b5c983 100644 --- a/ml4gw/dataloading/in_memory_dataset.py +++ b/ml4gw/dataloading/in_memory_dataset.py @@ -7,7 +7,7 @@ from ml4gw.utils.slicing import slice_kernels -class InMemoryDataset: +class InMemoryDataset(torch.utils.data.IterableDataset): """Dataset for iterating through in-memory multi-channel timeseries Dataset for arrays of timeseries data which can be stored @@ -131,7 +131,6 @@ def __init__( self.batches_per_epoch = batches_per_epoch self.shuffle = shuffle self.coincident = coincident - self._i = self._idx = None @property def num_kernels(self) -> int: @@ -156,8 +155,8 @@ def __len__(self) -> int: # support it num_kernels = self.num_kernels ** len(self.X) return (num_kernels - 1) // self.batch_size + 1 - - def __iter__(self): + + def init_indices(self): """ Initialize arrays of indices we'll use to slice through X and y at iteration time. This helps by @@ -204,36 +203,24 @@ def __iter__(self): # the simplest case: deteriminstic and coincident idx = torch.arange(num_kernels, device=device) - self._idx = idx - self._i = 0 - return self + return idx - def __next__( + def __iter__( self, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - if self._i is None or self._idx is None: - raise TypeError( - "Must initialize InMemoryDataset iteration " - "before calling __next__" - ) - - # check if we're out of batches, and if so - # make sure to reset before stopping - if self._i >= len(self): - self._i = self._idx = None - raise StopIteration - - # slice the array of _indices_ we'll be using to - # slice our timeseries, and scale them by the stride - slc = slice(self._i * self.batch_size, (self._i + 1) * self.batch_size) - idx = self._idx[slc] * self.stride - - # slice our timeseries - X = slice_kernels(self.X, idx, self.kernel_size) - if self.y is not None: - y = slice_kernels(self.y, idx, self.kernel_size) - - self._i += 1 - if self.y is not None: - return X, y - return X + + indices = self.init_indices() + for i in range(len(self)): + # slice the array of _indices_ we'll be using to + # slice our timeseries, and scale them by the stride + slc = slice(i * self.batch_size, (i + 1) * self.batch_size) + idx = indices[slc] * self.stride + + # slice our timeseries + X = slice_kernels(self.X, idx, self.kernel_size) + if self.y is not None: + y = slice_kernels(self.y, idx, self.kernel_size) + + if self.y is not None: + yield X, y + yield X