Skip to content

Commit

Permalink
multicrop HF version
Browse files Browse the repository at this point in the history
  • Loading branch information
PedroConrado committed Feb 12, 2025
1 parent 2cd15d8 commit c253d7a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
10 changes: 5 additions & 5 deletions terratorch/datasets/multi_temporal_crop_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class MultiTemporalCropClassification(NonGeoDataset):
num_classes = 13
time_steps = 3
splits = {"train": "training", "val": "validation"} # Only train and val splits available
metadata_file_name = "chip_df_final.csv"
metadata_file_name = "chips_df.csv"
col_name = "chip_id"
date_columns = ["first_img_date", "middle_img_date", "last_img_date"]

Expand Down Expand Up @@ -106,7 +106,7 @@ def __init__(
data_dir = self.data_root / f"{split_name}_chips"
self.image_files = sorted(glob.glob(os.path.join(data_dir, "*_merged.tif")))
self.segmentation_mask_files = sorted(glob.glob(os.path.join(data_dir, "*.mask.tif")))
split_file = data_dir / f"{split_name}_data.txt"
split_file = self.data_root / f"{split_name}_data.txt"

with open(split_file) as f:
split = f.readlines()
Expand Down Expand Up @@ -263,9 +263,9 @@ def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure

if "prediction" in sample:
prediction = sample["prediction"]
ax[self.time_steps + 1].axis("off")
ax[self.time_steps+2].title.set_text("Predicted Mask")
ax[self.time_steps+2].imshow(prediction, cmap="jet", norm=norm)
ax[self.time_steps + 2].axis("off")
ax[self.time_steps + 2].title.set_text("Predicted Mask")
ax[self.time_steps + 2].imshow(prediction, cmap="jet", norm=norm)

cmap = plt.get_cmap("jet")
legend_data = [[i, cmap(norm(i)), self.class_names[i]] for i in range(self.num_classes)]
Expand Down
6 changes: 3 additions & 3 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,10 +531,10 @@ def crop_classification_data_root(tmp_path):
img_data.rio.to_raster(str(image_path))
mask_data.rio.to_raster(str(mask_path))

with open(training_dir / "training_data.txt", "w") as f:
with open(data_root / "training_data.txt", "w") as f:
f.write("\n".join([f"chip_{i}" for i in range(2)]))

with open(validation_dir / "validation_data.txt", "w") as f:
with open(data_root / "validation_data.txt", "w") as f:
f.write("\n".join([f"chip_{i}" for i in range(2)]))

metadata = pd.DataFrame({
Expand All @@ -543,7 +543,7 @@ def crop_classification_data_root(tmp_path):
"middle_img_date": ["2021-01-15", "2021-01-16"],
"last_img_date": ["2021-02-01", "2021-02-02"],
})
metadata.to_csv(data_root / "chip_df_final.csv", index=False)
metadata.to_csv(data_root / "chips_df.csv", index=False)

return str(data_root)

Expand Down

0 comments on commit c253d7a

Please sign in to comment.