diff --git a/dpipe/__version__.py b/dpipe/__version__.py index d93b5b2..a6587ae 100644 --- a/dpipe/__version__.py +++ b/dpipe/__version__.py @@ -1 +1 @@ -__version__ = '0.2.3' +__version__ = '0.2.4' diff --git a/dpipe/batch_iter/expiration_pool.py b/dpipe/batch_iter/expiration_pool.py index c7f715f..659ff61 100644 --- a/dpipe/batch_iter/expiration_pool.py +++ b/dpipe/batch_iter/expiration_pool.py @@ -22,15 +22,17 @@ class ExpirationPool(Iterator): ) """ - def __init__(self, pool_size: int, repetitions: int): - super().__init__(partial(expiration_pool, pool_size=pool_size, repetitions=repetitions), ) + def __init__(self, pool_size: int, repetitions: int, iterations: int = 1): + super().__init__(partial(expiration_pool, pool_size=pool_size, repetitions=repetitions, iterations=iterations)) -def expiration_pool(iterable: Iterable, pool_size: int, repetitions: int): +def expiration_pool(iterable: Iterable, pool_size: int, repetitions: int, iterations: int = 1): """ Caches ``pool_size`` items from ``iterable``. The item is removed from cache after it was generated ``repetitions`` times. After an item is removed, a new one is extracted from the ``iterable``. + Finally, ``iterations`` controls how many values are generated after a new value is added, + thus speeding up the pipeline at early stages. """ assert pool_size > 0 @@ -51,7 +53,10 @@ def sample_value(): value_frequency = {} # i -> [value, frequency] for idx, value in iterable: value_frequency[idx] = [value, 0] - yield sample_value() + + for _ in range(iterations): + if value_frequency: + yield sample_value() while len(value_frequency) >= pool_size: yield sample_value() diff --git a/dpipe/batch_iter/pipeline.py b/dpipe/batch_iter/pipeline.py index 7ad0927..6c571fa 100644 --- a/dpipe/batch_iter/pipeline.py +++ b/dpipe/batch_iter/pipeline.py @@ -131,6 +131,9 @@ def close(self): """Stop all background processes.""" self.__exit__(None, None, None) + def __iter__(self): + return self() + def __call__(self): if not self.pipeline.pipeline_active: self.__enter__() diff --git a/dpipe/torch/functional.py b/dpipe/torch/functional.py index 359ea6d..8b05b07 100644 --- a/dpipe/torch/functional.py +++ b/dpipe/torch/functional.py @@ -231,7 +231,8 @@ def masked_loss(mask: torch.Tensor, criterion: Callable, prediction: torch.Tenso If the ``mask`` is empty - returns 0 wrapped in a torch tensor. """ if not mask.any(): - return torch.tensor(0., requires_grad=True).to(prediction) + # https://github.com/neuro-ml/deep_pipe/issues/75 + return 0 * prediction.flatten()[0] return criterion(prediction[mask], target[mask], **kwargs)