Skip to content

Commit

Permalink
Added yolo dir on filepath
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitijrajsharma committed Oct 31, 2024
1 parent 6312326 commit 392e1be
Showing 1 changed file with 58 additions and 23 deletions.
81 changes: 58 additions & 23 deletions hot_fair_utilities/preprocessing/yolo_v8_v1/yolo_format.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
# Standard library imports
import concurrent.futures
import cv2
import numpy as np
import yaml
import rasterio
import random
import warnings
import traceback
import warnings
from pathlib import Path

# Third party imports
import cv2
import numpy as np
import rasterio
import yaml

# Mask types from https://rampml.global/data-preparation/
CLASS_NAMES = ["footprint", "boundary", "contact"]


def yolo_format(preprocessed_dirs, yolo_dir, val_dirs=None, multimask=False, p_val=None):
def yolo_format(
preprocessed_dirs, yolo_dir, val_dirs=None, multimask=False, p_val=None
):
"""
Creates ultralytics YOLOv5 format dataset from RAMP preprocessed data.
Supports either single data directory or multiple directories.
Expand Down Expand Up @@ -53,22 +57,37 @@ def yolo_format(preprocessed_dirs, yolo_dir, val_dirs=None, multimask=False, p_v
yolo_dir_suffixes = ["_train", "_val"] if p_val else [""]

# Save image symlinks and labels
for dname, dname_stem in zip(preprocessed_dirs + val_dirs, preprocessed_dirs_stems + val_dirs_stems):
for dname, dname_stem in zip(
preprocessed_dirs + val_dirs, preprocessed_dirs_stems + val_dirs_stems
):
img_dir = dname / "chips" if (dname / "chips").is_dir() else dname / "source"
mask_dir = dname / mask_dirname
yolo_img_dir, yolo_label_dir = yolo_dir / "images" / dname_stem, yolo_dir / "labels" / dname_stem
yolo_img_dir, yolo_label_dir = (
yolo_dir / "images" / dname_stem,
yolo_dir / "labels" / dname_stem,
)

for dir in [yolo_img_dir, yolo_label_dir]:
for suf in yolo_dir_suffixes:
Path(str(dir) + suf).mkdir(parents=True, exist_ok=True)

files = list(img_dir.iterdir())
random.shuffle(files)
_image_iteration(files[0], img_dir, mask_dir, yolo_img_dir, yolo_label_dir, classes, 1.0 if p_val else None)
_image_iteration(
files[0],
img_dir,
mask_dir,
yolo_img_dir,
yolo_label_dir,
classes,
1.0 if p_val else None,
)
with concurrent.futures.ThreadPoolExecutor(max_workers=6) as executor:
executor.map(
lambda x: __image_iteration_func(x, img_dir, mask_dir, yolo_img_dir, yolo_label_dir, classes, p_val),
files[1:]
lambda x: __image_iteration_func(
x, img_dir, mask_dir, yolo_img_dir, yolo_label_dir, classes, p_val
),
files[1:],
)

if p_val:
Expand All @@ -77,19 +96,30 @@ def yolo_format(preprocessed_dirs, yolo_dir, val_dirs=None, multimask=False, p_v

# Save dataset.yaml
dataset = {
"names": {i-1: name for i, name in zip(classes, CLASS_NAMES[:len(classes)])},
"names": {i - 1: name for i, name in zip(classes, CLASS_NAMES[: len(classes)])},
"path": str(yolo_dir.absolute()),
"train": f"./images/{str(preprocessed_dirs_stems[0])}/" if len(preprocessed_dirs) == 1 else \
[f"./images/{str(d)}" for d in preprocessed_dirs_stems],
"train": (
f"{yolo_dir.absolute()}/images/{str(preprocessed_dirs_stems[0])}/"
if len(preprocessed_dirs) == 1
else [
f"{yolo_dir.absolute()}/images/{str(d)}"
for d in preprocessed_dirs_stems
]
),
}
if len(val_dirs_stems) > 0:
dataset["val"] = f"./images/{str(val_dirs_stems[0])}/" if len(val_dirs_stems) == 1 else \
[f"./images/{str(d)}" for d in val_dirs_stems]
with open(yolo_dir / "dataset.yaml", 'w') as handle:
dataset["val"] = (
f"./images/{str(val_dirs_stems[0])}/"
if len(val_dirs_stems) == 1
else [f"{yolo_dir.absolute()}/images/{str(d)}" for d in val_dirs_stems]
)
with open(yolo_dir / "dataset.yaml", "w") as handle:
yaml.dump(dataset, handle, default_flow_style=False)


def _image_iteration(img, img_dir, mask_dir, yolo_img_dir, yolo_label_dir, classes, p_val):
def _image_iteration(
img, img_dir, mask_dir, yolo_img_dir, yolo_label_dir, classes, p_val
):
if p_val:
if random.uniform(0, 1) > p_val:
yolo_img_dir = Path(str(yolo_img_dir) + "_train")
Expand All @@ -111,20 +141,26 @@ def _image_iteration(img, img_dir, mask_dir, yolo_img_dir, yolo_label_dir, class
data = handle.read()
h, w = data.shape[1:]
label = str(img)[:-4] + ".txt"
with open(yolo_label_dir / label, 'w') as handle:
with open(yolo_label_dir / label, "w") as handle:
for cls in classes:
x = np.where(data == cls, 255, 0).squeeze().astype("uint8")
contours, _ = cv2.findContours(x, mode=cv2.RETR_EXTERNAL, method=cv2.CHAIN_APPROX_TC89_KCOS)
contours, _ = cv2.findContours(
x, mode=cv2.RETR_EXTERNAL, method=cv2.CHAIN_APPROX_TC89_KCOS
)
for contour in contours: # contour (n, 1, 2)
if contour.shape[0] > 2: # at least 3-point polygon
contour = contour / [w, h]
line = f"{cls - 1} {' '.join([str(c) for c in contour.flatten().tolist()])}\n"
handle.write(line)


def __image_iteration_func(img, img_dir, mask_dir, yolo_img_dir, yolo_label_dir, classes, p_val):
def __image_iteration_func(
img, img_dir, mask_dir, yolo_img_dir, yolo_label_dir, classes, p_val
):
try:
_image_iteration(img, img_dir, mask_dir, yolo_img_dir, yolo_label_dir, classes, p_val)
_image_iteration(
img, img_dir, mask_dir, yolo_img_dir, yolo_label_dir, classes, p_val
)
except Exception as e:
full_trace = "\n" + " ".join(traceback.format_exception(e))
warnings.warn(f"Image {img.name} caused {full_trace}")
Expand All @@ -135,4 +171,3 @@ def __image_iteration_func(img, img_dir, mask_dir, yolo_img_dir, yolo_label_dir,
# root = "/tf/ramp-data/sample_119"
root = "/home/powmol/wip/hotosm/fAIr-utilities/ramp-data/sample_119"
yolo_format([root + "/preprocessed"], root + "/yolo", multimask=False, p_val=0.05)

0 comments on commit 392e1be

Please sign in to comment.