Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Job splitters should retain original geometries #153

Merged
merged 3 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 34 additions & 18 deletions src/openeo_gfmap/manager/job_splitters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,27 @@
from openeo_gfmap.manager import _log


def load_s2_grid() -> gpd.GeoDataFrame:
def load_s2_grid(web_mercator: bool = False) -> gpd.GeoDataFrame:
"""Returns a geo data frame from the S2 grid."""
# Builds the path where the geodataframe should be
gdf_path = Path.home() / ".openeo-gfmap" / "s2grid_bounds.geojson"
if not web_mercator:
gdf_path = Path.home() / ".openeo-gfmap" / "s2grid_bounds_4326.geoparquet"
url = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/gfmap/s2grid_bounds_4326.geoparquet"
else:
gdf_path = Path.home() / ".openeo-gfmap" / "s2grid_bounds_3857.geoparquet"
url = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/gfmap/s2grid_bounds_3857.geoparquet"

if not gdf_path.exists():
_log.info("S2 grid not found, downloading it from artifactory.")
# Downloads the file from the artifactory URL
gdf_path.parent.mkdir(exist_ok=True)
response = requests.get(
"https://artifactory.vgt.vito.be/artifactory/auxdata-public/gfmap/s2grid_bounds.geojson",
url,
timeout=180, # 3mins
)
with open(gdf_path, "wb") as f:
f.write(response.content)
return gpd.read_file(gdf_path)
return gpd.read_parquet(gdf_path)


def _resplit_group(
Expand All @@ -38,7 +44,7 @@ def _resplit_group(


def split_job_s2grid(
polygons: gpd.GeoDataFrame, max_points: int = 500
polygons: gpd.GeoDataFrame, max_points: int = 500, web_mercator: bool = False
) -> List[gpd.GeoDataFrame]:
"""Split a job into multiple jobs from the position of the polygons/points. The centroid of
the geometries to extract are used to select tile in the Sentinel-2 tile grid.
Expand All @@ -60,17 +66,23 @@ def split_job_s2grid(
if polygons.crs is None:
raise ValueError("The GeoDataFrame must contain a CRS")

polygons = polygons.to_crs(epsg=4326)
if polygons.geometry.geom_type[0] != "Point":
polygons["geometry"] = polygons.geometry.centroid
epsg = 3857 if web_mercator else 4326

original_crs = polygons.crs

polygons = polygons.to_crs(epsg=epsg)

polygons["centroid"] = polygons.geometry.centroid

# Dataset containing all the S2 tiles, find the nearest S2 tile for each point
s2_grid = load_s2_grid()
s2_grid = load_s2_grid(web_mercator)
s2_grid["geometry"] = s2_grid.geometry.centroid

polygons = gpd.sjoin_nearest(polygons, s2_grid[["tile", "geometry"]]).drop(
columns=["index_right"]
)
polygons = gpd.sjoin_nearest(
polygons.set_geometry("centroid"), s2_grid[["tile", "geometry"]]
).drop(columns=["index_right", "centroid"])

polygons = polygons.set_geometry("geometry").to_crs(original_crs)

split_datasets = []
for _, sub_gdf in polygons.groupby("tile"):
Expand All @@ -86,10 +98,13 @@ def append_h3_index(
polygons: gpd.GeoDataFrame, grid_resolution: int = 3
) -> gpd.GeoDataFrame:
"""Append the H3 index to the polygons."""
if polygons.geometry.geom_type[0] != "Point":
geom_col = polygons.geometry.centroid
else:
geom_col = polygons.geometry

# Project to Web mercator to calculate centroids
polygons = polygons.to_crs(epsg=3857)
geom_col = polygons.geometry.centroid
# Project to lat lon to calculate the h3 index
geom_col = geom_col.to_crs(epsg=4326)

polygons["h3index"] = geom_col.apply(
lambda pt: h3.geo_to_h3(pt.y, pt.x, grid_resolution)
)
Expand Down Expand Up @@ -127,12 +142,13 @@ def split_job_hex(
if polygons.crs is None:
raise ValueError("The GeoDataFrame must contain a CRS")

# Project to lat/lon positions
polygons = polygons.to_crs(epsg=4326)
original_crs = polygons.crs

# Split the polygons into multiple jobs
polygons = append_h3_index(polygons, grid_resolution)

polygons = polygons.to_crs(original_crs)

split_datasets = []
for _, sub_gdf in polygons.groupby("h3index"):
if len(sub_gdf) > max_points:
Expand Down
98 changes: 83 additions & 15 deletions tests/test_openeo_gfmap/test_managers.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,94 @@
"""Test the job splitters and managers of OpenEO GFMAP."""

from pathlib import Path

import geopandas as gpd
from shapely.geometry import Point, Polygon

from openeo_gfmap.manager.job_splitters import split_job_hex
from openeo_gfmap.manager.job_splitters import split_job_hex, split_job_s2grid


# TODO can we instead assert on exact numbers ?
# would remove the print statement
def test_split_jobs():
dataset_path = Path(__file__).parent / "resources/wc_extraction_dataset.gpkg"
def test_split_job_s2grid():
# Create a mock GeoDataFrame with points
# The points are located in two different S2 tiles
data = {
"id": [1, 2, 3, 4, 5],
"geometry": [
Point(60.02, 4.57),
Point(59.6, 5.04),
Point(59.92, 3.37),
Point(59.07, 4.11),
Point(58.77, 4.87),
],
}
polygons = gpd.GeoDataFrame(data, crs="EPSG:4326")

# Load the dataset
dataset = gpd.read_file(dataset_path)
# Define expected number of split groups
max_points = 2

# Split the dataset
split_dataset = split_job_hex(dataset, max_points=500)
# Call the function
result = split_job_s2grid(polygons, max_points)

# Check the number of splits
assert len(split_dataset) > 1
assert (
len(result) == 3
), "The number of GeoDataFrames returned should match the number of splits needed."

for ds in split_dataset:
print(len(ds))
assert len(ds) <= 500
# Check if the geometries are preserved
for gdf in result:
assert (
"geometry" in gdf.columns
), "Each GeoDataFrame should have a geometry column."
assert gdf.crs == 4326, "The original CRS should be preserved."
assert all(
gdf.geometry.geom_type == "Point"
), "Original geometries should be preserved."


def test_split_job_hex():
# Create a mock GeoDataFrame with points
# The points/polygons are located in three different h3 hexes of size 3
data = {
"id": [1, 2, 3, 4, 5, 6],
"geometry": [
Point(60.02, 4.57),
Point(58.34, 5.06),
Point(59.92, 3.37),
Point(58.85, 4.90),
Point(58.77, 4.87),
Polygon(
[
(58.78, 4.88),
(58.78, 4.86),
(58.76, 4.86),
(58.76, 4.88),
(58.78, 4.88),
]
),
],
}
polygons = gpd.GeoDataFrame(data, crs="EPSG:4326")

max_points = 3

result = split_job_hex(polygons, max_points)

assert (
len(result) == 4
), "The number of GeoDataFrames returned should match the number of splits needed."

for idx, gdf in enumerate(result):
assert (
"geometry" in gdf.columns
), "Each GeoDataFrame should have a geometry column."
assert gdf.crs == 4326, "The CRS should be preserved."
if idx == 1:
assert all(
gdf.geometry.geom_type == "Polygon"
), "Original geometries should be preserved."
else:
assert all(
gdf.geometry.geom_type == "Point"
), "Original geometries should be preserved."

assert (
len(result[0]) == 3
), "The number of geometries in the first split should be 3."
Loading