Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dtype for swath -> area resampling with gradient search #618

Merged
merged 13 commits into from
Sep 18, 2024
2 changes: 1 addition & 1 deletion pyresample/gradient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def parallel_gradient_search(data, src_x, src_y, dst_x, dst_y,
dst_x[i], dst_y[i],
method=method)
res = da.from_delayed(res, (num_bands, ) + dst_x[i].shape,
dtype=np.float64)
dtype=np.float64).astype(arr.dtype)
Copy link
Member

@mraspaud mraspaud Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I think we can do better than this. What I see here (and a bit above) is that we convert the input data to float64, do the resampling, then convert back to float32.
Using cython fused types, we should be able to template the cython function to accept both float64, float32, or even ints. This cython doc section should help. I can give you a hand if you don't feel comfortable with this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'm not going to learn Cython and its typing that deeply at this point 😅

if dst_mosaic_locations[i] in chunks:
if not is_pad:
chunks[dst_mosaic_locations[i]].append(res)
Expand Down
14 changes: 9 additions & 5 deletions pyresample/test/test_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,25 +249,29 @@ def test_resample_area_to_area_3d_single_channel(self):
assert res.shape == (1, ) + self.dst_area.shape
assert np.allclose(res[0, :, :], 1.0)

def test_resample_swath_to_area_2d(self):
@pytest.mark.parametrize("input_dtype", (np.float32, np.float64))
def test_resample_swath_to_area_2d(self, input_dtype):
"""Resample swath to area, 2d."""
data = xr.DataArray(da.ones(self.src_swath.shape, dtype=np.float64),
data = xr.DataArray(da.ones(self.src_swath.shape, dtype=input_dtype),
dims=['y', 'x'])
with np.errstate(invalid="ignore"): # 'inf' space pixels cause runtime warnings
res = self.swath_resampler.compute(
data, method='bil').compute(scheduler='single-threaded')
assert res.dtype == data.dtype
pnuu marked this conversation as resolved.
Show resolved Hide resolved
assert res.shape == self.dst_area.shape
assert not np.all(np.isnan(res))

def test_resample_swath_to_area_3d(self):
@pytest.mark.parametrize("input_dtype", (np.float32, np.float64))
def test_resample_swath_to_area_3d(self, input_dtype):
"""Resample area to area, 3d."""
data = xr.DataArray(da.ones((3, ) + self.src_swath.shape,
dtype=np.float64) *
dtype=input_dtype) *
np.array([1, 2, 3])[:, np.newaxis, np.newaxis],
dims=['bands', 'y', 'x'])
with np.errstate(invalid="ignore"): # 'inf' space pixels cause runtime warnings
res = self.swath_resampler.compute(
data, method='bil').compute(scheduler='single-threaded')
assert res.dtype == data.dtype
pnuu marked this conversation as resolved.
Show resolved Hide resolved
assert res.shape == (3, ) + self.dst_area.shape
for i in range(res.shape[0]):
arr = np.ravel(res[i, :, :])
Expand Down Expand Up @@ -496,7 +500,7 @@ def test_resample_area_to_area_nn(self):


class TestRBGradientSearchResamplerArea2Swath:
"""Test RBGradientSearchResampler for the Swath to Area case."""
"""Test RBGradientSearchResampler for the Area to Swath case."""

def setup_method(self):
"""Set up the test case."""
Expand Down
Loading