Skip to content

Commit

Permalink
Merge pull request pangeo-data#1 from raphaeldussin/fix-75-v2
Browse files Browse the repository at this point in the history
re-enable legacy args
  • Loading branch information
huard authored Aug 17, 2020
2 parents 0a07d6e + 21b0c7b commit b5f3d07
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 2 deletions.
36 changes: 34 additions & 2 deletions xesmf/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def ds_to_ESMFlocstream(ds):

class Regridder(object):
def __init__(self, ds_in, ds_out, method, periodic=False,
filename=None, reuse_weights=False,
weights=None, ignore_degenerate=None,
locstream_in=False, locstream_out=False):
"""
Expand Down Expand Up @@ -134,6 +135,17 @@ def __init__(self, ds_in, ds_out, method, periodic=False,
Only useful for global grids with non-conservative regridding.
Will be forced to False for conservative regridding.
filename : str, optional
Name for the weight file. The default naming scheme is::
{method}_{Ny_in}x{Nx_in}_{Ny_out}x{Nx_out}.nc
e.g. bilinear_400x600_300x400.nc
reuse_weights : bool, optional
Whether to read existing weight file to save computing time.
False by default (i.e. re-compute, not reuse).
weights : None, coo_matrix, dict, str, Dataset, Path,
Regridding weights, stored as
- a scipy.sparse COO matrix,
Expand Down Expand Up @@ -167,6 +179,7 @@ def __init__(self, ds_in, ds_out, method, periodic=False,

self.method = method
self.periodic = periodic
self.reuse_weights = reuse_weights
self.ignore_degenerate = ignore_degenerate
self.locstream_in = locstream_in
self.locstream_out = locstream_out
Expand Down Expand Up @@ -222,12 +235,27 @@ def __init__(self, ds_in, ds_out, method, periodic=False,
self.n_in = shape_in[0] * shape_in[1]
self.n_out = shape_out[0] * shape_out[1]

if weights is None:
# some logic about reusing weights with either filename or weights args
if reuse_weights and (filename is None) and (weights is None):
raise ValueError("to reuse weights, you need to provide either filename or weights")

if not reuse_weights and weights is None:
weights = self._compute_weights() # Dictionary of weights
else:
weights = filename if filename is not None else weights

assert weights is not None

# Convert weights, whatever their format, to a sparse coo matrix
self.weights = read_weights(weights, self.n_in, self.n_out)

# follows legacy logic of writing weights if filename is provided
if filename is not None and not reuse_weights:
self.to_netcdf(filename=filename)

# set default weights filename if none given
self.filename = self._get_default_filename() if filename is None else filename

@property
def A(self):
message = (
Expand Down Expand Up @@ -265,11 +293,15 @@ def _compute_weights(self):
def __repr__(self):
info = ('xESMF Regridder \n'
'Regridding algorithm: {} \n'
'Weight filename: {} \n'
'Reuse pre-computed weights? {} \n'
'Input grid shape: {} \n'
'Output grid shape: {} \n'
'Output grid dimension name: {} \n'
'Periodic in longitude? {}'
.format(self.method,
self.filename,
self.reuse_weights,
self.shape_in,
self.shape_out,
self.out_horiz_dims,
Expand Down Expand Up @@ -479,7 +511,7 @@ def regrid_dataset(self, ds_in, keep_attrs=False):
def to_netcdf(self, filename=None):
'''Save weights to disk as a netCDF file.'''
if filename is None:
filename = self._get_default_filename()
filename = self.filename
w = self.weights
dim = "n_s"
ds = xr.Dataset({"S": (dim, w.data), "col": (dim, w.col + 1), "row": (dim, w.row + 1)})
Expand Down
23 changes: 23 additions & 0 deletions xesmf/tests/test_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,32 @@ def test_existing_weights():
weights=fn)
assert regridder_reuse.A.shape == regridder.A.shape

# this should also work with reuse_weights=True
regridder_reuse = xe.Regridder(ds_in, ds_out, method,
reuse_weights=True, weights=fn)
assert regridder_reuse.A.shape == regridder.A.shape

# or can also overwrite it
xe.Regridder(ds_in, ds_out, method)

# check legacy args still work
regridder = xe.Regridder(ds_in, ds_out, method, filename='wgts.nc')
regridder_reuse = xe.Regridder(ds_in, ds_out, method,
reuse_weights=True,
filename='wgts.nc')
assert regridder_reuse.A.shape == regridder.A.shape

# check fails on non-existent file
with pytest.raises(OSError):
regridder_reuse = xe.Regridder(ds_in, ds_out, method,
reuse_weights=True,
filename='fakewgts.nc')

# check fails if no weights are provided
with pytest.raises(ValueError):
regridder_reuse = xe.Regridder(ds_in, ds_out, method,
reuse_weights=True)


def test_to_netcdf(tmp_path):
from xesmf.backend import esmf_grid, esmf_regrid_build
Expand Down

0 comments on commit b5f3d07

Please sign in to comment.