From 33e7070194dbbdb57805ef0473321b4f65eae927 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 24 Aug 2024 13:26:51 +0200 Subject: [PATCH] More fixes --- tests/datasets/test_geo.py | 5 +- tests/datasets/test_splits.py | 4 +- tests/datasets/test_utils.py | 18 +++---- tests/samplers/test_batch.py | 4 +- tests/samplers/test_single.py | 4 +- torchgeo/datasets/enviroatlas.py | 26 ++++++---- torchgeo/datasets/eurocrops.py | 6 +-- torchgeo/datasets/fair1m.py | 2 +- torchgeo/datasets/gid15.py | 2 +- torchgeo/datasets/skippd.py | 8 +-- torchgeo/datasets/sustainbench_crop_yield.py | 4 +- torchgeo/datasets/utils.py | 54 ++++++++++++++++++-- 12 files changed, 93 insertions(+), 44 deletions(-) diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index 8ff30531fee..77dcd0f8d9e 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -26,6 +26,7 @@ NonGeoClassificationDataset, NonGeoDataset, RasterDataset, + Sample, Sentinel2, UnionDataset, VectorDataset, @@ -46,7 +47,7 @@ def __init__( self.res = res self.paths = paths or [] - def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]: + def __getitem__(self, query: BoundingBox) -> Sample: hits = self.index.intersection(tuple(query), objects=True) hit = next(iter(hits)) bounds = BoundingBox(*hit.bounds) @@ -77,7 +78,7 @@ class CustomSentinelDataset(Sentinel2): class CustomNonGeoDataset(NonGeoDataset): - def __getitem__(self, index: int) -> dict[str, int]: + def __getitem__(self, index: int) -> Sample: return {'index': index} def __len__(self) -> int: diff --git a/tests/datasets/test_splits.py b/tests/datasets/test_splits.py index 2977586ddb4..1ed16487d6b 100644 --- a/tests/datasets/test_splits.py +++ b/tests/datasets/test_splits.py @@ -3,7 +3,6 @@ from collections.abc import Sequence from math import floor, isclose -from typing import Any import pytest from rasterio.crs import CRS @@ -11,6 +10,7 @@ from torchgeo.datasets import ( BoundingBox, GeoDataset, + Sample, random_bbox_assignment, random_bbox_splitting, random_grid_cell_assignment, @@ -49,7 +49,7 @@ def __init__( self._crs = crs self.res = res - def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + def __getitem__(self, query: BoundingBox) -> Sample: hits = self.index.intersection(tuple(query), objects=True) hit = next(iter(hits)) return {'content': hit.object} diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py index 57866c8838d..4e4b6c33d5c 100644 --- a/tests/datasets/test_utils.py +++ b/tests/datasets/test_utils.py @@ -15,7 +15,7 @@ import torch from rasterio.crs import CRS -from torchgeo.datasets import BoundingBox, DependencyNotFoundError +from torchgeo.datasets import BoundingBox, DependencyNotFoundError, Sample from torchgeo.datasets.utils import ( Executable, array_to_tensor, @@ -381,13 +381,13 @@ def test_disambiguate_timestamp( class TestCollateFunctionsMatchingKeys: @pytest.fixture(scope='class') - def samples(self) -> list[dict[str, Any]]: + def samples(self) -> list[Sample]: return [ {'image': torch.tensor([1, 2, 0]), 'crs': CRS.from_epsg(2000)}, {'image': torch.tensor([0, 0, 3]), 'crs': CRS.from_epsg(2001)}, ] - def test_stack_unbind_samples(self, samples: list[dict[str, Any]]) -> None: + def test_stack_unbind_samples(self, samples: list[Sample]) -> None: sample = stack_samples(samples) assert sample['image'].size() == torch.Size([2, 3]) assert torch.allclose(sample['image'], torch.tensor([[1, 2, 0], [0, 0, 3]])) @@ -398,13 +398,13 @@ def test_stack_unbind_samples(self, samples: list[dict[str, Any]]) -> None: assert torch.allclose(samples[i]['image'], new_samples[i]['image']) assert samples[i]['crs'] == new_samples[i]['crs'] - def test_concat_samples(self, samples: list[dict[str, Any]]) -> None: + def test_concat_samples(self, samples: list[Sample]) -> None: sample = concat_samples(samples) assert sample['image'].size() == torch.Size([6]) assert torch.allclose(sample['image'], torch.tensor([1, 2, 0, 0, 0, 3])) assert sample['crs'] == CRS.from_epsg(2000) - def test_merge_samples(self, samples: list[dict[str, Any]]) -> None: + def test_merge_samples(self, samples: list[Sample]) -> None: sample = merge_samples(samples) assert sample['image'].size() == torch.Size([3]) assert torch.allclose(sample['image'], torch.tensor([1, 2, 3])) @@ -413,13 +413,13 @@ def test_merge_samples(self, samples: list[dict[str, Any]]) -> None: class TestCollateFunctionsDifferingKeys: @pytest.fixture(scope='class') - def samples(self) -> list[dict[str, Any]]: + def samples(self) -> list[Sample]: return [ {'image': torch.tensor([1, 2, 0]), 'crs1': CRS.from_epsg(2000)}, {'mask': torch.tensor([0, 0, 3]), 'crs2': CRS.from_epsg(2001)}, ] - def test_stack_unbind_samples(self, samples: list[dict[str, Any]]) -> None: + def test_stack_unbind_samples(self, samples: list[Sample]) -> None: sample = stack_samples(samples) assert sample['image'].size() == torch.Size([1, 3]) assert sample['mask'].size() == torch.Size([1, 3]) @@ -434,7 +434,7 @@ def test_stack_unbind_samples(self, samples: list[dict[str, Any]]) -> None: assert torch.allclose(samples[1]['mask'], new_samples[0]['mask']) assert samples[1]['crs2'] == new_samples[0]['crs2'] - def test_concat_samples(self, samples: list[dict[str, Any]]) -> None: + def test_concat_samples(self, samples: list[Sample]) -> None: sample = concat_samples(samples) assert sample['image'].size() == torch.Size([3]) assert sample['mask'].size() == torch.Size([3]) @@ -443,7 +443,7 @@ def test_concat_samples(self, samples: list[dict[str, Any]]) -> None: assert sample['crs1'] == CRS.from_epsg(2000) assert sample['crs2'] == CRS.from_epsg(2001) - def test_merge_samples(self, samples: list[dict[str, Any]]) -> None: + def test_merge_samples(self, samples: list[Sample]) -> None: sample = merge_samples(samples) assert sample['image'].size() == torch.Size([3]) assert sample['mask'].size() == torch.Size([3]) diff --git a/tests/samplers/test_batch.py b/tests/samplers/test_batch.py index 59c8aaa00be..2ad5e947a4a 100644 --- a/tests/samplers/test_batch.py +++ b/tests/samplers/test_batch.py @@ -10,7 +10,7 @@ from rasterio.crs import CRS from torch.utils.data import DataLoader -from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples +from torchgeo.datasets import BoundingBox, GeoDataset, Sample, stack_samples from torchgeo.samplers import BatchGeoSampler, RandomBatchGeoSampler, Units @@ -32,7 +32,7 @@ def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None: self._crs = crs self.res = res - def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]: + def __getitem__(self, query: BoundingBox) -> Sample: return {'index': query} diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index 1416368098a..57ea9539edf 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -10,7 +10,7 @@ from rasterio.crs import CRS from torch.utils.data import DataLoader -from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples +from torchgeo.datasets import BoundingBox, GeoDataset, Sample, stack_samples from torchgeo.samplers import ( GeoSampler, GridGeoSampler, @@ -39,7 +39,7 @@ def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None: self._crs = crs self.res = res - def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]: + def __getitem__(self, query: BoundingBox) -> Sample: return {'index': query} diff --git a/torchgeo/datasets/enviroatlas.py b/torchgeo/datasets/enviroatlas.py index e41c75f1847..1712b2e8568 100644 --- a/torchgeo/datasets/enviroatlas.py +++ b/torchgeo/datasets/enviroatlas.py @@ -6,7 +6,7 @@ import os import sys from collections.abc import Callable, Sequence -from typing import ClassVar, cast +from typing import Any, ClassVar, cast import fiona import matplotlib.pyplot as plt @@ -347,8 +347,8 @@ def __getitem__(self, query: BoundingBox) -> Sample: """ hits = self.index.intersection(tuple(query), objects=True) filepaths = cast(list[dict[str, str]], [hit.object for hit in hits]) - - sample: Sample = {'image': [], 'mask': [], 'crs': self.crs, 'bounds': query} + images: list[np.typing.NDArray[Any]] = [] + masks: list[np.typing.NDArray[Any]] = [] if len(filepaths) == 0: raise IndexError( @@ -389,23 +389,27 @@ def __getitem__(self, query: BoundingBox) -> Sample: 'waterbodies', 'water', ]: - sample['image'].append(data) + images.append(data) elif layer in ['prior', 'prior_no_osm_no_buildings']: if self.prior_as_input: - sample['image'].append(data) + images.append(data) else: - sample['mask'].append(data) + masks.append(data) elif layer in ['lc']: data = self.raw_enviroatlas_to_idx_map[data] - sample['mask'].append(data) + masks.append(data) else: raise IndexError(f'query: {query} spans multiple tiles which is not valid') - sample['image'] = np.concatenate(sample['image'], axis=0) - sample['mask'] = np.concatenate(sample['mask'], axis=0) + image = torch.from_numpy(np.concatenate(images, axis=0)) + mask = torch.from_numpy(np.concatenate(masks, axis=0)) - sample['image'] = torch.from_numpy(sample['image']) - sample['mask'] = torch.from_numpy(sample['mask']) + sample: Sample = { + 'image': image, + 'mask': mask, + 'crs': self.crs, + 'bounds': query, + } if self.transforms is not None: sample = self.transforms(sample) diff --git a/torchgeo/datasets/eurocrops.py b/torchgeo/datasets/eurocrops.py index 93b829d1a99..20584a1ead9 100644 --- a/torchgeo/datasets/eurocrops.py +++ b/torchgeo/datasets/eurocrops.py @@ -7,13 +7,13 @@ import os import pathlib from collections.abc import Callable, Iterable -from typing import Any import fiona import matplotlib.pyplot as plt import numpy as np from matplotlib.figure import Figure from rasterio.crs import CRS +from torch import Tensor from .errors import DatasetNotFoundError from .geo import VectorDataset @@ -247,9 +247,7 @@ def plot( fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(4, 4)) - def apply_cmap( - arr: 'np.typing.NDArray[Any]', - ) -> 'np.typing.NDArray[np.float64]': + def apply_cmap(arr: Tensor) -> 'np.typing.NDArray[np.float64]': # Color 0 as black, while applying default color map for the class indices. cmap = plt.get_cmap('viridis') im: np.typing.NDArray[np.float64] = cmap(arr / len(self.class_map)) diff --git a/torchgeo/datasets/fair1m.py b/torchgeo/datasets/fair1m.py index 0cb8de4430e..5ceadde0918 100644 --- a/torchgeo/datasets/fair1m.py +++ b/torchgeo/datasets/fair1m.py @@ -283,7 +283,7 @@ def __getitem__(self, index: int) -> Sample: label_path = label_path.replace('.tif', '.xml') voc = parse_pascal_voc(label_path) boxes, labels = self._load_target(voc['points'], voc['labels']) - sample: Sample = {'image': image, 'boxes': boxes, 'label': labels} + sample = {'image': image, 'boxes': boxes, 'label': labels} if self.transforms is not None: sample = self.transforms(sample) diff --git a/torchgeo/datasets/gid15.py b/torchgeo/datasets/gid15.py index 589f3c1115a..89a6dcfa4f7 100644 --- a/torchgeo/datasets/gid15.py +++ b/torchgeo/datasets/gid15.py @@ -139,7 +139,7 @@ def __getitem__(self, index: int) -> Sample: mask = self._load_target(files['mask']) sample: Sample = {'image': image, 'mask': mask} else: - sample: Sample = {'image': image} + sample = {'image': image} if self.transforms is not None: sample = self.transforms(sample) diff --git a/torchgeo/datasets/skippd.py b/torchgeo/datasets/skippd.py index 171cd5e8ce1..3393d55de0d 100644 --- a/torchgeo/datasets/skippd.py +++ b/torchgeo/datasets/skippd.py @@ -134,7 +134,7 @@ def __len__(self) -> int: return num_datapoints - def __getitem__(self, index: int) -> dict[str, str | Tensor]: + def __getitem__(self, index: int) -> Sample: """Return an index within the dataset. Args: @@ -143,7 +143,7 @@ def __getitem__(self, index: int) -> dict[str, str | Tensor]: Returns: data and label at that index """ - sample: dict[str, str | Tensor] = {'image': self._load_image(index)} + sample: Sample = {'image': self._load_image(index)} sample.update(self._load_features(index)) if self.transforms is not None: @@ -176,7 +176,7 @@ def _load_image(self, index: int) -> Tensor: tensor = torch.from_numpy(arr).to(torch.float32) return tensor - def _load_features(self, index: int) -> dict[str, str | Tensor]: + def _load_features(self, index: int) -> Sample: """Load label. Args: @@ -194,7 +194,7 @@ def _load_features(self, index: int) -> dict[str, str | Tensor]: path = os.path.join(self.root, f'times_{self.split}_{self.task}.npy') datestring = np.load(path, allow_pickle=True)[index].strftime(self.dateformat) - features: dict[str, str | Tensor] = { + features: Sample = { 'label': torch.tensor(label, dtype=torch.float32), 'date': datestring, } diff --git a/torchgeo/datasets/sustainbench_crop_yield.py b/torchgeo/datasets/sustainbench_crop_yield.py index 6dc041540fa..237d7adb8a7 100644 --- a/torchgeo/datasets/sustainbench_crop_yield.py +++ b/torchgeo/datasets/sustainbench_crop_yield.py @@ -98,7 +98,7 @@ def __init__( self._verify() self.images = [] - self.features = [] + self.features: list[Sample] = [] for country in self.countries: image_file_path = os.path.join( @@ -122,7 +122,7 @@ def __init__( year = year_npz_file[idx] ndvi = ndvi_npz_file[idx] - features = { + features: Sample = { 'label': torch.tensor(target).to(torch.float32), 'year': torch.tensor(int(year)), 'ndvi': torch.from_numpy(ndvi).to(dtype=torch.float32), diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index c31d9e6ee4f..e6f85413c70 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -66,8 +66,54 @@ class Sample(TypedDict, total=False): bounds: BoundingBox crs: CRS - # TODO: remove + # TODO: Additional dataset-specific keys that should be subclasses + images: Tensor + input: Tensor boxes: Tensor + bboxes: Tensor + masks: Tensor + labels: Tensor + prediction_masks: Tensor + prediction_boxes: Tensor + prediction_labels: Tensor + prediction_label: Tensor + prediction_scores: Tensor + audio: Tensor + points: Tensor + x: Tensor + y: Tensor + relative_time: Tensor + ocean: Tensor + array: Tensor + chm: Tensor + hsi: Tensor + las: Tensor + image1: Tensor + image2: Tensor + crs1: Tensor + crs2: Tensor + magnitude: Tensor + agb: Tensor + key: Tensor + patch: Tensor + geometry: Tensor + properties: Tensor + id: int + centroid_lat: Tensor + centroid_lon: Tensor + content: Tensor + year: Tensor + ndvi: Tensor + filename: str + category: str + field_ids: Tensor + tile_index: Tensor + transform: Tensor + src: Tensor + dst: Tensor + input_size: Tensor + output_size: Tensor + index: BoundingBox class Batch(Sample): @@ -456,7 +502,7 @@ def stack_samples(samples: Iterable[Sample]) -> Batch: .. versionadded:: 0.2 """ - collated: dict[Any, Any] = _list_dict_to_dict_list(samples) + collated: Batch = _list_dict_to_dict_list(samples) for key, value in collated.items(): if isinstance(value[0], Tensor): collated[key] = torch.stack(value) @@ -476,7 +522,7 @@ def concat_samples(samples: Iterable[Sample]) -> Batch: .. versionadded:: 0.2 """ - collated: dict[Any, Any] = _list_dict_to_dict_list(samples) + collated: Batch = _list_dict_to_dict_list(samples) for key, value in collated.items(): if isinstance(value[0], Tensor): collated[key] = torch.cat(value) @@ -498,7 +544,7 @@ def merge_samples(samples: Iterable[Sample]) -> Batch: .. versionadded:: 0.2 """ - collated: dict[Any, Any] = {} + collated: Batch = {} for sample in samples: for key, value in sample.items(): if key in collated and isinstance(value, Tensor):