diff --git a/dpipe/__version__.py b/dpipe/__version__.py index 14e974f..0404d81 100644 --- a/dpipe/__version__.py +++ b/dpipe/__version__.py @@ -1 +1 @@ -__version__ = '0.2.8' +__version__ = '0.3.0' diff --git a/dpipe/batch_iter/pipeline.py b/dpipe/batch_iter/pipeline.py index 6f6966d..8cb2a83 100644 --- a/dpipe/batch_iter/pipeline.py +++ b/dpipe/batch_iter/pipeline.py @@ -176,7 +176,8 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): - if self.pipeline is not None: + # need getattr here, because the destructor might get called before the field is initialized + if getattr(self, 'pipeline', None) is not None: self.pipeline, pipeline = None, self.pipeline return pipeline.__exit__(exc_type, exc_val, exc_tb) diff --git a/dpipe/predict/shape.py b/dpipe/predict/shape.py index 1f23d20..97186d1 100644 --- a/dpipe/predict/shape.py +++ b/dpipe/predict/shape.py @@ -81,7 +81,8 @@ def wrapper(x, *args, **kwargs): def patches_grid(patch_size: AxesLike, stride: AxesLike, axis: AxesLike = None, padding_values: Union[AxesParams, Callable] = 0, ratio: AxesParams = 0.5, - combiner: Type[PatchCombiner] = Average, get_boxes: Callable = get_boxes, **imops_kwargs): + combiner: Type[PatchCombiner] = Average, get_boxes: Callable = get_boxes, stream: bool = False, + **imops_kwargs): """ Divide an incoming array into patches of corresponding ``patch_size`` and ``stride`` and then combine the predicted patches by aggregating the overlapping regions using the ``combiner`` - Average by default. @@ -110,12 +111,15 @@ def wrapper(x, *args, **kwargs): elif ((shape - local_size) < 0).any() or ((local_stride - shape + local_size) % local_stride).any(): raise ValueError('Input cannot be patched without remainder.') + if stream: + patches = predict(divide(x, local_size, local_stride, input_axis, get_boxes=get_boxes), *args, **kwargs) + else: + patches = pmap( + predict, + divide(x, local_size, local_stride, input_axis, get_boxes=get_boxes), + *args, **kwargs + ) - patches = pmap( - predict, - divide(x, local_size, local_stride, input_axis, get_boxes=get_boxes), - *args, **kwargs - ) prediction = combine( patches, extract(x.shape, input_axis), local_stride, axis, combiner=combiner, get_boxes=get_boxes, diff --git a/tests/predict/test_shape.py b/tests/predict/test_shape.py index 69e65bb..a871a39 100644 --- a/tests/predict/test_shape.py +++ b/tests/predict/test_shape.py @@ -1,15 +1,28 @@ +import sys + import pytest import numpy as np from dpipe.im.utils import identity from dpipe.predict.shape import * +from dpipe.itertools import pmap assert_eq = np.testing.assert_array_almost_equal -def test_patches_grid(): +@pytest.fixture(params=[1, 2, 3, 4]) +def batch_size(request): + return request.param + + +@pytest.fixture(params=[False, True]) +def stream(request): + return request.param + + +def test_patches_grid(stream): def check_equal(**kwargs): - assert_eq(x, patches_grid(**kwargs, axis=-1)(identity)(x)) + assert_eq(x, patches_grid(**kwargs, stream=stream, axis=-1)(identity)(x)) x = np.random.randn(3, 23, 20, 27) * 10 check_equal(patch_size=10, stride=1, padding_values=0) @@ -27,9 +40,9 @@ def check_equal(**kwargs): check_equal(patch_size=15, stride=12, padding_values=None) -def test_divisible_patches(): +def test_divisible_patches(stream): def check_equal(**kwargs): - assert_eq(x, divisible_shape(divisible)(patches_grid(**kwargs)(identity))(x)) + assert_eq(x, divisible_shape(divisible)(patches_grid(**kwargs, stream=stream)(identity))(x)) size = [80] * 3 stride = [20] * 3 @@ -37,3 +50,19 @@ def check_equal(**kwargs): for shape in [(373, 302, 55), (330, 252, 67)]: x = np.random.randn(*shape) check_equal(patch_size=size, stride=stride) + + +@pytest.mark.skipif(sys.version_info < (3, 7), reason='Requires python3.7 or higher.') +def test_batched_patches_grid(batch_size): + from more_itertools import batched + from itertools import chain + + def patch_predict(patch): + return patch + 1 + + def stream_predict(patches_generator): + return chain.from_iterable(pmap(patch_predict, map(np.array, batched(patches_generator, batch_size)))) + + x = np.random.randn(3, 23, 20, 27) * 10 + + assert_eq(x + 1, patches_grid(patch_size=(6, 8, 9), stride=(4, 3, 2), stream=True, axis=(-1, -2, -3))(stream_predict)(x)) diff --git a/tests/requirements.txt b/tests/requirements.txt index d29db9b..3f3f0cc 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,3 +1,4 @@ pytest pytest-cov pytest-subtests +more_itertools; python_version >= '3.7'