Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor to filtering working with/without geopandas #81

Merged
merged 2 commits into from
Jan 3, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 36 additions & 22 deletions src/opera_utils/burst_frame_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,59 +56,73 @@ def get_frame_to_burst_mapping(


def get_frame_geojson(
frame_ids: Optional[Sequence[int | str]] = None,
as_geodataframe: bool = False,
columns: Optional[Sequence[str]] = None,
frame_ids: Optional[Sequence[str]] = None,
) -> dict:
"""Get the GeoJSON for the frame geometries."""
where = _form_where_in_query(frame_ids, "frame_id") if frame_ids else None
return _get_geojson(
data = _get_geojson(
datasets.fetch_frame_geometries_simple(),
as_geodataframe=as_geodataframe,
columns=columns,
where=where,
fids=frame_ids,
index_name="frame_id",
)
if as_geodataframe or not frame_ids:
# `as_geodataframe` means it's already filtered
return data

# Manually filter for the case of no geopandas
return {
**data,
"features": [
f for f in data["features"] if f.get("id") in set(map(int, frame_ids))
],
}


def get_burst_id_geojson(
as_geodataframe: bool = False,
columns: Optional[Sequence[str]] = None,
burst_ids: Optional[Sequence[str]] = None,
as_geodataframe: bool = False,
) -> dict:
"""Get the GeoJSON for the burst_id geometries."""
where = _form_where_in_query(burst_ids, "burst_id_jpl") if burst_ids else None
return _get_geojson(
data = _get_geojson(
datasets.fetch_burst_id_geometries_simple(),
as_geodataframe=as_geodataframe,
columns=columns,
where=where,
fids=burst_ids,
index_name="burst_id_jpl",
)
if not burst_ids:
return data

if isinstance(burst_ids, str):
burst_ids = [burst_ids]
if as_geodataframe:
import geopandas

def _form_where_in_query(values: Sequence[str], column_name):
# Example:
# "burst_id_jpl in ('t005_009471_iw2','t007_013706_iw2','t008_015794_iw1')"
where_in_str = ",".join(f"'{b}'" for b in values)
return f"{column_name} IN ({where_in_str})"
assert isinstance(data, geopandas.GeoDataFrame)
return data[data.burst_id_jpl.isin(tuple(burst_ids))]
# Manually filter for the case of no geopandas
return {
**data,
"features": [
f
for f in data["features"]
if f["properties"]["burst_id_jpl"] in set(burst_ids)
],
}


def _get_geojson(
f,
as_geodataframe: bool = False,
columns: Optional[Sequence[str]] = None,
where: Optional[str] = None,
fids: Sequence[str | int] | None = None,
index_name: Optional[str] = None,
) -> dict:
# https://gdal.org/user/ogr_sql_dialect.html#where
# https://pyogrio.readthedocs.io/en/latest/introduction.html#filter-records-by-attribute-value
if as_geodataframe:
from pyogrio import read_dataframe

# import geopandas as gpd
# return gpd.read_file(f)
gdf = read_dataframe(f, columns=columns, where=where, fid_as_index=True)
gdf = read_dataframe(f, layer=None, fid_as_index=True, fids=fids)
if index_name:
gdf.index.name = index_name
return gdf
Expand Down
Loading