Skip to content

Commit

Permalink
Merge pull request #152 from Open-EO/update-stac-splitter
Browse files Browse the repository at this point in the history
Update stac splitter
  • Loading branch information
VincentVerelst authored Aug 28, 2024
2 parents 8d9a808 + cebdbe9 commit 27a239e
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 69 deletions.
2 changes: 1 addition & 1 deletion src/openeo_gfmap/utils/catalogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import requests
from pyproj.crs import CRS
from rasterio.warp import transform_bounds
from shapely import unary_union
from shapely.geometry import box, shape
from shapely.ops import unary_union

from openeo_gfmap import (
Backend,
Expand Down
122 changes: 54 additions & 68 deletions src/openeo_gfmap/utils/split_stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import os
from pathlib import Path
from typing import Union
from typing import Iterator, Union

import pystac

Expand All @@ -30,92 +30,59 @@ def _extract_epsg_from_stac_item(stac_item: pystac.Item) -> int:
raise KeyError("The 'proj:epsg' property is missing from the STAC item.")


def _create_item_by_epsg_dict(collection: pystac.Collection) -> dict:
def _get_items_by_epsg(
collection: pystac.Collection,
) -> Iterator[tuple[int, pystac.Item]]:
"""
Create a dictionary that groups items by their EPSG code.
Generator function that yields items grouped by their EPSG code.
Parameters:
collection (pystac.Collection): The STAC collection.
Returns:
dict: A dictionary that maps EPSG codes to lists of items.
Yields:
tuple[int, pystac.Item]: EPSG code and corresponding STAC item.
"""
# Dictionary to store items grouped by their EPSG codes
items_by_epsg = {}

# Iterate through items and group them
for item in collection.get_items():
for item in collection.get_all_items():
epsg = _extract_epsg_from_stac_item(item)
if epsg not in items_by_epsg:
items_by_epsg[epsg] = []
items_by_epsg[epsg].append(item)

return items_by_epsg
yield epsg, item


def _create_new_epsg_collection(
epsg: int, items: list, collection: pystac.Collection
def _create_collection_skeleton(
collection: pystac.Collection, epsg: int
) -> pystac.Collection:
"""
Create a new STAC collection with a given EPSG code.
Create a skeleton for a new STAC collection with a given EPSG code.
Parameters:
epsg (int): The EPSG code.
items (list): The list of items.
collection (pystac.Collection): The original STAC collection.
epsg (int): The EPSG code.
Returns:
pystac.Collection: The new STAC collection.
pystac.Collection: The skeleton of the new STAC collection.
"""
new_collection = collection.clone()
new_collection.id = f"{collection.id}_{epsg}"
new_collection.description = (
f"{collection.description} Containing only items with EPSG code {epsg}"
new_collection = pystac.Collection(
id=f"{collection.id}_{epsg}",
description=f"{collection.description} Containing only items with EPSG code {epsg}",
extent=collection.extent.clone(),
summaries=collection.summaries,
license=collection.license,
stac_extensions=collection.stac_extensions,
)
new_collection.clear_items()
for item in items:
new_collection.add_item(item)

new_collection.update_extent_from_items()

if "item_assets" in collection.extra_fields:
item_assets_extension = pystac.extensions.item_assets.ItemAssetsExtension.ext(
collection
)

new_item_assets_extension = (
pystac.extensions.item_assets.ItemAssetsExtension.ext(
new_collection, add_if_missing=True
)
)

new_item_assets_extension.item_assets = item_assets_extension.item_assets
return new_collection


def _create_collection_by_epsg_dict(collection: pystac.Collection) -> dict:
"""
Create a dictionary that groups collections by their EPSG code.
Parameters:
collection (pystac.Collection): The STAC collection.
Returns:
dict: A dictionary that maps EPSG codes to STAC collections.
"""
items_by_epsg = _create_item_by_epsg_dict(collection)
collections_by_epsg = {}
for epsg, items in items_by_epsg.items():
new_collection = _create_new_epsg_collection(epsg, items, collection)
collections_by_epsg[epsg] = new_collection

return collections_by_epsg


def _write_collection_dict(collection_dict: dict, output_dir: Union[str, Path]):
"""
Write the collection dictionary to disk.
Parameters:
collection_dict (dict): The dictionary that maps EPSG codes to STAC collections.
output_dir (str): The output directory.
"""
output_dir = Path(output_dir)
os.makedirs(output_dir, exist_ok=True)

for epsg, collection in collection_dict.items():
collection.normalize_hrefs(os.path.join(output_dir, f"collection-{epsg}"))
collection.save()


def split_collection_by_epsg(path: Union[str, Path], output_dir: Union[str, Path]):
"""
Split a STAC collection into multiple STAC collections based on EPSG code.
Expand All @@ -124,10 +91,29 @@ def split_collection_by_epsg(path: Union[str, Path], output_dir: Union[str, Path
path (str): The path to the STAC collection.
output_dir (str): The output directory.
"""

path = Path(path)
output_dir = Path(output_dir)
os.makedirs(output_dir, exist_ok=True)

try:
collection = pystac.read_file(path)
except pystac.STACError:
print("Please provide a path to a valid STAC collection.")
collection_dict = _create_collection_by_epsg_dict(collection)
_write_collection_dict(collection_dict, output_dir)
return

collections_by_epsg = {}

for epsg, item in _get_items_by_epsg(collection):
if epsg not in collections_by_epsg:
collections_by_epsg[epsg] = _create_collection_skeleton(collection, epsg)

# Add item to the corresponding collection
collections_by_epsg[epsg].add_item(item)

# Write each collection to disk
for epsg, new_collection in collections_by_epsg.items():
new_collection.update_extent_from_items() # Update extent based on added items
collection_path = output_dir / f"collection-{epsg}"
new_collection.normalize_hrefs(str(collection_path))
new_collection.save()

0 comments on commit 27a239e

Please sign in to comment.