Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding support for single hdf5 files with multiple datasets #77

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 120 additions & 36 deletions ml4gw/dataloading/hdf5_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import warnings
from typing import Sequence, Union
from contextlib import contextmanager
from typing import Optional, Sequence, Union

import h5py
import numpy as np
Expand All @@ -12,6 +13,114 @@ class ContiguousHdf5Warning(Warning):
pass


class _Reader:
def __new__(cls, fnames, path):
if isinstance(fnames, str):
cls = _SingleFileReader
else:
cls = _MultiFileReader
return super().__new__(cls)

def __init__(
self, fnames: Union[str, Sequence[str]], path: Optional[str] = None
):
self.fnames = fnames
if path is not None:
self.path = path.split("/")
else:
self.path = None
self.sizes = {}

def open(self, fname) -> tuple[h5py.File, h5py.Group]:
f = group = h5py.File(fname, "r")
if self.path is not None:
for path in self.path:
group = group[path]
return f, group

def _warn_non_contiguous(self, fname, dataset):
warnings.warn(
"File {} contains datasets at path {} that were generated "
"without using chunked storage. This can have "
"severe performance impacts at data loading time. "
"If you need faster loading, try re-generating "
"your datset with chunked storage turned on.".format(
fname, "/".join(self.path) + "/" + dataset
),
category=ContiguousHdf5Warning,
)

def get_sizes(self, channel):
raise NotImplementedError

def initialize_probs(self, channel):
self.get_sizes(channel)
total = sum(self.sizes.values())
self.probs = np.array([self.sizes[k] / total for k in self.fnames])

def sample_fnames(self, size):
return np.random.choice(
self.fnames,
p=self.probs,
size=size,
replace=True,
)

def __enter__(self):
return self

def __exit__(self, *exc_args):
return

@contextmanager
def __call__(self, fname):
raise NotImplementedError


class _MultiFileReader(_Reader):
def get_sizes(self, channel):
for fname in self.fnames:
with self(fname) as f:
dataset = f[channel]
if dataset.chunks is None:
self._warn_non_contiguous(fname, channel)
self.sizes[fname] = len(dataset)

@contextmanager
def __call__(self, fname):
f, group = self.open(fname)
with f:
yield group


class _SingleFileReader(_Reader):
_f = _group = None

def get_sizes(self, channel):
fname = self.fnames
self.fname = fname
with self:
for key, group in self._group.items():
dataset = group[channel]
if dataset.chunks is None:
path = f"{key}/{channel}"
self._warn_non_contiguous(fname, path)
self.sizes[key] = len(dataset)
self.fnames = sorted(self._group.keys())

def __enter__(self):
self._f, self._group = self.open(self.fname)
return self

def __exit__(self, *exc_args):
self._f.close()
self._f = self._group = None

@contextmanager
def __call__(self, dataset):
yield self._group[dataset]


class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
"""
Iterable dataset that samples and loads windows of
Expand Down Expand Up @@ -54,58 +163,32 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):

def __init__(
self,
fnames: Sequence[str],
fnames: Union[Sequence[str], str],
channels: Sequence[str],
kernel_size: int,
batch_size: int,
batches_per_epoch: int,
coincident: Union[bool, str],
path: Optional[str] = None,
) -> None:
if not isinstance(coincident, bool) and coincident != "files":
raise ValueError(
"coincident must be either a boolean or 'files', "
"got unrecognized value {}".format(coincident)
)

self.fnames = fnames
self.reader = _Reader(fnames, path)
self.reader.initialize_probs(channels[0])
self.channels = channels
self.num_channels = len(channels)
self.kernel_size = kernel_size
self.batch_size = batch_size
self.batches_per_epoch = batches_per_epoch
self.coincident = coincident

self.sizes = {}
for fname in self.fnames:
with h5py.File(fname, "r") as f:
dset = f[channels[0]]
if dset.chunks is None:
warnings.warn(
"File {} contains datasets that were generated "
"without using chunked storage. This can have "
"severe performance impacts at data loading time. "
"If you need faster loading, try re-generating "
"your datset with chunked storage turned on.".format(
fname
),
category=ContiguousHdf5Warning,
)

self.sizes[fname] = len(dset)
total = sum(self.sizes.values())
self.probs = np.array([i / total for i in self.sizes.values()])

def __len__(self) -> int:
return self.batches_per_epoch

def sample_fnames(self, size) -> np.ndarray:
return np.random.choice(
self.fnames,
p=self.probs,
size=size,
replace=True,
)

def sample_batch(self) -> WaveformTensor:
"""
Sample a single batch of multichannel timeseries
Expand All @@ -120,13 +203,13 @@ def sample_batch(self) -> WaveformTensor:
size = (self.batch_size,)
else:
size = (self.batch_size, self.num_channels)
fnames = self.sample_fnames(size)
fnames = self.reader.sample_fnames(size)

unique_fnames, inv, counts = np.unique(
fnames, return_inverse=True, return_counts=True
)
for i, (fname, count) in enumerate(zip(unique_fnames, counts)):
size = self.sizes[fname]
size = self.reader.sizes[fname]
max_idx = size - self.kernel_size

# figure out which batch indices should be
Expand Down Expand Up @@ -156,7 +239,7 @@ def sample_batch(self) -> WaveformTensor:

# open the file and sample a different set of
# kernels for each batch element it occupies
with h5py.File(fname, "r") as f:
with self.reader(fname) as f:
for b, c, i in zip(batch_indices, channel_indices, idx):
x[b, c] = f[self.channels[c]][i : i + self.kernel_size]
return torch.Tensor(x)
Expand All @@ -172,5 +255,6 @@ def __iter__(self) -> torch.Tensor:
if worker_info.id < remainder:
num_batches += 1

for _ in range(num_batches):
yield self.sample_batch()
with self.reader:
for _ in range(num_batches):
yield self.sample_batch()