Skip to content

Commit

Permalink
Fix tile polygon transform function (#3283)
Browse files Browse the repository at this point in the history
Co-authored-by: Vinnam Kim <[email protected]>
* Fix tile polygon function
  • Loading branch information
eugene123tw authored Apr 8, 2024
1 parent 66d51f8 commit 7dfda8f
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 4 deletions.
48 changes: 44 additions & 4 deletions src/otx/core/data/dataset/tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,19 @@
from __future__ import annotations

import logging as log
import operator
from copy import deepcopy
from itertools import product
from typing import TYPE_CHECKING, Callable

import numpy as np
import shapely.geometry as sg
import torch
from datumaro import Bbox, DatasetItem, DatasetSubset, Image, Polygon
from datumaro import Dataset as DmDataset
from datumaro.components.annotation import AnnotationType
from datumaro.plugins.tiling import Tile
from datumaro.plugins.tiling.tile import _apply_offset
from datumaro.plugins.tiling.util import (
clip_x1y1x2y2,
cxcywh_to_x1y1x2y2,
Expand Down Expand Up @@ -76,10 +81,45 @@ def __init__(
threshold_drop_ann=threshold_drop_ann,
)
self._tile_size = tile_size
# TODO (Eugene): Bug found in original Datumaro tile polygon function.
# https://github.com/eugene123tw/training_extensions/tree/eugene/fix-tile-polygon-func
# It lacks polygon validation, potentially leading to GeometryCollection or MultiPolygon results,
# which the current function doesn't handle.
self._tile_ann_func_map[AnnotationType.polygon] = OTXTileTransform._tile_polygon

@staticmethod
def _tile_polygon(
ann: Polygon,
roi_box: sg.Polygon,
threshold_drop_ann: float = 0.8,
*args, # noqa: ARG004
**kwargs, # noqa: ARG004
) -> Polygon | None:
polygon = sg.Polygon(ann.get_points())

# NOTE: polygon may be invalid, e.g. self-intersecting
if not roi_box.intersects(polygon) or not polygon.is_valid:
return None

# NOTE: intersection may return a GeometryCollection or MultiPolygon
inter = polygon.intersection(roi_box)
if isinstance(inter, (sg.GeometryCollection, sg.MultiPolygon)):
shapes = [(geom, geom.area) for geom in list(inter.geoms) if geom.is_valid]
if not shapes:
return None

inter, _ = max(shapes, key=operator.itemgetter(1))

if not isinstance(inter, sg.Polygon) and not inter.is_valid:
return None

prop_area = inter.area / polygon.area

if prop_area < threshold_drop_ann:
return None

inter = _apply_offset(inter, roi_box)

return ann.wrap(
points=[p for xy in inter.exterior.coords for p in xy],
attributes=deepcopy(ann.attributes),
)

def _extract_rois(self, image: Image) -> list[BboxIntCoords]:
"""Extracts Tile ROIs from the given image.
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/core/data/test_tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@

import numpy as np
import pytest
import shapely.geometry as sg
import torch
from datumaro import Dataset as DmDataset
from datumaro import Polygon
from omegaconf import DictConfig, OmegaConf
from otx.core.config.data import (
DataModuleConfig,
Expand Down Expand Up @@ -136,6 +138,22 @@ def test_tile_transform(self):
num_tile_cols = (width + w_stride - 1) // w_stride
assert len(tiled_dataset) == (num_tile_rows * num_tile_cols * len(dataset)), "Incorrect number of tiles"

def test_tile_polygon_func(self):
points = np.array([(1, 2), (3, 5), (4, 2), (4, 6), (1, 6)])
polygon = Polygon(points=points.flatten().tolist())
roi = sg.Polygon([(0, 0), (5, 0), (5, 5), (0, 5)])

inter_polygon = OTXTileTransform._tile_polygon(polygon, roi, threshold_drop_ann=0.0)
assert isinstance(inter_polygon, Polygon), "Intersection should be a Polygon"
assert inter_polygon.get_area() > 0, "Intersection area should be greater than 0"

assert (
OTXTileTransform._tile_polygon(polygon, roi, threshold_drop_ann=1.0) is None
), "Intersection should be None"

invalid_polygon = Polygon(points=[0, 0, 5, 0, 5, 5, 5, 0])
assert OTXTileTransform._tile_polygon(invalid_polygon, roi) is None, "Invalid polygon should be None"

def test_adaptive_tiling(self, fxt_det_data_config):
# Enable tile adapter
fxt_det_data_config.tile_config.enable_tiler = True
Expand Down

0 comments on commit 7dfda8f

Please sign in to comment.