-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e2ec93a
commit a672551
Showing
4 changed files
with
271 additions
and
103 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
import json | ||
import os | ||
from collections import defaultdict | ||
|
||
import numpy as np | ||
import pycocotools.mask as maskUtils | ||
from detectron2.data import DatasetCatalog, MetadataCatalog | ||
from detectron2.structures import BoxMode | ||
from PIL import Image | ||
from pycocotools.coco import COCO | ||
from shapely.geometry import Polygon | ||
|
||
|
||
def split_coco(coco_dict, val_percentage=0.2): | ||
# Extract images and annotations | ||
images = coco_dict["images"] | ||
annotations = coco_dict["annotations"] | ||
|
||
# Create a mapping from image_id to image metadata | ||
image_id_to_image = {img["id"]: img for img in images} | ||
|
||
# Group annotations by image_id | ||
image_to_annotations = defaultdict(list) | ||
for annotation in annotations: | ||
# Adjust category_id to be zero-indexed | ||
annotation["category_id"] -= 1 | ||
|
||
# If bounding box is missing, compute it from the segmentation mask | ||
if "bbox" not in annotation or not annotation["bbox"]: | ||
if "segmentation" in annotation: | ||
# Flatten the segmentation polygon and create a shapely Polygon | ||
segmentation = annotation["segmentation"][ | ||
0 | ||
] # First polygon if multiple | ||
poly = Polygon( | ||
[ | ||
(segmentation[i], segmentation[i + 1]) | ||
for i in range(0, len(segmentation), 2) | ||
] | ||
) | ||
|
||
# Calculate bounding box (x_min, y_min, x_max, y_max) | ||
x_min, y_min, x_max, y_max = poly.bounds | ||
annotation["bbox"] = [ | ||
x_min, | ||
y_min, | ||
x_max - x_min, | ||
y_max - y_min, | ||
] # Convert to [x, y, width, height] | ||
|
||
# Add bbox_mode for Detectron2 | ||
annotation["bbox_mode"] = BoxMode.XYWH_ABS | ||
|
||
# Add the annotation to the image_id group | ||
image_to_annotations[annotation["image_id"]].append(annotation) | ||
|
||
# Create a list of restructured image-level dictionaries | ||
dataset = [] | ||
for image_id, image_data in image_id_to_image.items(): | ||
image_dict = { | ||
"file_name": image_data["file_name"], | ||
"image_id": image_id, | ||
"height": image_data["height"], | ||
"width": image_data["width"], | ||
"annotations": image_to_annotations.get(image_id, []), | ||
} | ||
dataset.append(image_dict) | ||
|
||
# Split the dataset into training and validation sets | ||
split_index = int(len(dataset) * (1 - val_percentage)) | ||
train_dataset = dataset[:split_index] | ||
val_dataset = dataset[split_index:] | ||
|
||
with open("train_set.json", "w") as f: | ||
json.dump(train_dataset, f, indent=4) | ||
|
||
return train_dataset, val_dataset | ||
|
||
|
||
def register_coco(annotations_file, img_dir, val_percentage=0.2): | ||
""" | ||
Load the COCO dataset and split it in memory into training and validation sets. | ||
Args: | ||
- annotations_file: Path to the COCO JSON annotations file. | ||
- img_dir: Directory containing the images. | ||
- val_percentage: Percentage of images to allocate to validation (default 20%). | ||
Returns: | ||
- None | ||
""" | ||
# Load the COCO JSON in memory | ||
with open(annotations_file, "r") as f: | ||
coco_dict = json.load(f) | ||
|
||
# Split the dataset in memory | ||
train_dict, val_dict = split_coco(coco_dict, val_percentage) | ||
|
||
# Register the training dataset | ||
DatasetCatalog.register("my_segmentation_train", lambda: train_dict) | ||
MetadataCatalog.get("my_segmentation_train").set( | ||
thing_classes=[ | ||
"innend\u00f8rs", | ||
"parkering/sykkelstativ", | ||
"asfalt/betong", | ||
"gummifelt/kunstgress", | ||
"sand/stein", | ||
"gress", | ||
"tr\u00e6r", | ||
] | ||
) | ||
|
||
# Register the validation dataset | ||
DatasetCatalog.register("my_segmentation_val", lambda: val_dict) | ||
MetadataCatalog.get("my_segmentation_val").set( | ||
thing_classes=[ | ||
"innend\u00f8rs", | ||
"parkering/sykkelstativ", | ||
"asfalt/betong", | ||
"gummifelt/kunstgress", | ||
"sand/stein", | ||
"gress", | ||
"tr\u00e6r", | ||
] | ||
) | ||
|
||
print("Datasets successfully registered!") | ||
|
||
|
||
def convert_coco_to_masks(coco_annotation_path, output_mask_dir, output_image_dir): | ||
# Load COCO annotations | ||
coco = COCO(coco_annotation_path) | ||
|
||
# Get all image IDs | ||
image_ids = coco.getImgIds() | ||
|
||
for img_id in image_ids: | ||
# Load the corresponding image information | ||
img_info = coco.loadImgs(img_id)[0] | ||
img_filename = img_info["file_name"] | ||
|
||
# Load annotations (segmentations) for the image | ||
ann_ids = coco.getAnnIds(imgIds=img_id) | ||
anns = coco.loadAnns(ann_ids) | ||
|
||
# Create an empty mask | ||
height, width = img_info["height"], img_info["width"] | ||
mask = np.zeros((height, width), dtype=np.uint8) | ||
|
||
# For each annotation, fill in the mask with the class ID | ||
for ann in anns: | ||
category_id = ann["category_id"] | ||
# Get the segmentation mask for the annotation | ||
rle = coco.annToRLE(ann) | ||
binary_mask = maskUtils.decode(rle) | ||
|
||
# Set the class ID in the mask | ||
mask[binary_mask == 1] = category_id | ||
|
||
# Save the mask as a grayscale image | ||
mask_img = Image.fromarray(mask) | ||
mask_filename = os.path.splitext(img_filename)[0] + ".png" | ||
mask_img.save(os.path.join(output_mask_dir, mask_filename)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import os | ||
|
||
import geopandas as gpd | ||
import numpy as np | ||
import pandas as pd | ||
import yaml | ||
from PIL import Image | ||
from rasterio.features import rasterize | ||
from rasterio.transform import from_origin | ||
from shapely.geometry import MultiPolygon, Polygon | ||
|
||
|
||
def get_image_dimensions(image_path): | ||
""" | ||
Get the dimensions (width, height) of an image using PIL. | ||
""" | ||
with Image.open(image_path) as img: | ||
width, height = img.size | ||
return height, width | ||
|
||
|
||
def rasterize_masks( | ||
geojson_path, image_dir, output_mask_dir, label_column="labelTekst" | ||
): | ||
""" | ||
Convert GeoJSON polygons to raster masks and save them as PNGs. | ||
Parameters: | ||
- geojson_path: Path to the GeoJSON file with polygon geometries. | ||
- image_dir: Directory containing the corresponding images. | ||
- output_mask_dir: Directory to save the output rasterized masks. | ||
- label_column: The column in the GeoDataFrame that contains the text labels. | ||
- resolution: The resolution of the output mask (default is 1024x1024). | ||
""" | ||
# Load the GeoJSON data | ||
gdf = gpd.read_file(geojson_path) | ||
|
||
# Factorize the text labels to get numeric labels | ||
gdf["label_encoded"], _ = pd.factorize(gdf[label_column]) | ||
|
||
# Ensure the output directory exists | ||
os.makedirs(output_mask_dir, exist_ok=True) | ||
|
||
# Loop through each unique 'id' or corresponding image | ||
for image_id in gdf["id"].unique(): | ||
# Get all geometries and corresponding numeric labels for this image | ||
gdf_image = gdf[gdf["id"] == image_id] | ||
|
||
# Get the corresponding image file and extract dimensions | ||
image_filename = f"image_{image_id}.png" | ||
image_path = os.path.join(image_dir, image_filename) | ||
if not os.path.exists(image_path): | ||
print(f"Image not found: {image_path}") | ||
continue | ||
|
||
# Get the original image dimensions | ||
img_height, img_width = get_image_dimensions(image_path) | ||
|
||
# Prepare a list of (geometry, class_id) tuples and compute bounds | ||
shapes = [] | ||
x_min, y_min, x_max, y_max = gdf_image.total_bounds | ||
|
||
# Calculate pixel size based on image's geometry bounds and output resolution | ||
pixel_size_x = (x_max - x_min) / img_width | ||
pixel_size_y = (y_max - y_min) / img_height | ||
|
||
# Define the transform for the rasterization | ||
transform = from_origin(x_min, y_max, pixel_size_x, pixel_size_y) | ||
|
||
# Prepare a list of (geometry, class_id) tuples for rasterization | ||
for _, row in gdf_image.iterrows(): | ||
geometry = row.geometry | ||
if isinstance(geometry, Polygon): | ||
# If it's a single Polygon, add it to shapes | ||
shapes.append((geometry, row["label_encoded"])) | ||
elif isinstance(geometry, MultiPolygon): | ||
# If it's a MultiPolygon, loop through each Polygon | ||
for poly in geometry.geoms: | ||
shapes.append((poly, row["label_encoded"])) | ||
|
||
# Create an empty mask with the specified output size (resolution) | ||
mask = np.zeros((img_height, img_width), dtype=np.uint8) | ||
|
||
# Rasterize the geometries into the mask | ||
mask = rasterize( | ||
shapes=shapes, | ||
out_shape=(img_height, img_width), | ||
transform=transform, | ||
fill=0, | ||
dtype="uint8", | ||
) | ||
|
||
# Save the mask as a grayscale PNG | ||
mask_filename = f"mask_{image_id}.png" | ||
mask_img = Image.fromarray(mask) | ||
mask_img.save(os.path.join(output_mask_dir, mask_filename)) | ||
|
||
print(f"Saved mask: {mask_filename}") | ||
|
||
|
||
if __name__ == "__main__": | ||
with open("config.yaml") as f: | ||
cfg = yaml.load(f, Loader=yaml.FullLoader) | ||
|
||
rasterize_masks( | ||
cfg["MASK"], cfg["IMG_DIR"], cfg["MASKS_DIR"], label_column="labelTekst" | ||
) |