diff --git a/pyresample/gradient/__init__.py b/pyresample/gradient/__init__.py index c8750108..e60be439 100644 --- a/pyresample/gradient/__init__.py +++ b/pyresample/gradient/__init__.py @@ -186,7 +186,10 @@ def _concatenate_chunks(chunks): def _fill_in_coords(target_geo_def, data_coords, data_dims): - x_coord, y_coord = target_geo_def.get_proj_vectors() + try: + x_coord, y_coord = target_geo_def.get_proj_vectors() + except AttributeError: + return None coords = [] for key in data_dims: if key == 'x': @@ -219,8 +222,6 @@ class ResampleBlocksGradientSearchResampler(BaseResampler): def __init__(self, source_geo_def, target_geo_def): """Init GradientResampler.""" - if isinstance(target_geo_def, SwathDefinition): - raise NotImplementedError("Cannot resample to a SwathDefinition.") if isinstance(source_geo_def, SwathDefinition): source_geo_def.lons = source_geo_def.lons.persist() source_geo_def.lats = source_geo_def.lats.persist() @@ -325,11 +326,13 @@ def _get_coordinates_in_same_projection(source_area, target_area): except AttributeError as err: lons, lats = source_area.get_lonlats() src_x, src_y = da.compute(lons, lats) + transformer = pyproj.Transformer.from_crs(target_area.crs, source_area.crs, always_xy=True) try: - transformer = pyproj.Transformer.from_crs(target_area.crs, source_area.crs, always_xy=True) dst_x, dst_y = transformer.transform(*target_area.get_proj_coords()) except AttributeError as err: - raise NotImplementedError("Cannot resample to Swath for now.") from err + # target is a swath definition + lons, lats = target_area.get_lonlats() + dst_x, dst_y = transformer.transform(*da.compute(lons, lats)) src_gradient_xl, src_gradient_xp = np.gradient(src_x, axis=[0, 1]) src_gradient_yl, src_gradient_yp = np.gradient(src_y, axis=[0, 1]) return (dst_x, dst_y), (src_gradient_xl, src_gradient_xp, src_gradient_yl, src_gradient_yp), (src_x, src_y) diff --git a/pyresample/slicer.py b/pyresample/slicer.py index 579fd58a..dceb6976 100644 --- a/pyresample/slicer.py +++ b/pyresample/slicer.py @@ -148,7 +148,10 @@ class AreaSlicer(Slicer): def get_polygon_to_contain(self): """Get the shapely Polygon corresponding to *area_to_contain* in projection coordinates of *area_to_crop*.""" from shapely.geometry import Polygon - x, y = self.area_to_contain.get_edge_bbox_in_projection_coordinates(frequency=10) + try: + x, y = self.area_to_contain.get_edge_bbox_in_projection_coordinates(frequency=10) + except AttributeError: + x, y = self.area_to_contain.get_edge_lonlats(vertices_per_side=10) if self.area_to_crop.is_geostationary: x_geos, y_geos = get_geostationary_bounding_box_in_proj_coords(self.area_to_crop, 360) x_geos, y_geos = self._transformer.transform(x_geos, y_geos, direction=TransformDirection.INVERSE) diff --git a/pyresample/test/test_gradient.py b/pyresample/test/test_gradient.py index c911f297..9706d713 100644 --- a/pyresample/test/test_gradient.py +++ b/pyresample/test/test_gradient.py @@ -325,7 +325,12 @@ class TestRBGradientSearchResamplerArea2Swath: def setup_method(self): """Set up the test case.""" - chunks = 20 + lons, lats = np.meshgrid(np.linspace(0, 20, 100), np.linspace(45, 66, 100)) + self.dst_swath = SwathDefinition(lons, lats, crs="WGS84") + lons, lats = self.dst_swath.get_lonlats(chunks=10) + lons = xr.DataArray(lons, dims=["y", "x"]) + lats = xr.DataArray(lats, dims=["y", "x"]) + self.dst_swath_dask = SwathDefinition(lons, lats) self.src_area = AreaDefinition('euro40', 'euro40', None, {'proj': 'stere', 'lon_0': 14.0, @@ -335,34 +340,49 @@ def setup_method(self): (-2717181.7304994687, -5571048.14031214, 1378818.2695005313, -1475048.1403121399)) - self.dst_area = AreaDefinition( - 'omerc_otf', - 'On-the-fly omerc area', - None, - {'alpha': '8.99811271718795', - 'ellps': 'sphere', - 'gamma': '0', - 'k': '1', - 'lat_0': '0', - 'lonc': '13.8096029486222', - 'proj': 'omerc', - 'units': 'm'}, - 50, 100, - (-1461111.3603, 3440088.0459, 1534864.0322, 9598335.0457) - ) - - self.lons, self.lats = self.dst_area.get_lonlats(chunks=chunks) - xrlons = xr.DataArray(self.lons.persist()) - xrlats = xr.DataArray(self.lats.persist()) - self.dst_swath = SwathDefinition(xrlons, xrlats) - - def test_resampling_to_swath_is_not_implemented(self): - """Test that resampling to swath is not working yet.""" - from pyresample.gradient import ResampleBlocksGradientSearchResampler - - with pytest.raises(NotImplementedError): - ResampleBlocksGradientSearchResampler(self.src_area, - self.dst_swath) + @pytest.mark.parametrize("input_dtype", (np.float32, np.float64)) + def test_resample_area_to_swath_2d(self, input_dtype): + """Resample swath to area, 2d.""" + swath_resampler = ResampleBlocksGradientSearchResampler(self.src_area, self.dst_swath_dask) + + data = xr.DataArray(da.ones(self.src_area.shape, dtype=input_dtype), + dims=['y', 'x']) + with np.errstate(invalid="ignore"): # 'inf' space pixels cause runtime warnings + swath_resampler.precompute() + res_xr = swath_resampler.compute(data, method='bilinear') + res_np = res_xr.compute(scheduler='single-threaded') + + assert res_xr.dtype == data.dtype + assert res_np.dtype == data.dtype + assert res_xr.shape == self.dst_swath.shape + assert res_np.shape == self.dst_swath.shape + assert type(res_xr) is type(data) + assert type(res_xr.data) is type(data.data) + assert not np.all(np.isnan(res_np)) + + @pytest.mark.parametrize("input_dtype", (np.float32, np.float64)) + def test_resample_area_to_swath_3d(self, input_dtype): + """Resample area to area, 3d.""" + swath_resampler = ResampleBlocksGradientSearchResampler(self.src_area, self.dst_swath_dask) + + data = xr.DataArray(da.ones((3, ) + self.src_area.shape, + dtype=input_dtype) * + np.array([1, 2, 3])[:, np.newaxis, np.newaxis], + dims=['bands', 'y', 'x']) + with np.errstate(invalid="ignore"): # 'inf' space pixels cause runtime warnings + swath_resampler.precompute() + res_xr = swath_resampler.compute(data, method='bilinear') + res_np = res_xr.compute(scheduler='single-threaded') + + assert res_xr.dtype == data.dtype + assert res_np.dtype == data.dtype + assert res_xr.shape == (3, ) + self.dst_swath.shape + assert res_np.shape == (3, ) + self.dst_swath.shape + assert type(res_xr) is type(data) + assert type(res_xr.data) is type(data.data) + for i in range(res_np.shape[0]): + arr = np.ravel(res_np[i, :, :]) + assert np.allclose(arr[np.isfinite(arr)], float(i + 1)) class TestEnsureDataArray(unittest.TestCase):