From ac76c7fc6a82b102d6482183812240ec38e10c64 Mon Sep 17 00:00:00 2001 From: Scott Staniewicz Date: Fri, 6 Dec 2024 11:41:16 -0500 Subject: [PATCH] Refactor to filtering working with/without geopandas --- src/opera_utils/burst_frame_db.py | 58 +++++++++++++++++++------------ 1 file changed, 36 insertions(+), 22 deletions(-) diff --git a/src/opera_utils/burst_frame_db.py b/src/opera_utils/burst_frame_db.py index 61c78a1..0e4355c 100644 --- a/src/opera_utils/burst_frame_db.py +++ b/src/opera_utils/burst_frame_db.py @@ -56,49 +56,65 @@ def get_frame_to_burst_mapping( def get_frame_geojson( + frame_ids: Optional[Sequence[int | str]] = None, as_geodataframe: bool = False, - columns: Optional[Sequence[str]] = None, - frame_ids: Optional[Sequence[str]] = None, ) -> dict: """Get the GeoJSON for the frame geometries.""" - where = _form_where_in_query(frame_ids, "frame_id") if frame_ids else None - return _get_geojson( + data = _get_geojson( datasets.fetch_frame_geometries_simple(), as_geodataframe=as_geodataframe, - columns=columns, - where=where, + fids=frame_ids, index_name="frame_id", ) + if as_geodataframe or not frame_ids: + # `as_geodataframe` means it's already filtered + return data + + # Manually filter for the case of no geopandas + return { + **data, + "features": [ + f for f in data["features"] if f.get("id") in set(map(int, frame_ids)) + ], + } def get_burst_id_geojson( - as_geodataframe: bool = False, - columns: Optional[Sequence[str]] = None, burst_ids: Optional[Sequence[str]] = None, + as_geodataframe: bool = False, ) -> dict: """Get the GeoJSON for the burst_id geometries.""" - where = _form_where_in_query(burst_ids, "burst_id_jpl") if burst_ids else None - return _get_geojson( + data = _get_geojson( datasets.fetch_burst_id_geometries_simple(), as_geodataframe=as_geodataframe, - columns=columns, - where=where, + fids=burst_ids, index_name="burst_id_jpl", ) + if not burst_ids: + return data + if isinstance(burst_ids, str): + burst_ids = [burst_ids] + if as_geodataframe: + import geopandas -def _form_where_in_query(values: Sequence[str], column_name): - # Example: - # "burst_id_jpl in ('t005_009471_iw2','t007_013706_iw2','t008_015794_iw1')" - where_in_str = ",".join(f"'{b}'" for b in values) - return f"{column_name} IN ({where_in_str})" + assert isinstance(data, geopandas.GeoDataFrame) + return data[data.burst_id_jpl.isin(tuple(burst_ids))] + # Manually filter for the case of no geopandas + return { + **data, + "features": [ + f + for f in data["features"] + if f["properties"]["burst_id_jpl"] in set(burst_ids) + ], + } def _get_geojson( f, as_geodataframe: bool = False, - columns: Optional[Sequence[str]] = None, - where: Optional[str] = None, + fids: Sequence[str | int] | None = None, index_name: Optional[str] = None, ) -> dict: # https://gdal.org/user/ogr_sql_dialect.html#where @@ -106,9 +122,7 @@ def _get_geojson( if as_geodataframe: from pyogrio import read_dataframe - # import geopandas as gpd - # return gpd.read_file(f) - gdf = read_dataframe(f, columns=columns, where=where, fid_as_index=True) + gdf = read_dataframe(f, layer=None, fid_as_index=True, fids=fids) if index_name: gdf.index.name = index_name return gdf