From eb9182bba6d7a3c39e27867fed8615821751d09b Mon Sep 17 00:00:00 2001 From: Alex Severin Date: Mon, 16 Dec 2024 16:10:24 +0300 Subject: [PATCH] refactor --- src/microwink/common.py | 9 +++++++-- src/microwink/seg.py | 16 +++++----------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/microwink/common.py b/src/microwink/common.py index 5589ed8..394a762 100644 --- a/src/microwink/common.py +++ b/src/microwink/common.py @@ -1,3 +1,4 @@ +import typing import numpy as np from dataclasses import dataclass @@ -5,6 +6,9 @@ from PIL import Image, ImageDraw from PIL.Image import Image as PILImage +if typing.TYPE_CHECKING: + from _typeshed import ConvertibleToFloat + @dataclass class Box: @@ -14,7 +18,7 @@ class Box: w: float @staticmethod - def from_xyxy(box: Iterable[Any]) -> "Box": + def from_xyxy(box: Iterable["ConvertibleToFloat"]) -> "Box": x1, y1, x2, y2 = [float(t) for t in box] h = y2 - y1 w = x2 - x1 @@ -34,6 +38,7 @@ def draw_box( color: tuple[int, ...] | str | float = (255, 0, 0), width: int = 3, ) -> PILImage: + assert width >= 0 image = image.copy() draw = ImageDraw.Draw(image) points = [(box.x, box.y), (box.x + box.w, box.y + box.h)] @@ -51,7 +56,7 @@ def draw_mask( assert 0.0 <= alpha <= 1.0 assert (image.height, image.width) == binary_mask.shape img = np.array(image) - assert len(img.shape) == len(color) + assert img.ndim == len(color) overlay = np.zeros_like(img) overlay[binary_mask] = color assert overlay.shape == img.shape diff --git a/src/microwink/seg.py b/src/microwink/seg.py index afa2a9c..c493f04 100644 --- a/src/microwink/seg.py +++ b/src/microwink/seg.py @@ -94,7 +94,7 @@ def apply( out = [] assert len(result.boxes) == len(result.scores) == len(result.mask_maps) for box, score, mask in zip(result.boxes, result.scores, result.mask_maps): - assert len(mask.shape) == 2 + assert mask.ndim == 2 assert mask.dtype == np.float64 out.append( SegResult( @@ -166,17 +166,11 @@ def postprocess_mask( (ih, iw), (mask_height, mask_width), ) - mask_maps = np.zeros( - ( - len(scaled_boxes), - ih, - iw, - ) - ) + mask_maps = np.zeros((len(scaled_boxes), ih, iw)) assert len(scaled_boxes) == len(masks) assert len(scaled_boxes) == len(boxes) for i, (box, scaled_box, mask) in enumerate(zip(boxes, scaled_boxes, masks)): - assert 2 == len(mask.shape) + assert mask.ndim == 2 scale_x1 = math.floor(scaled_box[0]) scale_y1 = math.floor(scaled_box[1]) @@ -212,7 +206,7 @@ def preprocess(self, image: PILImage) -> np.ndarray: if image.size != size: image = image.resize(size) img = np.array(image).astype(np.float32) - assert len(img.shape) == 3 + assert img.ndim == 3 img /= 255.0 img = img.transpose(2, 0, 1) tensor = img[np.newaxis, :, :, :] @@ -307,5 +301,5 @@ def resize(buf: np.ndarray, size: tuple[W, H]) -> np.ndarray: img = Image.fromarray(buf).resize(size) out = np.array(img) assert out.dtype == buf.dtype - assert len(out.shape) == len(buf.shape) + assert out.ndim == buf.ndim return out