diff --git a/dpipe/__version__.py b/dpipe/__version__.py index 73e3bb4..80eb7f9 100644 --- a/dpipe/__version__.py +++ b/dpipe/__version__.py @@ -1 +1 @@ -__version__ = '0.3.2' +__version__ = '0.3.3' diff --git a/dpipe/predict/shape.py b/dpipe/predict/shape.py index 2120ad0..4a4c245 100644 --- a/dpipe/predict/shape.py +++ b/dpipe/predict/shape.py @@ -123,9 +123,13 @@ def wrapper(x, *args, **kwargs): else: patches = pmap(predict, input_patches, *args, **kwargs) + prediction_kwargs = {'combiner': combiner, 'get_boxes': get_boxes} + if not use_torch: + prediction_kwargs['use_torch'] = use_torch + prediction = combine( patches_wrapper(patches), extract(x.shape, input_axis), local_stride, axis, - combiner=combiner, get_boxes=get_boxes, use_torch=use_torch + **prediction_kwargs ) if valid: