-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for masked data (closes #20)
- Loading branch information
1 parent
4477af0
commit 77f2119
Showing
7 changed files
with
247 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Integration tests for esmf_regrid.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Integration tests for :mod:`esmf_regrid.esmf_regridder`.""" |
67 changes: 67 additions & 0 deletions
67
esmf_regrid/tests/integration/esmf_regridder/test_Regridder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
"""Unit tests for :class:`esmf_regrid.esmf_regridder.Regridder`.""" | ||
|
||
import ESMF | ||
import numpy as np | ||
from numpy import ma | ||
|
||
from esmf_regrid.esmf_regridder import GridInfo, Regridder | ||
from esmf_regrid.tests import make_grid_args | ||
|
||
|
||
def test_esmpy_normalisation(): | ||
""" | ||
Integration test for :meth:`~esmf_regrid.esmf_regridder.Regridder`. | ||
Checks against ESMF to ensure results are consistent. | ||
""" | ||
src_data = np.array( | ||
[ | ||
[1.0, 1.0, 1.0], | ||
[1.0, 0.0, 0.0], | ||
], | ||
) | ||
src_mask = np.array( | ||
[ | ||
[True, False, False], | ||
[False, False, False], | ||
] | ||
) | ||
src_array = ma.array(src_data, mask=src_mask) | ||
|
||
lon, lat, lon_bounds, lat_bounds = make_grid_args(2, 3) | ||
src_grid = GridInfo(lon, lat, lon_bounds, lat_bounds) | ||
src_esmpy_grid = src_grid._make_esmf_grid() | ||
src_esmpy_grid.add_item(ESMF.GridItem.MASK, staggerloc=ESMF.StaggerLoc.CENTER) | ||
src_esmpy_grid.mask[0][...] = src_mask.T | ||
src_field = ESMF.Field(src_esmpy_grid) | ||
src_field.data[...] = src_data.T | ||
|
||
lon, lat, lon_bounds, lat_bounds = make_grid_args(3, 2) | ||
tgt_grid = GridInfo(lon, lat, lon_bounds, lat_bounds) | ||
tgt_field = tgt_grid.make_esmf_field() | ||
|
||
regridder = Regridder(src_grid, tgt_grid) | ||
|
||
regridding_kwargs = { | ||
"ignore_degenerate": True, | ||
"regrid_method": ESMF.RegridMethod.CONSERVE, | ||
"unmapped_action": ESMF.UnmappedAction.IGNORE, | ||
"factors": True, | ||
"src_mask_values": [1], | ||
} | ||
esmpy_fracarea_regridder = ESMF.Regrid( | ||
src_field, tgt_field, norm_type=ESMF.NormType.FRACAREA, **regridding_kwargs | ||
) | ||
esmpy_dstarea_regridder = ESMF.Regrid( | ||
src_field, tgt_field, norm_type=ESMF.NormType.DSTAREA, **regridding_kwargs | ||
) | ||
|
||
tgt_field_dstarea = esmpy_dstarea_regridder(src_field, tgt_field) | ||
result_esmpy_dstarea = tgt_field_dstarea.data | ||
result_dstarea = regridder.regrid(src_array, norm_type="dstarea").T | ||
assert ma.allclose(result_esmpy_dstarea, result_dstarea) | ||
|
||
tgt_field_fracarea = esmpy_fracarea_regridder(src_field, tgt_field) | ||
result_esmpy_fracarea = tgt_field_fracarea.data | ||
result_fracarea = regridder.regrid(src_array, norm_type="fracarea").T | ||
assert ma.allclose(result_esmpy_fracarea, result_fracarea) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Unit tests for :mod:`esmf_regrid.esmf_regridder`.""" |
120 changes: 120 additions & 0 deletions
120
esmf_regrid/tests/unit/esmf_regridder/test_Regridder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
"""Unit tests for :class:`esmf_regrid.esmf_regridder.Regridder`.""" | ||
|
||
import numpy as np | ||
from numpy import ma | ||
import pytest | ||
import scipy.sparse | ||
|
||
from esmf_regrid.esmf_regridder import GridInfo, Regridder | ||
from esmf_regrid.tests import make_grid_args | ||
|
||
|
||
def _expected_weights(): | ||
weight_list = np.array( | ||
[ | ||
0.6674194025656819, | ||
0.3325805974343169, | ||
0.3351257294386341, | ||
0.6648742705613656, | ||
0.33363933739884066, | ||
0.1663606626011589, | ||
0.333639337398841, | ||
0.1663606626011591, | ||
0.16742273275056854, | ||
0.33250863479149745, | ||
0.16742273275056876, | ||
0.33250863479149767, | ||
0.6674194025656823, | ||
0.3325805974343174, | ||
0.3351257294386344, | ||
0.6648742705613663, | ||
] | ||
) | ||
rows = np.array([0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 5, 5]) | ||
columns = np.array([0, 1, 1, 2, 0, 1, 3, 4, 1, 2, 4, 5, 3, 4, 4, 5]) | ||
|
||
shape = (6, 6) | ||
|
||
weights = scipy.sparse.csr_matrix((weight_list, (rows, columns)), shape=shape) | ||
return weights | ||
|
||
|
||
def test_Regridder_init(): | ||
"""Basic test for :meth:`~esmf_regrid.esmf_regridder.Regridder.__init__`.""" | ||
lon, lat, lon_bounds, lat_bounds = make_grid_args(2, 3) | ||
src_grid = GridInfo(lon, lat, lon_bounds, lat_bounds) | ||
|
||
lon, lat, lon_bounds, lat_bounds = make_grid_args(3, 2) | ||
tgt_grid = GridInfo(lon, lat, lon_bounds, lat_bounds) | ||
|
||
rg = Regridder(src_grid, tgt_grid) | ||
|
||
result = rg.weight_matrix | ||
expected = _expected_weights() | ||
|
||
assert np.allclose(result.toarray(), expected.toarray()) | ||
|
||
|
||
def test_Regridder_regrid(): | ||
"""Basic test for :meth:`~esmf_regrid.esmf_regridder.Regridder.regrid`.""" | ||
lon, lat, lon_bounds, lat_bounds = make_grid_args(2, 3) | ||
src_grid = GridInfo(lon, lat, lon_bounds, lat_bounds) | ||
|
||
lon, lat, lon_bounds, lat_bounds = make_grid_args(3, 2) | ||
tgt_grid = GridInfo(lon, lat, lon_bounds, lat_bounds) | ||
|
||
# Set up the regridder with precomputed weights. | ||
rg = Regridder(src_grid, tgt_grid, precomputed_weights=_expected_weights()) | ||
|
||
src_array = np.array([[1.0, 1.0, 1.0], [1.0, 0.0, 0.0]]) | ||
src_masked = ma.array(src_array, mask=[[1, 0, 0], [0, 0, 0]]) | ||
|
||
# Regrid with unmasked data. | ||
result_nomask = rg.regrid(src_array) | ||
expected_nomask = ma.array( | ||
[ | ||
[1.0, 1.0], | ||
[0.8336393373988409, 0.4999999999999997], | ||
[0.6674194025656824, 0.0], | ||
] | ||
) | ||
assert ma.allclose(result_nomask, expected_nomask) | ||
|
||
# Regrid with an masked array with no masked points. | ||
result_ma_nomask = rg.regrid(ma.array(src_array)) | ||
assert ma.allclose(result_ma_nomask, expected_nomask) | ||
|
||
# Regrid with a fully masked array. | ||
result_fullmask = rg.regrid(ma.array(src_array, mask=True)) | ||
expected_fulmask = ma.array(np.zeros([3, 2]), mask=True) | ||
assert ma.allclose(result_fullmask, expected_fulmask) | ||
|
||
# Regrid with a masked array containing a masked point. | ||
result_withmask = rg.regrid(src_masked) | ||
expected_withmask = ma.array( | ||
[ | ||
[0.9999999999999999, 1.0], | ||
[0.7503444126612077, 0.4999999999999997], | ||
[0.6674194025656824, 0.0], | ||
] | ||
) | ||
assert ma.allclose(result_withmask, expected_withmask) | ||
|
||
# Regrid while setting mdtol. | ||
result_half_mdtol = rg.regrid(src_masked, mdtol=0.5) | ||
expected_half_mdtol = ma.array(expected_withmask, mask=[[1, 0], [0, 0], [1, 0]]) | ||
assert ma.allclose(result_half_mdtol, expected_half_mdtol) | ||
|
||
# Regrid with norm_type="dstarea". | ||
result_dstarea = rg.regrid(src_masked, norm_type="dstarea") | ||
expected_dstarea = ma.array( | ||
[ | ||
[0.3325805974343169, 0.9999999999999998], | ||
[0.4999999999999999, 0.499931367542066], | ||
[0.6674194025656823, 0.0], | ||
] | ||
) | ||
assert ma.allclose(result_dstarea, expected_dstarea) | ||
|
||
with pytest.raises(ValueError): | ||
_ = rg.regrid(src_masked, norm_type="INVALID") |