From 0e45edaf60843b0a6730dd2660d2c0a0574ef4ea Mon Sep 17 00:00:00 2001 From: Aaron Zedwick <95507181+aaronzedwick@users.noreply.github.com> Date: Thu, 9 Jan 2025 15:09:53 -0600 Subject: [PATCH] Add Inverse Face Indices to Subsetted Grids (#1122) * Initial Work * Addressed review comments * Updated doc string * Added inverse_indices support for data arrays and cross sections * Added ability to choose which inverse_indices to store, stores as ds * updated grid, doc strings, api, test cases * fixed leftover variables * Fixed failing tests * Update grid.py * New naming convention, fixed spelling errors * Updated subsetting notebook * Fixed precommit * Added is_subset property * Update uxarray/grid/grid.py Co-authored-by: Philip Chmielowiec <67855069+philipc2@users.noreply.github.com> * Added doc string, updated test case * Update grid.py --------- Co-authored-by: Philip Chmielowiec <67855069+philipc2@users.noreply.github.com> --- docs/api.rst | 1 + docs/user-guide/subset.ipynb | 91 +++++++++++++++++++- test/test_subset.py | 34 +++++++- uxarray/core/dataarray.py | 14 ++- uxarray/cross_sections/dataarray_accessor.py | 20 +++-- uxarray/cross_sections/grid_accessor.py | 18 +++- uxarray/grid/grid.py | 58 ++++++++++++- uxarray/grid/slice.py | 54 ++++++++++-- uxarray/subset/dataarray_accessor.py | 20 ++++- uxarray/subset/grid_accessor.py | 32 ++++--- 10 files changed, 304 insertions(+), 38 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index f9146e69f..10307a16f 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -54,6 +54,7 @@ Indexing :toctree: generated/ Grid.isel + Grid.inverse_indices Dimensions ~~~~~~~~~~ diff --git a/docs/user-guide/subset.ipynb b/docs/user-guide/subset.ipynb index 07df1e82b..c1e7a97df 100644 --- a/docs/user-guide/subset.ipynb +++ b/docs/user-guide/subset.ipynb @@ -553,6 +553,95 @@ "print(\"Bounding Box Mean: \", bbox_subset_nodes.values.mean())\n", "print(\"Bounding Circle Mean: \", bcircle_subset.values.mean())" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Retrieving Orignal Grid Indices" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Sometimes having the original grids' indices is useful. These indices can be stored within the subset with the `inverse_indices` variable. This can be used to store the indices of the original face centers, edge centers, and node indices. This variable can be used within the subset as follows:\n", + "\n", + "* Passing in `True`, which will store the face center indices\n", + "* Passing in a list of which indices to store, along with `True`, to indicate what kind of original grid indices to store.\n", + " * Options for which indices to select include: `face`, `node`, and `edge`\n", + "\n", + "This currently only works when the element is `face centers`. Elements `nodes` and `edge centers` will be supported in the future." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "subset_indices = uxds[\"relhum_200hPa\"][0].subset.bounding_circle(\n", + " center_coord,\n", + " r,\n", + " element=\"face centers\",\n", + " inverse_indices=([\"face\", \"node\", \"edge\"], True),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "These indices can be retrieve through the grid:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "subset_indices.uxgrid.inverse_indices" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Determining if a Grid is a Subset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To check if a Grid (or dataset using `.uxgrid`) is a subset, we can use `Grid.is_subset`, which will return either `True` or `False`, depending on whether the `Grid` is a subset. Since `subset_indices` is a subset, using this feature we will return `True`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "subset_indices.uxgrid.is_subset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The file we have been using to create these subsets, `uxds`, is not a subset, so using the same call we will return `False:`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "uxds.uxgrid.is_subset" + ] } ], "metadata": { @@ -571,7 +660,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.12.7" } }, "nbformat": 4, diff --git a/test/test_subset.py b/test/test_subset.py index a1bfaed7f..f13ae4919 100644 --- a/test/test_subset.py +++ b/test/test_subset.py @@ -104,8 +104,6 @@ def test_grid_bounding_box_subset(): bbox_antimeridian[0], bbox_antimeridian[1], element=element) - - def test_uxda_isel(): uxds = ux.open_dataset(GRID_PATHS[0], DATA_PATHS[0]) @@ -113,6 +111,7 @@ def test_uxda_isel(): assert len(sub) == 3 + def test_uxda_isel_with_coords(): uxds = ux.open_dataset(GRID_PATHS[0], DATA_PATHS[0]) uxds = uxds.assign_coords({"lon_face": uxds.uxgrid.face_lon}) @@ -120,3 +119,34 @@ def test_uxda_isel_with_coords(): assert "lon_face" in sub.coords assert len(sub.coords['lon_face']) == 3 + + +def test_inverse_indices(): + grid = ux.open_grid(GRID_PATHS[0]) + + # Test nearest neighbor subsetting + coord = [0, 0] + subset = grid.subset.nearest_neighbor(coord, k=1, element="face centers", inverse_indices=True) + + assert subset.inverse_indices is not None + + # Test bounding box subsetting + box = [(-10, 10), (-10, 10)] + subset = grid.subset.bounding_box(box[0], box[1], element="face centers", inverse_indices=True) + + assert subset.inverse_indices is not None + + # Test bounding circle subsetting + center_coord = [0, 0] + subset = grid.subset.bounding_circle(center_coord, r=10, element="face centers", inverse_indices=True) + + assert subset.inverse_indices is not None + + # Ensure code raises exceptions when the element is edges or nodes or inverse_indices is incorrect + assert pytest.raises(Exception, grid.subset.bounding_circle, center_coord, r=10, element="edge centers", inverse_indices=True) + assert pytest.raises(Exception, grid.subset.bounding_circle, center_coord, r=10, element="nodes", inverse_indices=True) + assert pytest.raises(ValueError, grid.subset.bounding_circle, center_coord, r=10, element="face center", inverse_indices=(['not right'], True)) + + # Test isel directly + subset = grid.isel(n_face=[1], inverse_indices=True) + assert subset.inverse_indices.face.values == 1 diff --git a/uxarray/core/dataarray.py b/uxarray/core/dataarray.py index 216b2309c..bfec83280 100644 --- a/uxarray/core/dataarray.py +++ b/uxarray/core/dataarray.py @@ -1035,7 +1035,7 @@ def _edge_centered(self) -> bool: "n_edge" dimension)""" return "n_edge" in self.dims - def isel(self, ignore_grid=False, *args, **kwargs): + def isel(self, ignore_grid=False, inverse_indices=False, *args, **kwargs): """Grid-informed implementation of xarray's ``isel`` method, which enables indexing across grid dimensions. @@ -1069,11 +1069,17 @@ def isel(self, ignore_grid=False, *args, **kwargs): raise ValueError("Only one grid dimension can be sliced at a time") if "n_node" in kwargs: - sliced_grid = self.uxgrid.isel(n_node=kwargs["n_node"]) + sliced_grid = self.uxgrid.isel( + n_node=kwargs["n_node"], inverse_indices=inverse_indices + ) elif "n_edge" in kwargs: - sliced_grid = self.uxgrid.isel(n_edge=kwargs["n_edge"]) + sliced_grid = self.uxgrid.isel( + n_edge=kwargs["n_edge"], inverse_indices=inverse_indices + ) else: - sliced_grid = self.uxgrid.isel(n_face=kwargs["n_face"]) + sliced_grid = self.uxgrid.isel( + n_face=kwargs["n_face"], inverse_indices=inverse_indices + ) return self._slice_from_grid(sliced_grid) diff --git a/uxarray/cross_sections/dataarray_accessor.py b/uxarray/cross_sections/dataarray_accessor.py index 6f82a8f2e..52599c7a4 100644 --- a/uxarray/cross_sections/dataarray_accessor.py +++ b/uxarray/cross_sections/dataarray_accessor.py @@ -1,7 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union, List, Set if TYPE_CHECKING: pass @@ -22,7 +22,9 @@ def __repr__(self): return prefix + methods_heading - def constant_latitude(self, lat: float): + def constant_latitude( + self, lat: float, inverse_indices: Union[List[str], Set[str], bool] = False + ): """Extracts a cross-section of the data array by selecting all faces that intersect with a specified line of constant latitude. @@ -31,6 +33,9 @@ def constant_latitude(self, lat: float): lat : float The latitude at which to extract the cross-section, in degrees. Must be between -90.0 and 90.0 + inverse_indices : Union[List[str], Set[str], bool], optional + Indicates whether to store the original grids indices. Passing `True` stores the original face centers, + other reverse indices can be stored by passing any or all of the following: (["face", "edge", "node"], True) Returns ------- @@ -60,9 +65,11 @@ def constant_latitude(self, lat: float): faces = self.uxda.uxgrid.get_faces_at_constant_latitude(lat) - return self.uxda.isel(n_face=faces) + return self.uxda.isel(n_face=faces, inverse_indices=inverse_indices) - def constant_longitude(self, lon: float): + def constant_longitude( + self, lon: float, inverse_indices: Union[List[str], Set[str], bool] = False + ): """Extracts a cross-section of the data array by selecting all faces that intersect with a specified line of constant longitude. @@ -71,6 +78,9 @@ def constant_longitude(self, lon: float): lon : float The latitude at which to extract the cross-section, in degrees. Must be between -180.0 and 180.0 + inverse_indices : Union[List[str], Set[str], bool], optional + Indicates whether to store the original grids indices. Passing `True` stores the original face centers, + other reverse indices can be stored by passing any or all of the following: (["face", "edge", "node"], True) Returns ------- @@ -102,7 +112,7 @@ def constant_longitude(self, lon: float): lon, ) - return self.uxda.isel(n_face=faces) + return self.uxda.isel(n_face=faces, inverse_indices=inverse_indices) def gca(self, *args, **kwargs): raise NotImplementedError diff --git a/uxarray/cross_sections/grid_accessor.py b/uxarray/cross_sections/grid_accessor.py index ee30bd913..76485fbda 100644 --- a/uxarray/cross_sections/grid_accessor.py +++ b/uxarray/cross_sections/grid_accessor.py @@ -1,7 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union, List, Set if TYPE_CHECKING: from uxarray.grid import Grid @@ -25,6 +25,7 @@ def constant_latitude( self, lat: float, return_face_indices: bool = False, + inverse_indices: Union[List[str], Set[str], bool] = False, ): """Extracts a cross-section of the grid by selecting all faces that intersect with a specified line of constant latitude. @@ -36,6 +37,9 @@ def constant_latitude( Must be between -90.0 and 90.0 return_face_indices : bool, optional If True, also returns the indices of the faces that intersect with the line of constant latitude. + inverse_indices : Union[List[str], Set[str], bool], optional + Indicates whether to store the original grids indices. Passing `True` stores the original face centers, + other reverse indices can be stored by passing any or all of the following: (["face", "edge", "node"], True) Returns ------- @@ -66,7 +70,9 @@ def constant_latitude( if len(faces) == 0: raise ValueError(f"No intersections found at lat={lat}.") - grid_at_constant_lat = self.uxgrid.isel(n_face=faces) + grid_at_constant_lat = self.uxgrid.isel( + n_face=faces, inverse_indices=inverse_indices + ) if return_face_indices: return grid_at_constant_lat, faces @@ -77,6 +83,7 @@ def constant_longitude( self, lon: float, return_face_indices: bool = False, + inverse_indices: Union[List[str], Set[str], bool] = False, ): """Extracts a cross-section of the grid by selecting all faces that intersect with a specified line of constant longitude. @@ -88,6 +95,9 @@ def constant_longitude( Must be between -90.0 and 90.0 return_face_indices : bool, optional If True, also returns the indices of the faces that intersect with the line of constant longitude. + inverse_indices : Union[List[str], Set[str], bool], optional + Indicates whether to store the original grids indices. Passing `True` stores the original face centers, + other reverse indices can be stored by passing any or all of the following: (["face", "edge", "node"], True) Returns ------- @@ -117,7 +127,9 @@ def constant_longitude( if len(faces) == 0: raise ValueError(f"No intersections found at lon={lon}") - grid_at_constant_lon = self.uxgrid.isel(n_face=faces) + grid_at_constant_lon = self.uxgrid.isel( + n_face=faces, inverse_indices=inverse_indices + ) if return_face_indices: return grid_at_constant_lon, faces diff --git a/uxarray/grid/grid.py b/uxarray/grid/grid.py index e1bac7d24..c975ebff8 100644 --- a/uxarray/grid/grid.py +++ b/uxarray/grid/grid.py @@ -9,6 +9,8 @@ from typing import ( Optional, Union, + List, + Set, ) # reader and writer imports @@ -137,6 +139,12 @@ class Grid: source_dims_dict : dict, default={} Mapping of dimensions from the source dataset to their UGRID equivalent (i.e. {nCell : n_face}) + is_subset : bool, default=False + Flag to mark if the grid is a subset or not + + inverse_indices: xr.Dataset, default=None + A dataset of indices that correspond to the original grid, if the grid being constructed is a subset + Examples ---------- @@ -160,6 +168,8 @@ def __init__( grid_ds: xr.Dataset, source_grid_spec: Optional[str] = None, source_dims_dict: Optional[dict] = {}, + is_subset: bool = False, + inverse_indices: Optional[xr.Dataset] = None, ): # check if inputted dataset is a minimum representable 2D UGRID unstructured grid if not _validate_minimum_ugrid(grid_ds): @@ -191,6 +201,10 @@ def __init__( # initialize attributes self._antimeridian_face_indices = None self._ds.assign_attrs({"source_grid_spec": self.source_grid_spec}) + self._is_subset = is_subset + + if inverse_indices is not None: + self._inverse_indices = inverse_indices # cached parameters for GeoDataFrame conversions self._gdf_cached_parameters = { @@ -252,6 +266,8 @@ def from_dataset(cls, dataset, use_dual: Optional[bool] = False, **kwargs): containing ASCII files represents a FESOM2 grid. use_dual : bool, default=False When reading in MPAS formatted datasets, indicates whether to use the Dual Mesh + is_subset : bool, default=False + Bool flag to indicate whether a grid is a subset """ if isinstance(dataset, xr.Dataset): @@ -301,7 +317,13 @@ def from_dataset(cls, dataset, use_dual: Optional[bool] = False, **kwargs): except TypeError: raise ValueError("Unsupported Grid Format") - return cls(grid_ds, source_grid_spec, source_dims_dict) + return cls( + grid_ds, + source_grid_spec, + source_dims_dict, + is_subset=kwargs.get("is_subset", False), + inverse_indices=kwargs.get("inverse_indices"), + ) @classmethod def from_file( @@ -1506,6 +1528,21 @@ def global_sphere_coverage(self): (i.e. contains no holes)""" return not self.partial_sphere_coverage + @property + def inverse_indices(self) -> xr.Dataset: + """Indices for a subset that map each face in the subset back to the original grid""" + if self.is_subset: + return self._inverse_indices + else: + raise Exception( + "Grid is not a subset, therefore no inverse face indices exist" + ) + + @property + def is_subset(self): + """Returns `True` if the Grid is a subset, 'False' otherwise.""" + return self._is_subset + def chunk(self, n_node="auto", n_edge="auto", n_face="auto"): """Converts all arrays to dask arrays with given chunks across grid dimensions in-place. @@ -2201,7 +2238,9 @@ def get_dual(self): return dual - def isel(self, **dim_kwargs): + def isel( + self, inverse_indices: Union[List[str], Set[str], bool] = False, **dim_kwargs + ): """Indexes an unstructured grid along a given dimension (``n_node``, ``n_edge``, or ``n_face``) and returns a new grid. @@ -2211,6 +2250,9 @@ def isel(self, **dim_kwargs): exclusive and clipped indexing is in the works. Parameters + inverse_indices : Union[List[str], Set[str], bool], default=False + Indicates whether to store the original grids indices. Passing `True` stores the original face indices, + other reverse indices can be stored by passing any or all of the following: (["face", "edge", "node"], True) **dims_kwargs: kwargs Dimension to index, one of ['n_node', 'n_edge', 'n_face'] @@ -2226,13 +2268,23 @@ def isel(self, **dim_kwargs): raise ValueError("Indexing must be along a single dimension.") if "n_node" in dim_kwargs: + if inverse_indices: + raise Exception( + "Inverse indices are not yet supported for node selection, please use face centers" + ) return _slice_node_indices(self, dim_kwargs["n_node"]) elif "n_edge" in dim_kwargs: + if inverse_indices: + raise Exception( + "Inverse indices are not yet supported for edge selection, please use face centers" + ) return _slice_edge_indices(self, dim_kwargs["n_edge"]) elif "n_face" in dim_kwargs: - return _slice_face_indices(self, dim_kwargs["n_face"]) + return _slice_face_indices( + self, dim_kwargs["n_face"], inverse_indices=inverse_indices + ) else: raise ValueError( diff --git a/uxarray/grid/slice.py b/uxarray/grid/slice.py index bc660332e..94e8e0eb8 100644 --- a/uxarray/grid/slice.py +++ b/uxarray/grid/slice.py @@ -4,13 +4,17 @@ import xarray as xr from uxarray.constants import INT_FILL_VALUE, INT_DTYPE -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union, List, Set if TYPE_CHECKING: pass -def _slice_node_indices(grid, indices, inclusive=True): +def _slice_node_indices( + grid, + indices, + inclusive=True, +): """Slices (indexes) an unstructured grid given a list/array of node indices, returning a new Grid composed of elements that contain the nodes specified in the indices. @@ -36,7 +40,11 @@ def _slice_node_indices(grid, indices, inclusive=True): return _slice_face_indices(grid, face_indices) -def _slice_edge_indices(grid, indices, inclusive=True): +def _slice_edge_indices( + grid, + indices, + inclusive=True, +): """Slices (indexes) an unstructured grid given a list/array of edge indices, returning a new Grid composed of elements that contain the edges specified in the indices. @@ -62,7 +70,12 @@ def _slice_edge_indices(grid, indices, inclusive=True): return _slice_face_indices(grid, face_indices) -def _slice_face_indices(grid, indices, inclusive=True): +def _slice_face_indices( + grid, + indices, + inclusive=True, + inverse_indices: Union[List[str], Set[str], bool] = False, +): """Slices (indexes) an unstructured grid given a list/array of face indices, returning a new Grid composed of elements that contain the faces specified in the indices. @@ -76,8 +89,10 @@ def _slice_face_indices(grid, indices, inclusive=True): inclusive: bool Whether to perform inclusive (i.e. elements must contain at least one desired feature from a slice) as opposed to exclusive (i.e elements be made up all desired features from a slice) + inverse_indices : Union[List[str], Set[str], bool], optional + Indicates whether to store the original grids indices. Passing `True` stores the original face centers, + other reverse indices can be stored by passing any or all of the following: (["face", "edge", "node"], True) """ - if inclusive is False: raise ValueError("Exclusive slicing is not yet supported.") @@ -132,4 +147,31 @@ def _slice_face_indices(grid, indices, inclusive=True): # drop any conn that would require re-computation ds = ds.drop_vars(conn_name) - return Grid.from_dataset(ds, source_grid_spec=grid.source_grid_spec) + if inverse_indices: + inverse_indices_ds = xr.Dataset() + + index_types = { + "face": face_indices, + "edge": edge_indices, + "node": node_indices, + } + if isinstance(inverse_indices, bool): + inverse_indices_ds["face"] = face_indices + else: + for index_type in inverse_indices[0]: + if index_type in index_types: + inverse_indices_ds[index_type] = index_types[index_type] + else: + raise ValueError( + "Incorrect type of index for `inverse_indices`. Try passing one of the following " + "instead: 'face', 'edge', 'node'" + ) + + return Grid.from_dataset( + ds, + source_grid_spec=grid.source_grid_spec, + is_subset=True, + inverse_indices=inverse_indices_ds, + ) + + return Grid.from_dataset(ds, source_grid_spec=grid.source_grid_spec, is_subset=True) diff --git a/uxarray/subset/dataarray_accessor.py b/uxarray/subset/dataarray_accessor.py index 3624e1e3e..2d966c587 100644 --- a/uxarray/subset/dataarray_accessor.py +++ b/uxarray/subset/dataarray_accessor.py @@ -2,7 +2,7 @@ import numpy as np -from typing import TYPE_CHECKING, Union, Tuple, List, Optional +from typing import TYPE_CHECKING, Union, Tuple, List, Optional, Set if TYPE_CHECKING: pass @@ -33,6 +33,7 @@ def bounding_box( lat_bounds: Union[Tuple, List, np.ndarray], element: Optional[str] = "nodes", method: Optional[str] = "coords", + inverse_indices: Union[List[str], Set[str], bool] = False, **kwargs, ): """Subsets an unstructured grid between two latitude and longitude @@ -53,9 +54,12 @@ def bounding_box( face centers, or edge centers lie within the bounds. element: str Element for use with `coords` comparison, one of `nodes`, `face centers`, or `edge centers` + inverse_indices : Union[List[str], Set[str], bool], optional + Indicates whether to store the original grids indices. Passing `True` stores the original face centers, + other reverse indices can be stored by passing any or all of the following: (["face", "edge", "node"], True) """ grid = self.uxda.uxgrid.subset.bounding_box( - lon_bounds, lat_bounds, element, method + lon_bounds, lat_bounds, element, method, inverse_indices=inverse_indices ) return self.uxda._slice_from_grid(grid) @@ -65,6 +69,7 @@ def bounding_circle( center_coord: Union[Tuple, List, np.ndarray], r: Union[float, int], element: Optional[str] = "nodes", + inverse_indices: Union[List[str], Set[str], bool] = False, **kwargs, ): """Subsets an unstructured grid by returning all elements within some @@ -78,9 +83,12 @@ def bounding_circle( Radius of bounding circle (in degrees) element: str Element for use with `coords` comparison, one of `nodes`, `face centers`, or `edge centers` + inverse_indices : Union[List[str], Set[str], bool], optional + Indicates whether to store the original grids indices. Passing `True` stores the original face centers, + other reverse indices can be stored by passing any or all of the following: (["face", "edge", "node"], True) """ grid = self.uxda.uxgrid.subset.bounding_circle( - center_coord, r, element, **kwargs + center_coord, r, element, inverse_indices=inverse_indices, **kwargs ) return self.uxda._slice_from_grid(grid) @@ -89,6 +97,7 @@ def nearest_neighbor( center_coord: Union[Tuple, List, np.ndarray], k: int, element: Optional[str] = "nodes", + inverse_indices: Union[List[str], Set[str], bool] = False, **kwargs, ): """Subsets an unstructured grid by returning the ``k`` closest @@ -102,10 +111,13 @@ def nearest_neighbor( Number of neighbors to query element: str Element for use with `coords` comparison, one of `nodes`, `face centers`, or `edge centers` + inverse_indices : Union[List[str], Set[str], bool], optional + Indicates whether to store the original grids indices. Passing `True` stores the original face centers, + other reverse indices can be stored by passing any or all of the following: (["face", "edge", "node"], True) """ grid = self.uxda.uxgrid.subset.nearest_neighbor( - center_coord, k, element, **kwargs + center_coord, k, element, inverse_indices=inverse_indices, **kwargs ) return self.uxda._slice_from_grid(grid) diff --git a/uxarray/subset/grid_accessor.py b/uxarray/subset/grid_accessor.py index 60dc8c800..a504179f1 100644 --- a/uxarray/subset/grid_accessor.py +++ b/uxarray/subset/grid_accessor.py @@ -2,7 +2,7 @@ import numpy as np -from typing import TYPE_CHECKING, Union, Tuple, List, Optional +from typing import TYPE_CHECKING, Union, Tuple, List, Optional, Set if TYPE_CHECKING: from uxarray.grid import Grid @@ -33,6 +33,7 @@ def bounding_box( lat_bounds: Union[Tuple, List, np.ndarray], element: Optional[str] = "nodes", method: Optional[str] = "coords", + inverse_indices: Union[List[str], Set[str], bool] = False, **kwargs, ): """Subsets an unstructured grid between two latitude and longitude @@ -53,6 +54,9 @@ def bounding_box( face centers, or edge centers lie within the bounds. element: str Element for use with `coords` comparison, one of `nodes`, `face centers`, or `edge centers` + inverse_indices : Union[List[str], Set[str], bool], optional + Indicates whether to store the original grids indices. Passing `True` stores the original face centers, + other reverse indices can be stored by passing any or all of the following: (["face", "edge", "node"], True) """ if method == "coords": @@ -101,11 +105,11 @@ def bounding_box( ) if element == "nodes": - return self.uxgrid.isel(n_node=indices) + return self.uxgrid.isel(inverse_indices, n_node=indices) elif element == "face centers": - return self.uxgrid.isel(n_face=indices) + return self.uxgrid.isel(inverse_indices, n_face=indices) elif element == "edge centers": - return self.uxgrid.isel(n_edge=indices) + return self.uxgrid.isel(inverse_indices, n_edge=indices) else: raise ValueError(f"Method '{method}' not supported.") @@ -115,6 +119,7 @@ def bounding_circle( center_coord: Union[Tuple, List, np.ndarray], r: Union[float, int], element: Optional[str] = "nodes", + inverse_indices: Union[List[str], Set[str], bool] = False, **kwargs, ): """Subsets an unstructured grid by returning all elements within some @@ -128,6 +133,9 @@ def bounding_circle( Radius of bounding circle (in degrees) element: str Element for use with `coords` comparison, one of `nodes`, `face centers`, or `edge centers` + inverse_indices : Union[List[str], Set[str], bool], optional + Indicates whether to store the original grids indices. Passing `True` stores the original face centers, + other reverse indices can be stored by passing any or all of the following: (["face", "edge", "node"], True) """ coords = np.asarray(center_coord) @@ -141,13 +149,14 @@ def bounding_circle( f"No elements founding within the bounding circle with radius {r} when querying {element}" ) - return self._index_grid(ind, element) + return self._index_grid(ind, element, inverse_indices) def nearest_neighbor( self, center_coord: Union[Tuple, List, np.ndarray], k: int, element: Optional[str] = "nodes", + inverse_indices: Union[List[str], Set[str], bool] = False, **kwargs, ): """Subsets an unstructured grid by returning the ``k`` closest @@ -161,6 +170,9 @@ def nearest_neighbor( Number of neighbors to query element: str Element for use with `coords` comparison, one of `nodes`, `face centers`, or `edge centers` + inverse_indices : Union[List[str], Set[str], bool], optional + Indicates whether to store the original grids indices. Passing `True` stores the original face centers, + other reverse indices can be stored by passing any or all of the following: (["face", "edge", "node"], True) """ coords = np.asarray(center_coord) @@ -169,7 +181,7 @@ def nearest_neighbor( _, ind = tree.query(coords, k) - return self._index_grid(ind, element) + return self._index_grid(ind, element, inverse_indices=inverse_indices) def _get_tree(self, coords, tree_type): """Internal helper for obtaining the desired KDTree or BallTree.""" @@ -187,12 +199,12 @@ def _get_tree(self, coords, tree_type): return tree - def _index_grid(self, ind, tree_type): + def _index_grid(self, ind, tree_type, inverse_indices=False): """Internal helper for indexing a grid with indices based off the provided tree type.""" if tree_type == "nodes": - return self.uxgrid.isel(n_node=ind) + return self.uxgrid.isel(inverse_indices, n_node=ind) elif tree_type == "edge centers": - return self.uxgrid.isel(n_edge=ind) + return self.uxgrid.isel(inverse_indices, n_edge=ind) else: - return self.uxgrid.isel(n_face=ind) + return self.uxgrid.isel(inverse_indices, n_face=ind)