Skip to content

Commit

Permalink
Merge pull request JiaweiZhuang#60 from malmans2/consistent_naming
Browse files Browse the repository at this point in the history
Return objects consistent with ds_out coordinates
  • Loading branch information
raphaeldussin authored Jan 11, 2021
2 parents 9ffbefb + 7b12e00 commit 2ea833b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
20 changes: 16 additions & 4 deletions xesmf/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import scipy.sparse as sps
import xarray as xr
from xarray import DataArray

from .backend import Grid, LocStream, Mesh, add_corner, esmf_regrid_build, esmf_regrid_finalize
from .smm import _combine_weight_multipoly, add_nans_to_weights, apply_weights, read_weights
Expand Down Expand Up @@ -664,6 +665,12 @@ def __init__(
# record output grid and metadata
lon_out, lat_out = _get_lon_lat(ds_out)
self._lon_out, self._lat_out = np.asarray(lon_out), np.asarray(lat_out)
self._coord_names = dict(
lon=lon_out.name if isinstance(lon_out, DataArray) else 'lon',
lat=lat_out.name if isinstance(lat_out, DataArray) else 'lat',
)
self._lon_out_attrs = lon_out.attrs if isinstance(lon_out, DataArray) else {}
self._lat_out_attrs = lat_out.attrs if isinstance(lat_out, DataArray) else {}

if self._lon_out.ndim == 2:
try:
Expand Down Expand Up @@ -692,18 +699,23 @@ def _format_xroutput(self, out, new_dims=None):

# append output horizontal coordinate values
# extra coordinates are automatically tracked by apply_ufunc
lon_args = dict(data=self._lon_out, attrs=self._lon_out_attrs)
lat_args = dict(data=self._lat_out, attrs=self._lat_out_attrs)
if self.sequence_out:
out.coords['lon'] = xr.DataArray(self._lon_out, dims=('locations',))
out.coords['lat'] = xr.DataArray(self._lat_out, dims=('locations',))
out.coords['lon'] = xr.DataArray(**lon_args, dims=('locations',))
out.coords['lat'] = xr.DataArray(**lat_args, dims=('locations',))
else:
out.coords['lon'] = xr.DataArray(self._lon_out, dims=self.lon_dim)
out.coords['lat'] = xr.DataArray(self._lat_out, dims=self.lat_dim)
out.coords['lon'] = xr.DataArray(**lon_args, dims=self.lon_dim)
out.coords['lat'] = xr.DataArray(**lat_args, dims=self.lat_dim)

out.attrs['regrid_method'] = self.method

if self.sequence_out:
out = out.squeeze(dim='dummy')

# Use ds_out coordinates
out = out.rename(self._coord_names)

return out


Expand Down
12 changes: 8 additions & 4 deletions xesmf/tests/test_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,10 +268,12 @@ def test_regrid_dataarray(use_cfxr):
# xarray.DataArray containing in-memory numpy array
if use_cfxr:
ds_in2 = ds_in.rename(lat='Latitude', lon='Longitude')
ds_out2 = ds_out.rename(lat='Latitude', lon='Longitude')
else:
ds_in2 = ds_in
ds_out2 = ds_out

regridder = xe.Regridder(ds_in2, ds_out, 'conservative')
regridder = xe.Regridder(ds_in2, ds_out2, 'conservative')

outdata = regridder(ds_in2['data'].values) # pure numpy array
dr_out = regridder(ds_in2['data']) # xarray DataArray
Expand All @@ -280,12 +282,14 @@ def test_regrid_dataarray(use_cfxr):
assert_equal(outdata, dr_out.values)

# compare with analytical solution
rel_err = (ds_out['data_ref'] - dr_out) / ds_out['data_ref']
rel_err = (ds_out2['data_ref'] - dr_out) / ds_out2['data_ref']
assert np.max(np.abs(rel_err)) < 0.05

# check metadata
assert_equal(dr_out['lat'].values, ds_out['lat'].values)
assert_equal(dr_out['lon'].values, ds_out['lon'].values)
lat_name = 'Latitude' if use_cfxr else 'lat'
lon_name = 'Longitude' if use_cfxr else 'lon'
xr.testing.assert_identical(dr_out[lat_name], ds_out2[lat_name])
xr.testing.assert_identical(dr_out[lon_name], ds_out2[lon_name])

# test broadcasting
dr_out_4D = regridder(ds_in2['data4D'])
Expand Down

0 comments on commit 2ea833b

Please sign in to comment.