From c10b5b4b3bff017e260382e598eb333511f3d1fb Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Wed, 20 Mar 2024 16:12:08 +0400 Subject: [PATCH 01/12] Initial commit --- tests/datasets/test_nasa_marine_debris.py | 6 +- tests/datasets/test_vhr10.py | 14 ++-- torchgeo/datamodules/nasa_marine_debris.py | 11 ++- torchgeo/datamodules/utils.py | 80 ++++------------------ torchgeo/datamodules/vhr10.py | 36 +++++----- torchgeo/datasets/nasa_marine_debris.py | 8 ++- torchgeo/datasets/vhr10.py | 16 ++--- torchgeo/trainers/detection.py | 18 ++++- 8 files changed, 74 insertions(+), 115 deletions(-) diff --git a/tests/datasets/test_nasa_marine_debris.py b/tests/datasets/test_nasa_marine_debris.py index e2787b51b7d..e2544951c87 100644 --- a/tests/datasets/test_nasa_marine_debris.py +++ b/tests/datasets/test_nasa_marine_debris.py @@ -28,9 +28,9 @@ def test_getitem(self, dataset: NASAMarineDebris) -> None: x = dataset[0] assert isinstance(x, dict) assert isinstance(x['image'], torch.Tensor) - assert isinstance(x['boxes'], torch.Tensor) + assert isinstance(x['bbox_xyxy'], torch.Tensor) assert x['image'].shape[0] == 3 - assert x['boxes'].shape[-1] == 4 + assert x['bbox_xyxy'].shape[-1] == 4 def test_len(self, dataset: NASAMarineDebris) -> None: assert len(dataset) == 5 @@ -50,6 +50,6 @@ def test_plot(self, dataset: NASAMarineDebris) -> None: plt.close() dataset.plot(x, show_titles=False) plt.close() - x['prediction_boxes'] = x['boxes'].clone() + x['prediction_boxes'] = x['bbox_xyxy'].clone() dataset.plot(x) plt.close() diff --git a/tests/datasets/test_vhr10.py b/tests/datasets/test_vhr10.py index aa0920d69e7..3d04b88f50e 100644 --- a/tests/datasets/test_vhr10.py +++ b/tests/datasets/test_vhr10.py @@ -41,10 +41,10 @@ def test_getitem(self, dataset: VHR10) -> None: assert isinstance(x, dict) assert isinstance(x['image'], torch.Tensor) if dataset.split == 'positive': - assert isinstance(x['labels'], torch.Tensor) - assert isinstance(x['boxes'], torch.Tensor) - if 'masks' in x: - assert isinstance(x['masks'], torch.Tensor) + assert isinstance(x['class'], torch.Tensor) + assert isinstance(x['bbox_xyxy'], torch.Tensor) + if 'mask' in x: + assert isinstance(x['mask'], torch.Tensor) def test_len(self, dataset: VHR10) -> None: if dataset.split == 'positive': @@ -82,10 +82,10 @@ def test_plot(self, dataset: VHR10) -> None: scores = [0.7, 0.3, 0.7] for i in range(3): x = dataset[i] - x['prediction_labels'] = x['labels'] - x['prediction_boxes'] = x['boxes'] + x['prediction_labels'] = x['class'] + x['prediction_boxes'] = x['bbox_xyxy'] x['prediction_scores'] = torch.Tensor([scores[i]]) if 'masks' in x: - x['prediction_masks'] = x['masks'] + x['prediction_masks'] = x['mask'] dataset.plot(x, show_feats='masks') plt.close() diff --git a/torchgeo/datamodules/nasa_marine_debris.py b/torchgeo/datamodules/nasa_marine_debris.py index 33290fa4ac9..2fcba2236c3 100644 --- a/torchgeo/datamodules/nasa_marine_debris.py +++ b/torchgeo/datamodules/nasa_marine_debris.py @@ -10,9 +10,8 @@ from torch.utils.data import random_split from ..datasets import NASAMarineDebris -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule -from .utils import AugPipe, collate_fn_detection +from .utils import collate_fn_detection class NASAMarineDebrisDataModule(NonGeoDataModule): @@ -46,12 +45,10 @@ def __init__( self.val_split_pct = val_split_pct self.test_split_pct = test_split_pct - self.aug = AugPipe( - AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'boxes'] - ), - batch_size, + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) + self.aug.keepdim = True # type: ignore[attr-defined] self.collate_fn = collate_fn_detection diff --git a/torchgeo/datamodules/utils.py b/torchgeo/datamodules/utils.py index 4c3aab63b61..5ae7127ab68 100644 --- a/torchgeo/datamodules/utils.py +++ b/torchgeo/datamodules/utils.py @@ -4,14 +4,12 @@ """Common datamodule utilities.""" import math -from collections.abc import Callable, Iterable +from collections.abc import Iterable from typing import Any import numpy as np import torch -from einops import rearrange from torch import Tensor -from torch.nn import Module # Based on lightning_lite.utilities.exceptions @@ -19,60 +17,6 @@ class MisconfigurationException(Exception): """Exception used to inform users of misuse with Lightning.""" -class AugPipe(Module): - """Pipeline for applying augmentations sequentially on select data keys. - - .. versionadded:: 0.6 - """ - - def __init__( - self, augs: Callable[[dict[str, Any]], dict[str, Any]], batch_size: int - ) -> None: - """Initialize a new AugPipe instance. - - Args: - augs: Augmentations to apply. - batch_size: Batch size - """ - super().__init__() - self.augs = augs - self.batch_size = batch_size - - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: - """Apply the augmentation. - - Args: - batch: Input batch. - - Returns: - Augmented batch. - """ - batch_len = len(batch['image']) - for bs in range(batch_len): - batch_dict = { - 'image': batch['image'][bs], - 'labels': batch['labels'][bs], - 'boxes': batch['boxes'][bs], - } - - if 'masks' in batch: - batch_dict['masks'] = batch['masks'][bs] - - batch_dict = self.augs(batch_dict) - - batch['image'][bs] = batch_dict['image'] - batch['labels'][bs] = batch_dict['labels'] - batch['boxes'][bs] = batch_dict['boxes'] - - if 'masks' in batch: - batch['masks'][bs] = batch_dict['masks'] - - # Stack images - batch['image'] = rearrange(batch['image'], 'b () c h w -> b c h w') - - return batch - - def collate_fn_detection(batch: list[dict[str, Tensor]]) -> dict[str, Any]: """Custom collate fn for object detection and instance segmentation. @@ -85,17 +29,23 @@ def collate_fn_detection(batch: list[dict[str, Tensor]]) -> dict[str, Any]: .. versionadded:: 0.6 """ output: dict[str, Any] = {} - output['image'] = [sample['image'] for sample in batch] - output['boxes'] = [sample['boxes'].float() for sample in batch] - if 'labels' in batch[0]: - output['labels'] = [sample['labels'] for sample in batch] + output['image'] = torch.stack([sample['image'] for sample in batch]) + # Get bbox key as it can be one of {"bbox", "bbox_xyxy", "bbox_xywh"} + bbox_key = 'boxes' + for key in batch[0].keys(): + if key in {'bbox', 'bbox_xyxy', 'bbox_xywh'}: + bbox_key = key + + output[bbox_key] = [sample[bbox_key].float() for sample in batch] + if 'class' in batch[0].keys(): + output['class'] = [sample['class'] for sample in batch] else: - output['labels'] = [ - torch.tensor([1] * len(sample['boxes'])) for sample in batch + output['class'] = [ + torch.tensor([1] * len(sample[bbox_key])) for sample in batch ] - if 'masks' in batch[0]: - output['masks'] = [sample['masks'] for sample in batch] + if 'mask' in batch[0]: + output['mask'] = [sample['mask'] for sample in batch] return output diff --git a/torchgeo/datamodules/vhr10.py b/torchgeo/datamodules/vhr10.py index 1cf8987f007..49e7b90241b 100644 --- a/torchgeo/datamodules/vhr10.py +++ b/torchgeo/datamodules/vhr10.py @@ -11,9 +11,8 @@ from ..datasets import VHR10 from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule -from .utils import AugPipe, collate_fn_detection +from .utils import collate_fn_detection class VHR10DataModule(NonGeoDataModule): @@ -52,25 +51,20 @@ def __init__( self.collate_fn = collate_fn_detection - self.train_aug = AugPipe( - AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), - K.Resize(self.patch_size), - K.RandomHorizontalFlip(), - K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=0.7), - K.RandomVerticalFlip(), - data_keys=['image', 'boxes', 'masks'], - ), - batch_size, + self.train_aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), + K.RandomHorizontalFlip(), + K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=0.7), + K.RandomVerticalFlip(), + data_keys=None, + keepdim=True, ) - self.aug = AugPipe( - AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), - K.Resize(self.patch_size), - data_keys=['image', 'boxes', 'masks'], - ), - batch_size, + self.train_aug.keepdim = True # type: ignore[attr-defined] + + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) + self.aug.keepdim = True # type: ignore[attr-defined] def setup(self, stage: str) -> None: """Set up datasets. @@ -78,6 +72,10 @@ def setup(self, stage: str) -> None: Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ + self.kwargs['transforms'] = K.AugmentationSequential( + K.Resize(self.patch_size), data_keys=None, keepdim=True + ) + self.kwargs['transforms'].keepdim = True self.dataset = VHR10(**self.kwargs) generator = torch.Generator().manual_seed(0) self.train_dataset, self.val_dataset, self.test_dataset = random_split( diff --git a/torchgeo/datasets/nasa_marine_debris.py b/torchgeo/datasets/nasa_marine_debris.py index 9c57e290406..791c3cae81c 100644 --- a/torchgeo/datasets/nasa_marine_debris.py +++ b/torchgeo/datasets/nasa_marine_debris.py @@ -105,7 +105,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: indices = w_check & h_check boxes = boxes[indices] - sample = {'image': image, 'boxes': boxes} + sample = {'image': image, 'bbox_xyxy': boxes} if self.transforms is not None: sample = self.transforms(sample) @@ -161,8 +161,10 @@ def plot( sample['image'] = sample['image'].byte() image = sample['image'] - if 'boxes' in sample and len(sample['boxes']): - image = draw_bounding_boxes(image=sample['image'], boxes=sample['boxes']) + if 'bbox_xyxy' in sample and len(sample['bbox_xyxy']): + image = draw_bounding_boxes( + image=sample['image'], boxes=sample['bbox_xyxy'] + ) image_arr = image.permute((1, 2, 0)).numpy() if 'prediction_boxes' in sample and len(sample['prediction_boxes']): diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index 9adc2f44e9e..b0f0f2de340 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -248,9 +248,9 @@ def __getitem__(self, index: int) -> dict[str, Any]: if sample['label']['annotations']: sample = self.coco_convert(sample) - sample['labels'] = sample['label']['labels'] - sample['boxes'] = sample['label']['boxes'] - sample['masks'] = sample['label']['masks'] + sample['class'] = sample['label']['labels'] + sample['bbox_xyxy'] = sample['label']['boxes'] + sample['mask'] = sample['label']['masks'].float() del sample['label'] if self.transforms is not None: @@ -400,11 +400,11 @@ def plot( if show_feats != 'boxes': skimage = lazy_import('skimage') - boxes = sample['boxes'].cpu().numpy() - labels = sample['labels'].cpu().numpy() - - if 'masks' in sample: - masks = [mask.squeeze().cpu().numpy() for mask in sample['masks']] + image = sample['image'].permute(1, 2, 0).numpy() + boxes = sample['bbox_xyxy'].cpu().numpy() + labels = sample['class'].cpu().numpy() + if 'mask' in sample: + masks = [mask.squeeze().cpu().numpy() for mask in sample['mask']] n_gt = len(boxes) diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index 3d970abdae0..b16dd22ed96 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -238,8 +238,12 @@ def training_step( """ x = batch['image'] batch_size = x.shape[0] + # Get bbox key as it can be one of {"bbox", "bbox_xyxy", "bbox_xywh"} + for key in batch.keys(): + if key in {'bbox', 'bbox_xyxy', 'bbox_xywh'}: + bbox_key = key y = [ - {'boxes': batch['boxes'][i], 'labels': batch['labels'][i]} + {'boxes': batch[bbox_key][i], 'labels': batch['class'][i]} for i in range(batch_size) ] loss_dict = self(x, y) @@ -259,8 +263,12 @@ def validation_step( """ x = batch['image'] batch_size = x.shape[0] + # Get bbox key as it can be one of {"bbox", "bbox_xyxy", "bbox_xywh"} + for key in batch.keys(): + if key in {'bbox', 'bbox_xyxy', 'bbox_xywh'}: + bbox_key = key y = [ - {'boxes': batch['boxes'][i], 'labels': batch['labels'][i]} + {'boxes': batch[bbox_key][i], 'labels': batch['class'][i]} for i in range(batch_size) ] y_hat = self(x) @@ -313,8 +321,12 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None """ x = batch['image'] batch_size = x.shape[0] + # Get bbox key as it can be one of {"bbox", "bbox_xyxy", "bbox_xywh"} + for key in batch.keys(): + if key in {'bbox', 'bbox_xyxy', 'bbox_xywh'}: + bbox_key = key y = [ - {'boxes': batch['boxes'][i], 'labels': batch['labels'][i]} + {'boxes': batch[bbox_key][i], 'labels': batch['class'][i]} for i in range(batch_size) ] y_hat = self(x) From 7014548bc975a7d55662ae7a3fe03095f2a85eed Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Thu, 4 Jul 2024 11:08:24 +0400 Subject: [PATCH 02/12] Add issue link --- torchgeo/datamodules/nasa_marine_debris.py | 1 + torchgeo/datamodules/vhr10.py | 1 + 2 files changed, 2 insertions(+) diff --git a/torchgeo/datamodules/nasa_marine_debris.py b/torchgeo/datamodules/nasa_marine_debris.py index 2fcba2236c3..add6b32b063 100644 --- a/torchgeo/datamodules/nasa_marine_debris.py +++ b/torchgeo/datamodules/nasa_marine_debris.py @@ -48,6 +48,7 @@ def __init__( self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) + # https://github.com/kornia/kornia/issues/2848 self.aug.keepdim = True # type: ignore[attr-defined] self.collate_fn = collate_fn_detection diff --git a/torchgeo/datamodules/vhr10.py b/torchgeo/datamodules/vhr10.py index 49e7b90241b..8e9d56c7f7a 100644 --- a/torchgeo/datamodules/vhr10.py +++ b/torchgeo/datamodules/vhr10.py @@ -64,6 +64,7 @@ def __init__( self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) + # https://github.com/kornia/kornia/issues/2848 self.aug.keepdim = True # type: ignore[attr-defined] def setup(self, stage: str) -> None: From e8e54b00ef5bde13713c107f2a4279b1510d2c5d Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Wed, 6 Nov 2024 21:59:52 +0400 Subject: [PATCH 03/12] Remove explicit keepdim --- torchgeo/datamodules/nasa_marine_debris.py | 2 -- torchgeo/datamodules/vhr10.py | 4 ---- 2 files changed, 6 deletions(-) diff --git a/torchgeo/datamodules/nasa_marine_debris.py b/torchgeo/datamodules/nasa_marine_debris.py index add6b32b063..43511e9d5f8 100644 --- a/torchgeo/datamodules/nasa_marine_debris.py +++ b/torchgeo/datamodules/nasa_marine_debris.py @@ -48,8 +48,6 @@ def __init__( self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) - # https://github.com/kornia/kornia/issues/2848 - self.aug.keepdim = True # type: ignore[attr-defined] self.collate_fn = collate_fn_detection diff --git a/torchgeo/datamodules/vhr10.py b/torchgeo/datamodules/vhr10.py index 8e9d56c7f7a..db4f3437667 100644 --- a/torchgeo/datamodules/vhr10.py +++ b/torchgeo/datamodules/vhr10.py @@ -59,13 +59,10 @@ def __init__( data_keys=None, keepdim=True, ) - self.train_aug.keepdim = True # type: ignore[attr-defined] self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) - # https://github.com/kornia/kornia/issues/2848 - self.aug.keepdim = True # type: ignore[attr-defined] def setup(self, stage: str) -> None: """Set up datasets. @@ -76,7 +73,6 @@ def setup(self, stage: str) -> None: self.kwargs['transforms'] = K.AugmentationSequential( K.Resize(self.patch_size), data_keys=None, keepdim=True ) - self.kwargs['transforms'].keepdim = True self.dataset = VHR10(**self.kwargs) generator = torch.Generator().manual_seed(0) self.train_dataset, self.val_dataset, self.test_dataset = random_split( From ef6404e1e25a291e352488e578f963b6b8c22d27 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Sat, 9 Nov 2024 22:24:01 +0400 Subject: [PATCH 04/12] Fix coverage --- tests/datasets/test_vhr10.py | 2 +- torchgeo/datasets/vhr10.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/datasets/test_vhr10.py b/tests/datasets/test_vhr10.py index 3d04b88f50e..95451d2866e 100644 --- a/tests/datasets/test_vhr10.py +++ b/tests/datasets/test_vhr10.py @@ -85,7 +85,7 @@ def test_plot(self, dataset: VHR10) -> None: x['prediction_labels'] = x['class'] x['prediction_boxes'] = x['bbox_xyxy'] x['prediction_scores'] = torch.Tensor([scores[i]]) - if 'masks' in x: + if 'mask' in x: x['prediction_masks'] = x['mask'] dataset.plot(x, show_feats='masks') plt.close() diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index b0f0f2de340..6f8be71852a 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -459,7 +459,7 @@ def plot( ) # Add masks - if show_feats in {'masks', 'both'} and 'masks' in sample: + if show_feats in {'masks', 'both'} and 'mask' in sample: mask = masks[i] contours = skimage.measure.find_contours(mask, 0.5) for verts in contours: From 24b9f4d6f84e6d0e6a9daf995ab3cf64c0bd9f19 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Sat, 28 Dec 2024 18:56:16 +0400 Subject: [PATCH 05/12] test transforms: Switch to kornia AugmentationSequential --- tests/transforms/test_transforms.py | 76 ++++++++++++----------------- 1 file changed, 32 insertions(+), 44 deletions(-) diff --git a/tests/transforms/test_transforms.py b/tests/transforms/test_transforms.py index 1f2071ae812..36e020bb412 100644 --- a/tests/transforms/test_transforms.py +++ b/tests/transforms/test_transforms.py @@ -23,7 +23,7 @@ def batch_gray() -> dict[str, Tensor]: return { 'image': torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float), 'mask': torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), - 'boxes': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), + 'bbox_xyxy': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), 'labels': torch.tensor([[0, 1]]), } @@ -42,7 +42,7 @@ def batch_rgb() -> dict[str, Tensor]: dtype=torch.float, ), 'mask': torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), - 'boxes': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), + 'bbox_xyxy': torch.tensor([[0.0, 1.0, 1.0, 2.0]], dtype=torch.float), 'labels': torch.tensor([[0, 1]]), } @@ -63,7 +63,7 @@ def batch_multispectral() -> dict[str, Tensor]: dtype=torch.float, ), 'mask': torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), - 'boxes': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), + 'bbox_xyxy': torch.tensor([[0.0, 1.0, 1.0, 2.0]], dtype=torch.float), 'labels': torch.tensor([[0, 1]]), } @@ -79,12 +79,10 @@ def test_augmentation_sequential_gray(batch_gray: dict[str, Tensor]) -> None: expected = { 'image': torch.tensor([[[[3, 2, 1], [6, 5, 4], [9, 8, 7]]]], dtype=torch.float), 'mask': torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), - 'boxes': torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float), + 'bbox_xyxy': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), 'labels': torch.tensor([[0, 1]]), } - augs = transforms.AugmentationSequential( - K.RandomHorizontalFlip(p=1.0), data_keys=['image', 'mask', 'boxes'] - ) + augs = K.AugmentationSequential(K.RandomHorizontalFlip(p=1.0), data_keys=None) output = augs(batch_gray) assert_matching(output, expected) @@ -102,12 +100,10 @@ def test_augmentation_sequential_rgb(batch_rgb: dict[str, Tensor]) -> None: dtype=torch.float, ), 'mask': torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), - 'boxes': torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float), + 'bbox_xyxy': torch.tensor([[1.0, 1.0, 2.0, 2.0]], dtype=torch.float), 'labels': torch.tensor([[0, 1]]), } - augs = transforms.AugmentationSequential( - K.RandomHorizontalFlip(p=1.0), data_keys=['image', 'mask', 'boxes'] - ) + augs = K.AugmentationSequential(K.RandomHorizontalFlip(p=1.0), data_keys=None) output = augs(batch_rgb) assert_matching(output, expected) @@ -119,22 +115,20 @@ def test_augmentation_sequential_multispectral( 'image': torch.tensor( [ [ - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], + [[7, 8, 9], [4, 5, 6], [1, 2, 3]], + [[7, 8, 9], [4, 5, 6], [1, 2, 3]], + [[7, 8, 9], [4, 5, 6], [1, 2, 3]], + [[7, 8, 9], [4, 5, 6], [1, 2, 3]], + [[7, 8, 9], [4, 5, 6], [1, 2, 3]], ] ], dtype=torch.float, ), - 'mask': torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), - 'boxes': torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float), + 'mask': torch.tensor([[[1, 1, 1], [0, 1, 1], [0, 0, 1]]], dtype=torch.long), + 'bbox_xyxy': torch.tensor([[0.0, 0.0, 1.0, 1.0]], dtype=torch.float), 'labels': torch.tensor([[0, 1]]), } - augs = transforms.AugmentationSequential( - K.RandomHorizontalFlip(p=1.0), data_keys=['image', 'mask', 'boxes'] - ) + augs = K.AugmentationSequential(K.RandomVerticalFlip(p=1.0), data_keys=None) output = augs(batch_multispectral) assert_matching(output, expected) @@ -142,28 +136,22 @@ def test_augmentation_sequential_multispectral( def test_augmentation_sequential_image_only( batch_multispectral: dict[str, Tensor], ) -> None: - expected = { - 'image': torch.tensor( + expected_image = torch.tensor( + [ [ - [ - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - ] - ], - dtype=torch.float, - ), - 'mask': torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), - 'boxes': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), - 'labels': torch.tensor([[0, 1]]), - } - augs = transforms.AugmentationSequential( - K.RandomHorizontalFlip(p=1.0), data_keys=['image'] + [[3, 2, 1], [6, 5, 4], [9, 8, 7]], + [[3, 2, 1], [6, 5, 4], [9, 8, 7]], + [[3, 2, 1], [6, 5, 4], [9, 8, 7]], + [[3, 2, 1], [6, 5, 4], [9, 8, 7]], + [[3, 2, 1], [6, 5, 4], [9, 8, 7]], + ] + ], + dtype=torch.float, ) - output = augs(batch_multispectral) - assert_matching(output, expected) + + augs = K.AugmentationSequential(K.RandomHorizontalFlip(p=1.0), data_keys=['image']) + aug_image = augs(batch_multispectral['image']) + assert torch.allclose(aug_image, expected_image) def test_sequential_transforms_augmentations( @@ -188,17 +176,17 @@ def test_sequential_transforms_augmentations( dtype=torch.float, ), 'mask': torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), - 'boxes': torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float), + 'bbox_xyxy': torch.tensor([[1.0, 1.0, 2.0, 2.0]], dtype=torch.float), 'labels': torch.tensor([[0, 1]]), } - train_transforms = transforms.AugmentationSequential( + train_transforms = K.AugmentationSequential( indices.AppendNBR(index_nir=0, index_swir=0), indices.AppendNDBI(index_swir=0, index_nir=0), indices.AppendNDSI(index_green=0, index_swir=0), indices.AppendNDVI(index_red=0, index_nir=0), indices.AppendNDWI(index_green=0, index_nir=0), K.RandomHorizontalFlip(p=1.0), - data_keys=['image', 'mask', 'boxes'], + data_keys=None, ) output = train_transforms(batch_multispectral) assert_matching(output, expected) From 2105789ea00d117c118f8872b737ab0d9f34a910 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Sat, 4 Jan 2025 23:50:02 +0400 Subject: [PATCH 06/12] Switch boxes key to bbox_xyxy --- tests/datamodules/test_fair1m.py | 2 +- tests/datasets/test_fair1m.py | 6 +++--- tests/datasets/test_forestdamage.py | 4 ++-- tests/datasets/test_idtrees.py | 14 +++++++------- tests/datasets/test_pastis.py | 2 +- tests/datasets/test_reforestree.py | 6 +++--- tests/datasets/test_zuericrop.py | 6 +++--- torchgeo/datamodules/fair1m.py | 4 ++-- torchgeo/datamodules/utils.py | 10 ++-------- torchgeo/datasets/fair1m.py | 6 +++--- torchgeo/datasets/forestdamage.py | 6 +++--- torchgeo/datasets/idtrees.py | 19 +++++++++++-------- torchgeo/datasets/pastis.py | 2 +- torchgeo/datasets/reforestree.py | 4 ++-- torchgeo/datasets/zuericrop.py | 13 +++++++++---- torchgeo/trainers/detection.py | 21 ++++++--------------- 16 files changed, 59 insertions(+), 66 deletions(-) diff --git a/tests/datamodules/test_fair1m.py b/tests/datamodules/test_fair1m.py index 4aa1e16d846..ef1beb7178e 100644 --- a/tests/datamodules/test_fair1m.py +++ b/tests/datamodules/test_fair1m.py @@ -35,7 +35,7 @@ def test_plot(self, datamodule: FAIR1MDataModule) -> None: batch = next(iter(datamodule.val_dataloader())) sample = { 'image': batch['image'][0], - 'boxes': batch['boxes'][0], + 'bbox_xyxy': batch['bbox_xyxy'][0], 'label': batch['label'][0], } datamodule.plot(sample) diff --git a/tests/datasets/test_fair1m.py b/tests/datasets/test_fair1m.py index 3ff3f66733f..fec6d309ab1 100644 --- a/tests/datasets/test_fair1m.py +++ b/tests/datasets/test_fair1m.py @@ -70,9 +70,9 @@ def test_getitem(self, dataset: FAIR1M) -> None: assert x['image'].shape[0] == 3 if dataset.split != 'test': - assert isinstance(x['boxes'], torch.Tensor) + assert isinstance(x['bbox_xyxy'], torch.Tensor) assert isinstance(x['label'], torch.Tensor) - assert x['boxes'].shape[-2:] == (5, 2) + assert x['bbox_xyxy'].shape[-2:] == (5, 2) assert x['label'].ndim == 1 def test_len(self, dataset: FAIR1M) -> None: @@ -124,6 +124,6 @@ def test_plot(self, dataset: FAIR1M) -> None: plt.close() if dataset.split != 'test': - x['prediction_boxes'] = x['boxes'].clone() + x['prediction_boxes'] = x['bbox_xyxy'].clone() dataset.plot(x) plt.close() diff --git a/tests/datasets/test_forestdamage.py b/tests/datasets/test_forestdamage.py index 64acb20b5b3..8e8fd7c1ab9 100644 --- a/tests/datasets/test_forestdamage.py +++ b/tests/datasets/test_forestdamage.py @@ -36,7 +36,7 @@ def test_getitem(self, dataset: ForestDamage) -> None: assert isinstance(x, dict) assert isinstance(x['image'], torch.Tensor) assert isinstance(x['label'], torch.Tensor) - assert isinstance(x['boxes'], torch.Tensor) + assert isinstance(x['bbox_xyxy'], torch.Tensor) assert x['image'].shape[0] == 3 assert x['image'].ndim == 3 @@ -67,6 +67,6 @@ def test_plot(self, dataset: ForestDamage) -> None: def test_plot_prediction(self, dataset: ForestDamage) -> None: x = dataset[0].copy() - x['prediction_boxes'] = x['boxes'].clone() + x['prediction_boxes'] = x['bbox_xyxy'].clone() dataset.plot(x, suptitle='Prediction') plt.close() diff --git a/tests/datasets/test_idtrees.py b/tests/datasets/test_idtrees.py index 4c145b4c006..5a5f8d75fee 100644 --- a/tests/datasets/test_idtrees.py +++ b/tests/datasets/test_idtrees.py @@ -57,11 +57,11 @@ def test_getitem(self, dataset: IDTReeS) -> None: if 'label' in x: assert isinstance(x['label'], torch.Tensor) - if 'boxes' in x: - assert isinstance(x['boxes'], torch.Tensor) - if x['boxes'].ndim != 1: - assert x['boxes'].ndim == 2 - assert x['boxes'].shape[-1] == 4 + if 'bbox_xyxy' in x: + assert isinstance(x['bbox_xyxy'], torch.Tensor) + if x['bbox_xyxy'].ndim != 1: + assert x['bbox_xyxy'].ndim == 2 + assert x['bbox_xyxy'].shape[-1] == 4 def test_len(self, dataset: IDTReeS) -> None: assert len(dataset) == 3 @@ -87,8 +87,8 @@ def test_plot(self, dataset: IDTReeS) -> None: dataset.plot(x, show_titles=False) plt.close() - if 'boxes' in x: - x['prediction_boxes'] = x['boxes'] + if 'bbox_xyxy' in x: + x['prediction_boxes'] = x['bbox_xyxy'] dataset.plot(x, show_titles=True) plt.close() if 'label' in x: diff --git a/tests/datasets/test_pastis.py b/tests/datasets/test_pastis.py index be327e1628d..30299a56ac5 100644 --- a/tests/datasets/test_pastis.py +++ b/tests/datasets/test_pastis.py @@ -52,7 +52,7 @@ def test_getitem_instance(self, dataset: PASTIS) -> None: assert isinstance(x, dict) assert isinstance(x['image'], torch.Tensor) assert isinstance(x['mask'], torch.Tensor) - assert isinstance(x['boxes'], torch.Tensor) + assert isinstance(x['bbox_xyxy'], torch.Tensor) assert isinstance(x['label'], torch.Tensor) def test_len(self, dataset: PASTIS) -> None: diff --git a/tests/datasets/test_reforestree.py b/tests/datasets/test_reforestree.py index c0ab375d9f6..e3cf9948d82 100644 --- a/tests/datasets/test_reforestree.py +++ b/tests/datasets/test_reforestree.py @@ -36,11 +36,11 @@ def test_getitem(self, dataset: ReforesTree) -> None: assert isinstance(x, dict) assert isinstance(x['image'], torch.Tensor) assert isinstance(x['label'], torch.Tensor) - assert isinstance(x['boxes'], torch.Tensor) + assert isinstance(x['bbox_xyxy'], torch.Tensor) assert isinstance(x['agb'], torch.Tensor) assert x['image'].shape[0] == 3 assert x['image'].ndim == 3 - assert len(x['boxes']) == 2 + assert len(x['bbox_xyxy']) == 2 def test_len(self, dataset: ReforesTree) -> None: assert len(dataset) == 2 @@ -67,6 +67,6 @@ def test_plot(self, dataset: ReforesTree) -> None: def test_plot_prediction(self, dataset: ReforesTree) -> None: x = dataset[0].copy() - x['prediction_boxes'] = x['boxes'].clone() + x['prediction_boxes'] = x['bbox_xyxy'].clone() dataset.plot(x, suptitle='Prediction') plt.close() diff --git a/tests/datasets/test_zuericrop.py b/tests/datasets/test_zuericrop.py index 6d4cdc8844c..14418bb542e 100644 --- a/tests/datasets/test_zuericrop.py +++ b/tests/datasets/test_zuericrop.py @@ -29,7 +29,7 @@ def test_getitem(self, dataset: ZueriCrop) -> None: assert isinstance(x, dict) assert isinstance(x['image'], torch.Tensor) assert isinstance(x['mask'], torch.Tensor) - assert isinstance(x['boxes'], torch.Tensor) + assert isinstance(x['bbox_xyxy'], torch.Tensor) assert isinstance(x['label'], torch.Tensor) # Image tests @@ -40,8 +40,8 @@ def test_getitem(self, dataset: ZueriCrop) -> None: assert x['mask'].shape[-2:] == x['image'].shape[-2:] # Bboxes tests - assert x['boxes'].ndim == 2 - assert x['boxes'].shape[1] == 4 + assert x['bbox_xyxy'].ndim == 2 + assert x['bbox_xyxy'].shape[1] == 4 # Labels tests assert x['label'].ndim == 1 diff --git a/torchgeo/datamodules/fair1m.py b/torchgeo/datamodules/fair1m.py index 291dd617e04..6eda02dab9b 100644 --- a/torchgeo/datamodules/fair1m.py +++ b/torchgeo/datamodules/fair1m.py @@ -26,8 +26,8 @@ def collate_fn(batch: list[dict[str, Tensor]]) -> dict[str, Any]: output: dict[str, Any] = {} output['image'] = torch.stack([sample['image'] for sample in batch]) - if 'boxes' in batch[0]: - output['boxes'] = [sample['boxes'] for sample in batch] + if 'bbox_xyxy' in batch[0]: + output['bbox_xyxy'] = [sample['bbox_xyxy'] for sample in batch] if 'label' in batch[0]: output['label'] = [sample['label'] for sample in batch] diff --git a/torchgeo/datamodules/utils.py b/torchgeo/datamodules/utils.py index 5ae7127ab68..c420c4811e6 100644 --- a/torchgeo/datamodules/utils.py +++ b/torchgeo/datamodules/utils.py @@ -30,18 +30,12 @@ def collate_fn_detection(batch: list[dict[str, Tensor]]) -> dict[str, Any]: """ output: dict[str, Any] = {} output['image'] = torch.stack([sample['image'] for sample in batch]) - # Get bbox key as it can be one of {"bbox", "bbox_xyxy", "bbox_xywh"} - bbox_key = 'boxes' - for key in batch[0].keys(): - if key in {'bbox', 'bbox_xyxy', 'bbox_xywh'}: - bbox_key = key - - output[bbox_key] = [sample[bbox_key].float() for sample in batch] + output['bbox_xyxy'] = [sample['bbox_xyxy'].float() for sample in batch] if 'class' in batch[0].keys(): output['class'] = [sample['class'] for sample in batch] else: output['class'] = [ - torch.tensor([1] * len(sample[bbox_key])) for sample in batch + torch.tensor([1] * len(sample['bbox_xyxy'])) for sample in batch ] if 'mask' in batch[0]: diff --git a/torchgeo/datasets/fair1m.py b/torchgeo/datasets/fair1m.py index d58968eaa19..d01267d9d85 100644 --- a/torchgeo/datasets/fair1m.py +++ b/torchgeo/datasets/fair1m.py @@ -283,7 +283,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: label_path = label_path.replace('.tif', '.xml') voc = parse_pascal_voc(label_path) boxes, labels = self._load_target(voc['points'], voc['labels']) - sample = {'image': image, 'boxes': boxes, 'label': labels} + sample = {'image': image, 'bbox_xyxy': boxes, 'label': labels} if self.transforms is not None: sample = self.transforms(sample) @@ -411,10 +411,10 @@ def plot( axs[0].imshow(image) axs[0].axis('off') - if 'boxes' in sample: + if 'bbox_xyxy' in sample: polygons = [ patches.Polygon(points, color='r', fill=False) - for points in sample['boxes'].numpy() + for points in sample['bbox_xyxy'].numpy() ] for polygon in polygons: axs[0].add_patch(polygon) diff --git a/torchgeo/datasets/forestdamage.py b/torchgeo/datasets/forestdamage.py index 9c3de28a2b5..b7ea83504aa 100644 --- a/torchgeo/datasets/forestdamage.py +++ b/torchgeo/datasets/forestdamage.py @@ -146,7 +146,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: boxes, labels = self._load_target(parsed['bboxes'], parsed['labels']) - sample = {'image': image, 'boxes': boxes, 'label': labels} + sample = {'image': image, 'bbox_xyxy': boxes, 'label': labels} if self.transforms is not None: sample = self.transforms(sample) @@ -222,7 +222,7 @@ def _verify(self) -> None: if os.path.isdir(filepath): return - filepath = os.path.join(self.root, self.data_dir + '.zip') + filepath = os.path.join(self.root, f'{self.data_dir}.zip') if os.path.isfile(filepath): if self.checksum and not check_integrity(filepath, self.md5): raise RuntimeError('Dataset found, but corrupted.') @@ -284,7 +284,7 @@ def plot( edgecolor='r', facecolor='none', ) - for bbox in sample['boxes'].numpy() + for bbox in sample['bbox_xyxy'].numpy() ] for bbox in bboxes: axs[0].add_patch(bbox) diff --git a/torchgeo/datasets/idtrees.py b/torchgeo/datasets/idtrees.py index 2c5777ba350..343785f3a31 100644 --- a/torchgeo/datasets/idtrees.py +++ b/torchgeo/datasets/idtrees.py @@ -212,20 +212,23 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: if self.split == 'test': if self.task == 'task2': - sample['boxes'] = self._load_boxes(path) + sample['bbox_xyxy'] = self._load_boxes(path) h, w = sample['image'].shape[1:] - sample['boxes'], _ = self._filter_boxes( - image_size=(h, w), min_size=1, boxes=sample['boxes'], labels=None + sample['bbox_xyxy'], _ = self._filter_boxes( + image_size=(h, w), + min_size=1, + boxes=sample['bbox_xyxy'], + labels=None, ) else: - sample['boxes'] = self._load_boxes(path) + sample['bbox_xyxy'] = self._load_boxes(path) sample['label'] = self._load_target(path) h, w = sample['image'].shape[1:] - sample['boxes'], sample['label'] = self._filter_boxes( + sample['bbox_xyxy'], sample['label'] = self._filter_boxes( image_size=(h, w), min_size=1, - boxes=sample['boxes'], + boxes=sample['bbox_xyxy'], labels=sample['label'], ) @@ -504,14 +507,14 @@ def normalize(x: Tensor) -> Tensor: hsi = normalize(sample['hsi'][hsi_indices, :, :]).permute((1, 2, 0)).numpy() chm = normalize(sample['chm']).permute((1, 2, 0)).numpy() - if 'boxes' in sample and len(sample['boxes']): + if 'bbox_xyxy' in sample and len(sample['bbox_xyxy']): labels = ( [self.idx2class[int(i)] for i in sample['label']] if 'label' in sample else None ) image = draw_bounding_boxes( - image=sample['image'], boxes=sample['boxes'], labels=labels + image=sample['image'], boxes=sample['bbox_xyxy'], labels=labels ) image = image.permute((1, 2, 0)).numpy() else: diff --git a/torchgeo/datasets/pastis.py b/torchgeo/datasets/pastis.py index 06f716a9ffb..fd6628dc25d 100644 --- a/torchgeo/datasets/pastis.py +++ b/torchgeo/datasets/pastis.py @@ -194,7 +194,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: sample = {'image': image, 'mask': mask} elif self.mode == 'instance': mask, boxes, labels = self._load_instance_targets(index) - sample = {'image': image, 'mask': mask, 'boxes': boxes, 'label': labels} + sample = {'image': image, 'mask': mask, 'bbox_xyxy': boxes, 'label': labels} if self.transforms is not None: sample = self.transforms(sample) diff --git a/torchgeo/datasets/reforestree.py b/torchgeo/datasets/reforestree.py index 1c46c450191..4caced4d477 100644 --- a/torchgeo/datasets/reforestree.py +++ b/torchgeo/datasets/reforestree.py @@ -109,7 +109,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: boxes, labels, agb = self._load_target(filepath) - sample = {'image': image, 'boxes': boxes, 'label': labels, 'agb': agb} + sample = {'image': image, 'bbox_xyxy': boxes, 'label': labels, 'agb': agb} if self.transforms is not None: sample = self.transforms(sample) @@ -239,7 +239,7 @@ def plot( edgecolor='r', facecolor='none', ) - for bbox in sample['boxes'].numpy() + for bbox in sample['bbox_xyxy'].numpy() ] for bbox in bboxes: axs[0].add_patch(bbox) diff --git a/torchgeo/datasets/zuericrop.py b/torchgeo/datasets/zuericrop.py index 2c9296e05a2..7b6988df6ec 100644 --- a/torchgeo/datasets/zuericrop.py +++ b/torchgeo/datasets/zuericrop.py @@ -11,9 +11,14 @@ from matplotlib.figure import Figure from torch import Tensor -from .errors import DatasetNotFoundError, RGBBandsMissingError -from .geo import NonGeoDataset -from .utils import Path, download_url, lazy_import, percentile_normalization +from torchgeo.datasets.errors import DatasetNotFoundError, RGBBandsMissingError +from torchgeo.datasets.geo import NonGeoDataset +from torchgeo.datasets.utils import ( + Path, + download_url, + lazy_import, + percentile_normalization, +) class ZueriCrop(NonGeoDataset): @@ -109,7 +114,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: image = self._load_image(index) mask, boxes, label = self._load_target(index) - sample = {'image': image, 'mask': mask, 'boxes': boxes, 'label': label} + sample = {'image': image, 'mask': mask, 'bbox_xyxy': boxes, 'label': label} if self.transforms is not None: sample = self.transforms(sample) diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index b16dd22ed96..af052fbe411 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -238,12 +238,9 @@ def training_step( """ x = batch['image'] batch_size = x.shape[0] - # Get bbox key as it can be one of {"bbox", "bbox_xyxy", "bbox_xywh"} - for key in batch.keys(): - if key in {'bbox', 'bbox_xyxy', 'bbox_xywh'}: - bbox_key = key + assert 'bbox_xyxy' in batch, 'bbox_xyxy is required for object detection.' y = [ - {'boxes': batch[bbox_key][i], 'labels': batch['class'][i]} + {'boxes': batch['bbox_xyxy'][i], 'labels': batch['class'][i]} for i in range(batch_size) ] loss_dict = self(x, y) @@ -263,12 +260,9 @@ def validation_step( """ x = batch['image'] batch_size = x.shape[0] - # Get bbox key as it can be one of {"bbox", "bbox_xyxy", "bbox_xywh"} - for key in batch.keys(): - if key in {'bbox', 'bbox_xyxy', 'bbox_xywh'}: - bbox_key = key + assert 'bbox_xyxy' in batch, 'bbox_xyxy is required for object detection.' y = [ - {'boxes': batch[bbox_key][i], 'labels': batch['class'][i]} + {'boxes': batch['bbox_xyxy'][i], 'labels': batch['class'][i]} for i in range(batch_size) ] y_hat = self(x) @@ -321,12 +315,9 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None """ x = batch['image'] batch_size = x.shape[0] - # Get bbox key as it can be one of {"bbox", "bbox_xyxy", "bbox_xywh"} - for key in batch.keys(): - if key in {'bbox', 'bbox_xyxy', 'bbox_xywh'}: - bbox_key = key + assert 'bbox_xyxy' in batch, 'bbox_xyxy is required for object detection.' y = [ - {'boxes': batch[bbox_key][i], 'labels': batch['class'][i]} + {'boxes': batch['bbox_xyxy'][i], 'labels': batch['class'][i]} for i in range(batch_size) ] y_hat = self(x) From 32ffe01de745eb3f1f021a271f029031e01d8593 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Mon, 6 Jan 2025 14:48:06 +0400 Subject: [PATCH 07/12] Switch class -> label --- tests/datasets/test_vhr10.py | 4 ++-- torchgeo/datamodules/utils.py | 6 +++--- torchgeo/datasets/vhr10.py | 4 ++-- torchgeo/trainers/detection.py | 6 +++--- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/datasets/test_vhr10.py b/tests/datasets/test_vhr10.py index 95451d2866e..7ac91c31b20 100644 --- a/tests/datasets/test_vhr10.py +++ b/tests/datasets/test_vhr10.py @@ -41,7 +41,7 @@ def test_getitem(self, dataset: VHR10) -> None: assert isinstance(x, dict) assert isinstance(x['image'], torch.Tensor) if dataset.split == 'positive': - assert isinstance(x['class'], torch.Tensor) + assert isinstance(x['label'], torch.Tensor) assert isinstance(x['bbox_xyxy'], torch.Tensor) if 'mask' in x: assert isinstance(x['mask'], torch.Tensor) @@ -82,7 +82,7 @@ def test_plot(self, dataset: VHR10) -> None: scores = [0.7, 0.3, 0.7] for i in range(3): x = dataset[i] - x['prediction_labels'] = x['class'] + x['prediction_labels'] = x['label'] x['prediction_boxes'] = x['bbox_xyxy'] x['prediction_scores'] = torch.Tensor([scores[i]]) if 'mask' in x: diff --git a/torchgeo/datamodules/utils.py b/torchgeo/datamodules/utils.py index c420c4811e6..afcef5e9ae1 100644 --- a/torchgeo/datamodules/utils.py +++ b/torchgeo/datamodules/utils.py @@ -31,10 +31,10 @@ def collate_fn_detection(batch: list[dict[str, Tensor]]) -> dict[str, Any]: output: dict[str, Any] = {} output['image'] = torch.stack([sample['image'] for sample in batch]) output['bbox_xyxy'] = [sample['bbox_xyxy'].float() for sample in batch] - if 'class' in batch[0].keys(): - output['class'] = [sample['class'] for sample in batch] + if 'label' in batch[0].keys(): + output['label'] = [sample['label'] for sample in batch] else: - output['class'] = [ + output['label'] = [ torch.tensor([1] * len(sample['bbox_xyxy'])) for sample in batch ] diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index 6f8be71852a..f0a73723e80 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -251,7 +251,7 @@ def __getitem__(self, index: int) -> dict[str, Any]: sample['class'] = sample['label']['labels'] sample['bbox_xyxy'] = sample['label']['boxes'] sample['mask'] = sample['label']['masks'].float() - del sample['label'] + sample['label'] = sample.pop('class') if self.transforms is not None: sample = self.transforms(sample) @@ -402,7 +402,7 @@ def plot( image = sample['image'].permute(1, 2, 0).numpy() boxes = sample['bbox_xyxy'].cpu().numpy() - labels = sample['class'].cpu().numpy() + labels = sample['label'].cpu().numpy() if 'mask' in sample: masks = [mask.squeeze().cpu().numpy() for mask in sample['mask']] diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index af052fbe411..290b6917b04 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -240,7 +240,7 @@ def training_step( batch_size = x.shape[0] assert 'bbox_xyxy' in batch, 'bbox_xyxy is required for object detection.' y = [ - {'boxes': batch['bbox_xyxy'][i], 'labels': batch['class'][i]} + {'boxes': batch['bbox_xyxy'][i], 'labels': batch['label'][i]} for i in range(batch_size) ] loss_dict = self(x, y) @@ -262,7 +262,7 @@ def validation_step( batch_size = x.shape[0] assert 'bbox_xyxy' in batch, 'bbox_xyxy is required for object detection.' y = [ - {'boxes': batch['bbox_xyxy'][i], 'labels': batch['class'][i]} + {'boxes': batch['bbox_xyxy'][i], 'labels': batch['label'][i]} for i in range(batch_size) ] y_hat = self(x) @@ -317,7 +317,7 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None batch_size = x.shape[0] assert 'bbox_xyxy' in batch, 'bbox_xyxy is required for object detection.' y = [ - {'boxes': batch['bbox_xyxy'][i], 'labels': batch['class'][i]} + {'boxes': batch['bbox_xyxy'][i], 'labels': batch['label'][i]} for i in range(batch_size) ] y_hat = self(x) From 65ece96bf54137d079d63bf2f42a75fef267ddcc Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Mon, 6 Jan 2025 15:07:10 +0400 Subject: [PATCH 08/12] Switch prediction_boxes -> prediction_bbox_xyxy --- tests/datasets/test_fair1m.py | 2 +- tests/datasets/test_forestdamage.py | 2 +- tests/datasets/test_idtrees.py | 2 +- tests/datasets/test_nasa_marine_debris.py | 2 +- tests/datasets/test_reforestree.py | 2 +- tests/datasets/test_vhr10.py | 2 +- torchgeo/datasets/fair1m.py | 4 ++-- torchgeo/datasets/forestdamage.py | 4 ++-- torchgeo/datasets/idtrees.py | 4 ++-- torchgeo/datasets/nasa_marine_debris.py | 4 ++-- torchgeo/datasets/reforestree.py | 4 ++-- torchgeo/datasets/vhr10.py | 6 +++--- torchgeo/trainers/detection.py | 2 +- 13 files changed, 20 insertions(+), 20 deletions(-) diff --git a/tests/datasets/test_fair1m.py b/tests/datasets/test_fair1m.py index fec6d309ab1..ac0a3048bb0 100644 --- a/tests/datasets/test_fair1m.py +++ b/tests/datasets/test_fair1m.py @@ -124,6 +124,6 @@ def test_plot(self, dataset: FAIR1M) -> None: plt.close() if dataset.split != 'test': - x['prediction_boxes'] = x['bbox_xyxy'].clone() + x['prediction_bbox_xyxy'] = x['bbox_xyxy'].clone() dataset.plot(x) plt.close() diff --git a/tests/datasets/test_forestdamage.py b/tests/datasets/test_forestdamage.py index 8e8fd7c1ab9..7ae2ebb9eba 100644 --- a/tests/datasets/test_forestdamage.py +++ b/tests/datasets/test_forestdamage.py @@ -67,6 +67,6 @@ def test_plot(self, dataset: ForestDamage) -> None: def test_plot_prediction(self, dataset: ForestDamage) -> None: x = dataset[0].copy() - x['prediction_boxes'] = x['bbox_xyxy'].clone() + x['prediction_bbox_xyxy'] = x['bbox_xyxy'].clone() dataset.plot(x, suptitle='Prediction') plt.close() diff --git a/tests/datasets/test_idtrees.py b/tests/datasets/test_idtrees.py index 5a5f8d75fee..e92945da3e3 100644 --- a/tests/datasets/test_idtrees.py +++ b/tests/datasets/test_idtrees.py @@ -88,7 +88,7 @@ def test_plot(self, dataset: IDTReeS) -> None: plt.close() if 'bbox_xyxy' in x: - x['prediction_boxes'] = x['bbox_xyxy'] + x['prediction_bbox_xyxy'] = x['bbox_xyxy'] dataset.plot(x, show_titles=True) plt.close() if 'label' in x: diff --git a/tests/datasets/test_nasa_marine_debris.py b/tests/datasets/test_nasa_marine_debris.py index e2544951c87..d5f94a20816 100644 --- a/tests/datasets/test_nasa_marine_debris.py +++ b/tests/datasets/test_nasa_marine_debris.py @@ -50,6 +50,6 @@ def test_plot(self, dataset: NASAMarineDebris) -> None: plt.close() dataset.plot(x, show_titles=False) plt.close() - x['prediction_boxes'] = x['bbox_xyxy'].clone() + x['prediction_bbox_xyxy'] = x['bbox_xyxy'].clone() dataset.plot(x) plt.close() diff --git a/tests/datasets/test_reforestree.py b/tests/datasets/test_reforestree.py index e3cf9948d82..e7c1079b0a3 100644 --- a/tests/datasets/test_reforestree.py +++ b/tests/datasets/test_reforestree.py @@ -67,6 +67,6 @@ def test_plot(self, dataset: ReforesTree) -> None: def test_plot_prediction(self, dataset: ReforesTree) -> None: x = dataset[0].copy() - x['prediction_boxes'] = x['bbox_xyxy'].clone() + x['prediction_bbox_xyxy'] = x['bbox_xyxy'].clone() dataset.plot(x, suptitle='Prediction') plt.close() diff --git a/tests/datasets/test_vhr10.py b/tests/datasets/test_vhr10.py index 7ac91c31b20..1d3b8eac440 100644 --- a/tests/datasets/test_vhr10.py +++ b/tests/datasets/test_vhr10.py @@ -83,7 +83,7 @@ def test_plot(self, dataset: VHR10) -> None: for i in range(3): x = dataset[i] x['prediction_labels'] = x['label'] - x['prediction_boxes'] = x['bbox_xyxy'] + x['prediction_bbox_xyxy'] = x['bbox_xyxy'] x['prediction_scores'] = torch.Tensor([scores[i]]) if 'mask' in x: x['prediction_masks'] = x['mask'] diff --git a/torchgeo/datasets/fair1m.py b/torchgeo/datasets/fair1m.py index d01267d9d85..d8cb5b56d05 100644 --- a/torchgeo/datasets/fair1m.py +++ b/torchgeo/datasets/fair1m.py @@ -401,7 +401,7 @@ def plot( image = sample['image'].permute((1, 2, 0)).numpy() ncols = 1 - if 'prediction_boxes' in sample: + if 'prediction_bbox_xyxy' in sample: ncols += 1 fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) @@ -427,7 +427,7 @@ def plot( axs[1].axis('off') polygons = [ patches.Polygon(points, color='r', fill=False) - for points in sample['prediction_boxes'].numpy() + for points in sample['prediction_bbox_xyxy'].numpy() ] for polygon in polygons: axs[0].add_patch(polygon) diff --git a/torchgeo/datasets/forestdamage.py b/torchgeo/datasets/forestdamage.py index b7ea83504aa..d4f8e487b81 100644 --- a/torchgeo/datasets/forestdamage.py +++ b/torchgeo/datasets/forestdamage.py @@ -264,7 +264,7 @@ def plot( image = sample['image'].permute((1, 2, 0)).numpy() ncols = 1 - showing_predictions = 'prediction_boxes' in sample + showing_predictions = 'prediction_bbox_xyxy' in sample if showing_predictions: ncols += 1 @@ -305,7 +305,7 @@ def plot( edgecolor='r', facecolor='none', ) - for bbox in sample['prediction_boxes'].numpy() + for bbox in sample['prediction_bbox_xyxy'].numpy() ] for bbox in pred_bboxes: axs[1].add_patch(bbox) diff --git a/torchgeo/datasets/idtrees.py b/torchgeo/datasets/idtrees.py index 343785f3a31..62c196aa087 100644 --- a/torchgeo/datasets/idtrees.py +++ b/torchgeo/datasets/idtrees.py @@ -520,7 +520,7 @@ def normalize(x: Tensor) -> Tensor: else: image = sample['image'].permute((1, 2, 0)).numpy() - if 'prediction_boxes' in sample and len(sample['prediction_boxes']): + if 'prediction_bbox_xyxy' in sample and len(sample['prediction_bbox_xyxy']): ncols += 1 labels = ( [self.idx2class[int(i)] for i in sample['prediction_label']] @@ -528,7 +528,7 @@ def normalize(x: Tensor) -> Tensor: else None ) preds = draw_bounding_boxes( - image=sample['image'], boxes=sample['prediction_boxes'], labels=labels + image=sample['image'], boxes=sample['prediction_bbox_xyxy'], labels=labels ) preds = preds.permute((1, 2, 0)).numpy() diff --git a/torchgeo/datasets/nasa_marine_debris.py b/torchgeo/datasets/nasa_marine_debris.py index 791c3cae81c..62f8df37fc2 100644 --- a/torchgeo/datasets/nasa_marine_debris.py +++ b/torchgeo/datasets/nasa_marine_debris.py @@ -167,10 +167,10 @@ def plot( ) image_arr = image.permute((1, 2, 0)).numpy() - if 'prediction_boxes' in sample and len(sample['prediction_boxes']): + if 'prediction_bbox_xyxy' in sample and len(sample['prediction_bbox_xyxy']): ncols += 1 preds = draw_bounding_boxes( - image=sample['image'], boxes=sample['prediction_boxes'] + image=sample['image'], boxes=sample['prediction_bbox_xyxy'] ) preds_arr = preds.permute((1, 2, 0)).numpy() diff --git a/torchgeo/datasets/reforestree.py b/torchgeo/datasets/reforestree.py index 4caced4d477..0973442791a 100644 --- a/torchgeo/datasets/reforestree.py +++ b/torchgeo/datasets/reforestree.py @@ -219,7 +219,7 @@ def plot( """ image = sample['image'].permute((1, 2, 0)).numpy() ncols = 1 - showing_predictions = 'prediction_boxes' in sample + showing_predictions = 'prediction_bbox_xyxy' in sample if showing_predictions: ncols += 1 @@ -260,7 +260,7 @@ def plot( edgecolor='r', facecolor='none', ) - for bbox in sample['prediction_boxes'].numpy() + for bbox in sample['prediction_bbox_xyxy'].numpy() ] for bbox in pred_bboxes: axs[1].add_patch(bbox) diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index f0a73723e80..5628bcd16d6 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -416,8 +416,8 @@ def plot( show_pred_masks = False prediction_labels = sample['prediction_labels'].numpy() prediction_scores = sample['prediction_scores'].numpy() - if 'prediction_boxes' in sample: - prediction_boxes = sample['prediction_boxes'].numpy() + if 'prediction_bbox_xyxy' in sample: + prediction_bbox_xyxy = sample['prediction_bbox_xyxy'].numpy() show_pred_boxes = True if 'prediction_masks' in sample: prediction_masks = sample['prediction_masks'].numpy() @@ -485,7 +485,7 @@ def plot( if show_pred_boxes: # Add bounding boxes - x1, y1, x2, y2 = prediction_boxes[i] + x1, y1, x2, y2 = prediction_bbox_xyxy[i] r = patches.Rectangle( (x1, y1), x2 - x1, diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index 290b6917b04..f7f73542933 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -282,7 +282,7 @@ def validation_step( and hasattr(self.logger.experiment, 'add_figure') ): datamodule = self.trainer.datamodule - batch['prediction_boxes'] = [b['boxes'].cpu() for b in y_hat] + batch['prediction_bbox_xyxy'] = [b['boxes'].cpu() for b in y_hat] batch['prediction_labels'] = [b['labels'].cpu() for b in y_hat] batch['prediction_scores'] = [b['scores'].cpu() for b in y_hat] batch['image'] = batch['image'].cpu() From 8a716eb3b982c41aaf5eb45b1ea8267e1a15cc01 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Mon, 6 Jan 2025 15:58:40 +0400 Subject: [PATCH 09/12] Switch to relative imports --- torchgeo/datasets/idtrees.py | 4 +++- torchgeo/datasets/zuericrop.py | 11 +++-------- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/torchgeo/datasets/idtrees.py b/torchgeo/datasets/idtrees.py index 62c196aa087..fd24288c85a 100644 --- a/torchgeo/datasets/idtrees.py +++ b/torchgeo/datasets/idtrees.py @@ -528,7 +528,9 @@ def normalize(x: Tensor) -> Tensor: else None ) preds = draw_bounding_boxes( - image=sample['image'], boxes=sample['prediction_bbox_xyxy'], labels=labels + image=sample['image'], + boxes=sample['prediction_bbox_xyxy'], + labels=labels, ) preds = preds.permute((1, 2, 0)).numpy() diff --git a/torchgeo/datasets/zuericrop.py b/torchgeo/datasets/zuericrop.py index 7b6988df6ec..400f0d52b83 100644 --- a/torchgeo/datasets/zuericrop.py +++ b/torchgeo/datasets/zuericrop.py @@ -11,14 +11,9 @@ from matplotlib.figure import Figure from torch import Tensor -from torchgeo.datasets.errors import DatasetNotFoundError, RGBBandsMissingError -from torchgeo.datasets.geo import NonGeoDataset -from torchgeo.datasets.utils import ( - Path, - download_url, - lazy_import, - percentile_normalization, -) +from .errors import DatasetNotFoundError, RGBBandsMissingError +from .geo import NonGeoDataset +from .utils import Path, download_url, lazy_import, percentile_normalization class ZueriCrop(NonGeoDataset): From d588af7b83bea122cc4ee4fba08a6c9698e1ac91 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Wed, 5 Feb 2025 21:52:03 +0400 Subject: [PATCH 10/12] Switch to kornia AugmentationSequential - for real --- .pre-commit-config.yaml | 1 + tests/transforms/test_transforms.py | 28 +++++++++++++--------------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e1bd1a15d67..f49379ad943 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,6 +29,7 @@ repos: - pillow>=10.4.0 - pytest>=6.1.2 - scikit-image>=0.22.0 + - timm>=1.0.14 - torch>=2.6 - torchmetrics>=0.10 - torchvision>=0.18 diff --git a/tests/transforms/test_transforms.py b/tests/transforms/test_transforms.py index 36e020bb412..8b47765ffa5 100644 --- a/tests/transforms/test_transforms.py +++ b/tests/transforms/test_transforms.py @@ -6,7 +6,7 @@ import torch from torch import Tensor -from torchgeo.transforms import indices, transforms +from torchgeo.transforms import indices from torchgeo.transforms.transforms import _ExtractPatches # Kornia is very particular about its boxes: @@ -203,12 +203,12 @@ def test_extract_patches() -> None: 'image': torch.randn(size=(b, c, h, w)), 'mask': torch.randint(low=0, high=2, size=(b, h, w)), } - train_transforms = transforms.AugmentationSequential( - _ExtractPatches(window_size=p), same_on_batch=True, data_keys=['image', 'mask'] + train_transforms = K.AugmentationSequential( + _ExtractPatches(window_size=p), same_on_batch=True, data_keys=None ) output = train_transforms(batch) - assert batch['image'].shape == (b * num_patches, c, p, p) - assert batch['mask'].shape == (b * num_patches, p, p) + assert output['image'].shape == (b * num_patches, c, p, p) + assert output['mask'].shape == (b * num_patches, 1, p, p) # Test different stride s = 16 @@ -217,14 +217,12 @@ def test_extract_patches() -> None: 'image': torch.randn(size=(b, c, h, w)), 'mask': torch.randint(low=0, high=2, size=(b, h, w)), } - train_transforms = transforms.AugmentationSequential( - _ExtractPatches(window_size=p, stride=s), - same_on_batch=True, - data_keys=['image', 'mask'], + train_transforms = K.AugmentationSequential( + _ExtractPatches(window_size=p, stride=s), same_on_batch=True, data_keys=None ) output = train_transforms(batch) - assert batch['image'].shape == (b * num_patches, c, p, p) - assert batch['mask'].shape == (b * num_patches, p, p) + assert output['image'].shape == (b * num_patches, c, p, p) + assert output['mask'].shape == (b * num_patches, 1, p, p) # Test keepdim=False s = p @@ -233,13 +231,13 @@ def test_extract_patches() -> None: 'image': torch.randn(size=(b, c, h, w)), 'mask': torch.randint(low=0, high=2, size=(b, h, w)), } - train_transforms = transforms.AugmentationSequential( + train_transforms = K.AugmentationSequential( _ExtractPatches(window_size=p, stride=s, keepdim=False), same_on_batch=True, - data_keys=['image', 'mask'], + data_keys=None, ) output = train_transforms(batch) for k, v in output.items(): print(k, v.shape, v.dtype) - assert batch['image'].shape == (b, num_patches, c, p, p) - assert batch['mask'].shape == (b, num_patches, 1, p, p) + assert output['image'].shape == (b, num_patches, c, p, p) + assert output['mask'].shape == (b, num_patches, 1, p, p) From 90465977ab54d03bc4ad3d68388076bd0e2eb3d6 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Thu, 6 Feb 2025 13:10:57 +0400 Subject: [PATCH 11/12] Remove AugmentationSequential --- torchgeo/transforms/transforms.py | 99 +------------------------------ 1 file changed, 3 insertions(+), 96 deletions(-) diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index d8f80bdcaac..60e35aa4a88 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -8,106 +8,13 @@ import kornia.augmentation as K import torch from einops import rearrange +from kornia.augmentation import AugmentationSequential from kornia.contrib import extract_tensor_patches from kornia.geometry import crop_by_indices -from kornia.geometry.boxes import Boxes from torch import Tensor -from torch.nn.modules import Module - -# TODO: contribute these to Kornia and delete this file -class AugmentationSequential(Module): - """Wrapper around kornia AugmentationSequential to handle input dicts. - - .. deprecated:: 0.4 - Use :class:`kornia.augmentation.container.AugmentationSequential` instead. - """ - - def __init__( - self, - *args: K.base._AugmentationBase | K.ImageSequential, - data_keys: list[str], - **kwargs: Any, - ) -> None: - """Initialize a new augmentation sequential instance. - - Args: - *args: Sequence of kornia augmentations - data_keys: List of inputs to augment (e.g., ["image", "mask", "boxes"]) - **kwargs: Keyword arguments passed to ``K.AugmentationSequential`` - - .. versionadded:: 0.5 - The ``**kwargs`` parameter. - """ - super().__init__() - self.data_keys = data_keys - - keys: list[str] = [] - for key in data_keys: - if key.startswith('image'): - keys.append('input') - elif key == 'boxes': - keys.append('bbox') - elif key == 'masks': - keys.append('mask') - else: - keys.append(key) - - self.augs = K.AugmentationSequential(*args, data_keys=keys, **kwargs) - - def forward(self, batch: dict[str, Any]) -> dict[str, Any]: - """Perform augmentations and update data dict. - - Args: - batch: the input - - Returns: - the augmented input - """ - # Kornia augmentations require all inputs to be float - dtype = {} - for key in self.data_keys: - dtype[key] = batch[key].dtype - batch[key] = batch[key].float() - - # Convert shape of boxes from [N, 4] to [N, 4, 2] - if 'boxes' in batch and ( - isinstance(batch['boxes'], list) or batch['boxes'].ndim == 2 - ): - batch['boxes'] = Boxes.from_tensor(batch['boxes']).data - - # Kornia requires masks to have a channel dimension - if 'mask' in batch and batch['mask'].ndim == 3: - batch['mask'] = rearrange(batch['mask'], 'b h w -> b () h w') - - if 'masks' in batch and batch['masks'].ndim == 3: - batch['masks'] = rearrange(batch['masks'], 'c h w -> () c h w') - - inputs = [batch[k] for k in self.data_keys] - outputs_list: Tensor | list[Tensor] = self.augs(*inputs) - outputs_list = ( - outputs_list if isinstance(outputs_list, list) else [outputs_list] - ) - outputs: dict[str, Tensor] = { - k: v for k, v in zip(self.data_keys, outputs_list) - } - batch.update(outputs) - - # Convert all inputs back to their previous dtype - for key in self.data_keys: - batch[key] = batch[key].to(dtype[key]) - - # Convert boxes to default [N, 4] - if 'boxes' in batch: - batch['boxes'] = Boxes(batch['boxes']).to_tensor(mode='xyxy') - - # Torchmetrics does not support masks with a channel dimension - if 'mask' in batch and batch['mask'].shape[1] == 1: - batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w') - if 'masks' in batch and batch['masks'].ndim == 4: - batch['masks'] = rearrange(batch['masks'], '() c h w -> c h w') - - return batch +# Only include import redirects +__all__ = ('AugmentationSequential',) class _RandomNCrop(K.GeometricAugmentationBase2D): From 2fa70b66b65c53975105ab7c359594da8fac1880 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Thu, 6 Feb 2025 16:53:28 +0400 Subject: [PATCH 12/12] Exclude AugmentationSequential --- docs/api/transforms.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/api/transforms.rst b/docs/api/transforms.rst index dc2ab4e1d27..568264f0998 100644 --- a/docs/api/transforms.rst +++ b/docs/api/transforms.rst @@ -2,3 +2,4 @@ torchgeo.transforms =================== .. automodule:: torchgeo.transforms + :exclude-members: AugmentationSequential