From 485f791ad9c24e1d4aac5b0e63c8aa82c6f715f4 Mon Sep 17 00:00:00 2001 From: Benjamin Cretois Date: Fri, 29 Nov 2024 10:47:04 +0100 Subject: [PATCH] [FIX] migrate rasterize_mask and metrics to Hydra --- src/dataset/get_imgs_to_predict.py | 85 +++++++++++++ src/dataset/rasterize_masks.py | 17 ++- src/metrics.py | 50 ++------ src/predict_crop.py | 190 +++++++++++++++++++++++++++++ src/utils/transforms_crop.py | 29 +++++ 5 files changed, 324 insertions(+), 47 deletions(-) create mode 100644 src/dataset/get_imgs_to_predict.py create mode 100644 src/predict_crop.py create mode 100644 src/utils/transforms_crop.py diff --git a/src/dataset/get_imgs_to_predict.py b/src/dataset/get_imgs_to_predict.py new file mode 100644 index 0000000..998a6e5 --- /dev/null +++ b/src/dataset/get_imgs_to_predict.py @@ -0,0 +1,85 @@ +#!/usr/env/bin python3 + +import os +from io import BytesIO + +import hydra +import backoff +import geopandas as gpd +import numpy as np +import requests +from PIL import Image +from rasterio.warp import Resampling, calculate_default_transform, reproject +from tqdm import tqdm + + +@backoff.on_exception(backoff.expo, requests.exceptions.HTTPError, max_tries=5) +def get_image(gdf, item_id, wms_url, crs, output_path): + """ + Get the image from WMS and returns a numpy array + """ + + gdf_item = gdf[gdf["id"] == item_id] + minx, miny, maxx, maxy = gdf_item.total_bounds + + # Fetch ortofoto + wms_url = wms_url # "https://wms.geonorge.no/skwms1/wms.nib" + params = { + "SERVICE": "WMS", + "VERSION": "1.3.0", + "REQUEST": "GetMap", + "LAYERS": "ortofoto", + "BBOX": f"{minx},{miny},{maxx},{maxy}", + "WIDTH": int(11**5 * abs(maxx - minx)), + "HEIGHT": int(11**5 * abs(maxy - miny)), + "FORMAT": "image/png", + "SRS": "EPSG:4326", + } + + response = requests.get(wms_url, params=params) + response.raise_for_status() + + # Open image and convert to numpy array + img = Image.open(BytesIO(response.content)) + img_np = np.array(img) + + # Calculate the transform and shape for the new coordinate system + transform, width, height = calculate_default_transform( + crs, crs, img_np.shape[1], img_np.shape[0], minx, miny, maxx, maxy + ) + + # Reproject the image data to the new CRS + warped_img = np.empty((height, width, img_np.shape[2]), dtype=np.uint8) + for i in range(img_np.shape[2]): # Loop through the RGB bands + reproject( + source=img_np[:, :, i], + destination=warped_img[:, :, i], + src_transform=transform, # Use the affine transform from before + src_crs=crs, + dst_transform=transform, # The new transformation + dst_crs=crs, + resampling=Resampling.nearest, + ) + + output_filename = os.path.join(output_path, f"image_{item_id}.png") + Image.fromarray(warped_img).save(output_filename) + + return warped_img + + +@hydra.main(version_base=None, config_path="../../configs", config_name="config") +def main(cfg): + gdf = gpd.read_file(cfg.paths.BOUNDING_BOXES) + + for item_id in tqdm(gdf["id"].unique(), desc="Saving Images"): + get_image( + gdf, + item_id, + cfg.dataset.WMS_URL, + cfg.dataset.CRS, + cfg.path.IMG_DIR_PREDICT, + ) + + +if __name__ == "__main__": + main() diff --git a/src/dataset/rasterize_masks.py b/src/dataset/rasterize_masks.py index 0f69bd2..1c4a8e8 100644 --- a/src/dataset/rasterize_masks.py +++ b/src/dataset/rasterize_masks.py @@ -2,10 +2,10 @@ import os +import hydra import geopandas as gpd import numpy as np import pandas as pd -import yaml from PIL import Image from rasterio.features import rasterize from rasterio.transform import from_origin @@ -101,10 +101,15 @@ def rasterize_masks( print(f"Saved mask: {mask_filename}") -if __name__ == "__main__": - with open("config.yaml") as f: - cfg = yaml.load(f, Loader=yaml.FullLoader) - +@hydra.main(version_base=None, config_path="../../configs", config_name="config") +def main(cfg): rasterize_masks( - cfg["MASK"], cfg["IMG_DIR"], cfg["MASKS_DIR"], label_column="labelTekst" + cfg.paths.MASK, + cfg.paths.IMG_DIR, + cfg.paths.MASKS_DIR, + label_column="labelTekst", ) + + +if __name__ == "__main__": + main() diff --git a/src/metrics.py b/src/metrics.py index ead6ab7..2b7be42 100644 --- a/src/metrics.py +++ b/src/metrics.py @@ -1,10 +1,12 @@ import glob import os +import hydra import numpy as np -import yaml from PIL import Image +# TODO: Get an estimate of % or error on each class + def pixel_accuracy(pred, label): correct = (pred == label).sum().item() @@ -131,12 +133,10 @@ def aggregate_metrics(metrics_list): return avg_metrics -if __name__ == "__main__": - with open("./config.yaml") as f: - cfg = yaml.load(f, Loader=yaml.FullLoader) - - pred_folder = cfg["PREDICTED_MASKS"] - gt_folder = cfg["MASKS_DIR"] +@hydra.main(version_base=None, config_path="../configs", config_name="config") +def main(cfg): + pred_folder = cfg.paths.PRED_TEST_MASKS + gt_folder = cfg.paths.GT_TEST_MASKS num_classes = 8 ignore_value = -1 @@ -148,37 +148,5 @@ def aggregate_metrics(metrics_list): print(f"{metric}: {value:.4f}") -# Without augmentation - -# Average Metrics for all masks: -# Pixel Accuracy: 0.9000 -# Mean IoU: 0.6252 -# Mean Dice Coefficient: 0.7721 -# recision: 0.7907 -# Recall: 0.7270 - -# Without augmentation but albumentations: -# Pixel Accuracy: 0.8778 -# Mean IoU: 0.5277 -# Mean Dice Coefficient: 0.6926 -# Precision: 0.7484 -# Recall: 0.6418 - - -# With augmentation - -# Pixel Accuracy: 0.6825 -# Mean IoU: 0.1765 -# Mean Dice Coefficient: 0.3935 -# Precision: 0.5590 -# Recall: 0.2953 - - -# With only normalisation - -# Average Metrics for all masks: -# Pixel Accuracy: 0.7626 -# Mean IoU: 0.2935 -# Mean Dice Coefficient: 0.4953 -# Precision: 0.5157 -# Recall: 0.4339 +if __name__ == "__main__": + main() diff --git a/src/predict_crop.py b/src/predict_crop.py new file mode 100644 index 0000000..55937c3 --- /dev/null +++ b/src/predict_crop.py @@ -0,0 +1,190 @@ +#!/usr/env/bin python3 + +import os +import sys + +import numpy as np +import torch +import torchvision.transforms as T +import yaml +from PIL import Image +from torchvision.models.segmentation import deeplabv3_resnet50 + + +def load_model(checkpoint_path, num_classes=8): + model = deeplabv3_resnet50(pretrained=False) + model.classifier[4] = torch.nn.Conv2d(256, num_classes, kernel_size=(1, 1)) + + # Load the saved state dict + checkpoint = torch.load( + checkpoint_path, + map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"), + ) + + # Remove the 'model.' prefix from the state_dict keys + state_dict = { + k.replace("model.", ""): v for k, v in checkpoint["state_dict"].items() + } + + # Filter out the auxiliary classifier keys + state_dict = {k: v for k, v in state_dict.items() if "aux_classifier" not in k} + + # Load the modified state_dict into the model + model.load_state_dict(state_dict, strict=False) + model.eval() # Set the model to evaluation mode + return model + + +def preprocess_image(image_path, patch_size=(512, 512)): + # Ensure patch_size is a tuple of integers + if isinstance(patch_size, tuple) and len(patch_size) == 2: + pass # Valid patch_size + else: + raise ValueError( + f"Invalid patch_size: {patch_size}. Expected a tuple (height, width)." + ) + + # Open the image file using PIL + image = Image.open(image_path).convert("RGB") + + # Save the original image size + original_size = image.size # (width, height) + + # Calculate padding required to make dimensions divisible by patch size + width, height = original_size + pad_width = (patch_size[0] - (width % patch_size[0])) % patch_size[0] + pad_height = (patch_size[1] - (height % patch_size[1])) % patch_size[1] + + # Apply padding + padded_image = Image.new("RGB", (width + pad_width, height + pad_height), (0, 0, 0)) + padded_image.paste(image, (0, 0)) + + # Convert to tensor + preprocess = T.Compose([T.ToTensor()]) + input_tensor = preprocess(padded_image).unsqueeze(0) + + return input_tensor, original_size, (width + pad_width, height + pad_height) + + +def get_corresponding_mask(image_path, mask_dir): + """ + Finds the corresponding ground truth mask for a given image. + + Args: + image_path (str): Path to the input image. + mask_dir (str): Directory containing ground truth masks. + + Returns: + str: Path to the corresponding ground truth mask. + """ + image_filename = os.path.basename(image_path) + identifier = image_filename.split("image_")[-1].replace(".png", "") + mask_filename = f"mask_{identifier}.tif" + + mask_path = os.path.join(mask_dir, mask_filename) + + if not os.path.exists(mask_path): + raise FileNotFoundError( + f"Mask file not found for {image_path}. Expected: {mask_path}" + ) + + return mask_path + + +def apply_ignore_index(predicted_mask, ground_truth_mask, ignore_value=-1): + # Ensure both predicted_mask and ground_truth_mask are numpy arrays + if isinstance(predicted_mask, Image.Image): + predicted_mask = np.array(predicted_mask) + + if isinstance(ground_truth_mask, Image.Image): + ground_truth_mask = np.array(ground_truth_mask) + + # Set the values in the predicted mask to -1 wherever the ground truth mask is -1 + predicted_mask[ground_truth_mask == ignore_value] = ignore_value + + return predicted_mask + + +def predict_image( + image_path, + mask_dir, + model, + output_mask_path, + patch_size=(512, 512), + ignore_value=-1, +): + """ + Predicts segmentation mask for a single image, applies ignore index based on ground truth, and saves the output. + + Args: + image_path (str): Path to the input image. + mask_dir (str): Directory containing the ground truth masks. + model (torch.nn.Module): Pretrained segmentation model. + output_mask_path (str): Directory to save the predicted mask. + patch_size (tuple): Size of the patch for padding (default: (512, 512)). + ignore_value (int): Value to ignore in the ground truth mask. + """ + # Preprocess the image (padding applied) + input_tensor, original_size, padded_size = preprocess_image(image_path, patch_size) + original_width, original_height = original_size + + with torch.no_grad(): + output = model(input_tensor)["out"] + + # Get predicted mask + predicted_mask = torch.argmax(output.squeeze(), dim=0).cpu().numpy() + + # Crop the predicted mask back to the original image size + predicted_mask_cropped = predicted_mask[:original_height, :original_width] + + # Find the corresponding ground truth mask + gt_mask_path = get_corresponding_mask(image_path, mask_dir) + if not os.path.exists(gt_mask_path): + raise FileNotFoundError(f"Ground truth mask not found: {gt_mask_path}") + ground_truth_mask = Image.open(gt_mask_path) + + # Apply ignore index to predicted mask + final_mask_with_ignore = apply_ignore_index( + predicted_mask_cropped, ground_truth_mask, ignore_value + ) + + # Ensure the output directory exists + os.makedirs(output_mask_path, exist_ok=True) + + # Save the final mask + pred_mask_name = os.path.join( + output_mask_path, + "predmask_" + os.path.basename(image_path).split(".")[0] + ".tif", + ) + final_mask = Image.fromarray(final_mask_with_ignore.astype(np.int16)) + final_mask.save(pred_mask_name) + print(f"Predicted mask saved at {pred_mask_name}") + + +if __name__ == "__main__": + with open("./config.yaml") as f: + cfg = yaml.load(f, Loader=yaml.FullLoader) + + if not os.path.exists(cfg["PRED_TEST_MASKS"]): + os.makedirs(cfg["PRED_TEST_MASKS"]) + + if len(sys.argv) != 2: + print("Usage: python predict.py ") + sys.exit(1) + + input_image_path = sys.argv[1] + + # Load the model + model = load_model(cfg["MODEL_PATH"]) + + # Run prediction and save the mask + predict_image( + image_path=input_image_path, + model=model, + mask_dir=cfg["GT_TEST_MASKS"], + output_mask_path=cfg["PRED_TEST_MASKS"], + patch_size=( + 512, + 512, + ), # Use the default explicitly or omit this if the default is fine + ) diff --git a/src/utils/transforms_crop.py b/src/utils/transforms_crop.py new file mode 100644 index 0000000..c5f6299 --- /dev/null +++ b/src/utils/transforms_crop.py @@ -0,0 +1,29 @@ +import albumentations as A +from albumentations.pytorch import ToTensorV2 +from torchvision import transforms + +# Define a cropping transform +albumentations_transform = A.Compose( + [ + A.PadIfNeeded( + min_height=512, + min_width=512, + border_mode=0, + value=0, + ), + A.RandomCrop(width=512, height=512), + A.RandomRotate90(p=0.5), + A.HorizontalFlip(p=0.5), + A.VerticalFlip(p=0.5), + A.RandomBrightnessContrast(p=0.5), + A.GaussNoise(p=0.5), + A.GridDistortion(p=0.5), + A.ElasticTransform(p=0.5), + A.Normalize(mean=(0, 0, 0), std=(1, 1, 1)), + ToTensorV2(), + ], + additional_targets={"mask": "mask"}, +) + +# Torchvision resize transformation with InterpolationMode.BILINEAR +resize_transform = transforms.Compose([])