From c253d7ad24cd0514ecf37e99495ad9f58ef640d3 Mon Sep 17 00:00:00 2001 From: Pedro Henrique Conrado Date: Wed, 12 Feb 2025 09:18:04 -0500 Subject: [PATCH] multicrop HF version --- .../datasets/multi_temporal_crop_classification.py | 10 +++++----- tests/test_datasets.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/terratorch/datasets/multi_temporal_crop_classification.py b/terratorch/datasets/multi_temporal_crop_classification.py index 709800d4..457b5385 100644 --- a/terratorch/datasets/multi_temporal_crop_classification.py +++ b/terratorch/datasets/multi_temporal_crop_classification.py @@ -56,7 +56,7 @@ class MultiTemporalCropClassification(NonGeoDataset): num_classes = 13 time_steps = 3 splits = {"train": "training", "val": "validation"} # Only train and val splits available - metadata_file_name = "chip_df_final.csv" + metadata_file_name = "chips_df.csv" col_name = "chip_id" date_columns = ["first_img_date", "middle_img_date", "last_img_date"] @@ -106,7 +106,7 @@ def __init__( data_dir = self.data_root / f"{split_name}_chips" self.image_files = sorted(glob.glob(os.path.join(data_dir, "*_merged.tif"))) self.segmentation_mask_files = sorted(glob.glob(os.path.join(data_dir, "*.mask.tif"))) - split_file = data_dir / f"{split_name}_data.txt" + split_file = self.data_root / f"{split_name}_data.txt" with open(split_file) as f: split = f.readlines() @@ -263,9 +263,9 @@ def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure if "prediction" in sample: prediction = sample["prediction"] - ax[self.time_steps + 1].axis("off") - ax[self.time_steps+2].title.set_text("Predicted Mask") - ax[self.time_steps+2].imshow(prediction, cmap="jet", norm=norm) + ax[self.time_steps + 2].axis("off") + ax[self.time_steps + 2].title.set_text("Predicted Mask") + ax[self.time_steps + 2].imshow(prediction, cmap="jet", norm=norm) cmap = plt.get_cmap("jet") legend_data = [[i, cmap(norm(i)), self.class_names[i]] for i in range(self.num_classes)] diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 7609cacb..e613a9df 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -531,10 +531,10 @@ def crop_classification_data_root(tmp_path): img_data.rio.to_raster(str(image_path)) mask_data.rio.to_raster(str(mask_path)) - with open(training_dir / "training_data.txt", "w") as f: + with open(data_root / "training_data.txt", "w") as f: f.write("\n".join([f"chip_{i}" for i in range(2)])) - with open(validation_dir / "validation_data.txt", "w") as f: + with open(data_root / "validation_data.txt", "w") as f: f.write("\n".join([f"chip_{i}" for i in range(2)])) metadata = pd.DataFrame({ @@ -543,7 +543,7 @@ def crop_classification_data_root(tmp_path): "middle_img_date": ["2021-01-15", "2021-01-16"], "last_img_date": ["2021-02-01", "2021-02-02"], }) - metadata.to_csv(data_root / "chip_df_final.csv", index=False) + metadata.to_csv(data_root / "chips_df.csv", index=False) return str(data_root)