From 3443defb9063ab7b9791ea7f0e08cc551dbb36c9 Mon Sep 17 00:00:00 2001 From: felix Date: Wed, 27 Sep 2023 11:44:47 +0200 Subject: [PATCH 1/2] update random apply to work also with targets --- doctr/transforms/modules/base.py | 8 ++++---- references/detection/train_pytorch.py | 4 ++-- references/detection/train_tensorflow.py | 4 ++-- tests/common/test_transforms.py | 3 +++ 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/doctr/transforms/modules/base.py b/doctr/transforms/modules/base.py index ae5e784bb6..fc49b09253 100644 --- a/doctr/transforms/modules/base.py +++ b/doctr/transforms/modules/base.py @@ -5,7 +5,7 @@ import math import random -from typing import Any, Callable, Dict, List, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np @@ -206,10 +206,10 @@ def __init__(self, transform: Callable[[Any], Any], p: float = 0.5) -> None: def extra_repr(self) -> str: return f"transform={self.transform}, p={self.p}" - def __call__(self, img: Any) -> Any: + def __call__(self, img: Any, target: Optional[np.ndarray] = None) -> Union[Any, Tuple[Any, np.ndarray]]: if random.random() < self.p: - return self.transform(img) - return img + return self.transform(img) if target is None else self.transform(img, target) + return img if target is None else (img, target) class RandomRotate(NestedObject): diff --git a/references/detection/train_pytorch.py b/references/detection/train_pytorch.py index 41f421a341..e6cbf14784 100644 --- a/references/detection/train_pytorch.py +++ b/references/detection/train_pytorch.py @@ -193,7 +193,7 @@ def main(args): + ( [ T.Resize(args.input_size, preserve_aspect_ratio=True), # This does not pad - T.RandomRotate(90, expand=True), + T.RandomApply(T.RandomRotate(90, expand=True), 0.5), T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True), ] if args.rotation and not args.eval_straight @@ -286,7 +286,7 @@ def main(args): + ( [ T.Resize(args.input_size, preserve_aspect_ratio=True), - T.RandomRotate(90, expand=True), + T.RandomApply(T.RandomRotate(90, expand=True), 0.5), T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True), ] if args.rotation diff --git a/references/detection/train_tensorflow.py b/references/detection/train_tensorflow.py index a0f1eeae64..95437b444f 100644 --- a/references/detection/train_tensorflow.py +++ b/references/detection/train_tensorflow.py @@ -155,7 +155,7 @@ def main(args): + ( [ T.Resize(args.input_size, preserve_aspect_ratio=True), # This does not pad - T.RandomRotate(90, expand=True), + T.RandomApply(T.RandomRotate(90, expand=True), 0.5), T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True), ] if args.rotation and not args.eval_straight @@ -240,7 +240,7 @@ def main(args): + ( [ T.Resize(args.input_size, preserve_aspect_ratio=True), # This does not pad - T.RandomRotate(90, expand=True), + T.RandomApply(T.RandomRotate(90, expand=True), 0.5), T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True), ] if args.rotation diff --git a/tests/common/test_transforms.py b/tests/common/test_transforms.py index 2b235377c1..4cc04c32d6 100644 --- a/tests/common/test_transforms.py +++ b/tests/common/test_transforms.py @@ -27,6 +27,9 @@ def test_randomapply(): transfo = T.RandomApply(lambda x: 1 - x) out = transfo(1) assert out == 0 or out == 1 + transfo = T.RandomApply(lambda x, y: (1 - x, 2 * y)) + out = transfo(1, np.array([2])) + assert out == (0, 4) or out == (1, 2) and isinstance(out[1], np.ndarray) assert repr(transfo).endswith(", p=0.5)") From 5e9e428d2b5b60feca6dabd53c421f7b4367e23a Mon Sep 17 00:00:00 2001 From: felix Date: Wed, 27 Sep 2023 12:02:57 +0200 Subject: [PATCH 2/2] ignore mypy warning --- doctr/transforms/modules/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doctr/transforms/modules/base.py b/doctr/transforms/modules/base.py index fc49b09253..d47a838037 100644 --- a/doctr/transforms/modules/base.py +++ b/doctr/transforms/modules/base.py @@ -208,7 +208,7 @@ def extra_repr(self) -> str: def __call__(self, img: Any, target: Optional[np.ndarray] = None) -> Union[Any, Tuple[Any, np.ndarray]]: if random.random() < self.p: - return self.transform(img) if target is None else self.transform(img, target) + return self.transform(img) if target is None else self.transform(img, target) # type: ignore[call-arg] return img if target is None else (img, target)