Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update stac splitter #152

Merged
merged 4 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/openeo_gfmap/utils/catalogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import requests
from pyproj.crs import CRS
from rasterio.warp import transform_bounds
from shapely import unary_union
from shapely.geometry import box, shape
from shapely.ops import unary_union

from openeo_gfmap import (
Backend,
Expand Down
122 changes: 54 additions & 68 deletions src/openeo_gfmap/utils/split_stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import os
from pathlib import Path
from typing import Union
from typing import Iterator, Union

import pystac

Expand All @@ -30,92 +30,59 @@ def _extract_epsg_from_stac_item(stac_item: pystac.Item) -> int:
raise KeyError("The 'proj:epsg' property is missing from the STAC item.")


def _create_item_by_epsg_dict(collection: pystac.Collection) -> dict:
def _get_items_by_epsg(
collection: pystac.Collection,
) -> Iterator[tuple[int, pystac.Item]]:
"""
Create a dictionary that groups items by their EPSG code.
Generator function that yields items grouped by their EPSG code.

Parameters:
collection (pystac.Collection): The STAC collection.

Returns:
dict: A dictionary that maps EPSG codes to lists of items.
Yields:
tuple[int, pystac.Item]: EPSG code and corresponding STAC item.
"""
# Dictionary to store items grouped by their EPSG codes
items_by_epsg = {}

# Iterate through items and group them
for item in collection.get_items():
for item in collection.get_all_items():
epsg = _extract_epsg_from_stac_item(item)
if epsg not in items_by_epsg:
items_by_epsg[epsg] = []
items_by_epsg[epsg].append(item)

return items_by_epsg
yield epsg, item


def _create_new_epsg_collection(
epsg: int, items: list, collection: pystac.Collection
def _create_collection_skeleton(
collection: pystac.Collection, epsg: int
) -> pystac.Collection:
"""
Create a new STAC collection with a given EPSG code.
Create a skeleton for a new STAC collection with a given EPSG code.

Parameters:
epsg (int): The EPSG code.
items (list): The list of items.
collection (pystac.Collection): The original STAC collection.
epsg (int): The EPSG code.

Returns:
pystac.Collection: The new STAC collection.
pystac.Collection: The skeleton of the new STAC collection.
"""
new_collection = collection.clone()
new_collection.id = f"{collection.id}_{epsg}"
new_collection.description = (
f"{collection.description} Containing only items with EPSG code {epsg}"
new_collection = pystac.Collection(
id=f"{collection.id}_{epsg}",
description=f"{collection.description} Containing only items with EPSG code {epsg}",
extent=collection.extent.clone(),
summaries=collection.summaries,
license=collection.license,
stac_extensions=collection.stac_extensions,
)
new_collection.clear_items()
for item in items:
new_collection.add_item(item)

new_collection.update_extent_from_items()

if "item_assets" in collection.extra_fields:
item_assets_extension = pystac.extensions.item_assets.ItemAssetsExtension.ext(
collection
)

new_item_assets_extension = (
pystac.extensions.item_assets.ItemAssetsExtension.ext(
new_collection, add_if_missing=True
)
)

new_item_assets_extension.item_assets = item_assets_extension.item_assets
return new_collection


def _create_collection_by_epsg_dict(collection: pystac.Collection) -> dict:
"""
Create a dictionary that groups collections by their EPSG code.

Parameters:
collection (pystac.Collection): The STAC collection.

Returns:
dict: A dictionary that maps EPSG codes to STAC collections.
"""
items_by_epsg = _create_item_by_epsg_dict(collection)
collections_by_epsg = {}
for epsg, items in items_by_epsg.items():
new_collection = _create_new_epsg_collection(epsg, items, collection)
collections_by_epsg[epsg] = new_collection

return collections_by_epsg


def _write_collection_dict(collection_dict: dict, output_dir: Union[str, Path]):
"""
Write the collection dictionary to disk.

Parameters:
collection_dict (dict): The dictionary that maps EPSG codes to STAC collections.
output_dir (str): The output directory.
"""
output_dir = Path(output_dir)
os.makedirs(output_dir, exist_ok=True)

for epsg, collection in collection_dict.items():
collection.normalize_hrefs(os.path.join(output_dir, f"collection-{epsg}"))
collection.save()


def split_collection_by_epsg(path: Union[str, Path], output_dir: Union[str, Path]):
"""
Split a STAC collection into multiple STAC collections based on EPSG code.
Expand All @@ -124,10 +91,29 @@ def split_collection_by_epsg(path: Union[str, Path], output_dir: Union[str, Path
path (str): The path to the STAC collection.
output_dir (str): The output directory.
"""

path = Path(path)
output_dir = Path(output_dir)
os.makedirs(output_dir, exist_ok=True)

try:
collection = pystac.read_file(path)
except pystac.STACError:
print("Please provide a path to a valid STAC collection.")
collection_dict = _create_collection_by_epsg_dict(collection)
_write_collection_dict(collection_dict, output_dir)
return

collections_by_epsg = {}

for epsg, item in _get_items_by_epsg(collection):
if epsg not in collections_by_epsg:
collections_by_epsg[epsg] = _create_collection_skeleton(collection, epsg)

# Add item to the corresponding collection
collections_by_epsg[epsg].add_item(item)

# Write each collection to disk
for epsg, new_collection in collections_by_epsg.items():
new_collection.update_extent_from_items() # Update extent based on added items
collection_path = output_dir / f"collection-{epsg}"
new_collection.normalize_hrefs(str(collection_path))
new_collection.save()
Loading