Skip to content

Commit

Permalink
Merge pull request #82 from neuro-ml/develop
Browse files Browse the repository at this point in the history
Add batched predict in `patches_grid`
  • Loading branch information
vovaf709 authored Nov 2, 2023
2 parents 804d2bc + 186a0cd commit 0448e2d
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 12 deletions.
2 changes: 1 addition & 1 deletion dpipe/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.2.8'
__version__ = '0.3.0'
3 changes: 2 additions & 1 deletion dpipe/batch_iter/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if self.pipeline is not None:
# need getattr here, because the destructor might get called before the field is initialized
if getattr(self, 'pipeline', None) is not None:
self.pipeline, pipeline = None, self.pipeline
return pipeline.__exit__(exc_type, exc_val, exc_tb)

Expand Down
16 changes: 10 additions & 6 deletions dpipe/predict/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def wrapper(x, *args, **kwargs):

def patches_grid(patch_size: AxesLike, stride: AxesLike, axis: AxesLike = None,
padding_values: Union[AxesParams, Callable] = 0, ratio: AxesParams = 0.5,
combiner: Type[PatchCombiner] = Average, get_boxes: Callable = get_boxes, **imops_kwargs):
combiner: Type[PatchCombiner] = Average, get_boxes: Callable = get_boxes, stream: bool = False,
**imops_kwargs):
"""
Divide an incoming array into patches of corresponding ``patch_size`` and ``stride`` and then combine
the predicted patches by aggregating the overlapping regions using the ``combiner`` - Average by default.
Expand Down Expand Up @@ -110,12 +111,15 @@ def wrapper(x, *args, **kwargs):
elif ((shape - local_size) < 0).any() or ((local_stride - shape + local_size) % local_stride).any():
raise ValueError('Input cannot be patched without remainder.')

if stream:
patches = predict(divide(x, local_size, local_stride, input_axis, get_boxes=get_boxes), *args, **kwargs)
else:
patches = pmap(
predict,
divide(x, local_size, local_stride, input_axis, get_boxes=get_boxes),
*args, **kwargs
)

patches = pmap(
predict,
divide(x, local_size, local_stride, input_axis, get_boxes=get_boxes),
*args, **kwargs
)
prediction = combine(
patches, extract(x.shape, input_axis), local_stride, axis,
combiner=combiner, get_boxes=get_boxes,
Expand Down
37 changes: 33 additions & 4 deletions tests/predict/test_shape.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,28 @@
import sys

import pytest
import numpy as np

from dpipe.im.utils import identity
from dpipe.predict.shape import *
from dpipe.itertools import pmap

assert_eq = np.testing.assert_array_almost_equal


def test_patches_grid():
@pytest.fixture(params=[1, 2, 3, 4])
def batch_size(request):
return request.param


@pytest.fixture(params=[False, True])
def stream(request):
return request.param


def test_patches_grid(stream):
def check_equal(**kwargs):
assert_eq(x, patches_grid(**kwargs, axis=-1)(identity)(x))
assert_eq(x, patches_grid(**kwargs, stream=stream, axis=-1)(identity)(x))

x = np.random.randn(3, 23, 20, 27) * 10
check_equal(patch_size=10, stride=1, padding_values=0)
Expand All @@ -27,13 +40,29 @@ def check_equal(**kwargs):
check_equal(patch_size=15, stride=12, padding_values=None)


def test_divisible_patches():
def test_divisible_patches(stream):
def check_equal(**kwargs):
assert_eq(x, divisible_shape(divisible)(patches_grid(**kwargs)(identity))(x))
assert_eq(x, divisible_shape(divisible)(patches_grid(**kwargs, stream=stream)(identity))(x))

size = [80] * 3
stride = [20] * 3
divisible = [8] * 3
for shape in [(373, 302, 55), (330, 252, 67)]:
x = np.random.randn(*shape)
check_equal(patch_size=size, stride=stride)


@pytest.mark.skipif(sys.version_info < (3, 7), reason='Requires python3.7 or higher.')
def test_batched_patches_grid(batch_size):
from more_itertools import batched
from itertools import chain

def patch_predict(patch):
return patch + 1

def stream_predict(patches_generator):
return chain.from_iterable(pmap(patch_predict, map(np.array, batched(patches_generator, batch_size))))

x = np.random.randn(3, 23, 20, 27) * 10

assert_eq(x + 1, patches_grid(patch_size=(6, 8, 9), stride=(4, 3, 2), stream=True, axis=(-1, -2, -3))(stream_predict)(x))
1 change: 1 addition & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pytest
pytest-cov
pytest-subtests
more_itertools; python_version >= '3.7'

0 comments on commit 0448e2d

Please sign in to comment.