From 54fc645699baf2f8e03d258af9be0a1bb4c27c73 Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 16 May 2023 13:00:01 +0300 Subject: [PATCH 1/4] added arg to exp pool --- dpipe/batch_iter/expiration_pool.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/dpipe/batch_iter/expiration_pool.py b/dpipe/batch_iter/expiration_pool.py index c7f715f..659ff61 100644 --- a/dpipe/batch_iter/expiration_pool.py +++ b/dpipe/batch_iter/expiration_pool.py @@ -22,15 +22,17 @@ class ExpirationPool(Iterator): ) """ - def __init__(self, pool_size: int, repetitions: int): - super().__init__(partial(expiration_pool, pool_size=pool_size, repetitions=repetitions), ) + def __init__(self, pool_size: int, repetitions: int, iterations: int = 1): + super().__init__(partial(expiration_pool, pool_size=pool_size, repetitions=repetitions, iterations=iterations)) -def expiration_pool(iterable: Iterable, pool_size: int, repetitions: int): +def expiration_pool(iterable: Iterable, pool_size: int, repetitions: int, iterations: int = 1): """ Caches ``pool_size`` items from ``iterable``. The item is removed from cache after it was generated ``repetitions`` times. After an item is removed, a new one is extracted from the ``iterable``. + Finally, ``iterations`` controls how many values are generated after a new value is added, + thus speeding up the pipeline at early stages. """ assert pool_size > 0 @@ -51,7 +53,10 @@ def sample_value(): value_frequency = {} # i -> [value, frequency] for idx, value in iterable: value_frequency[idx] = [value, 0] - yield sample_value() + + for _ in range(iterations): + if value_frequency: + yield sample_value() while len(value_frequency) >= pool_size: yield sample_value() From e3cf4565f02b6bbbb05234773e42b391ad831fd9 Mon Sep 17 00:00:00 2001 From: Philipenko Vladimir Date: Wed, 17 May 2023 14:00:48 +0300 Subject: [PATCH 2/4] Fix #75 --- dpipe/torch/functional.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dpipe/torch/functional.py b/dpipe/torch/functional.py index 359ea6d..954edf3 100644 --- a/dpipe/torch/functional.py +++ b/dpipe/torch/functional.py @@ -231,7 +231,8 @@ def masked_loss(mask: torch.Tensor, criterion: Callable, prediction: torch.Tenso If the ``mask`` is empty - returns 0 wrapped in a torch tensor. """ if not mask.any(): - return torch.tensor(0., requires_grad=True).to(prediction) + # https://github.com/neuro-ml/deep_pipe/issues/75 + return (prediction * 0).sum() return criterion(prediction[mask], target[mask], **kwargs) From 4a41c58b60b92280cb00c478d595a60bcc6ccb54 Mon Sep 17 00:00:00 2001 From: Philipenko Vladimir Date: Wed, 17 May 2023 15:52:47 +0300 Subject: [PATCH 3/4] Use 1 element from prediction --- dpipe/torch/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpipe/torch/functional.py b/dpipe/torch/functional.py index 954edf3..8b05b07 100644 --- a/dpipe/torch/functional.py +++ b/dpipe/torch/functional.py @@ -232,7 +232,7 @@ def masked_loss(mask: torch.Tensor, criterion: Callable, prediction: torch.Tenso """ if not mask.any(): # https://github.com/neuro-ml/deep_pipe/issues/75 - return (prediction * 0).sum() + return 0 * prediction.flatten()[0] return criterion(prediction[mask], target[mask], **kwargs) From ae554894e53d73b44699533f1c1cff70b46bed42 Mon Sep 17 00:00:00 2001 From: Max Date: Thu, 29 Jun 2023 22:30:05 +0300 Subject: [PATCH 4/4] batch iter is an iterable now --- dpipe/__version__.py | 2 +- dpipe/batch_iter/pipeline.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/dpipe/__version__.py b/dpipe/__version__.py index d93b5b2..a6587ae 100644 --- a/dpipe/__version__.py +++ b/dpipe/__version__.py @@ -1 +1 @@ -__version__ = '0.2.3' +__version__ = '0.2.4' diff --git a/dpipe/batch_iter/pipeline.py b/dpipe/batch_iter/pipeline.py index 7ad0927..6c571fa 100644 --- a/dpipe/batch_iter/pipeline.py +++ b/dpipe/batch_iter/pipeline.py @@ -131,6 +131,9 @@ def close(self): """Stop all background processes.""" self.__exit__(None, None, None) + def __iter__(self): + return self() + def __call__(self): if not self.pipeline.pipeline_active: self.__enter__()