Skip to content

Commit

Permalink
Merge branch 'main' into try_except
Browse files Browse the repository at this point in the history
  • Loading branch information
odulcy-mindee committed Dec 18, 2024
2 parents 536b93e + 5d4f28e commit 3fda58e
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ repos:
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.1
rev: v0.8.3
hooks:
- id: ruff
args: [ --fix ]
Expand Down
8 changes: 4 additions & 4 deletions docs/source/using_doctr/using_model_export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,10 @@ It defines a common format for representing models, including the network struct
from doctr.models import vitstr_small
from doctr.models.utils import export_model_to_onnx
batch_size = 16
batch_size = 1
input_shape = (3, 32, 128)
model = vitstr_small(pretrained=True, exportable=True)
dummy_input = torch.rand((batch_size, input_shape), dtype=torch.float32)
dummy_input = torch.rand((batch_size, *input_shape), dtype=torch.float32)
model_path = export_model_to_onnx(
model,
model_name="vitstr.onnx",
Expand All @@ -137,10 +137,10 @@ It defines a common format for representing models, including the network struct
from doctr.models import vitstr_small
from doctr.models.utils import export_model_to_onnx
batch_size = 16
batch_size = 1
input_shape = (32, 128, 3)
model = vitstr_small(pretrained=True, exportable=True)
dummy_input = [tf.TensorSpec([batch_size, input_shape], tf.float32, name="input")]
dummy_input = [tf.TensorSpec([batch_size, *input_shape], tf.float32, name="input")]
model_path, output = export_model_to_onnx(
model,
model_name="vitstr.onnx",
Expand Down
21 changes: 11 additions & 10 deletions doctr/transforms/functional/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
import torch
from scipy.ndimage import gaussian_filter
from torchvision.transforms import functional as F

from doctr.utils.geometry import rotate_abs_geoms
Expand Down Expand Up @@ -113,24 +114,24 @@ def crop_detection(


def random_shadow(img: torch.Tensor, opacity_range: tuple[float, float], **kwargs) -> torch.Tensor:
"""Crop and image and associated bboxes
"""Apply a random shadow effect to an image using NumPy for blurring.
Args:
img: image to modify
opacity_range: the minimum and maximum desired opacity of the shadow
**kwargs: additional arguments to pass to `create_shadow_mask`
img: Image to modify (C, H, W) as a PyTorch tensor.
opacity_range: The minimum and maximum desired opacity of the shadow.
**kwargs: Additional arguments to pass to `create_shadow_mask`.
Returns:
shaded image
Shadowed image as a PyTorch tensor (same shape as input).
"""
shadow_mask = create_shadow_mask(img.shape[1:], **kwargs)

opacity = np.random.uniform(*opacity_range)
shadow_tensor = 1 - torch.from_numpy(shadow_mask[None, ...])

# Add some blur to make it believable
k = 7 + 2 * int(4 * np.random.rand(1))
# Apply Gaussian blur to the shadow mask
sigma = np.random.uniform(0.5, 5.0)
shadow_tensor = F.gaussian_blur(shadow_tensor, k, sigma=[sigma, sigma])
blurred_mask = gaussian_filter(shadow_mask, sigma=sigma)

shadow_tensor = 1 - torch.from_numpy(blurred_mask).float()
shadow_tensor = shadow_tensor.to(img.device).unsqueeze(0) # Add channel dimension

return opacity * shadow_tensor * img + (1 - opacity) * img
44 changes: 43 additions & 1 deletion doctr/transforms/modules/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,22 @@
import numpy as np
import torch
from PIL.Image import Image
from scipy.ndimage import gaussian_filter
from torch.nn.functional import pad
from torchvision.transforms import functional as F
from torchvision.transforms import transforms as T

from ..functional.pytorch import random_shadow

__all__ = ["Resize", "GaussianNoise", "ChannelShuffle", "RandomHorizontalFlip", "RandomShadow", "RandomResize"]
__all__ = [
"Resize",
"GaussianNoise",
"ChannelShuffle",
"RandomHorizontalFlip",
"RandomShadow",
"RandomResize",
"GaussianBlur",
]


class Resize(T.Resize):
Expand Down Expand Up @@ -142,6 +151,39 @@ def extra_repr(self) -> str:
return f"mean={self.mean}, std={self.std}"


class GaussianBlur(torch.nn.Module):
"""Apply Gaussian Blur to the input tensor
>>> import torch
>>> from doctr.transforms import GaussianBlur
>>> transfo = GaussianBlur(sigma=(0.0, 1.0))
Args:
sigma : standard deviation range for the gaussian kernel
"""

def __init__(self, sigma: tuple[float, float]) -> None:
super().__init__()
self.sigma_range = sigma

def forward(self, x: torch.Tensor) -> torch.Tensor:
# Sample a random sigma value within the specified range
sigma = torch.empty(1).uniform_(*self.sigma_range).item()

# Apply Gaussian blur along spatial dimensions only
blurred = torch.tensor(
gaussian_filter(
x.numpy(),
sigma=sigma,
mode="reflect",
truncate=4.0,
),
dtype=x.dtype,
device=x.device,
)
return blurred


class ChannelShuffle(torch.nn.Module):
"""Randomly shuffle channel order of a given image"""

Expand Down
6 changes: 3 additions & 3 deletions references/detection/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import wandb
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR, OneCycleLR, PolynomialLR
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torchvision.transforms.v2 import Compose, GaussianBlur, Normalize, RandomGrayscale, RandomPhotometricDistort
from torchvision.transforms.v2 import Compose, Normalize, RandomGrayscale, RandomPhotometricDistort
from tqdm.auto import tqdm

from doctr import transforms as T
Expand Down Expand Up @@ -384,12 +384,12 @@ def main(args):
img_transforms = T.OneOf([
Compose([
T.RandomApply(T.ColorInversion(), 0.3),
T.RandomApply(GaussianBlur(kernel_size=5, sigma=(0.1, 4)), 0.2),
T.RandomApply(T.GaussianBlur(sigma=(0.5, 1.5)), 0.2),
]),
Compose([
T.RandomApply(T.RandomShadow(), 0.3),
T.RandomApply(T.GaussianNoise(), 0.1),
T.RandomApply(GaussianBlur(kernel_size=5, sigma=(0.1, 4)), 0.3),
T.RandomApply(T.GaussianBlur(sigma=(0.5, 1.5)), 0.3),
RandomGrayscale(p=0.15),
]),
RandomPhotometricDistort(p=0.3),
Expand Down
4 changes: 2 additions & 2 deletions references/detection/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,13 +251,13 @@ def main(args):
img_transforms = T.OneOf([
T.Compose([
T.RandomApply(T.ColorInversion(), 0.3),
T.RandomApply(T.GaussianBlur(kernel_shape=5, std=(0.1, 4)), 0.2),
T.RandomApply(T.GaussianBlur(kernel_shape=5, std=(0.5, 1.5)), 0.2),
]),
T.Compose([
T.RandomApply(T.RandomJpegQuality(60), 0.15),
# T.RandomApply(T.RandomShadow(), 0.2), # Broken atm on GPU
T.RandomApply(T.GaussianNoise(), 0.1),
T.RandomApply(T.GaussianBlur(kernel_shape=5, std=(0.1, 4)), 0.3),
T.RandomApply(T.GaussianBlur(kernel_shape=5, std=(0.5, 1.5)), 0.3),
T.RandomApply(T.ToGray(num_output_channels=3), 0.15),
]),
T.Compose([
Expand Down
33 changes: 33 additions & 0 deletions tests/pytorch/test_transforms_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from doctr.transforms import (
ChannelShuffle,
ColorInversion,
GaussianBlur,
GaussianNoise,
RandomCrop,
RandomHorizontalFlip,
Expand Down Expand Up @@ -278,6 +279,38 @@ def test_gaussian_noise(input_dtype, input_shape):
assert torch.all(transformed <= 1.0)


@pytest.mark.parametrize(
"input_dtype, input_shape",
[
[torch.float32, (3, 32, 32)],
[torch.uint8, (3, 32, 32)],
],
)
def test_gaussian_blur(input_dtype, input_shape):
sigma_range = (0.0, 1.0)
transform = GaussianBlur(sigma=sigma_range)

input_t = torch.rand(input_shape, dtype=torch.float32)

if input_dtype == torch.uint8:
input_t = (255 * input_t).round().to(dtype=torch.uint8)

blurred = transform(input_t)

assert isinstance(blurred, torch.Tensor)
assert blurred.shape == input_shape
assert blurred.dtype == input_dtype

if input_dtype == torch.uint8:
assert torch.any(blurred != input_t)
assert torch.all(blurred <= 255)
assert torch.all(blurred >= 0)
else:
assert torch.any(blurred != input_t)
assert torch.all(blurred <= 1.0)
assert torch.all(blurred >= 0.0)


@pytest.mark.parametrize(
"p,target",
[
Expand Down

0 comments on commit 3fda58e

Please sign in to comment.