Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
vovaf709 committed Apr 16, 2024
1 parent 3b9ee95 commit 88d8c45
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
9 changes: 4 additions & 5 deletions imops/crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .backend import BackendLike
from .numeric import _NUMERIC_DEFAULT_NUM_THREADS
from .pad import pad
from .utils import AxesLike, AxesParams, broadcast_axis, fill_by_indices
from .utils import AxesLike, AxesParams, assert_subdtype, broadcast_axis, fill_by_indices


def crop_to_shape(x: np.ndarray, shape: AxesLike, axis: AxesLike = None, ratio: AxesParams = 0.5) -> np.ndarray:
Expand Down Expand Up @@ -36,8 +36,8 @@ def crop_to_shape(x: np.ndarray, shape: AxesLike, axis: AxesLike = None, ratio:
"""
x = np.asarray(x)
shape = np.asarray(shape)
if not np.issubdtype(shape.dtype, np.integer):
raise ValueError(f'`shape` must be of integer dtype, got {shape.dtype}')
assert_subdtype(shape.dtype, np.integer, 'shape')

axis, shape, ratio = broadcast_axis(axis, x.ndim, shape, ratio)

old_shape, new_shape = np.array(x.shape), np.array(fill_by_indices(x.shape, shape, axis))
Expand Down Expand Up @@ -93,8 +93,7 @@ def crop_to_box(
"""
x = np.asarray(x)
box = np.asarray(box)
if not np.issubdtype(box.dtype, np.integer):
raise ValueError(f'`box` must be of integer dtype, got {box.dtype}')
assert_subdtype(box.dtype, np.integer, 'box')

start, stop = box
axis, start, stop = broadcast_axis(axis, x.ndim, start, stop)
Expand Down
5 changes: 5 additions & 0 deletions imops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,8 @@ def check_len(*args) -> None:
lengths = list(map(len, args))
if any(length != lengths[0] for length in lengths):
raise ValueError(f'Arguments of equal length are required: {", ".join(map(str, lengths))}')


def assert_subdtype(dtype, ref_dtype, name):
if not np.issubdtype(dtype, ref_dtype):
raise ValueError(f'`{name}` must be of {ref_dtype.__name__} dtype, got {dtype}')

0 comments on commit 88d8c45

Please sign in to comment.