Skip to content

Commit

Permalink
Merge pull request #76 from neuro-ml/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
maxme1 authored Jul 10, 2023
2 parents 8244db4 + ae55489 commit 0ef97f4
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 6 deletions.
2 changes: 1 addition & 1 deletion dpipe/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.2.3'
__version__ = '0.2.4'
13 changes: 9 additions & 4 deletions dpipe/batch_iter/expiration_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,17 @@ class ExpirationPool(Iterator):
)
"""

def __init__(self, pool_size: int, repetitions: int):
super().__init__(partial(expiration_pool, pool_size=pool_size, repetitions=repetitions), )
def __init__(self, pool_size: int, repetitions: int, iterations: int = 1):
super().__init__(partial(expiration_pool, pool_size=pool_size, repetitions=repetitions, iterations=iterations))


def expiration_pool(iterable: Iterable, pool_size: int, repetitions: int):
def expiration_pool(iterable: Iterable, pool_size: int, repetitions: int, iterations: int = 1):
"""
Caches ``pool_size`` items from ``iterable``.
The item is removed from cache after it was generated ``repetitions`` times.
After an item is removed, a new one is extracted from the ``iterable``.
Finally, ``iterations`` controls how many values are generated after a new value is added,
thus speeding up the pipeline at early stages.
"""

assert pool_size > 0
Expand All @@ -51,7 +53,10 @@ def sample_value():
value_frequency = {} # i -> [value, frequency]
for idx, value in iterable:
value_frequency[idx] = [value, 0]
yield sample_value()

for _ in range(iterations):
if value_frequency:
yield sample_value()

while len(value_frequency) >= pool_size:
yield sample_value()
Expand Down
3 changes: 3 additions & 0 deletions dpipe/batch_iter/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ def close(self):
"""Stop all background processes."""
self.__exit__(None, None, None)

def __iter__(self):
return self()

def __call__(self):
if not self.pipeline.pipeline_active:
self.__enter__()
Expand Down
3 changes: 2 additions & 1 deletion dpipe/torch/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ def masked_loss(mask: torch.Tensor, criterion: Callable, prediction: torch.Tenso
If the ``mask`` is empty - returns 0 wrapped in a torch tensor.
"""
if not mask.any():
return torch.tensor(0., requires_grad=True).to(prediction)
# https://github.com/neuro-ml/deep_pipe/issues/75
return 0 * prediction.flatten()[0]

return criterion(prediction[mask], target[mask], **kwargs)

Expand Down

0 comments on commit 0ef97f4

Please sign in to comment.