Skip to content

Commit

Permalink
Merge pull request #1 from raphaeldussin/masking_update
Browse files Browse the repository at this point in the history
Masking update
  • Loading branch information
raphaeldussin authored Sep 2, 2020
2 parents 40d8f48 + c72a2fc commit 10d0a8f
Show file tree
Hide file tree
Showing 6 changed files with 761 additions and 11 deletions.
1 change: 1 addition & 0 deletions doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ Contents
notebooks/Compare_algorithms
notebooks/Reuse_regridder
notebooks/Using_LocStream
notebooks/Masking
large_problems_on_HPC

.. toctree::
Expand Down
670 changes: 670 additions & 0 deletions doc/notebooks/Masking.ipynb

Large diffs are not rendered by default.

61 changes: 53 additions & 8 deletions xesmf/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def warn_lat_range(lat):
warnings.warn("Latitude is outside of [-90, 90]")


def esmf_grid(lon, lat, periodic=False):
def esmf_grid(lon, lat, periodic=False, mask=None):
'''
Create an ESMF.Grid object, for contrusting ESMF.Field and ESMF.Regrid
Create an ESMF.Grid object, for constructing ESMF.Field and ESMF.Regrid.
Parameters
----------
Expand All @@ -70,6 +70,13 @@ def esmf_grid(lon, lat, periodic=False):
Periodic in longitude? Default to False.
Only useful for source grid.
mask : 2D numpy array, optional
Grid mask. According to the ESMF convention, masked cells
are set to 0 and unmasked cells to 1.
Shape should be ``(Nlon, Nlat)`` for rectilinear grid,
or ``(Nx, Ny)`` for general quadrilateral grid.
Returns
-------
grid : ESMF.Grid object
Expand Down Expand Up @@ -111,6 +118,22 @@ def esmf_grid(lon, lat, periodic=False):
lon_pointer[...] = lon
lat_pointer[...] = lat

# Follows SCRIP convention where 1 is unmasked and 0 is masked.
# See https://github.com/NCPP/ocgis/blob/61d88c60e9070215f28c1317221c2e074f8fb145/src/ocgis/regrid/base.py#L391-L404
if mask is not None:
# remove fractional values
mask = np.where(mask == 0, 0, 1)
# convert array type to integer (ESMF compat)
grid_mask = mask.astype(np.int32)
if not (grid_mask.shape == lon.shape):
raise ValueError(
"mask must have the same shape as the latitude/longitude"
"coordinates, got: mask.shape = %s, lon.shape = %s" %
(mask.shape, lon.shape))
grid.add_item(ESMF.GridItem.MASK, staggerloc=ESMF.StaggerLoc.CENTER,
from_file=False)
grid.mask[0][:] = grid_mask

return grid


Expand Down Expand Up @@ -207,6 +230,7 @@ def esmf_regrid_build(sourcegrid, destgrid, method,
- 'bilinear'
- 'conservative', **need grid corner information**
- 'conservative_normed', **need grid corner information**
- 'patch'
- 'nearest_s2d'
- 'nearest_d2s'
Expand Down Expand Up @@ -241,6 +265,7 @@ def esmf_regrid_build(sourcegrid, destgrid, method,
# use shorter, clearer names for options in ESMF.RegridMethod
method_dict = {'bilinear': ESMF.RegridMethod.BILINEAR,
'conservative': ESMF.RegridMethod.CONSERVE,
'conservative_normed': ESMF.RegridMethod.CONSERVE,
'patch': ESMF.RegridMethod.PATCH,
'nearest_s2d': ESMF.RegridMethod.NEAREST_STOD,
'nearest_d2s': ESMF.RegridMethod.NEAREST_DTOS
Expand All @@ -252,7 +277,7 @@ def esmf_regrid_build(sourcegrid, destgrid, method,
'{}'.format(list(method_dict.keys())))

# conservative regridding needs cell corner information
if method == 'conservative':
if method in ['conservative', 'conservative_normed']:
if not sourcegrid.has_corners:
raise ValueError('source grid has no corner information. '
'cannot use conservative regridding.')
Expand All @@ -266,20 +291,40 @@ def esmf_regrid_build(sourcegrid, destgrid, method,
sourcefield = ESMF.Field(sourcegrid, ndbounds=extra_dims)
destfield = ESMF.Field(destgrid, ndbounds=extra_dims)

# ESMF bug? when using locstream objects, options src_mask_values
# and dst_mask_values produce runtime errors
allow_masked_values = True
if isinstance(sourcefield.grid, ESMF.api.locstream.LocStream):
allow_masked_values = False
if isinstance(destfield.grid, ESMF.api.locstream.LocStream):
allow_masked_values = False

# ESMPy will throw an incomprehensive error if the weight file
# already exists. Better to catch it here!
if filename is not None:
assert not os.path.exists(filename), (
'Weight file already exists! Please remove it or use a new name.')

# re-normalize conservative regridding results
# https://github.com/JiaweiZhuang/xESMF/issues/17
if method == 'conservative_normed':
norm_type = ESMF.NormType.FRACAREA
else:
norm_type = ESMF.NormType.DSTAREA

# Calculate regridding weights.
# Must set unmapped_action to IGNORE, otherwise the function will fail,
# if the destination grid is larger than the source grid.
regrid = ESMF.Regrid(sourcefield, destfield, filename=filename,
regrid_method=esmf_regrid_method,
unmapped_action=ESMF.UnmappedAction.IGNORE,
ignore_degenerate=ignore_degenerate,
factors=filename is None)
kwargs=dict(filename=filename,
regrid_method=esmf_regrid_method,
unmapped_action=ESMF.UnmappedAction.IGNORE,
ignore_degenerate=ignore_degenerate,
norm_type=norm_type,
factors=filename is None)
if allow_masked_values:
kwargs.update(dict(src_mask_values=[0], dst_mask_values=[0]))

regrid = ESMF.Regrid(sourcefield, destfield, **kwargs)

return regrid

Expand Down
18 changes: 15 additions & 3 deletions xesmf/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,16 @@ def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None):
lat = np.asarray(ds['lat'])
lon, lat = as_2d_mesh(lon, lat)

if 'mask' in ds:
mask = np.asarray(ds['mask'])
else:
mask = None

# tranpose the arrays so they become Fortran-ordered
grid = esmf_grid(lon.T, lat.T, periodic=periodic)
if mask is not None:
grid = esmf_grid(lon.T, lat.T, periodic=periodic, mask=mask.T)
else:
grid = esmf_grid(lon.T, lat.T, periodic=periodic, mask=None)

if need_bounds:
lon_b = np.asarray(ds['lon_b'])
Expand Down Expand Up @@ -115,17 +123,21 @@ def __init__(self, ds_in, ds_out, method, periodic=False,
ds_in, ds_out : xarray DataSet, or dictionary
Contain input and output grid coordinates. Look for variables
``lon``, ``lat``, and optionally ``lon_b``, ``lat_b`` for
conservative method.
conservative methods.
Shape can be 1D (n_lon,) and (n_lat,) for rectilinear grids,
or 2D (n_y, n_x) for general curvilinear grids.
Shape of bounds should be (n+1,) or (n_y+1, n_x+1).
If either dataset includes a 2d mask variable, that will also be
used to inform the regridding.
method : str
Regridding method. Options are
- 'bilinear'
- 'conservative', **need grid corner information**
- 'conservative_normed', **need grid corner information**
- 'patch'
- 'nearest_s2d'
- 'nearest_d2s'
Expand Down Expand Up @@ -171,7 +183,7 @@ def __init__(self, ds_in, ds_out, method, periodic=False,
"""

# record basic switches
if method == 'conservative':
if method in ['conservative', 'conservative_normed']:
self.need_bounds = True
periodic = False # bound shape will not be N+1 for periodic grid
else:
Expand Down
6 changes: 6 additions & 0 deletions xesmf/tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,11 @@ def test_esmf_locstream():
with pytest.raises(ValueError):
ls = esmf_locstream(lon2d, lat)

grid_in = esmf_grid(lon_in.T, lat_in.T, periodic=True)
regrid = esmf_regrid_build(grid_in, ls, 'bilinear')

regrid = esmf_regrid_build(ls, grid_in, 'nearest_s2d')


def test_read_weights(tmp_path):
fn = tmp_path / "weights.nc"
Expand Down Expand Up @@ -257,3 +262,4 @@ def test_read_weights(tmp_path):
with pytest.raises(ValueError):
ds = xr.open_dataset(fn)
read_weights(ds.drop_vars("col"), lon_in.size, lon_out.size)

16 changes: 16 additions & 0 deletions xesmf/tests/test_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,22 @@ def test_regrid_dataset_to_locstream():
ds_result = regridder(ds_in)


def test_build_regridder_with_masks():
ds_in['mask'] = xr.DataArray(
np.random.randint(2, size=ds_in['data'].shape),
dims=('y', 'x'))
print(ds_in)
# 'patch' is too slow to test
for method in ['bilinear', 'conservative', 'conservative_normed',
'nearest_s2d', 'nearest_d2s']:
regridder = xe.Regridder(ds_in, ds_out, method)

# check screen output
assert repr(regridder) == str(regridder)
assert 'xESMF Regridder' in str(regridder)
assert method in str(regridder)


def test_regrid_dataset_from_locstream():
# xarray.Dataset containing in-memory numpy array

Expand Down

0 comments on commit 10d0a8f

Please sign in to comment.