Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Job splitters should retain original geometries #153

Merged
merged 3 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions src/openeo_gfmap/manager/job_splitters.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,23 @@ def split_job_s2grid(
if polygons.crs is None:
raise ValueError("The GeoDataFrame must contain a CRS")

polygons = polygons.to_crs(epsg=4326)
if polygons.geometry.geom_type[0] != "Point":
polygons["geometry"] = polygons.geometry.centroid
original_crs = polygons.crs

# Transform to web mercator, to calculate the centroid
polygons = polygons.to_crs(epsg=3857)

polygons["centroid"] = polygons.geometry.centroid

# Dataset containing all the S2 tiles, find the nearest S2 tile for each point
s2_grid = load_s2_grid()
s2_grid = s2_grid.to_crs(epsg=3857)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this not a slow step for the entire s2 grid?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@VincentVerelst Maybe if this is a speed issue you could upload a second s2 grid in the artifactory with the CRS reprojected, and then change the load_s2_grid() function in gfmap to accept a parameter web_mercator: bool = False

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be worth the effort indeed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion, I'll already update that.

Do have to note that the largest bottleneck is typically not the CRS conversion of the s2 grid, but rather the CRS conversion of the dataframe to split itself.

For example, for 2018_BEL_LPIS-Flanders_POLY_110 the old job_splitter takes 10s to run, while the new one takes 20s.

s2_grid["geometry"] = s2_grid.geometry.centroid

polygons = gpd.sjoin_nearest(polygons, s2_grid[["tile", "geometry"]]).drop(
columns=["index_right"]
)
polygons = gpd.sjoin_nearest(
polygons.set_geometry("centroid"), s2_grid[["tile", "geometry"]]
).drop(columns=["index_right", "centroid"])

polygons = polygons.set_geometry("geometry").to_crs(original_crs)

split_datasets = []
for _, sub_gdf in polygons.groupby("tile"):
Expand All @@ -86,10 +92,13 @@ def append_h3_index(
polygons: gpd.GeoDataFrame, grid_resolution: int = 3
) -> gpd.GeoDataFrame:
"""Append the H3 index to the polygons."""
if polygons.geometry.geom_type[0] != "Point":
geom_col = polygons.geometry.centroid
else:
geom_col = polygons.geometry

# Project to Web mercator to calculate centroids
polygons = polygons.to_crs(epsg=3857)
geom_col = polygons.geometry.centroid
# Project to lat lon to calculate the h3 index
geom_col = geom_col.to_crs(epsg=4326)

polygons["h3index"] = geom_col.apply(
lambda pt: h3.geo_to_h3(pt.y, pt.x, grid_resolution)
)
Expand Down Expand Up @@ -127,12 +136,13 @@ def split_job_hex(
if polygons.crs is None:
raise ValueError("The GeoDataFrame must contain a CRS")

# Project to lat/lon positions
polygons = polygons.to_crs(epsg=4326)
original_crs = polygons.crs

# Split the polygons into multiple jobs
polygons = append_h3_index(polygons, grid_resolution)

polygons = polygons.to_crs(original_crs)

split_datasets = []
for _, sub_gdf in polygons.groupby("h3index"):
if len(sub_gdf) > max_points:
Expand Down
98 changes: 83 additions & 15 deletions tests/test_openeo_gfmap/test_managers.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,94 @@
"""Test the job splitters and managers of OpenEO GFMAP."""

from pathlib import Path

import geopandas as gpd
from shapely.geometry import Point, Polygon

from openeo_gfmap.manager.job_splitters import split_job_hex
from openeo_gfmap.manager.job_splitters import split_job_hex, split_job_s2grid


# TODO can we instead assert on exact numbers ?
# would remove the print statement
def test_split_jobs():
dataset_path = Path(__file__).parent / "resources/wc_extraction_dataset.gpkg"
def test_split_job_s2grid():
# Create a mock GeoDataFrame with points
# The points are located in two different S2 tiles
data = {
"id": [1, 2, 3, 4, 5],
"geometry": [
Point(60.02, 4.57),
Point(59.6, 5.04),
Point(59.92, 3.37),
Point(59.07, 4.11),
Point(58.77, 4.87),
],
}
polygons = gpd.GeoDataFrame(data, crs="EPSG:4326")

# Load the dataset
dataset = gpd.read_file(dataset_path)
# Define expected number of split groups
max_points = 2

# Split the dataset
split_dataset = split_job_hex(dataset, max_points=500)
# Call the function
result = split_job_s2grid(polygons, max_points)

# Check the number of splits
assert len(split_dataset) > 1
assert (
len(result) == 3
), "The number of GeoDataFrames returned should match the number of splits needed."

for ds in split_dataset:
print(len(ds))
assert len(ds) <= 500
# Check if the geometries are preserved
for gdf in result:
assert (
"geometry" in gdf.columns
), "Each GeoDataFrame should have a geometry column."
assert gdf.crs == 4326, "The original CRS should be preserved."
assert all(
gdf.geometry.geom_type == "Point"
), "Original geometries should be preserved."


def test_split_job_hex():
# Create a mock GeoDataFrame with points
# The points/polygons are located in three different h3 hexes of size 3
data = {
"id": [1, 2, 3, 4, 5, 6],
"geometry": [
Point(60.02, 4.57),
Point(58.34, 5.06),
Point(59.92, 3.37),
Point(58.85, 4.90),
Point(58.77, 4.87),
Polygon(
[
(58.78, 4.88),
(58.78, 4.86),
(58.76, 4.86),
(58.76, 4.88),
(58.78, 4.88),
]
),
],
}
polygons = gpd.GeoDataFrame(data, crs="EPSG:4326")

max_points = 3

result = split_job_hex(polygons, max_points)

assert (
len(result) == 4
), "The number of GeoDataFrames returned should match the number of splits needed."

for idx, gdf in enumerate(result):
assert (
"geometry" in gdf.columns
), "Each GeoDataFrame should have a geometry column."
assert gdf.crs == 4326, "The CRS should be preserved."
if idx == 1:
assert all(
gdf.geometry.geom_type == "Polygon"
), "Original geometries should be preserved."
else:
assert all(
gdf.geometry.geom_type == "Point"
), "Original geometries should be preserved."

assert (
len(result[0]) == 3
), "The number of geometries in the first split should be 3."
Loading