Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
maxme1 committed Oct 18, 2021
2 parents 68870f1 + bdd8e67 commit fa55963
Show file tree
Hide file tree
Showing 25 changed files with 989 additions and 64 deletions.
2 changes: 1 addition & 1 deletion dpipe/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.0'
__version__ = '0.1.1'
49 changes: 37 additions & 12 deletions dpipe/im/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
Function for working with patches from tensors.
See the :doc:`tutorials/patches` tutorial for more details.
"""
from typing import Iterable
from typing import Iterable, Type, Tuple

import numpy as np

from .shape_ops import crop_to_box
from .axes import fill_by_indices, AxesLike, resolve_deprecation, axis_from_dim, broadcast_to_axis
from .box import make_box_, Box
from dpipe.itertools import zip_equal, peek, negate_indices, extract
from dpipe.itertools import zip_equal, peek
from .shape_utils import shape_after_convolution
from .utils import build_slices

__all__ = 'get_boxes', 'divide', 'combine'
__all__ = 'get_boxes', 'divide', 'combine', 'PatchCombiner', 'Average'


def get_boxes(shape: AxesLike, box_size: AxesLike, stride: AxesLike, axis: AxesLike = None,
Expand Down Expand Up @@ -72,11 +72,40 @@ def divide(x: np.ndarray, patch_size: AxesLike, stride: AxesLike, axis: AxesLike
yield crop_to_box(x, box)


class PatchCombiner:
def __init__(self, shape: Tuple[int, ...], dtype: np.dtype):
self.dtype = dtype
self.shape = shape

def update(self, box: Box, patch: np.ndarray):
raise NotImplementedError

def build(self) -> np.ndarray:
raise NotImplementedError


class Average(PatchCombiner):
def __init__(self, shape: Tuple[int, ...], dtype: np.dtype):
super().__init__(shape, dtype)
self._result = np.zeros(shape, dtype)
self._counts = np.zeros(shape, int)

def update(self, box: Box, patch: np.ndarray):
slc = build_slices(*box)
self._result[slc] += patch
self._counts[slc] += 1

def build(self):
np.true_divide(self._result, self._counts, out=self._result, where=self._counts > 0)
return self._result


def combine(patches: Iterable[np.ndarray], output_shape: AxesLike, stride: AxesLike,
axis: AxesLike = None, valid: bool = False) -> np.ndarray:
axis: AxesLike = None, valid: bool = False, combiner: Type[PatchCombiner] = Average) -> np.ndarray:
"""
Build a tensor of shape ``output_shape`` from ``patches`` obtained in a convolution-like approach
with corresponding parameters. The overlapping parts are averaged.
with corresponding parameters.
The overlapping parts are aggregated using the strategy from ``combiner`` - Average by default.
References
----------
Expand All @@ -98,12 +127,8 @@ def combine(patches: Iterable[np.ndarray], output_shape: AxesLike, stride: AxesL
if not np.issubdtype(dtype, np.floating):
dtype = float

result = np.zeros(output_shape, dtype)
counts = np.zeros(output_shape, int)
combiner = combiner(output_shape, dtype)
for box, patch in zip_equal(get_boxes(output_shape, patch_size, stride, valid=valid), patches):
slc = build_slices(*box)
result[slc] += patch
counts[slc] += 1
combiner.update(box, patch)

np.true_divide(result, counts, out=result, where=counts > 0)
return result
return combiner.build()
15 changes: 9 additions & 6 deletions dpipe/im/patch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Tools for patch extraction and generation.
"""
from functools import partial
from typing import Callable

import numpy as np
Expand All @@ -13,13 +14,15 @@
from dpipe.itertools import squeeze_first, extract, lmap


def sample_box_center_uniformly(shape, box_size: np.array):
"""Returns the center of a sampled uniformly box of size ``box_size``, contained in the array of shape ``shape``."""
return get_random_box(shape, box_size)[0] + box_size // 2
def uniform(shape, random_state: np.random.RandomState = None):
if not isinstance(random_state, np.random.RandomState):
random_state = np.random.RandomState(seed=random_state)
return np.array(lmap(random_state.randint, np.atleast_1d(shape)))


def uniform(shape):
return np.array(lmap(np.random.randint, np.atleast_1d(shape)))
def sample_box_center_uniformly(shape, box_size: np.array, random_state: np.random.RandomState = None):
"""Returns the center of a sampled uniformly box of size ``box_size``, contained in the array of shape ``shape``."""
return get_random_box(shape, box_size, distribution=partial(uniform, random_state=random_state))[0] + box_size // 2


def get_random_patch(*arrays: np.ndarray, patch_size: AxesLike, axis: AxesLike = None,
Expand Down Expand Up @@ -55,10 +58,10 @@ def get_random_patch(*arrays: np.ndarray, patch_size: AxesLike, axis: AxesLike =
return squeeze_first(tuple(crop_to_box(arr, box, axis) for arr in arrays))


# TODO: what to do if axis != None?
@returns_box
def get_random_box(shape: AxesLike, box_shape: AxesLike, axis: AxesLike = None, distribution: Callable = uniform):
"""Get a random box of shape ``box_shape`` that fits in the ``shape`` along the given ``axes``."""
axis = resolve_deprecation(axis, len(shape), box_shape)
start = distribution(shape_after_full_convolution(shape, box_shape, axis))
return start, start + fill_by_indices(shape, box_shape, axis)
# TODO: what to do if axis != None?
16 changes: 14 additions & 2 deletions dpipe/im/shape_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,27 @@ def append_dims(array, ndim=1):
return np.asarray(array)[idx]


def insert_dims(array, index=0, ndim=1):
"""Increase the dimensionality of ``array`` by adding ``ndim`` singleton dimensions before the specified ``index` of its shape."""
array = np.asarray(array)
idx = [(slice(None, None, 1)) for _ in range(array.ndim)]
idx = tuple(idx[:index] + [None]*ndim + idx[index:])
return array[idx]


def shape_after_convolution(shape: AxesLike, kernel_size: AxesLike, stride: AxesLike = 1, padding: AxesLike = 0,
dilation: AxesLike = 1, valid: bool = True) -> tuple:
"""Get the shape of a tensor after applying a convolution with corresponding parameters."""
padding, shape, dilation, kernel_size = map(np.asarray, [padding, shape, dilation, kernel_size])

result = (shape + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1
to_int = np.floor if valid else np.ceil
if valid:
result = np.floor(result).astype(int)
else:
# values <= 0 just mean that the kernel is greater than the input shape
# which is fine for valid=False
result = np.maximum(np.ceil(result).astype(int), 1)

result = to_int(result).astype(int)
new_shape = tuple(result)
if (result <= 0).any():
raise ValueError(f'Such a convolution is not possible. Output shape: {new_shape}.')
Expand Down
2 changes: 2 additions & 0 deletions dpipe/im/tests/test_shape_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def test_shape_after_convolution(subtests):
else:
assert new_shape == tuple(conv(tensor).shape[2:])

assert shape_after_convolution((10, 20), 11, 1, valid=False) == (1, 10)


def test_shape_after_full_convolution(subtests):
def subtest(shape, real_shape, kernel_size, axes=None):
Expand Down
9 changes: 5 additions & 4 deletions dpipe/im/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ def _get_rows_cols(max_cols, data):


def _slice_base(data: [np.ndarray], axis: int = -1, scale: int = 5, max_columns: int = None, colorbar: bool = False,
show_axes: bool = False, cmap: Union[Colormap, str] = 'gray', vlim: AxesParams = None,
show_axes: bool = False, cmap: Union[Colormap, str, Sequence[Colormap], Sequence[str]] = 'gray', vlim: AxesParams = None,
callback: Callable = None, sliders: dict = None, titles: list = None):
from ipywidgets import interact, IntSlider
check_shape_along_axis(*data, axis=axis)
cmap = np.broadcast_to(cmap, len(data)).tolist()
vlim = np.broadcast_to(vlim, [len(data), 2]).tolist()
rows, columns = _get_rows_cols(max_columns, data)
sliders = sliders or {}
Expand All @@ -40,8 +41,8 @@ def update(idx, **kwargs):
# hide unneeded axes
for ax in axes[len(data):]:
ax.set_visible(False)
for ax, x, (vmin, vmax), title in zip(axes, data, vlim, titles):
im = ax.imshow(x.take(idx, axis=axis), cmap=cmap, vmin=vmin, vmax=vmax)
for ax, x, cmap_, (vmin, vmax), title in zip(axes, data, cmap, vlim, titles):
im = ax.imshow(x.take(idx, axis=axis), cmap=cmap_, vmin=vmin, vmax=vmax)
if colorbar:
fig.colorbar(im, ax=ax, orientation='horizontal')
if not show_axes:
Expand Down Expand Up @@ -153,4 +154,4 @@ def default_clip(image, body_organ='Brain'):
if body_organ == 'Brain':
return np.clip(image, -20, 90)
elif body_organ == 'Lungs':
return np.clip(image, -1250, 250)
return np.clip(image, -1250, 250)
4 changes: 2 additions & 2 deletions dpipe/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def load_experiment_test_pred(identifier, experiment_path):
raise FileNotFoundError('No prediction found')


def load(path: PathLike, **kwargs):
def load(path: PathLike, ext: str = None, **kwargs):
"""
Load a file located at ``path``.
``kwargs`` are format-specific keyword arguments.
Expand All @@ -79,7 +79,7 @@ def load(path: PathLike, **kwargs):
npy, tif, png, jpg, bmp, hdr, img, csv,
dcm, nii, nii.gz, json, mhd, csv, txt, pickle, pkl, config
"""
name = Path(path).name
name = Path(path).name if ext is None else ext

if name.endswith(('.npy', '.npy.gz')):
if name.endswith('.gz'):
Expand Down
9 changes: 7 additions & 2 deletions dpipe/layers/fpn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import Callable, Sequence, Union
from warnings import warn

Expand Down Expand Up @@ -133,11 +134,15 @@ def interpolate_to_left(left: torch.Tensor, down: torch.Tensor, order: int = 0,
warn(msg, UserWarning)
warn(msg, DeprecationWarning)
order = mode
mode = None

if isinstance(order, int):
order = order_to_mode(order, len(down.shape) - 2)

if np.not_equal(left.shape, down.shape).any():
down = functional.interpolate(down, size=left.shape[2:], mode=order, align_corners=False)
interpolate = functional.interpolate
if order in ['linear', 'bilinear', ' bicubic', 'trilinear']:
interpolate = partial(interpolate, align_corners=False)

down = interpolate(down, size=left.shape[2:], mode=order)

return left, down
11 changes: 6 additions & 5 deletions dpipe/predict/shape.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from functools import wraps
from typing import Union, Callable
from typing import Union, Callable, Type

import numpy as np

from ..im.axes import broadcast_to_axis, AxesLike, AxesParams, axis_from_dim, resolve_deprecation
from ..im.grid import divide, combine
from ..im.grid import divide, combine, PatchCombiner, Average
from ..itertools import extract, pmap
from ..im.shape_ops import pad_to_shape, crop_to_shape, pad_to_divisible
from ..im.shape_utils import prepend_dims, extract_dims
Expand Down Expand Up @@ -80,10 +80,11 @@ def wrapper(x, *args, **kwargs):


def patches_grid(patch_size: AxesLike, stride: AxesLike, axis: AxesLike = None,
padding_values: Union[AxesParams, Callable] = 0, ratio: AxesParams = 0.5):
padding_values: Union[AxesParams, Callable] = 0, ratio: AxesParams = 0.5,
combiner: Type[PatchCombiner] = Average):
"""
Divide an incoming array into patches of corresponding ``patch_size`` and ``stride`` and then combine
predicted patches by averaging the overlapping regions.
the predicted patches by aggregating the overlapping regions using the ``combiner`` - Average by default.
If ``padding_values`` is not None, the array will be padded to an appropriate shape to make a valid division.
Afterwards the padding is removed.
Expand All @@ -107,7 +108,7 @@ def wrapper(x, *args, **kwargs):
x = pad_to_shape(x, new_shape, input_axis, padding_values, ratio)

patches = pmap(predict, divide(x, local_size, local_stride, input_axis), *args, **kwargs)
prediction = combine(patches, extract(x.shape, input_axis), local_stride, axis)
prediction = combine(patches, extract(x.shape, input_axis), local_stride, axis, combiner=combiner)

if valid:
prediction = crop_to_shape(prediction, shape, axis, ratio)
Expand Down
1 change: 1 addition & 0 deletions dpipe/prototypes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .strategy import *
2 changes: 2 additions & 0 deletions dpipe/prototypes/strategy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .train import *
from .optimization import *
93 changes: 93 additions & 0 deletions dpipe/prototypes/strategy/optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import torch
from torch.optim import Optimizer
from torch.cuda.amp import GradScaler
from abc import ABCMeta, abstractmethod
from typing import Sequence, Any, Generator, Union, Dict

from dpipe.torch import to_np
from dpipe.torch.utils import set_params
from .policy import PolicyHandler, Policy


class OptimizationPolicy(Policy, metaclass=ABCMeta):
def __init__(self, optimizer: Optimizer, *, optimizer_parameters: Union[Dict, PolicyHandler],
set_to_none=False, scaler: GradScaler = None):
self.scaler = scaler
self.optimizer = optimizer
self.set_to_none = set_to_none

if isinstance(optimizer_parameters, PolicyHandler):
self.optimizer_parameters = optimizer_parameters
else:
self.optimizer_parameters = PolicyHandler(optimizer_parameters)

@abstractmethod
def optimize(self, losses_gen: Generator):
pass

@property
def policies(self):
return self.optimizer_parameters.policies

def epoch_started(self, epoch: int):
self.optimizer_parameters.epoch_started(epoch)
set_params(self.optimizer, **self.optimizer_parameters.current_values)

def epoch_finished(self, epoch: int, train_losses: Sequence, metrics: dict = None, policies: dict = None):
self.optimizer_parameters.epoch_finished(epoch, train_losses, metrics)

def train_step_started(self, epoch: int, iteration: int):
self.optimizer_parameters.train_step_started(epoch, iteration)

def train_step_finished(self, epoch: int, iteration: int, loss: Any):
self.optimizer_parameters.train_step_finished(epoch, iteration, loss)

def validation_started(self, epoch: int, train_losses: Sequence):
self.optimizer_parameters.validation_started(epoch, train_losses)


class GradientsAccumulator(OptimizationPolicy):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def optimize(self, losses_gen: Generator):
assert isinstance(losses_gen, Generator)
self.optimizer.zero_grad(set_to_none=self.set_to_none)

total_loss = 0.
if self.scaler is not None:
with torch.cuda.amp.autocast(False):
for loss in losses_gen:
self.scaler.scale(loss).backward()
total_loss += loss

self.scaler.step(self.optimizer)
self.scaler.update()
else:
for loss in losses_gen:
loss.backward()
total_loss += loss
self.optimizer.step()

return to_np(total_loss)


class LossAccumulator(OptimizationPolicy):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def optimize(self, losses_gen: Generator):
assert isinstance(losses_gen, Generator)
self.optimizer.zero_grad(set_to_none=self.set_to_none)

total_loss = sum(losses_gen)
if self.scaler is not None:
with torch.cuda.amp.autocast(False):
self.scaler.scale(total_loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
else:
total_loss.backward()
self.optimizer.step()

return to_np(total_loss)
Loading

0 comments on commit fa55963

Please sign in to comment.