Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Masking update #1

Merged
merged 13 commits into from
Sep 2, 2020
1 change: 1 addition & 0 deletions doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ Contents
notebooks/Compare_algorithms
notebooks/Reuse_regridder
notebooks/Using_LocStream
notebooks/Masking

.. toctree::
:maxdepth: 1
Expand Down
670 changes: 670 additions & 0 deletions doc/notebooks/Masking.ipynb

Large diffs are not rendered by default.

61 changes: 53 additions & 8 deletions xesmf/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def warn_lat_range(lat):
warnings.warn("Latitude is outside of [-90, 90]")


def esmf_grid(lon, lat, periodic=False):
def esmf_grid(lon, lat, periodic=False, mask=None):
'''
Create an ESMF.Grid object, for contrusting ESMF.Field and ESMF.Regrid
Create an ESMF.Grid object, for constructing ESMF.Field and ESMF.Regrid.

Parameters
----------
Expand All @@ -70,6 +70,13 @@ def esmf_grid(lon, lat, periodic=False):
Periodic in longitude? Default to False.
Only useful for source grid.

mask : 2D numpy array, optional
Grid mask. According to the ESMF convention, masked cells
are set to 0 and unmasked cells to 1.

Shape should be ``(Nlon, Nlat)`` for rectilinear grid,
or ``(Nx, Ny)`` for general quadrilateral grid.

Returns
-------
grid : ESMF.Grid object
Expand Down Expand Up @@ -111,6 +118,22 @@ def esmf_grid(lon, lat, periodic=False):
lon_pointer[...] = lon
lat_pointer[...] = lat

# Follows SCRIP convention where 1 is unmasked and 0 is masked.
# See https://github.com/NCPP/ocgis/blob/61d88c60e9070215f28c1317221c2e074f8fb145/src/ocgis/regrid/base.py#L391-L404
if mask is not None:
# remove fractional values
mask = np.where(mask == 0, 0, 1)
# convert array type to integer (ESMF compat)
grid_mask = mask.astype(np.int32)
if not (grid_mask.shape == lon.shape):
raise ValueError(
"mask must have the same shape as the latitude/longitude"
"coordinates, got: mask.shape = %s, lon.shape = %s" %
(mask.shape, lon.shape))
grid.add_item(ESMF.GridItem.MASK, staggerloc=ESMF.StaggerLoc.CENTER,
from_file=False)
grid.mask[0][:] = grid_mask

return grid


Expand Down Expand Up @@ -207,6 +230,7 @@ def esmf_regrid_build(sourcegrid, destgrid, method,

- 'bilinear'
- 'conservative', **need grid corner information**
- 'conservative_normed', **need grid corner information**
- 'patch'
- 'nearest_s2d'
- 'nearest_d2s'
Expand Down Expand Up @@ -241,6 +265,7 @@ def esmf_regrid_build(sourcegrid, destgrid, method,
# use shorter, clearer names for options in ESMF.RegridMethod
method_dict = {'bilinear': ESMF.RegridMethod.BILINEAR,
'conservative': ESMF.RegridMethod.CONSERVE,
'conservative_normed': ESMF.RegridMethod.CONSERVE,
'patch': ESMF.RegridMethod.PATCH,
'nearest_s2d': ESMF.RegridMethod.NEAREST_STOD,
'nearest_d2s': ESMF.RegridMethod.NEAREST_DTOS
Expand All @@ -252,7 +277,7 @@ def esmf_regrid_build(sourcegrid, destgrid, method,
'{}'.format(list(method_dict.keys())))

# conservative regridding needs cell corner information
if method == 'conservative':
if method in ['conservative', 'conservative_normed']:
if not sourcegrid.has_corners:
raise ValueError('source grid has no corner information. '
'cannot use conservative regridding.')
Expand All @@ -266,20 +291,40 @@ def esmf_regrid_build(sourcegrid, destgrid, method,
sourcefield = ESMF.Field(sourcegrid, ndbounds=extra_dims)
destfield = ESMF.Field(destgrid, ndbounds=extra_dims)

# ESMF bug? when using locstream objects, options src_mask_values
# and dst_mask_values produce runtime errors
allow_masked_values = True
if isinstance(sourcefield.grid, ESMF.api.locstream.LocStream):
allow_masked_values = False
if isinstance(destfield.grid, ESMF.api.locstream.LocStream):
allow_masked_values = False

# ESMPy will throw an incomprehensive error if the weight file
# already exists. Better to catch it here!
if filename is not None:
assert not os.path.exists(filename), (
'Weight file already exists! Please remove it or use a new name.')

# re-normalize conservative regridding results
# https://github.com/JiaweiZhuang/xESMF/issues/17
if method == 'conservative_normed':
norm_type = ESMF.NormType.FRACAREA
else:
norm_type = ESMF.NormType.DSTAREA

# Calculate regridding weights.
# Must set unmapped_action to IGNORE, otherwise the function will fail,
# if the destination grid is larger than the source grid.
regrid = ESMF.Regrid(sourcefield, destfield, filename=filename,
regrid_method=esmf_regrid_method,
unmapped_action=ESMF.UnmappedAction.IGNORE,
ignore_degenerate=ignore_degenerate,
factors=filename is None)
kwargs=dict(filename=filename,
regrid_method=esmf_regrid_method,
unmapped_action=ESMF.UnmappedAction.IGNORE,
ignore_degenerate=ignore_degenerate,
norm_type=norm_type,
factors=filename is None)
if allow_masked_values:
kwargs.update(dict(src_mask_values=[0], dst_mask_values=[0]))

regrid = ESMF.Regrid(sourcefield, destfield, **kwargs)

return regrid

Expand Down
18 changes: 15 additions & 3 deletions xesmf/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,16 @@ def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None):
lat = np.asarray(ds['lat'])
lon, lat = as_2d_mesh(lon, lat)

if 'mask' in ds:
mask = np.asarray(ds['mask'])
else:
mask = None

# tranpose the arrays so they become Fortran-ordered
grid = esmf_grid(lon.T, lat.T, periodic=periodic)
if mask is not None:
grid = esmf_grid(lon.T, lat.T, periodic=periodic, mask=mask.T)
else:
grid = esmf_grid(lon.T, lat.T, periodic=periodic, mask=None)

if need_bounds:
lon_b = np.asarray(ds['lon_b'])
Expand Down Expand Up @@ -115,17 +123,21 @@ def __init__(self, ds_in, ds_out, method, periodic=False,
ds_in, ds_out : xarray DataSet, or dictionary
Contain input and output grid coordinates. Look for variables
``lon``, ``lat``, and optionally ``lon_b``, ``lat_b`` for
conservative method.
conservative methods.

Shape can be 1D (n_lon,) and (n_lat,) for rectilinear grids,
or 2D (n_y, n_x) for general curvilinear grids.
Shape of bounds should be (n+1,) or (n_y+1, n_x+1).

If either dataset includes a 2d mask variable, that will also be
used to inform the regridding.

method : str
Regridding method. Options are

- 'bilinear'
- 'conservative', **need grid corner information**
- 'conservative_normed', **need grid corner information**
- 'patch'
- 'nearest_s2d'
- 'nearest_d2s'
Expand Down Expand Up @@ -171,7 +183,7 @@ def __init__(self, ds_in, ds_out, method, periodic=False,
"""

# record basic switches
if method == 'conservative':
if method in ['conservative', 'conservative_normed']:
self.need_bounds = True
periodic = False # bound shape will not be N+1 for periodic grid
else:
Expand Down
6 changes: 6 additions & 0 deletions xesmf/tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,11 @@ def test_esmf_locstream():
with pytest.raises(ValueError):
ls = esmf_locstream(lon2d, lat)

grid_in = esmf_grid(lon_in.T, lat_in.T, periodic=True)
regrid = esmf_regrid_build(grid_in, ls, 'bilinear')

regrid = esmf_regrid_build(ls, grid_in, 'nearest_s2d')


def test_read_weights(tmp_path):
fn = tmp_path / "weights.nc"
Expand Down Expand Up @@ -257,3 +262,4 @@ def test_read_weights(tmp_path):
with pytest.raises(ValueError):
ds = xr.open_dataset(fn)
read_weights(ds.drop_vars("col"), lon_in.size, lon_out.size)

16 changes: 16 additions & 0 deletions xesmf/tests/test_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,22 @@ def test_regrid_dataset_to_locstream():
ds_result = regridder(ds_in)


def test_build_regridder_with_masks():
ds_in['mask'] = xr.DataArray(
np.random.randint(2, size=ds_in['data'].shape),
dims=('y', 'x'))
print(ds_in)
# 'patch' is too slow to test
for method in ['bilinear', 'conservative', 'conservative_normed',
'nearest_s2d', 'nearest_d2s']:
regridder = xe.Regridder(ds_in, ds_out, method)

# check screen output
assert repr(regridder) == str(regridder)
assert 'xESMF Regridder' in str(regridder)
assert method in str(regridder)


def test_regrid_dataset_from_locstream():
# xarray.Dataset containing in-memory numpy array

Expand Down