Skip to content

Commit

Permalink
Merge pull request #915 from UXARRAY/philipc2/grid-chunking
Browse files Browse the repository at this point in the history
Dask Array Support & Chunking for `Grid`
  • Loading branch information
rajeeja authored Sep 6, 2024
2 parents 82e46b8 + 4461601 commit e50c879
Show file tree
Hide file tree
Showing 4 changed files with 313 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/user_api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ Methods
Grid.get_kd_tree
Grid.copy
Grid.isel
Grid.chunk


Dimensions
Expand Down
51 changes: 51 additions & 0 deletions test/test_dask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import uxarray as ux
import numpy as np
import dask.array as da

import pytest
import os
from pathlib import Path



current_path = Path(os.path.dirname(os.path.realpath(__file__)))

mpas_grid = current_path / 'meshfiles' / "mpas" / "QU" / 'oQU480.231010.nc'

csne30_grid = current_path / 'meshfiles' / "ugrid" / "outCSne30" / 'outCSne30.ug'
csne30_data = current_path / 'meshfiles' / "ugrid" / "outCSne30" / 'outCSne30_var2.nc'


def test_grid_chunking():
"""Tests the chunking of an entire grid."""
uxgrid = ux.open_grid(mpas_grid)

for var in uxgrid._ds:
# variables should all be np.ndarray
assert isinstance(uxgrid._ds[var].data, np.ndarray)

# chunk every data variable
uxgrid.chunk(n_node=1, n_face=2, n_edge=4)

for var in uxgrid._ds:
# variables should all be da.Array
assert isinstance(uxgrid._ds[var].data, da.Array)

def test_individual_var_chunking():
"""Tests the chunking of a single grid variable."""
uxgrid = ux.open_grid(mpas_grid)

# face_node_conn should originally be a numpy array
assert isinstance(uxgrid.face_node_connectivity.data, np.ndarray)

# chunk face_node_connectivity
uxgrid.face_node_connectivity = uxgrid.face_node_connectivity.chunk(chunks={"n_face": 16})

# face_node_conn should now be a dask array
assert isinstance(uxgrid.face_node_connectivity.data, da.Array)


def test_uxds_chunking():
uxds = ux.open_dataset(csne30_grid, csne30_data, chunks={"n_face": 4})

pass
14 changes: 14 additions & 0 deletions uxarray/core/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,13 @@ def open_dataset(
grid_filename_or_obj, latlon=latlon, use_dual=use_dual, **grid_kwargs
)

if "chunks" in kwargs:
# correctly chunk standardized ugrid dimension names
source_dims_dict = uxgrid._source_dims_dict
for original_grid_dim, ugrid_grid_dim in source_dims_dict.items():
if ugrid_grid_dim in kwargs["chunks"]:
kwargs["chunks"][original_grid_dim] = kwargs["chunks"][ugrid_grid_dim]

# UxDataset
ds = xr.open_dataset(filename_or_obj, **kwargs) # type: ignore

Expand Down Expand Up @@ -254,6 +261,13 @@ def open_mfdataset(
grid_filename_or_obj, latlon=latlon, use_dual=use_dual, **grid_kwargs
)

if "chunks" in kwargs:
# correctly chunk standardized ugrid dimension names
source_dims_dict = uxgrid._source_dims_dict
for original_grid_dim, ugrid_grid_dim in source_dims_dict.items():
if ugrid_grid_dim in kwargs["chunks"]:
kwargs["chunks"][original_grid_dim] = kwargs["chunks"][ugrid_grid_dim]

# UxDataset
ds = xr.open_mfdataset(paths, **kwargs) # type: ignore

Expand Down
Loading

0 comments on commit e50c879

Please sign in to comment.