From 40f666b80efe3bc752495a5483a67a246bbaf8cd Mon Sep 17 00:00:00 2001 From: Philipenko Vladimir Date: Mon, 30 Oct 2023 18:22:26 +0300 Subject: [PATCH 1/6] Add batched predict in `patches_grid` --- dpipe/__version__.py | 2 +- dpipe/predict/shape.py | 23 +++++++++++++++-------- tests/predict/test_shape.py | 13 +++++++++---- 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/dpipe/__version__.py b/dpipe/__version__.py index 14e974f..cd9b137 100644 --- a/dpipe/__version__.py +++ b/dpipe/__version__.py @@ -1 +1 @@ -__version__ = '0.2.8' +__version__ = '0.2.9' diff --git a/dpipe/predict/shape.py b/dpipe/predict/shape.py index 1f23d20..18ed369 100644 --- a/dpipe/predict/shape.py +++ b/dpipe/predict/shape.py @@ -1,5 +1,7 @@ from functools import wraps from typing import Union, Callable, Type +from more_itertools import batched +from itertools import chain import numpy as np @@ -81,7 +83,7 @@ def wrapper(x, *args, **kwargs): def patches_grid(patch_size: AxesLike, stride: AxesLike, axis: AxesLike = None, padding_values: Union[AxesParams, Callable] = 0, ratio: AxesParams = 0.5, - combiner: Type[PatchCombiner] = Average, get_boxes: Callable = get_boxes, **imops_kwargs): + combiner: Type[PatchCombiner] = Average, get_boxes: Callable = get_boxes, batch_size=1, **imops_kwargs): """ Divide an incoming array into patches of corresponding ``patch_size`` and ``stride`` and then combine the predicted patches by aggregating the overlapping regions using the ``combiner`` - Average by default. @@ -109,13 +111,18 @@ def wrapper(x, *args, **kwargs): x = pad_to_shape(x, new_shape, input_axis, padding_values, ratio, **imops_kwargs) elif ((shape - local_size) < 0).any() or ((local_stride - shape + local_size) % local_stride).any(): raise ValueError('Input cannot be patched without remainder.') - - - patches = pmap( - predict, - divide(x, local_size, local_stride, input_axis, get_boxes=get_boxes), - *args, **kwargs - ) + if batch_size == 1: + patches = pmap( + predict, + divide(x, local_size, local_stride, input_axis, get_boxes=get_boxes), + *args, **kwargs + ) + else: + patches = chain.from_iterable(pmap( + predict, + map(np.array, batched(divide(x, local_size, local_stride, input_axis, get_boxes=get_boxes), batch_size)), + *args, **kwargs + )) prediction = combine( patches, extract(x.shape, input_axis), local_stride, axis, combiner=combiner, get_boxes=get_boxes, diff --git a/tests/predict/test_shape.py b/tests/predict/test_shape.py index 69e65bb..96462dd 100644 --- a/tests/predict/test_shape.py +++ b/tests/predict/test_shape.py @@ -7,9 +7,14 @@ assert_eq = np.testing.assert_array_almost_equal -def test_patches_grid(): +@pytest.fixture(params=[1, 2, 3, 4]) +def batch_size(request): + return request.param + + +def test_patches_grid(batch_size): def check_equal(**kwargs): - assert_eq(x, patches_grid(**kwargs, axis=-1)(identity)(x)) + assert_eq(x, patches_grid(**kwargs, axis=-1, batch_size=batch_size)(identity)(x)) x = np.random.randn(3, 23, 20, 27) * 10 check_equal(patch_size=10, stride=1, padding_values=0) @@ -27,9 +32,9 @@ def check_equal(**kwargs): check_equal(patch_size=15, stride=12, padding_values=None) -def test_divisible_patches(): +def test_divisible_patches(batch_size): def check_equal(**kwargs): - assert_eq(x, divisible_shape(divisible)(patches_grid(**kwargs)(identity))(x)) + assert_eq(x, divisible_shape(divisible)(patches_grid(**kwargs, batch_size=batch_size)(identity))(x)) size = [80] * 3 stride = [20] * 3 From 90a681876d31f88a1df15a69a2ecf7011634db44 Mon Sep 17 00:00:00 2001 From: Philipenko Vladimir Date: Wed, 1 Nov 2023 16:05:22 +0300 Subject: [PATCH 2/6] Fix --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 5f26c4e..5538af2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,4 @@ wandb; python_version >= '3.7' nibabel torch imops>=0.8.4,<1.0.0 +more_itertools From 613afdf05790e19feb05ace68eeb682399f37657 Mon Sep 17 00:00:00 2001 From: Philipenko Vladimir Date: Wed, 1 Nov 2023 18:27:11 +0300 Subject: [PATCH 3/6] Add `stream` arg to `patches_grid` --- dpipe/predict/shape.py | 15 +++++++-------- requirements.txt | 1 - tests/predict/test_shape.py | 32 ++++++++++++++++++++++++++++---- tests/requirements.txt | 1 + 4 files changed, 36 insertions(+), 13 deletions(-) diff --git a/dpipe/predict/shape.py b/dpipe/predict/shape.py index 18ed369..e074941 100644 --- a/dpipe/predict/shape.py +++ b/dpipe/predict/shape.py @@ -83,7 +83,8 @@ def wrapper(x, *args, **kwargs): def patches_grid(patch_size: AxesLike, stride: AxesLike, axis: AxesLike = None, padding_values: Union[AxesParams, Callable] = 0, ratio: AxesParams = 0.5, - combiner: Type[PatchCombiner] = Average, get_boxes: Callable = get_boxes, batch_size=1, **imops_kwargs): + combiner: Type[PatchCombiner] = Average, get_boxes: Callable = get_boxes, stream: bool = False, + **imops_kwargs): """ Divide an incoming array into patches of corresponding ``patch_size`` and ``stride`` and then combine the predicted patches by aggregating the overlapping regions using the ``combiner`` - Average by default. @@ -111,18 +112,16 @@ def wrapper(x, *args, **kwargs): x = pad_to_shape(x, new_shape, input_axis, padding_values, ratio, **imops_kwargs) elif ((shape - local_size) < 0).any() or ((local_stride - shape + local_size) % local_stride).any(): raise ValueError('Input cannot be patched without remainder.') - if batch_size == 1: + + if stream: + patches = predict(divide(x, local_size, local_stride, input_axis, get_boxes=get_boxes), *args, **kwargs) + else: patches = pmap( predict, divide(x, local_size, local_stride, input_axis, get_boxes=get_boxes), *args, **kwargs ) - else: - patches = chain.from_iterable(pmap( - predict, - map(np.array, batched(divide(x, local_size, local_stride, input_axis, get_boxes=get_boxes), batch_size)), - *args, **kwargs - )) + prediction = combine( patches, extract(x.shape, input_axis), local_stride, axis, combiner=combiner, get_boxes=get_boxes, diff --git a/requirements.txt b/requirements.txt index 5538af2..5f26c4e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,3 @@ wandb; python_version >= '3.7' nibabel torch imops>=0.8.4,<1.0.0 -more_itertools diff --git a/tests/predict/test_shape.py b/tests/predict/test_shape.py index 96462dd..a871a39 100644 --- a/tests/predict/test_shape.py +++ b/tests/predict/test_shape.py @@ -1,8 +1,11 @@ +import sys + import pytest import numpy as np from dpipe.im.utils import identity from dpipe.predict.shape import * +from dpipe.itertools import pmap assert_eq = np.testing.assert_array_almost_equal @@ -12,9 +15,14 @@ def batch_size(request): return request.param -def test_patches_grid(batch_size): +@pytest.fixture(params=[False, True]) +def stream(request): + return request.param + + +def test_patches_grid(stream): def check_equal(**kwargs): - assert_eq(x, patches_grid(**kwargs, axis=-1, batch_size=batch_size)(identity)(x)) + assert_eq(x, patches_grid(**kwargs, stream=stream, axis=-1)(identity)(x)) x = np.random.randn(3, 23, 20, 27) * 10 check_equal(patch_size=10, stride=1, padding_values=0) @@ -32,9 +40,9 @@ def check_equal(**kwargs): check_equal(patch_size=15, stride=12, padding_values=None) -def test_divisible_patches(batch_size): +def test_divisible_patches(stream): def check_equal(**kwargs): - assert_eq(x, divisible_shape(divisible)(patches_grid(**kwargs, batch_size=batch_size)(identity))(x)) + assert_eq(x, divisible_shape(divisible)(patches_grid(**kwargs, stream=stream)(identity))(x)) size = [80] * 3 stride = [20] * 3 @@ -42,3 +50,19 @@ def check_equal(**kwargs): for shape in [(373, 302, 55), (330, 252, 67)]: x = np.random.randn(*shape) check_equal(patch_size=size, stride=stride) + + +@pytest.mark.skipif(sys.version_info < (3, 7), reason='Requires python3.7 or higher.') +def test_batched_patches_grid(batch_size): + from more_itertools import batched + from itertools import chain + + def patch_predict(patch): + return patch + 1 + + def stream_predict(patches_generator): + return chain.from_iterable(pmap(patch_predict, map(np.array, batched(patches_generator, batch_size)))) + + x = np.random.randn(3, 23, 20, 27) * 10 + + assert_eq(x + 1, patches_grid(patch_size=(6, 8, 9), stride=(4, 3, 2), stream=True, axis=(-1, -2, -3))(stream_predict)(x)) diff --git a/tests/requirements.txt b/tests/requirements.txt index d29db9b..3f3f0cc 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,3 +1,4 @@ pytest pytest-cov pytest-subtests +more_itertools; python_version >= '3.7' From f4e1d2c023ae582d3aea1a6dbc8c93262a86d781 Mon Sep 17 00:00:00 2001 From: Philipenko Vladimir Date: Wed, 1 Nov 2023 18:28:16 +0300 Subject: [PATCH 4/6] Remove imports --- dpipe/predict/shape.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/dpipe/predict/shape.py b/dpipe/predict/shape.py index e074941..97186d1 100644 --- a/dpipe/predict/shape.py +++ b/dpipe/predict/shape.py @@ -1,7 +1,5 @@ from functools import wraps from typing import Union, Callable, Type -from more_itertools import batched -from itertools import chain import numpy as np From 6e03c9d4852e80e8eddcfc7e2dedf121d2b263ae Mon Sep 17 00:00:00 2001 From: Philipenko Vladimir Date: Wed, 1 Nov 2023 19:00:41 +0300 Subject: [PATCH 5/6] Minor version --- dpipe/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpipe/__version__.py b/dpipe/__version__.py index cd9b137..0404d81 100644 --- a/dpipe/__version__.py +++ b/dpipe/__version__.py @@ -1 +1 @@ -__version__ = '0.2.9' +__version__ = '0.3.0' From 186a0cd58e5e4866b5b565e86204fcd1c6c160cc Mon Sep 17 00:00:00 2001 From: Max Date: Wed, 1 Nov 2023 18:13:50 +0000 Subject: [PATCH 6/6] fixed the destructor in Infinite --- dpipe/batch_iter/pipeline.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dpipe/batch_iter/pipeline.py b/dpipe/batch_iter/pipeline.py index 6f6966d..8cb2a83 100644 --- a/dpipe/batch_iter/pipeline.py +++ b/dpipe/batch_iter/pipeline.py @@ -176,7 +176,8 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): - if self.pipeline is not None: + # need getattr here, because the destructor might get called before the field is initialized + if getattr(self, 'pipeline', None) is not None: self.pipeline, pipeline = None, self.pipeline return pipeline.__exit__(exc_type, exc_val, exc_tb)