Skip to content

Commit

Permalink
Fix bug in _map_dims_to_ugrid, use Polars to improve SCRIP reader p…
Browse files Browse the repository at this point in the history
…erformance (#1109)

* remove check for n_edge which was constructing connectivity

* use polars for unique calls in SCRIP reader
  • Loading branch information
philipc2 authored Dec 12, 2024
1 parent c30f0b0 commit 4d2cc3b
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 25 deletions.
1 change: 1 addition & 0 deletions ci/asv.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies:
- pandas
- pathlib
- pre_commit
- polars
- pyarrow
- pytest
- pytest-cov
Expand Down
1 change: 1 addition & 0 deletions ci/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies:
- pandas
- geocat-datafiles
- spatialpandas
- polars
- geopandas
- pip:
- antimeridian
Expand Down
1 change: 1 addition & 0 deletions ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies:
- pandas
- pathlib
- pre_commit
- polars
- pyarrow
- pytest
- pip
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ dependencies = [
"geopandas",
"xarray",
"hvplot",
"polars",
]
# minimal dependencies end

Expand Down
13 changes: 9 additions & 4 deletions uxarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,21 @@ def _map_dims_to_ugrid(
# drop dimensions not present in the original dataset
_source_dims_dict.pop(key)

# only check edge dimension if it is present (to avoid overhead of computing connectivity)
if "n_edge" in grid._ds.dims:
n_edge = grid._ds.sizes["n_edge"]
else:
n_edge = None

for dim in set(ds.dims) ^ _source_dims_dict.keys():
# obtain dimensions that were not parsed source_dims_dict and attempt to match to a grid element
if ds.sizes[dim] == grid.n_face:
_source_dims_dict[dim] = "n_face"
elif ds.sizes[dim] == grid.n_node:
_source_dims_dict[dim] = "n_node"
elif ds.sizes[dim] == grid.n_edge:
_source_dims_dict[dim] = "n_edge"

# Possible Issue: https://github.com/UXARRAY/uxarray/issues/610
elif n_edge is not None:
if ds.sizes[dim] == n_edge:
_source_dims_dict[dim] = "n_edge"

# rename dimensions to follow the UGRID conventions
ds = ds.swap_dims(_source_dims_dict)
Expand Down
46 changes: 25 additions & 21 deletions uxarray/io/_scrip.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import xarray as xr
import numpy as np

import polars as pl

from uxarray.grid.connectivity import _replace_fill_values
from uxarray.constants import INT_DTYPE, INT_FILL_VALUE

Expand All @@ -11,43 +13,45 @@ def _to_ugrid(in_ds, out_ds):
"""If input dataset (``in_ds``) file is an unstructured SCRIP file,
function will reassign SCRIP variables to UGRID conventions in output file
(``out_ds``).
Parameters
----------
in_ds : xarray.Dataset
Original scrip dataset of interest being used
out_ds : xarray.Variable
file to be returned by ``_populate_scrip_data``, used as an empty placeholder file
to store reassigned SCRIP variables in UGRID conventions
"""

source_dims_dict = {}

if in_ds["grid_area"].all():
# Create node_lon & node_lat variables from grid_corner_lat/lon
# Turn latitude scrip array into 1D instead of 2D
# Turn latitude and longitude scrip arrays into 1D
corner_lat = in_ds["grid_corner_lat"].values.ravel()

# Repeat above steps with longitude data instead
corner_lon = in_ds["grid_corner_lon"].values.ravel()

# Combine flat lat and lon arrays
corner_lon_lat = np.vstack((corner_lon, corner_lat)).T
# Use Polars to find unique coordinate pairs
df = pl.DataFrame({"lon": corner_lon, "lat": corner_lat}).with_row_count(
"original_index"
)

# Get unique rows (first occurrence). This preserves the order in which they appear.
unique_df = df.unique(subset=["lon", "lat"], keep="first")

# unq_ind: The indices of the unique rows in the original array
unq_ind = unique_df["original_index"].to_numpy().astype(INT_DTYPE)

# To get the inverse index (unq_inv): map each original row back to its unique row index.
# Add a unique_id to the unique_df which will serve as the "inverse" mapping.
unique_df = unique_df.with_row_count("unique_id")

# Run numpy unique to determine which rows/values are actually unique
_, unq_ind, unq_inv = np.unique(
corner_lon_lat, return_index=True, return_inverse=True, axis=0
# Join original df with unique_df to find out which unique_id corresponds to each row
df_joined = df.join(
unique_df.drop("original_index"), on=["lon", "lat"], how="left"
)
unq_inv = df_joined["unique_id"].to_numpy().astype(INT_DTYPE)

# Now, calculate unique lon and lat values to account for 'node_lon' and 'node_lat'
unq_lon = corner_lon_lat[unq_ind, :][:, 0]
unq_lat = corner_lon_lat[unq_ind, :][:, 1]
# Extract unique lon and lat values using unq_ind
unq_lon = corner_lon[unq_ind]
unq_lat = corner_lat[unq_ind]

# Reshape face nodes array into original shape for use in 'face_node_connectivity'
unq_inv = np.reshape(unq_inv, (len(in_ds.grid_size), len(in_ds.grid_corners)))

# Create node_lon & node_lat from unsorted, unique grid_corner_lat/lon
# Create node_lon & node_lat
out_ds[ugrid.NODE_COORDINATES[0]] = xr.DataArray(
unq_lon, dims=[ugrid.NODE_DIM], attrs=ugrid.NODE_LON_ATTRS
)
Expand Down

0 comments on commit 4d2cc3b

Please sign in to comment.