Skip to content

Commit

Permalink
Merge pull request #29 from hotosm/feature/multimasks
Browse files Browse the repository at this point in the history
Feature : Multimask Training
  • Loading branch information
kshitijrajsharma authored Sep 3, 2024
2 parents 93debb4 + 3ddb004 commit c876f9e
Show file tree
Hide file tree
Showing 7 changed files with 327 additions and 104 deletions.
141 changes: 141 additions & 0 deletions hot_fair_utilities/preprocessing/multimasks_from_polygons.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Patched from ramp-code.scripts.multi_masks_from_polygons created for ramp project by [email protected]

# Standard library imports
from pathlib import Path

# Third party imports
import geopandas as gpd
import rasterio as rio
from ramp.data_mgmt.chip_label_pairs import (
construct_mask_filepath,
get_tq_chip_label_pairs,
)
from ramp.utils.img_utils import to_channels_first
from ramp.utils.multimask_utils import df_to_px_mask, multimask_to_sparse_multimask
from solaris.utils.core import _check_rasterio_im_load
from solaris.utils.geo import get_crs
from solaris.vector.mask import crs_is_metric
from tqdm import tqdm


def get_rasterio_shape_and_transform(image_path):
# get the image shape and the affine transform to pass into df_to_px_mask.
with rio.open(image_path) as rio_dset:
shape = rio_dset.shape
transform = rio_dset.transform
return shape, transform


def multimasks_from_polygons(
in_poly_dir,
in_chip_dir,
out_mask_dir,
input_contact_spacing=8,
input_boundary_width=3,
):
"""
Create multichannel building footprint masks from a folder of geojson files.
This also requires the path to the matching image chips directory.Unit of input_contact_spacing and input_boundary_width is in pixel which is :
## Can not use meters for contact spacing and width because it won't maintain consistency in different zoom levels
Real-world width (in meters)= Pixel width×Resolution (meters per pixel)
Args:
in_poly_dir (str): Path to directory containing geojson files.
in_chip_dir (str): Path to directory containing image chip files with names matching geojson files.
out_mask_dir (str): Path to directory containing output SDT masks.
input_contact_spacing (int, optional): Pixels that are closer to two different polygons than contact_spacing will be labeled with the contact mask.
input_boundary_width (int, optional): Width in pixel of boundary inner buffer around building footprints
Example:
multimasks_from_polygons(
"data/preprocessed/labels",
"data/preprocessed/chips",
"data/preprocessed/multimasks"
)
"""

# If output mask directory doesn't exist, try to create it.
Path(out_mask_dir).mkdir(parents=True, exist_ok=True)

chip_label_pairs = get_tq_chip_label_pairs(in_chip_dir, in_poly_dir)

chip_paths, label_paths = list(zip(*chip_label_pairs))

# construct the output mask file names from the chip file names.
# these will have the same base filenames as the chip files,
# with a mask.tif extension in place of the .tif extension.
mask_paths = [
construct_mask_filepath(out_mask_dir, chip_path) for chip_path in chip_paths
]

# construct a list of full paths to the mask files
json_chip_mask_zips = zip(label_paths, chip_paths, mask_paths)
for json_path, chip_path, mask_path in tqdm(
json_chip_mask_zips, desc="Multimasks for input"
):

# We will run this on very large directories, and some label files might fail to process.
# We want to be able to resume mask creation from where we left off.
if Path(mask_path).is_file():
continue

# workaround for bug in solaris
mask_shape, mask_transform = get_rasterio_shape_and_transform(chip_path)

gdf = gpd.read_file(json_path)

# remove empty and null geometries
gdf = gdf[~gdf["geometry"].isna()]
gdf = gdf[~gdf.is_empty]

reference_im = _check_rasterio_im_load(chip_path)

if get_crs(gdf) != get_crs(reference_im):
# BUGFIX: if crs's don't match, reproject the geodataframe
gdf = gdf.to_crs(get_crs(reference_im))

if crs_is_metric(gdf):
meters = True
boundary_width = min(reference_im.res) * input_boundary_width
contact_spacing = min(reference_im.res) * input_contact_spacing

else:
meters = False
boundary_width = input_boundary_width
contact_spacing = input_contact_spacing

# NOTE: solaris does not support multipolygon geodataframes
# So first we call explode() to turn multipolygons into polygon dataframes
# ignore_index=True prevents polygons from the same multipolygon from being grouped into a series. -+
gdf_poly = gdf.explode(ignore_index=True)

# multi_mask is a one-hot, channels-last encoded mask
onehot_multi_mask = df_to_px_mask(
df=gdf_poly,
out_file=mask_path,
shape=mask_shape,
do_transform=True,
affine_obj=None,
channels=["footprint", "boundary", "contact"],
reference_im=reference_im,
boundary_width=boundary_width,
contact_spacing=contact_spacing,
out_type="uint8",
meters=meters,
)

# convert onehot_multi_mask to a sparse encoded mask
# of shape (1,H,W) for compatibility with rasterio writer
sparse_multi_mask = multimask_to_sparse_multimask(onehot_multi_mask)
sparse_multi_mask = to_channels_first(sparse_multi_mask)

# write out sparse mask file with rasterio.
with rio.open(chip_path, "r") as src:
meta = src.meta.copy()
meta.update(count=sparse_multi_mask.shape[0])
meta.update(dtype="uint8")
meta.update(nodata=None)
with rio.open(mask_path, "w", **meta) as dst:
dst.write(sparse_multi_mask)
24 changes: 24 additions & 0 deletions hot_fair_utilities/preprocessing/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ..georeferencing import georeference
from .clip_labels import clip_labels
from .fix_labels import fix_labels
from .multimasks_from_polygons import multimasks_from_polygons
from .reproject_labels import reproject_labels_to_epsg3857


Expand All @@ -13,6 +14,9 @@ def preprocess(
rasterize=False,
rasterize_options=None,
georeference_images=False,
multimasks=False,
input_contact_spacing=8, # only required if multimasks is set to true
input_boundary_width=3, # only required if mulltimasks is set to true
) -> None:
"""Fully preprocess the input data.
Expand All @@ -29,6 +33,7 @@ def preprocess(
(if georeference_images=True), and the directories
"binarymasks" and "grayscale_labels" if the corresponding
rasterizing options are chosen.
"multimasks" - for the multimasks labels (if multimasks=True)
rasterize: Whether to create the raster labels.
rasterize_options: A list with options how to rasterize the
label, if rasterize=True. Possible options: "grayscale"
Expand All @@ -37,6 +42,13 @@ def preprocess(
for the ramp model).
If rasterize=False, rasterize_options will be ignored.
georeference_images: Whether to georeference the OAM images.
multimasks: Whether to additionally output multimask labels.
input_contact_spacing (int, optional): Pixels that are closer to two different polygons than contact_spacing will be labeled with the contact mask.
input_boundary_width (int, optional): Width in pixel of boundary inner buffer around building footprints
Unit of input_contact_spacing and input_boundary_width is in pixel, we couldn't use meters to maintain consistency based on different zoom level as pixel resolution will be different which is :
Real-world width (in meters)= Pixel width×Resolution (meters per pixel)
Example::
Expand Down Expand Up @@ -82,3 +94,15 @@ def preprocess(

os.remove(f"{output_path}/corrected_labels.geojson")
os.remove(f"{output_path}/labels_epsg3857.geojson")

if multimasks:
assert os.path.isdir(
f"{output_path}/chips"
), "Chips do not exist. Set georeference_images=True."
multimasks_from_polygons(
f"{output_path}/labels",
f"{output_path}/chips",
f"{output_path}/multimasks",
input_contact_spacing=input_contact_spacing,
input_boundary_width=input_boundary_width,
)
13 changes: 10 additions & 3 deletions hot_fair_utilities/training/prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self, message):
self.message = message


def split_training_2_validation(input_path, output_path):
def split_training_2_validation(input_path, output_path, multimasks=False):
"""Converts training 2 validation
Currently supported for ramp , It converts training dataset provided by preprocessing script to validation datastes reuqired by ramp
Expand Down Expand Up @@ -101,14 +101,21 @@ def split_training_2_validation(input_path, output_path):
raise ex

try:
if multimasks:
sd = f"{dst_path}/multimasks"
td = f"{dst_path}/val-multimasks"
else:
sd = f"{dst_path}/binarymasks"
td = f"{dst_path}/val-binarymasks"

subprocess.check_output(
[
python_exec,
f"{RAMP_HOME}/ramp-code/scripts/move_chips_from_csv.py",
"-sd",
f"{dst_path}/binarymasks",
sd,
"-td",
f"{dst_path}/val-binarymasks",
td,
"-csv",
f"{dst_path}/fair_split_val.csv",
"-mv",
Expand Down
Loading

0 comments on commit c876f9e

Please sign in to comment.