Skip to content

Commit

Permalink
Refactor and test more
Browse files Browse the repository at this point in the history
  • Loading branch information
mraspaud committed Oct 24, 2024
1 parent c174f88 commit 4dd2948
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 29 deletions.
36 changes: 30 additions & 6 deletions pyresample/gradient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,28 @@ def GradientSearchResampler(source_geo_def, target_geo_def):

def create_gradient_search_resampler(source_geo_def, target_geo_def):
"""Create a gradient search resampler."""
if ((isinstance(source_geo_def, AreaDefinition) and isinstance(target_geo_def, AreaDefinition)) or
(isinstance(source_geo_def, SwathDefinition) and isinstance(target_geo_def, AreaDefinition))):
if (is_area_to_area(source_geo_def, target_geo_def) or
is_swath_to_area(source_geo_def, target_geo_def) or
is_area_to_swath(source_geo_def, target_geo_def)):
return ResampleBlocksGradientSearchResampler(source_geo_def, target_geo_def)
raise NotImplementedError


def is_area_to_area(source_geo_def, target_geo_def):
"""Check if source is area and target is area."""
return isinstance(source_geo_def, AreaDefinition) and isinstance(target_geo_def, AreaDefinition)


def is_swath_to_area(source_geo_def, target_geo_def):
"""Check if source is swath and target is area."""
return isinstance(source_geo_def, SwathDefinition) and isinstance(target_geo_def, AreaDefinition)


def is_area_to_swath(source_geo_def, target_geo_def):
"""Check if source is area and targed is swath."""
return isinstance(source_geo_def, AreaDefinition) and isinstance(target_geo_def, SwathDefinition)


def _gradient_resample_data(src_data, src_x, src_y,
src_gradient_xl, src_gradient_xp,
src_gradient_yl, src_gradient_yp,
Expand Down Expand Up @@ -323,13 +339,18 @@ def gradient_resampler_indices(source_area, target_area, block_info=None, **kwar
def _get_coordinates_in_same_projection(source_area, target_area):
try:
src_x, src_y = source_area.get_proj_coords()
except AttributeError as err:
work_crs = source_area.crs
except AttributeError:
# source is a swath definition, use target crs instead
lons, lats = source_area.get_lonlats()
src_x, src_y = da.compute(lons, lats)
transformer = pyproj.Transformer.from_crs(target_area.crs, source_area.crs, always_xy=True)
trans = pyproj.Transformer.from_crs(source_area.crs, target_area.crs, always_xy=True)
src_x, src_y = trans.transform(src_x, src_y)
work_crs = target_area.crs
transformer = pyproj.Transformer.from_crs(target_area.crs, work_crs, always_xy=True)
try:
dst_x, dst_y = transformer.transform(*target_area.get_proj_coords())
except AttributeError as err:
except AttributeError:
# target is a swath definition
lons, lats = target_area.get_lonlats()
dst_x, dst_y = transformer.transform(*da.compute(lons, lats))
Expand All @@ -345,6 +366,9 @@ def block_bilinear_interpolator(data, indices_xy, fill_value=np.nan, block_info=
weight_l, l_start = np.modf(y_indices.clip(0, data.shape[-2] - 1))
weight_p, p_start = np.modf(x_indices.clip(0, data.shape[-1] - 1))

weight_l = weight_l.astype(data.dtype)
weight_p = weight_p.astype(data.dtype)

l_start = l_start.astype(int)
p_start = p_start.astype(int)
l_end = np.clip(l_start + 1, 1, data.shape[-2] - 1)
Expand All @@ -353,7 +377,7 @@ def block_bilinear_interpolator(data, indices_xy, fill_value=np.nan, block_info=
res = ((1 - weight_l) * (1 - weight_p) * data[..., l_start, p_start] +
(1 - weight_l) * weight_p * data[..., l_start, p_end] +
weight_l * (1 - weight_p) * data[..., l_end, p_start] +
weight_l * weight_p * data[..., l_end, p_end]).astype(data.dtype)
weight_l * weight_p * data[..., l_end, p_end])
res = np.where(mask, fill_value, res)
return res

Expand Down
40 changes: 29 additions & 11 deletions pyresample/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,13 @@ class Slicer(ABC):
"""

def __init__(self, area_to_crop, area_to_contain):
def __init__(self, area_to_crop, area_to_contain, work_crs):
"""Set up the Slicer."""
self.area_to_crop = area_to_crop
self.area_to_contain = area_to_contain
self._transformer = Transformer.from_crs(self.area_to_contain.crs, self.area_to_crop.crs, always_xy=True)

self._source_transformer = Transformer.from_crs(self.area_to_contain.crs, work_crs, always_xy=True)
self._target_transformer = Transformer.from_crs(self.area_to_crop.crs, work_crs, always_xy=True)

def get_slices(self):
"""Get the slices to crop *area_to_crop* enclosing *area_to_contain*."""
Expand All @@ -92,17 +94,23 @@ def get_slices_from_polygon(self, poly):
class SwathSlicer(Slicer):
"""A Slicer for cropping SwathDefinitions."""

def __init__(self, area_to_crop, area_to_contain, work_crs=None):
"""Set up the Slicer."""
if work_crs is None:
work_crs = area_to_contain.crs
super().__init__(area_to_crop, area_to_contain, work_crs)

def get_polygon_to_contain(self):
"""Get the shapely Polygon corresponding to *area_to_contain* in lon/lat coordinates."""
from shapely.geometry import Polygon
x, y = self.area_to_contain.get_edge_bbox_in_projection_coordinates(10)
poly = Polygon(zip(*self._transformer.transform(x, y)))
poly = Polygon(zip(*self._source_transformer.transform(x, y)))
return poly

def get_slices_from_polygon(self, poly):
"""Get the slices based on the polygon."""
intersecting_chunk_slices = []
for smaller_poly, slices in _get_chunk_polygons_for_swath_to_crop(self.area_to_crop):
for smaller_poly, slices in self._get_chunk_polygons_for_swath_to_crop(self.area_to_crop):
if smaller_poly.intersects(poly):
intersecting_chunk_slices.append(slices)
if not intersecting_chunk_slices:
Expand All @@ -118,12 +126,18 @@ def _assemble_slices(chunk_slices):
slices = col_slice, line_slice
return slices

def _get_chunk_polygons_for_swath_to_crop(self, swath_to_crop):
"""Get the polygons for each chunk of the area_to_crop."""
from shapely.geometry import Polygon
for ((lons, lats), (line_slice, col_slice)) in _get_chunk_bboxes_for_swath_to_crop(swath_to_crop):
smaller_poly = Polygon(zip(*self._target_transformer.transform(lons, lats)))
yield (smaller_poly, (line_slice, col_slice))


@lru_cache(maxsize=10)
def _get_chunk_polygons_for_swath_to_crop(swath_to_crop):
"""Get the polygons for each chunk of the area_to_crop."""
def _get_chunk_bboxes_for_swath_to_crop(swath_to_crop):
"""Get the lon/lat bouding boxes for each chunk of the area_to_crop."""
res = []
from shapely.geometry import Polygon
src_chunks = swath_to_crop.lons.chunks
for _position, (line_slice, col_slice) in _enumerate_chunk_slices(src_chunks):
line_slice = expand_slice(line_slice)
Expand All @@ -132,8 +146,7 @@ def _get_chunk_polygons_for_swath_to_crop(swath_to_crop):
lons, lats = smaller_swath.get_edge_lonlats(10)
lons = np.hstack(lons)
lats = np.hstack(lats)
smaller_poly = Polygon(zip(lons, lats))
res.append((smaller_poly, (line_slice, col_slice)))
res.append(((lons, lats), (line_slice, col_slice)))
return res


Expand All @@ -145,6 +158,11 @@ def expand_slice(small_slice):
class AreaSlicer(Slicer):
"""A Slicer for cropping AreaDefinitions."""

def __init__(self, area_to_crop, area_to_contain):
"""Set up the Slicer."""
work_crs = area_to_crop.crs
super().__init__(area_to_crop, area_to_contain, work_crs)

def get_polygon_to_contain(self):
"""Get the shapely Polygon corresponding to *area_to_contain* in projection coordinates of *area_to_crop*."""
from shapely.geometry import Polygon
Expand All @@ -154,15 +172,15 @@ def get_polygon_to_contain(self):
x, y = self.area_to_contain.get_edge_lonlats(vertices_per_side=10)
if self.area_to_crop.is_geostationary:
x_geos, y_geos = get_geostationary_bounding_box_in_proj_coords(self.area_to_crop, 360)
x_geos, y_geos = self._transformer.transform(x_geos, y_geos, direction=TransformDirection.INVERSE)
x_geos, y_geos = self._source_transformer.transform(x_geos, y_geos, direction=TransformDirection.INVERSE)
geos_poly = Polygon(zip(x_geos, y_geos))
poly = Polygon(zip(x, y))
poly = poly.intersection(geos_poly)
if poly.is_empty:
raise IncompatibleAreas('No slice on area.')
x, y = zip(*poly.exterior.coords)

return Polygon(zip(*self._transformer.transform(x, y)))
return Polygon(zip(*self._source_transformer.transform(x, y)))

def get_slices_from_polygon(self, poly_to_contain):
"""Get the slices based on the polygon."""
Expand Down
4 changes: 2 additions & 2 deletions pyresample/test/test_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from pyresample.area_config import create_area_def
from pyresample.geometry import AreaDefinition, SwathDefinition
from pyresample.gradient import ResampleBlocksGradientSearchResampler
from pyresample.gradient import ResampleBlocksGradientSearchResampler, create_gradient_search_resampler


class TestRBGradientSearchResamplerArea2Area:
Expand Down Expand Up @@ -343,7 +343,7 @@ def setup_method(self):
@pytest.mark.parametrize("input_dtype", (np.float32, np.float64))
def test_resample_area_to_swath_2d(self, input_dtype):
"""Resample swath to area, 2d."""
swath_resampler = ResampleBlocksGradientSearchResampler(self.src_area, self.dst_swath_dask)
swath_resampler = create_gradient_search_resampler(self.src_area, self.dst_swath_dask)

data = xr.DataArray(da.ones(self.src_area.shape, dtype=input_dtype),
dims=['y', 'x'])
Expand Down
69 changes: 59 additions & 10 deletions pyresample/test/test_slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,15 +223,12 @@ def setUp(self):
(-1461111.3603, 3440088.0459, 1534864.0322, 9598335.0457)
)

self.lons, self.lats = self.src_area.get_lonlats(chunks=chunks)
xrlons = xr.DataArray(self.lons.persist())
xrlats = xr.DataArray(self.lats.persist())
self.src_swath = SwathDefinition(xrlons, xrlats)
self.src_swath = swath_from_area(self.src_area, chunks)

def test_slicer_init(self):
"""Test slicer initialization."""
slicer = create_slicer(self.src_swath, self.dst_area)
assert slicer.area_to_crop == self.src_area
assert slicer.area_to_crop == self.src_swath
assert slicer.area_to_contain == self.dst_area

def test_source_swath_slicing_does_not_return_full_dataset(self):
Expand All @@ -246,17 +243,61 @@ def test_source_swath_slicing_does_not_return_full_dataset(self):

def test_source_area_slicing_does_not_return_full_dataset(self):
"""Test source area covers dest area."""
slicer = create_slicer(self.src_area, self.dst_area)
slicer = create_slicer(self.src_swath, self.dst_area)
x_slice, y_slice = slicer.get_slices()
assert x_slice.start == 0
assert x_slice.stop == 35
assert y_slice.start == 16
assert y_slice.stop == 94
assert x_slice.stop == 41
assert y_slice.start == 9
assert y_slice.stop == 91

def test_source_area_slicing_over_date_line(self):
src_area = AreaDefinition(
'omerc_otf',
'On-the-fly omerc area',
None,
{'alpha': '8.99811271718795',
'ellps': 'sphere',
'gamma': '0',
'k': '1',
'lat_0': '0',
'lonc': '179.8096029486222',
'proj': 'omerc',
'units': 'm'},
50, 100,
(-1461111.3603, 3440088.0459, 1534864.0322, 9598335.0457)
)
chunks = 10
src_swath = swath_from_area(src_area, chunks)

dst_area = AreaDefinition('somewhere in the pacific', 'somewhere', None,
{'proj': 'stere', 'lon_0': 180.0,
'lat_0': 90.0, 'lat_ts': 60.0,
'ellps': 'bessel'},
102, 102,
(-2717181.7304994687, -5571048.14031214,
1378818.2695005313, -1475048.1403121399))

slicer = create_slicer(src_swath, dst_area)
x_slice, y_slice = slicer.get_slices()
assert x_slice.start == 0
assert x_slice.stop == 41
assert y_slice.start == 9
assert y_slice.stop == 91

def test_source_area_slicing_with_custom_work_crs(self):
"""Test source area covers dest area."""
from pyresample.slicer import SwathSlicer
slicer = SwathSlicer(self.src_swath, self.dst_area, work_crs=self.src_area.crs)
x_slice, y_slice = slicer.get_slices()
assert x_slice.start == 0
assert x_slice.stop == 41
assert y_slice.start == 9
assert y_slice.stop == 91

def test_area_get_polygon_returns_a_polygon(self):
"""Test getting a polygon returns a polygon."""
from shapely.geometry import Polygon
slicer = create_slicer(self.src_area, self.dst_area)
slicer = create_slicer(self.src_swath, self.dst_area)
poly = slicer.get_polygon_to_contain()
assert isinstance(poly, Polygon)

Expand All @@ -271,3 +312,11 @@ def test_cannot_slice_a_string(self):
"""Test that we cannot slice a string."""
with pytest.raises(NotImplementedError):
create_slicer("my_funky_area", self.dst_area)


def swath_from_area(src_area, chunks):
"""Create a SwathDefinition from an AreaDefinition."""
lons, lats = src_area.get_lonlats(chunks=chunks)
xrlons = xr.DataArray(lons.persist())
xrlats = xr.DataArray(lats.persist())
return SwathDefinition(xrlons, xrlats)

0 comments on commit 4dd2948

Please sign in to comment.