Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/Open-EO/openeo-gfmap into s…
Browse files Browse the repository at this point in the history
…1-extraction-fixes-PR2

Conflicts:
	src/openeo_gfmap/manager/job_splitters.py
  • Loading branch information
GriffinBabe committed Sep 4, 2024
2 parents bec9976 + 1110e4a commit 386aebf
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 28 deletions.
53 changes: 40 additions & 13 deletions src/openeo_gfmap/manager/job_splitters.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,22 @@
from openeo_gfmap.manager import _log


@lru_cache(maxsize=1)
def load_s2_grid() -> gpd.GeoDataFrame:
def load_s2_grid(web_mercator: bool = False) -> gpd.GeoDataFrame:
"""Returns a geo data frame from the S2 grid."""
# Builds the path where the geodataframe should be
gdf_path = Path.home() / ".openeo-gfmap" / "s2grid_bounds_v2.geojson"
if not web_mercator:
gdf_path = Path.home() / ".openeo-gfmap" / "s2grid_bounds_4326.geoparquet"
url = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/gfmap/s2grid_bounds_4326.geoparquet"
else:
gdf_path = Path.home() / ".openeo-gfmap" / "s2grid_bounds_3857.geoparquet"
url = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/gfmap/s2grid_bounds_3857.geoparquet"

if not gdf_path.exists():
_log.info("S2 grid not found, downloading it from artifactory.")
# Downloads the file from the artifactory URL
gdf_path.parent.mkdir(exist_ok=True)
response = requests.get(
"https://artifactory.vgt.vito.be/artifactory/auxdata-public/gfmap/s2grid_bounds_v2.geojson",
url,
timeout=180, # 3mins
)
if response.status_code != 200:
Expand All @@ -33,7 +38,7 @@ def load_s2_grid() -> gpd.GeoDataFrame:
)
with open(gdf_path, "wb") as f:
f.write(response.content)
return gpd.read_file(gdf_path)
return gpd.read_parquet(gdf_path)


def _resplit_group(
Expand All @@ -45,7 +50,7 @@ def _resplit_group(


def split_job_s2grid(
polygons: gpd.GeoDataFrame, max_points: int = 500
polygons: gpd.GeoDataFrame, max_points: int = 500, web_mercator: bool = False
) -> List[gpd.GeoDataFrame]:
"""Split a job into multiple jobs from the position of the polygons/points. The centroid of
the geometries to extract are used to select tile in the Sentinel-2 tile grid.
Expand All @@ -67,19 +72,37 @@ def split_job_s2grid(
if polygons.crs is None:
raise ValueError("The GeoDataFrame must contain a CRS")

<<<<<<< HEAD
polygons = polygons.to_crs(epsg=4326)
polygons["geometry"] = polygons.geometry.centroid
=======
epsg = 3857 if web_mercator else 4326

original_crs = polygons.crs

polygons = polygons.to_crs(epsg=epsg)

polygons["centroid"] = polygons.geometry.centroid
>>>>>>> 1110e4aa35cfbe72a9dbd9b56e40048ea40ca2d8

# Dataset containing all the S2 tiles, find the nearest S2 tile for each point
s2_grid = load_s2_grid()
s2_grid = load_s2_grid(web_mercator)
s2_grid["geometry"] = s2_grid.geometry.centroid

<<<<<<< HEAD
# Filter tiles on CDSE availability
s2_grid = s2_grid[s2_grid.cdse_valid]

polygons = gpd.sjoin_nearest(polygons, s2_grid[["tile", "geometry"]]).drop(
columns=["index_right"]
)
=======
polygons = gpd.sjoin_nearest(
polygons.set_geometry("centroid"), s2_grid[["tile", "geometry"]]
).drop(columns=["index_right", "centroid"])

polygons = polygons.set_geometry("geometry").to_crs(original_crs)
>>>>>>> 1110e4aa35cfbe72a9dbd9b56e40048ea40ca2d8

split_datasets = []
for _, sub_gdf in polygons.groupby("tile"):
Expand All @@ -95,10 +118,13 @@ def append_h3_index(
polygons: gpd.GeoDataFrame, grid_resolution: int = 3
) -> gpd.GeoDataFrame:
"""Append the H3 index to the polygons."""
if polygons.geometry.geom_type[0] != "Point":
geom_col = polygons.geometry.centroid
else:
geom_col = polygons.geometry

# Project to Web mercator to calculate centroids
polygons = polygons.to_crs(epsg=3857)
geom_col = polygons.geometry.centroid
# Project to lat lon to calculate the h3 index
geom_col = geom_col.to_crs(epsg=4326)

polygons["h3index"] = geom_col.apply(
lambda pt: h3.geo_to_h3(pt.y, pt.x, grid_resolution)
)
Expand Down Expand Up @@ -136,12 +162,13 @@ def split_job_hex(
if polygons.crs is None:
raise ValueError("The GeoDataFrame must contain a CRS")

# Project to lat/lon positions
polygons = polygons.to_crs(epsg=4326)
original_crs = polygons.crs

# Split the polygons into multiple jobs
polygons = append_h3_index(polygons, grid_resolution)

polygons = polygons.to_crs(original_crs)

split_datasets = []
for _, sub_gdf in polygons.groupby("h3index"):
if len(sub_gdf) > max_points:
Expand Down
98 changes: 83 additions & 15 deletions tests/test_openeo_gfmap/test_managers.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,94 @@
"""Test the job splitters and managers of OpenEO GFMAP."""

from pathlib import Path

import geopandas as gpd
from shapely.geometry import Point, Polygon

from openeo_gfmap.manager.job_splitters import split_job_hex
from openeo_gfmap.manager.job_splitters import split_job_hex, split_job_s2grid


# TODO can we instead assert on exact numbers ?
# would remove the print statement
def test_split_jobs():
dataset_path = Path(__file__).parent / "resources/wc_extraction_dataset.gpkg"
def test_split_job_s2grid():
# Create a mock GeoDataFrame with points
# The points are located in two different S2 tiles
data = {
"id": [1, 2, 3, 4, 5],
"geometry": [
Point(60.02, 4.57),
Point(59.6, 5.04),
Point(59.92, 3.37),
Point(59.07, 4.11),
Point(58.77, 4.87),
],
}
polygons = gpd.GeoDataFrame(data, crs="EPSG:4326")

# Load the dataset
dataset = gpd.read_file(dataset_path)
# Define expected number of split groups
max_points = 2

# Split the dataset
split_dataset = split_job_hex(dataset, max_points=500)
# Call the function
result = split_job_s2grid(polygons, max_points)

# Check the number of splits
assert len(split_dataset) > 1
assert (
len(result) == 3
), "The number of GeoDataFrames returned should match the number of splits needed."

for ds in split_dataset:
print(len(ds))
assert len(ds) <= 500
# Check if the geometries are preserved
for gdf in result:
assert (
"geometry" in gdf.columns
), "Each GeoDataFrame should have a geometry column."
assert gdf.crs == 4326, "The original CRS should be preserved."
assert all(
gdf.geometry.geom_type == "Point"
), "Original geometries should be preserved."


def test_split_job_hex():
# Create a mock GeoDataFrame with points
# The points/polygons are located in three different h3 hexes of size 3
data = {
"id": [1, 2, 3, 4, 5, 6],
"geometry": [
Point(60.02, 4.57),
Point(58.34, 5.06),
Point(59.92, 3.37),
Point(58.85, 4.90),
Point(58.77, 4.87),
Polygon(
[
(58.78, 4.88),
(58.78, 4.86),
(58.76, 4.86),
(58.76, 4.88),
(58.78, 4.88),
]
),
],
}
polygons = gpd.GeoDataFrame(data, crs="EPSG:4326")

max_points = 3

result = split_job_hex(polygons, max_points)

assert (
len(result) == 4
), "The number of GeoDataFrames returned should match the number of splits needed."

for idx, gdf in enumerate(result):
assert (
"geometry" in gdf.columns
), "Each GeoDataFrame should have a geometry column."
assert gdf.crs == 4326, "The CRS should be preserved."
if idx == 1:
assert all(
gdf.geometry.geom_type == "Polygon"
), "Original geometries should be preserved."
else:
assert all(
gdf.geometry.geom_type == "Point"
), "Original geometries should be preserved."

assert (
len(result[0]) == 3
), "The number of geometries in the first split should be 3."

0 comments on commit 386aebf

Please sign in to comment.