From ba2ccc93277e351f481e8ff6a5269c3b599a0dbc Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Mon, 22 Jul 2024 02:34:21 +0000 Subject: [PATCH 1/5] remove logging --- torchgeo/trainers/detection.py | 32 ------------------------------- torchgeo/trainers/regression.py | 30 ----------------------------- torchgeo/trainers/segmentation.py | 27 -------------------------- 3 files changed, 89 deletions(-) diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index d13d84dcf15..71b72b3454a 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -270,38 +270,6 @@ def validation_step( self.log_dict(metrics, batch_size=batch_size) - if ( - batch_idx < 10 - and hasattr(self.trainer, 'datamodule') - and hasattr(self.trainer.datamodule, 'plot') - and self.logger - and hasattr(self.logger, 'experiment') - and hasattr(self.logger.experiment, 'add_figure') - ): - datamodule = self.trainer.datamodule - batch['prediction_boxes'] = [b['boxes'].cpu() for b in y_hat] - batch['prediction_labels'] = [b['labels'].cpu() for b in y_hat] - batch['prediction_scores'] = [b['scores'].cpu() for b in y_hat] - batch['image'] = batch['image'].cpu() - sample = unbind_samples(batch)[0] - # Convert image to uint8 for plotting - if torch.is_floating_point(sample['image']): - sample['image'] *= 255 - sample['image'] = sample['image'].to(torch.uint8) - - fig: Figure | None = None - try: - fig = datamodule.plot(sample) - except RGBBandsMissingError: - pass - - if fig: - summary_writer = self.logger.experiment - summary_writer.add_figure( - f'image/{batch_idx}', fig, global_step=self.global_step - ) - plt.close() - def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """Compute the test metrics. diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 86c3423c656..0c2f4a5fc76 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -194,36 +194,6 @@ def validation_step( self.val_metrics(y_hat, y) self.log_dict(self.val_metrics, batch_size=batch_size) - if ( - batch_idx < 10 - and hasattr(self.trainer, 'datamodule') - and hasattr(self.trainer.datamodule, 'plot') - and self.logger - and hasattr(self.logger, 'experiment') - and hasattr(self.logger.experiment, 'add_figure') - ): - datamodule = self.trainer.datamodule - if self.target_key == 'mask': - y = y.squeeze(dim=1) - y_hat = y_hat.squeeze(dim=1) - batch['prediction'] = y_hat - for key in ['image', self.target_key, 'prediction']: - batch[key] = batch[key].cpu() - sample = unbind_samples(batch)[0] - - fig: Figure | None = None - try: - fig = datamodule.plot(sample) - except RGBBandsMissingError: - pass - - if fig: - summary_writer = self.logger.experiment - summary_writer.add_figure( - f'image/{batch_idx}', fig, global_step=self.global_step - ) - plt.close() - def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """Compute the test loss and additional metrics. diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index afd71521002..3841274b9f1 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -254,33 +254,6 @@ def validation_step( self.val_metrics(y_hat, y) self.log_dict(self.val_metrics, batch_size=batch_size) - if ( - batch_idx < 10 - and hasattr(self.trainer, 'datamodule') - and hasattr(self.trainer.datamodule, 'plot') - and self.logger - and hasattr(self.logger, 'experiment') - and hasattr(self.logger.experiment, 'add_figure') - ): - datamodule = self.trainer.datamodule - batch['prediction'] = y_hat.argmax(dim=1) - for key in ['image', 'mask', 'prediction']: - batch[key] = batch[key].cpu() - sample = unbind_samples(batch)[0] - - fig: Figure | None = None - try: - fig = datamodule.plot(sample) - except RGBBandsMissingError: - pass - - if fig: - summary_writer = self.logger.experiment - summary_writer.add_figure( - f'image/{batch_idx}', fig, global_step=self.global_step - ) - plt.close() - def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """Compute the test loss and additional metrics. From 016951b54b6d37ecddf085d0dd2d051c79fa43ec Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Mon, 22 Jul 2024 02:50:53 +0000 Subject: [PATCH 2/5] remove unused imports --- torchgeo/trainers/detection.py | 3 --- torchgeo/trainers/regression.py | 3 --- torchgeo/trainers/segmentation.py | 3 --- 3 files changed, 9 deletions(-) diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index 71b72b3454a..ee7cb4b3e50 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -6,10 +6,8 @@ from functools import partial from typing import Any -import matplotlib.pyplot as plt import torch import torchvision.models.detection -from matplotlib.figure import Figure from torch import Tensor from torchmetrics import MetricCollection from torchmetrics.detection.mean_ap import MeanAveragePrecision @@ -19,7 +17,6 @@ from torchvision.models.detection.rpn import AnchorGenerator from torchvision.ops import MultiScaleRoIAlign, feature_pyramid_network, misc -from ..datasets import RGBBandsMissingError, unbind_samples from .base import BaseTask BACKBONE_LAT_DIM_MAP = { diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 0c2f4a5fc76..8b249ce2081 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -6,17 +6,14 @@ import os from typing import Any -import matplotlib.pyplot as plt import segmentation_models_pytorch as smp import timm import torch import torch.nn as nn -from matplotlib.figure import Figure from torch import Tensor from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection from torchvision.models._api import WeightsEnum -from ..datasets import RGBBandsMissingError, unbind_samples from ..models import FCN, get_weight from . import utils from .base import BaseTask diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 3841274b9f1..74af908c5ee 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -6,16 +6,13 @@ import os from typing import Any -import matplotlib.pyplot as plt import segmentation_models_pytorch as smp import torch.nn as nn -from matplotlib.figure import Figure from torch import Tensor from torchmetrics import MetricCollection from torchmetrics.classification import MulticlassAccuracy, MulticlassJaccardIndex from torchvision.models._api import WeightsEnum -from ..datasets import RGBBandsMissingError, unbind_samples from ..models import FCN, get_weight from . import utils from .base import BaseTask From 22df2c479fa92b3e5c05494922683ec4705f0f28 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Mon, 22 Jul 2024 14:24:58 +0000 Subject: [PATCH 3/5] Remove plot --- torchgeo/datamodules/geo.py | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index 5f77c0c4d6b..e34ab519b43 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -9,7 +9,6 @@ import kornia.augmentation as K import torch from lightning.pytorch import LightningDataModule -from matplotlib.figure import Figure from torch import Tensor from torch.utils.data import DataLoader, Dataset, Subset, default_collate @@ -142,28 +141,6 @@ def on_after_batch_transfer( return batch - def plot(self, *args: Any, **kwargs: Any) -> Figure | None: - """Run the plot method of the validation dataset if one exists. - - Should only be called during 'fit' or 'validate' stages as ``val_dataset`` - may not exist during other stages. - - Args: - *args: Arguments passed to plot method. - **kwargs: Keyword arguments passed to plot method. - - Returns: - A matplotlib Figure with the image, ground truth, and predictions. - """ - fig: Figure | None = None - dataset = self.dataset or self.val_dataset - if isinstance(dataset, Subset): - dataset = dataset.dataset - if dataset is not None: - if hasattr(dataset, 'plot'): - fig = dataset.plot(*args, **kwargs) - return fig - class GeoDataModule(BaseDataModule): """Base class for data modules containing geospatial information. From 57485f55a2a65dac739e06bb0e6bc5d7f69ddc3f Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Mon, 22 Jul 2024 14:39:07 +0000 Subject: [PATCH 4/5] remove plot tests --- tests/datamodules/test_fair1m.py | 12 ----- tests/datamodules/test_geo.py | 18 ------- tests/datamodules/test_usavars.py | 8 --- tests/datamodules/test_xview2.py | 8 --- tests/trainers/test_classification.py | 70 +-------------------------- tests/trainers/test_detection.py | 38 +-------------- tests/trainers/test_regression.py | 38 +-------------- tests/trainers/test_segmentation.py | 42 +--------------- 8 files changed, 4 insertions(+), 230 deletions(-) diff --git a/tests/datamodules/test_fair1m.py b/tests/datamodules/test_fair1m.py index 4aa1e16d846..e5c3a4614ce 100644 --- a/tests/datamodules/test_fair1m.py +++ b/tests/datamodules/test_fair1m.py @@ -3,7 +3,6 @@ import os -import matplotlib.pyplot as plt import pytest from torchgeo.datamodules import FAIR1MDataModule @@ -29,14 +28,3 @@ def test_val_dataloader(self, datamodule: FAIR1MDataModule) -> None: def test_predict_dataloader(self, datamodule: FAIR1MDataModule) -> None: datamodule.setup('predict') next(iter(datamodule.predict_dataloader())) - - def test_plot(self, datamodule: FAIR1MDataModule) -> None: - datamodule.setup('validate') - batch = next(iter(datamodule.val_dataloader())) - sample = { - 'image': batch['image'][0], - 'boxes': batch['boxes'][0], - 'label': batch['label'][0], - } - datamodule.plot(sample) - plt.close() diff --git a/tests/datamodules/test_geo.py b/tests/datamodules/test_geo.py index 8380ce242b8..6c63dacdb08 100644 --- a/tests/datamodules/test_geo.py +++ b/tests/datamodules/test_geo.py @@ -3,12 +3,10 @@ from typing import Any -import matplotlib.pyplot as plt import pytest import torch from _pytest.fixtures import SubRequest from lightning.pytorch import Trainer -from matplotlib.figure import Figure from rasterio.crs import CRS from torch import Tensor @@ -34,9 +32,6 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: image = torch.arange(3 * 2 * 2).view(3, 2, 2) return {'image': image, 'crs': CRS.from_epsg(4326), 'bbox': query} - def plot(self, *args: Any, **kwargs: Any) -> Figure: - return plt.figure() - class CustomGeoDataModule(GeoDataModule): def __init__(self) -> None: @@ -73,9 +68,6 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: def __len__(self) -> int: return self.length - def plot(self, *args: Any, **kwargs: Any) -> Figure: - return plt.figure() - class CustomNonGeoDataModule(NonGeoDataModule): def __init__(self) -> None: @@ -133,11 +125,6 @@ def test_predict(self, datamodule: CustomGeoDataModule) -> None: batch = datamodule.transfer_batch_to_device(batch, torch.device('cpu'), 1) batch = datamodule.on_after_batch_transfer(batch, 0) - def test_plot(self, datamodule: CustomGeoDataModule) -> None: - datamodule.setup('validate') - datamodule.plot() - plt.close() - def test_no_datasets(self) -> None: dm = CustomGeoDataModule() msg = r'CustomGeoDataModule\.setup must define one of ' @@ -235,11 +222,6 @@ def test_predict(self, datamodule: CustomNonGeoDataModule) -> None: batch = next(iter(datamodule.predict_dataloader())) batch = datamodule.on_after_batch_transfer(batch, 0) - def test_plot(self, datamodule: CustomNonGeoDataModule) -> None: - datamodule.setup('validate') - datamodule.plot() - plt.close() - def test_no_datasets(self) -> None: dm = CustomNonGeoDataModule() msg = r'CustomNonGeoDataModule\.setup must define one of ' diff --git a/tests/datamodules/test_usavars.py b/tests/datamodules/test_usavars.py index c6d5b7c77bf..92a88b7289e 100644 --- a/tests/datamodules/test_usavars.py +++ b/tests/datamodules/test_usavars.py @@ -3,7 +3,6 @@ import os -import matplotlib.pyplot as plt import pytest from _pytest.fixtures import SubRequest @@ -41,10 +40,3 @@ def test_test_dataloader(self, datamodule: USAVarsDataModule) -> None: assert len(datamodule.test_dataloader()) == 1 batch = next(iter(datamodule.test_dataloader())) assert batch['image'].shape[0] == datamodule.batch_size - - def test_plot(self, datamodule: USAVarsDataModule) -> None: - datamodule.setup('validate') - batch = next(iter(datamodule.val_dataloader())) - sample = unbind_samples(batch)[0] - datamodule.plot(sample) - plt.close() diff --git a/tests/datamodules/test_xview2.py b/tests/datamodules/test_xview2.py index 53a230b9627..67db1208f78 100644 --- a/tests/datamodules/test_xview2.py +++ b/tests/datamodules/test_xview2.py @@ -3,7 +3,6 @@ import os -import matplotlib.pyplot as plt import pytest from torchgeo.datamodules import XView2DataModule @@ -33,10 +32,3 @@ def test_val_dataloader(self, datamodule: XView2DataModule) -> None: def test_test_dataloader(self, datamodule: XView2DataModule) -> None: datamodule.setup('test') next(iter(datamodule.test_dataloader())) - - def test_plot(self, datamodule: XView2DataModule) -> None: - datamodule.setup('validate') - batch = next(iter(datamodule.val_dataloader())) - sample = unbind_samples(batch)[0] - datamodule.plot(sample) - plt.close() diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index cd437f9faed..01e06751a8e 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -20,7 +20,7 @@ EuroSATDataModule, MisconfigurationException, ) -from torchgeo.datasets import BigEarthNet, EuroSAT, RGBBandsMissingError +from torchgeo.datasets import BigEarthNet, EuroSAT from torchgeo.main import main from torchgeo.models import ResNet18_Weights from torchgeo.trainers import ClassificationTask, MultiLabelClassificationTask @@ -61,14 +61,6 @@ def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]: return state_dict -def plot(*args: Any, **kwargs: Any) -> None: - return None - - -def plot_missing_bands(*args: Any, **kwargs: Any) -> None: - raise RGBBandsMissingError() - - class TestClassificationTask: @pytest.mark.parametrize( 'name', @@ -186,34 +178,6 @@ def test_invalid_loss(self) -> None: with pytest.raises(ValueError, match=match): ClassificationTask(model='resnet18', loss='invalid_loss') - def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: - monkeypatch.setattr(EuroSATDataModule, 'plot', plot) - datamodule = EuroSATDataModule( - root='tests/data/eurosat', batch_size=1, num_workers=0 - ) - model = ClassificationTask(model='resnet18', in_channels=13, num_classes=10) - trainer = Trainer( - accelerator='cpu', - fast_dev_run=fast_dev_run, - log_every_n_steps=1, - max_epochs=1, - ) - trainer.validate(model=model, datamodule=datamodule) - - def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: - monkeypatch.setattr(EuroSATDataModule, 'plot', plot_missing_bands) - datamodule = EuroSATDataModule( - root='tests/data/eurosat', batch_size=1, num_workers=0 - ) - model = ClassificationTask(model='resnet18', in_channels=13, num_classes=10) - trainer = Trainer( - accelerator='cpu', - fast_dev_run=fast_dev_run, - log_every_n_steps=1, - max_epochs=1, - ) - trainer.validate(model=model, datamodule=datamodule) - def test_predict(self, fast_dev_run: bool) -> None: datamodule = PredictClassificationDataModule( root='tests/data/eurosat', batch_size=1, num_workers=0 @@ -277,38 +241,6 @@ def test_invalid_loss(self) -> None: with pytest.raises(ValueError, match=match): MultiLabelClassificationTask(model='resnet18', loss='invalid_loss') - def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: - monkeypatch.setattr(BigEarthNetDataModule, 'plot', plot) - datamodule = BigEarthNetDataModule( - root='tests/data/bigearthnet', batch_size=1, num_workers=0 - ) - model = MultiLabelClassificationTask( - model='resnet18', in_channels=14, num_classes=19, loss='bce' - ) - trainer = Trainer( - accelerator='cpu', - fast_dev_run=fast_dev_run, - log_every_n_steps=1, - max_epochs=1, - ) - trainer.validate(model=model, datamodule=datamodule) - - def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: - monkeypatch.setattr(BigEarthNetDataModule, 'plot', plot_missing_bands) - datamodule = BigEarthNetDataModule( - root='tests/data/bigearthnet', batch_size=1, num_workers=0 - ) - model = MultiLabelClassificationTask( - model='resnet18', in_channels=14, num_classes=19, loss='bce' - ) - trainer = Trainer( - accelerator='cpu', - fast_dev_run=fast_dev_run, - log_every_n_steps=1, - max_epochs=1, - ) - trainer.validate(model=model, datamodule=datamodule) - def test_predict(self, fast_dev_run: bool) -> None: datamodule = PredictMultiLabelClassificationDataModule( root='tests/data/bigearthnet', batch_size=1, num_workers=0 diff --git a/tests/trainers/test_detection.py b/tests/trainers/test_detection.py index 035bdacc260..ecedbfd847e 100644 --- a/tests/trainers/test_detection.py +++ b/tests/trainers/test_detection.py @@ -13,7 +13,7 @@ from torch.nn.modules import Module from torchgeo.datamodules import MisconfigurationException, NASAMarineDebrisDataModule -from torchgeo.datasets import NASAMarineDebris, RGBBandsMissingError +from torchgeo.datasets import NASAMarineDebris from torchgeo.main import main from torchgeo.trainers import ObjectDetectionTask @@ -26,10 +26,6 @@ def setup(self, stage: str) -> None: self.predict_dataset = NASAMarineDebris(**self.kwargs) -def plot_missing_bands(*args: Any, **kwargs: Any) -> None: - raise RGBBandsMissingError() - - class ObjectDetectionTestModel(Module): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__() @@ -62,10 +58,6 @@ def forward(self, images: Any, targets: Any = None) -> Any: return output -def plot(*args: Any, **kwargs: Any) -> None: - return None - - class TestObjectDetectionTask: @pytest.mark.parametrize('name', ['nasa_marine_debris', 'vhr10']) @pytest.mark.parametrize('model_name', ['faster-rcnn', 'fcos', 'retinanet']) @@ -120,34 +112,6 @@ def test_invalid_backbone(self) -> None: def test_pretrained_backbone(self) -> None: ObjectDetectionTask(backbone='resnet18', weights=True) - def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: - monkeypatch.setattr(NASAMarineDebrisDataModule, 'plot', plot) - datamodule = NASAMarineDebrisDataModule( - root='tests/data/nasa_marine_debris', batch_size=1, num_workers=0 - ) - model = ObjectDetectionTask(backbone='resnet18', num_classes=2) - trainer = Trainer( - accelerator='cpu', - fast_dev_run=fast_dev_run, - log_every_n_steps=1, - max_epochs=1, - ) - trainer.validate(model=model, datamodule=datamodule) - - def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: - monkeypatch.setattr(NASAMarineDebrisDataModule, 'plot', plot_missing_bands) - datamodule = NASAMarineDebrisDataModule( - root='tests/data/nasa_marine_debris', batch_size=1, num_workers=0 - ) - model = ObjectDetectionTask(backbone='resnet18', num_classes=2) - trainer = Trainer( - accelerator='cpu', - fast_dev_run=fast_dev_run, - log_every_n_steps=1, - max_epochs=1, - ) - trainer.validate(model=model, datamodule=datamodule) - def test_predict(self, fast_dev_run: bool) -> None: datamodule = PredictObjectDetectionDataModule( root='tests/data/nasa_marine_debris', batch_size=1, num_workers=0 diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index c62c808c72f..0c3d03e1aae 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -17,7 +17,7 @@ from torchvision.models._api import WeightsEnum from torchgeo.datamodules import MisconfigurationException, TropicalCycloneDataModule -from torchgeo.datasets import RGBBandsMissingError, TropicalCyclone +from torchgeo.datasets import TropicalCyclone from torchgeo.main import main from torchgeo.models import ResNet18_Weights from torchgeo.trainers import PixelwiseRegressionTask, RegressionTask @@ -51,14 +51,6 @@ def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]: return state_dict -def plot(*args: Any, **kwargs: Any) -> None: - return None - - -def plot_missing_bands(*args: Any, **kwargs: Any) -> None: - raise RGBBandsMissingError() - - class TestRegressionTask: @classmethod def create_model(*args: Any, **kwargs: Any) -> Module: @@ -156,34 +148,6 @@ def test_weight_str_download(self, weights: WeightsEnum) -> None: in_channels=weights.meta['in_chans'], ) - def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: - monkeypatch.setattr(TropicalCycloneDataModule, 'plot', plot) - datamodule = TropicalCycloneDataModule( - root='tests/data/cyclone', batch_size=1, num_workers=0 - ) - model = RegressionTask(model='resnet18') - trainer = Trainer( - accelerator='cpu', - fast_dev_run=fast_dev_run, - log_every_n_steps=1, - max_epochs=1, - ) - trainer.validate(model=model, datamodule=datamodule) - - def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: - monkeypatch.setattr(TropicalCycloneDataModule, 'plot', plot_missing_bands) - datamodule = TropicalCycloneDataModule( - root='tests/data/cyclone', batch_size=1, num_workers=0 - ) - model = RegressionTask(model='resnet18') - trainer = Trainer( - accelerator='cpu', - fast_dev_run=fast_dev_run, - log_every_n_steps=1, - max_epochs=1, - ) - trainer.validate(model=model, datamodule=datamodule) - def test_predict(self, fast_dev_run: bool) -> None: datamodule = PredictRegressionDataModule( root='tests/data/cyclone', batch_size=1, num_workers=0 diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index d8b207d5d2d..d82d771b522 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -17,7 +17,7 @@ from torchvision.models._api import WeightsEnum from torchgeo.datamodules import MisconfigurationException, SEN12MSDataModule -from torchgeo.datasets import LandCoverAI, RGBBandsMissingError +from torchgeo.datasets import LandCoverAI from torchgeo.main import main from torchgeo.models import ResNet18_Weights from torchgeo.trainers import SemanticSegmentationTask @@ -43,14 +43,6 @@ def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]: return state_dict -def plot(*args: Any, **kwargs: Any) -> None: - return None - - -def plot_missing_bands(*args: Any, **kwargs: Any) -> None: - raise RGBBandsMissingError() - - class TestSemanticSegmentationTask: @pytest.mark.parametrize( 'name', @@ -189,38 +181,6 @@ def test_invalid_loss(self) -> None: with pytest.raises(ValueError, match=match): SemanticSegmentationTask(loss='invalid_loss') - def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: - monkeypatch.setattr(SEN12MSDataModule, 'plot', plot) - datamodule = SEN12MSDataModule( - root='tests/data/sen12ms', batch_size=1, num_workers=0 - ) - model = SemanticSegmentationTask( - backbone='resnet18', in_channels=15, num_classes=6 - ) - trainer = Trainer( - accelerator='cpu', - fast_dev_run=fast_dev_run, - log_every_n_steps=1, - max_epochs=1, - ) - trainer.validate(model=model, datamodule=datamodule) - - def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: - monkeypatch.setattr(SEN12MSDataModule, 'plot', plot_missing_bands) - datamodule = SEN12MSDataModule( - root='tests/data/sen12ms', batch_size=1, num_workers=0 - ) - model = SemanticSegmentationTask( - backbone='resnet18', in_channels=15, num_classes=6 - ) - trainer = Trainer( - accelerator='cpu', - fast_dev_run=fast_dev_run, - log_every_n_steps=1, - max_epochs=1, - ) - trainer.validate(model=model, datamodule=datamodule) - @pytest.mark.parametrize('model_name', ['unet', 'deeplabv3+']) @pytest.mark.parametrize( 'backbone', ['resnet18', 'mobilenet_v2', 'efficientnet-b0'] From 5df210ec735935944684d3cc97f6fc4dafea7488 Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Mon, 22 Jul 2024 14:42:13 +0000 Subject: [PATCH 5/5] tidy imports --- tests/datamodules/test_usavars.py | 1 - tests/datamodules/test_xview2.py | 1 - tests/trainers/test_segmentation.py | 3 +-- torchgeo/datamodules/geo.py | 2 +- 4 files changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/datamodules/test_usavars.py b/tests/datamodules/test_usavars.py index 92a88b7289e..c9a1d47bad0 100644 --- a/tests/datamodules/test_usavars.py +++ b/tests/datamodules/test_usavars.py @@ -7,7 +7,6 @@ from _pytest.fixtures import SubRequest from torchgeo.datamodules import USAVarsDataModule -from torchgeo.datasets import unbind_samples class TestUSAVarsDataModule: diff --git a/tests/datamodules/test_xview2.py b/tests/datamodules/test_xview2.py index 67db1208f78..22780ab3849 100644 --- a/tests/datamodules/test_xview2.py +++ b/tests/datamodules/test_xview2.py @@ -6,7 +6,6 @@ import pytest from torchgeo.datamodules import XView2DataModule -from torchgeo.datasets import unbind_samples class TestXView2DataModule: diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index d82d771b522..ffc0f46a6ad 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -11,12 +11,11 @@ import torch import torch.nn as nn import torchvision -from lightning.pytorch import Trainer from pytest import MonkeyPatch from torch.nn.modules import Module from torchvision.models._api import WeightsEnum -from torchgeo.datamodules import MisconfigurationException, SEN12MSDataModule +from torchgeo.datamodules import MisconfigurationException from torchgeo.datasets import LandCoverAI from torchgeo.main import main from torchgeo.models import ResNet18_Weights diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index e34ab519b43..c8b9fd3fbb7 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -10,7 +10,7 @@ import torch from lightning.pytorch import LightningDataModule from torch import Tensor -from torch.utils.data import DataLoader, Dataset, Subset, default_collate +from torch.utils.data import DataLoader, Dataset, default_collate from ..datasets import GeoDataset, NonGeoDataset, stack_samples from ..samplers import (