diff --git a/ml4gw/dataloading/hdf5_dataset.py b/ml4gw/dataloading/hdf5_dataset.py index 1b4eab82..ea72ef07 100644 --- a/ml4gw/dataloading/hdf5_dataset.py +++ b/ml4gw/dataloading/hdf5_dataset.py @@ -1,5 +1,6 @@ import warnings -from typing import Sequence, Union +from contextlib import contextmanager +from typing import Optional, Sequence, Union import h5py import numpy as np @@ -12,6 +13,114 @@ class ContiguousHdf5Warning(Warning): pass +class _Reader: + def __new__(cls, fnames, path): + if isinstance(fnames, str): + cls = _SingleFileReader + else: + cls = _MultiFileReader + return super().__new__(cls) + + def __init__( + self, fnames: Union[str, Sequence[str]], path: Optional[str] = None + ): + self.fnames = fnames + if path is not None: + self.path = path.split("/") + else: + self.path = None + self.sizes = {} + + def open(self, fname) -> tuple[h5py.File, h5py.Group]: + f = group = h5py.File(fname, "r") + if self.path is not None: + for path in self.path: + group = group[path] + return f, group + + def _warn_non_contiguous(self, fname, dataset): + warnings.warn( + "File {} contains datasets at path {} that were generated " + "without using chunked storage. This can have " + "severe performance impacts at data loading time. " + "If you need faster loading, try re-generating " + "your datset with chunked storage turned on.".format( + fname, "/".join(self.path) + "/" + dataset + ), + category=ContiguousHdf5Warning, + ) + + def get_sizes(self, channel): + raise NotImplementedError + + def initialize_probs(self, channel): + self.get_sizes(channel) + total = sum(self.sizes.values()) + self.probs = np.array([self.sizes[k] / total for k in self.fnames]) + + def sample_fnames(self, size): + return np.random.choice( + self.fnames, + p=self.probs, + size=size, + replace=True, + ) + + def __enter__(self): + return self + + def __exit__(self, *exc_args): + return + + @contextmanager + def __call__(self, fname): + raise NotImplementedError + + +class _MultiFileReader(_Reader): + def get_sizes(self, channel): + for fname in self.fnames: + with self(fname) as f: + dataset = f[channel] + if dataset.chunks is None: + self._warn_non_contiguous(fname, channel) + self.sizes[fname] = len(dataset) + + @contextmanager + def __call__(self, fname): + f, group = self.open(fname) + with f: + yield group + + +class _SingleFileReader(_Reader): + _f = _group = None + + def get_sizes(self, channel): + fname = self.fnames + self.fname = fname + with self: + for key, group in self._group.items(): + dataset = group[channel] + if dataset.chunks is None: + path = f"{key}/{channel}" + self._warn_non_contiguous(fname, path) + self.sizes[key] = len(dataset) + self.fnames = sorted(self._group.keys()) + + def __enter__(self): + self._f, self._group = self.open(self.fname) + return self + + def __exit__(self, *exc_args): + self._f.close() + self._f = self._group = None + + @contextmanager + def __call__(self, dataset): + yield self._group[dataset] + + class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset): """ Iterable dataset that samples and loads windows of @@ -54,12 +163,13 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset): def __init__( self, - fnames: Sequence[str], + fnames: Union[Sequence[str], str], channels: Sequence[str], kernel_size: int, batch_size: int, batches_per_epoch: int, coincident: Union[bool, str], + path: Optional[str] = None, ) -> None: if not isinstance(coincident, bool) and coincident != "files": raise ValueError( @@ -67,7 +177,8 @@ def __init__( "got unrecognized value {}".format(coincident) ) - self.fnames = fnames + self.reader = _Reader(fnames, path) + self.reader.initialize_probs(channels[0]) self.channels = channels self.num_channels = len(channels) self.kernel_size = kernel_size @@ -75,37 +186,9 @@ def __init__( self.batches_per_epoch = batches_per_epoch self.coincident = coincident - self.sizes = {} - for fname in self.fnames: - with h5py.File(fname, "r") as f: - dset = f[channels[0]] - if dset.chunks is None: - warnings.warn( - "File {} contains datasets that were generated " - "without using chunked storage. This can have " - "severe performance impacts at data loading time. " - "If you need faster loading, try re-generating " - "your datset with chunked storage turned on.".format( - fname - ), - category=ContiguousHdf5Warning, - ) - - self.sizes[fname] = len(dset) - total = sum(self.sizes.values()) - self.probs = np.array([i / total for i in self.sizes.values()]) - def __len__(self) -> int: return self.batches_per_epoch - def sample_fnames(self, size) -> np.ndarray: - return np.random.choice( - self.fnames, - p=self.probs, - size=size, - replace=True, - ) - def sample_batch(self) -> WaveformTensor: """ Sample a single batch of multichannel timeseries @@ -120,13 +203,13 @@ def sample_batch(self) -> WaveformTensor: size = (self.batch_size,) else: size = (self.batch_size, self.num_channels) - fnames = self.sample_fnames(size) + fnames = self.reader.sample_fnames(size) unique_fnames, inv, counts = np.unique( fnames, return_inverse=True, return_counts=True ) for i, (fname, count) in enumerate(zip(unique_fnames, counts)): - size = self.sizes[fname] + size = self.reader.sizes[fname] max_idx = size - self.kernel_size # figure out which batch indices should be @@ -156,7 +239,7 @@ def sample_batch(self) -> WaveformTensor: # open the file and sample a different set of # kernels for each batch element it occupies - with h5py.File(fname, "r") as f: + with self.reader(fname) as f: for b, c, i in zip(batch_indices, channel_indices, idx): x[b, c] = f[self.channels[c]][i : i + self.kernel_size] return torch.Tensor(x) @@ -172,5 +255,6 @@ def __iter__(self) -> torch.Tensor: if worker_info.id < remainder: num_batches += 1 - for _ in range(num_batches): - yield self.sample_batch() + with self.reader: + for _ in range(num_batches): + yield self.sample_batch()