diff --git a/pyresample/gradient/__init__.py b/pyresample/gradient/__init__.py index e60be439..28448b0d 100644 --- a/pyresample/gradient/__init__.py +++ b/pyresample/gradient/__init__.py @@ -53,12 +53,28 @@ def GradientSearchResampler(source_geo_def, target_geo_def): def create_gradient_search_resampler(source_geo_def, target_geo_def): """Create a gradient search resampler.""" - if ((isinstance(source_geo_def, AreaDefinition) and isinstance(target_geo_def, AreaDefinition)) or - (isinstance(source_geo_def, SwathDefinition) and isinstance(target_geo_def, AreaDefinition))): + if (is_area_to_area(source_geo_def, target_geo_def) or + is_swath_to_area(source_geo_def, target_geo_def) or + is_area_to_swath(source_geo_def, target_geo_def)): return ResampleBlocksGradientSearchResampler(source_geo_def, target_geo_def) raise NotImplementedError +def is_area_to_area(source_geo_def, target_geo_def): + """Check if source is area and target is area.""" + return isinstance(source_geo_def, AreaDefinition) and isinstance(target_geo_def, AreaDefinition) + + +def is_swath_to_area(source_geo_def, target_geo_def): + """Check if source is swath and target is area.""" + return isinstance(source_geo_def, SwathDefinition) and isinstance(target_geo_def, AreaDefinition) + + +def is_area_to_swath(source_geo_def, target_geo_def): + """Check if source is area and targed is swath.""" + return isinstance(source_geo_def, AreaDefinition) and isinstance(target_geo_def, SwathDefinition) + + def _gradient_resample_data(src_data, src_x, src_y, src_gradient_xl, src_gradient_xp, src_gradient_yl, src_gradient_yp, @@ -323,13 +339,18 @@ def gradient_resampler_indices(source_area, target_area, block_info=None, **kwar def _get_coordinates_in_same_projection(source_area, target_area): try: src_x, src_y = source_area.get_proj_coords() - except AttributeError as err: + work_crs = source_area.crs + except AttributeError: + # source is a swath definition, use target crs instead lons, lats = source_area.get_lonlats() src_x, src_y = da.compute(lons, lats) - transformer = pyproj.Transformer.from_crs(target_area.crs, source_area.crs, always_xy=True) + trans = pyproj.Transformer.from_crs(source_area.crs, target_area.crs, always_xy=True) + src_x, src_y = trans.transform(src_x, src_y) + work_crs = target_area.crs + transformer = pyproj.Transformer.from_crs(target_area.crs, work_crs, always_xy=True) try: dst_x, dst_y = transformer.transform(*target_area.get_proj_coords()) - except AttributeError as err: + except AttributeError: # target is a swath definition lons, lats = target_area.get_lonlats() dst_x, dst_y = transformer.transform(*da.compute(lons, lats)) @@ -345,6 +366,9 @@ def block_bilinear_interpolator(data, indices_xy, fill_value=np.nan, block_info= weight_l, l_start = np.modf(y_indices.clip(0, data.shape[-2] - 1)) weight_p, p_start = np.modf(x_indices.clip(0, data.shape[-1] - 1)) + weight_l = weight_l.astype(data.dtype) + weight_p = weight_p.astype(data.dtype) + l_start = l_start.astype(int) p_start = p_start.astype(int) l_end = np.clip(l_start + 1, 1, data.shape[-2] - 1) @@ -353,7 +377,7 @@ def block_bilinear_interpolator(data, indices_xy, fill_value=np.nan, block_info= res = ((1 - weight_l) * (1 - weight_p) * data[..., l_start, p_start] + (1 - weight_l) * weight_p * data[..., l_start, p_end] + weight_l * (1 - weight_p) * data[..., l_end, p_start] + - weight_l * weight_p * data[..., l_end, p_end]).astype(data.dtype) + weight_l * weight_p * data[..., l_end, p_end]) res = np.where(mask, fill_value, res) return res diff --git a/pyresample/slicer.py b/pyresample/slicer.py index dceb6976..0f6bcc52 100644 --- a/pyresample/slicer.py +++ b/pyresample/slicer.py @@ -67,11 +67,13 @@ class Slicer(ABC): """ - def __init__(self, area_to_crop, area_to_contain): + def __init__(self, area_to_crop, area_to_contain, work_crs): """Set up the Slicer.""" self.area_to_crop = area_to_crop self.area_to_contain = area_to_contain - self._transformer = Transformer.from_crs(self.area_to_contain.crs, self.area_to_crop.crs, always_xy=True) + + self._source_transformer = Transformer.from_crs(self.area_to_contain.crs, work_crs, always_xy=True) + self._target_transformer = Transformer.from_crs(self.area_to_crop.crs, work_crs, always_xy=True) def get_slices(self): """Get the slices to crop *area_to_crop* enclosing *area_to_contain*.""" @@ -92,17 +94,23 @@ def get_slices_from_polygon(self, poly): class SwathSlicer(Slicer): """A Slicer for cropping SwathDefinitions.""" + def __init__(self, area_to_crop, area_to_contain, work_crs=None): + """Set up the Slicer.""" + if work_crs is None: + work_crs = area_to_contain.crs + super().__init__(area_to_crop, area_to_contain, work_crs) + def get_polygon_to_contain(self): """Get the shapely Polygon corresponding to *area_to_contain* in lon/lat coordinates.""" from shapely.geometry import Polygon x, y = self.area_to_contain.get_edge_bbox_in_projection_coordinates(10) - poly = Polygon(zip(*self._transformer.transform(x, y))) + poly = Polygon(zip(*self._source_transformer.transform(x, y))) return poly def get_slices_from_polygon(self, poly): """Get the slices based on the polygon.""" intersecting_chunk_slices = [] - for smaller_poly, slices in _get_chunk_polygons_for_swath_to_crop(self.area_to_crop): + for smaller_poly, slices in self._get_chunk_polygons_for_swath_to_crop(self.area_to_crop): if smaller_poly.intersects(poly): intersecting_chunk_slices.append(slices) if not intersecting_chunk_slices: @@ -118,12 +126,18 @@ def _assemble_slices(chunk_slices): slices = col_slice, line_slice return slices + def _get_chunk_polygons_for_swath_to_crop(self, swath_to_crop): + """Get the polygons for each chunk of the area_to_crop.""" + from shapely.geometry import Polygon + for ((lons, lats), (line_slice, col_slice)) in _get_chunk_bboxes_for_swath_to_crop(swath_to_crop): + smaller_poly = Polygon(zip(*self._target_transformer.transform(lons, lats))) + yield (smaller_poly, (line_slice, col_slice)) + @lru_cache(maxsize=10) -def _get_chunk_polygons_for_swath_to_crop(swath_to_crop): - """Get the polygons for each chunk of the area_to_crop.""" +def _get_chunk_bboxes_for_swath_to_crop(swath_to_crop): + """Get the lon/lat bouding boxes for each chunk of the area_to_crop.""" res = [] - from shapely.geometry import Polygon src_chunks = swath_to_crop.lons.chunks for _position, (line_slice, col_slice) in _enumerate_chunk_slices(src_chunks): line_slice = expand_slice(line_slice) @@ -132,8 +146,7 @@ def _get_chunk_polygons_for_swath_to_crop(swath_to_crop): lons, lats = smaller_swath.get_edge_lonlats(10) lons = np.hstack(lons) lats = np.hstack(lats) - smaller_poly = Polygon(zip(lons, lats)) - res.append((smaller_poly, (line_slice, col_slice))) + res.append(((lons, lats), (line_slice, col_slice))) return res @@ -145,6 +158,11 @@ def expand_slice(small_slice): class AreaSlicer(Slicer): """A Slicer for cropping AreaDefinitions.""" + def __init__(self, area_to_crop, area_to_contain): + """Set up the Slicer.""" + work_crs = area_to_crop.crs + super().__init__(area_to_crop, area_to_contain, work_crs) + def get_polygon_to_contain(self): """Get the shapely Polygon corresponding to *area_to_contain* in projection coordinates of *area_to_crop*.""" from shapely.geometry import Polygon @@ -154,7 +172,7 @@ def get_polygon_to_contain(self): x, y = self.area_to_contain.get_edge_lonlats(vertices_per_side=10) if self.area_to_crop.is_geostationary: x_geos, y_geos = get_geostationary_bounding_box_in_proj_coords(self.area_to_crop, 360) - x_geos, y_geos = self._transformer.transform(x_geos, y_geos, direction=TransformDirection.INVERSE) + x_geos, y_geos = self._source_transformer.transform(x_geos, y_geos, direction=TransformDirection.INVERSE) geos_poly = Polygon(zip(x_geos, y_geos)) poly = Polygon(zip(x, y)) poly = poly.intersection(geos_poly) @@ -162,7 +180,7 @@ def get_polygon_to_contain(self): raise IncompatibleAreas('No slice on area.') x, y = zip(*poly.exterior.coords) - return Polygon(zip(*self._transformer.transform(x, y))) + return Polygon(zip(*self._source_transformer.transform(x, y))) def get_slices_from_polygon(self, poly_to_contain): """Get the slices based on the polygon.""" diff --git a/pyresample/test/test_gradient.py b/pyresample/test/test_gradient.py index 9706d713..5d93b17b 100644 --- a/pyresample/test/test_gradient.py +++ b/pyresample/test/test_gradient.py @@ -32,7 +32,7 @@ from pyresample.area_config import create_area_def from pyresample.geometry import AreaDefinition, SwathDefinition -from pyresample.gradient import ResampleBlocksGradientSearchResampler +from pyresample.gradient import ResampleBlocksGradientSearchResampler, create_gradient_search_resampler class TestRBGradientSearchResamplerArea2Area: @@ -343,7 +343,7 @@ def setup_method(self): @pytest.mark.parametrize("input_dtype", (np.float32, np.float64)) def test_resample_area_to_swath_2d(self, input_dtype): """Resample swath to area, 2d.""" - swath_resampler = ResampleBlocksGradientSearchResampler(self.src_area, self.dst_swath_dask) + swath_resampler = create_gradient_search_resampler(self.src_area, self.dst_swath_dask) data = xr.DataArray(da.ones(self.src_area.shape, dtype=input_dtype), dims=['y', 'x']) diff --git a/pyresample/test/test_slicer.py b/pyresample/test/test_slicer.py index af082f89..5868c363 100644 --- a/pyresample/test/test_slicer.py +++ b/pyresample/test/test_slicer.py @@ -223,15 +223,12 @@ def setUp(self): (-1461111.3603, 3440088.0459, 1534864.0322, 9598335.0457) ) - self.lons, self.lats = self.src_area.get_lonlats(chunks=chunks) - xrlons = xr.DataArray(self.lons.persist()) - xrlats = xr.DataArray(self.lats.persist()) - self.src_swath = SwathDefinition(xrlons, xrlats) + self.src_swath = swath_from_area(self.src_area, chunks) def test_slicer_init(self): """Test slicer initialization.""" slicer = create_slicer(self.src_swath, self.dst_area) - assert slicer.area_to_crop == self.src_area + assert slicer.area_to_crop == self.src_swath assert slicer.area_to_contain == self.dst_area def test_source_swath_slicing_does_not_return_full_dataset(self): @@ -246,17 +243,61 @@ def test_source_swath_slicing_does_not_return_full_dataset(self): def test_source_area_slicing_does_not_return_full_dataset(self): """Test source area covers dest area.""" - slicer = create_slicer(self.src_area, self.dst_area) + slicer = create_slicer(self.src_swath, self.dst_area) x_slice, y_slice = slicer.get_slices() assert x_slice.start == 0 - assert x_slice.stop == 35 - assert y_slice.start == 16 - assert y_slice.stop == 94 + assert x_slice.stop == 41 + assert y_slice.start == 9 + assert y_slice.stop == 91 + + def test_source_area_slicing_over_date_line(self): + src_area = AreaDefinition( + 'omerc_otf', + 'On-the-fly omerc area', + None, + {'alpha': '8.99811271718795', + 'ellps': 'sphere', + 'gamma': '0', + 'k': '1', + 'lat_0': '0', + 'lonc': '179.8096029486222', + 'proj': 'omerc', + 'units': 'm'}, + 50, 100, + (-1461111.3603, 3440088.0459, 1534864.0322, 9598335.0457) + ) + chunks = 10 + src_swath = swath_from_area(src_area, chunks) + + dst_area = AreaDefinition('somewhere in the pacific', 'somewhere', None, + {'proj': 'stere', 'lon_0': 180.0, + 'lat_0': 90.0, 'lat_ts': 60.0, + 'ellps': 'bessel'}, + 102, 102, + (-2717181.7304994687, -5571048.14031214, + 1378818.2695005313, -1475048.1403121399)) + + slicer = create_slicer(src_swath, dst_area) + x_slice, y_slice = slicer.get_slices() + assert x_slice.start == 0 + assert x_slice.stop == 41 + assert y_slice.start == 9 + assert y_slice.stop == 91 + + def test_source_area_slicing_with_custom_work_crs(self): + """Test source area covers dest area.""" + from pyresample.slicer import SwathSlicer + slicer = SwathSlicer(self.src_swath, self.dst_area, work_crs=self.src_area.crs) + x_slice, y_slice = slicer.get_slices() + assert x_slice.start == 0 + assert x_slice.stop == 41 + assert y_slice.start == 9 + assert y_slice.stop == 91 def test_area_get_polygon_returns_a_polygon(self): """Test getting a polygon returns a polygon.""" from shapely.geometry import Polygon - slicer = create_slicer(self.src_area, self.dst_area) + slicer = create_slicer(self.src_swath, self.dst_area) poly = slicer.get_polygon_to_contain() assert isinstance(poly, Polygon) @@ -271,3 +312,11 @@ def test_cannot_slice_a_string(self): """Test that we cannot slice a string.""" with pytest.raises(NotImplementedError): create_slicer("my_funky_area", self.dst_area) + + +def swath_from_area(src_area, chunks): + """Create a SwathDefinition from an AreaDefinition.""" + lons, lats = src_area.get_lonlats(chunks=chunks) + xrlons = xr.DataArray(lons.persist()) + xrlats = xr.DataArray(lats.persist()) + return SwathDefinition(xrlons, xrlats)