Skip to content

Commit

Permalink
ENH: Add support for passing in gcps to rio.reproject
Browse files Browse the repository at this point in the history
  • Loading branch information
snowman2 committed Jul 6, 2021
1 parent 9de23ef commit 50ad9ad
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 23 deletions.
5 changes: 3 additions & 2 deletions docs/history.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ History

Latest
------
- ENH: Allow passing in kwargs to reproject & overide nodata. Provide
default nodata based on dtype. (issue ##369)
- ENH: Allow passing in kwargs to `rio.reproject` (issue #369)
- ENH: Allow nodata override and provide default nodata based on dtype in `rio.reproject`.
- ENH: Add support for passing in gcps to rio.reproject (issue #339)
- BUG: Remove duplicate acquire in open_rasterio (pull #364)

0.4.3
Expand Down
29 changes: 14 additions & 15 deletions rioxarray/raster_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,33 +106,32 @@ def _add_attrs_proj(new_data_array, src_data_array):


def _make_dst_affine(
src_data_array, src_crs, dst_crs, dst_resolution=None, dst_shape=None
src_data_array, src_crs, dst_crs, dst_resolution=None, dst_shape=None, **kwargs
):
"""Determine the affine of the new projected `xarray.DataArray`"""
src_bounds = src_data_array.rio.bounds()
src_bounds = () if "gcps" in kwargs else src_data_array.rio.bounds()
src_height, src_width = src_data_array.rio.shape
dst_height, dst_width = dst_shape if dst_shape is not None else (None, None)
# pylint: disable=isinstance-second-argument-not-valid-type
if isinstance(dst_resolution, Iterable):
dst_resolution = tuple(abs(res_val) for res_val in dst_resolution)
elif dst_resolution is not None:
dst_resolution = abs(dst_resolution)
resolution_or_width_height = {
k: v
for k, v in [
("resolution", dst_resolution),
("dst_height", dst_height),
("dst_width", dst_width),
]
if v is not None
}

for key, value in (
("resolution", dst_resolution),
("dst_height", dst_height),
("dst_width", dst_width),
):
if key is not None:
kwargs[key] = value
dst_affine, dst_width, dst_height = rasterio.warp.calculate_default_transform(
src_crs,
dst_crs,
src_width,
src_height,
*src_bounds,
**resolution_or_width_height,
**kwargs,
)
return dst_affine, dst_width, dst_height

Expand Down Expand Up @@ -337,8 +336,8 @@ def reproject(
resolution=None,
shape=None,
transform=None,
nodata=None,
resampling=Resampling.nearest,
nodata=None,
**kwargs,
):
"""
Expand Down Expand Up @@ -394,10 +393,10 @@ def reproject(
"CRS not found. Please set the CRS with 'rio.write_crs()'."
f"{_get_data_var_message(self._obj)}"
)
src_affine = self.transform(recalc=True)
src_affine = None if "gcps" in kwargs else self.transform(recalc=True)
if transform is None:
dst_affine, dst_width, dst_height = _make_dst_affine(
self._obj, self.crs, dst_crs, resolution, shape
self._obj, self.crs, dst_crs, resolution, shape, **kwargs
)
else:
dst_affine = transform
Expand Down
17 changes: 16 additions & 1 deletion rioxarray/raster_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def reproject(
shape=None,
transform=None,
resampling=Resampling.nearest,
nodata=None,
**kwargs,
):
"""
Reproject :class:`xarray.Dataset` objects
Expand All @@ -69,6 +71,7 @@ def reproject(
.. versionadded:: 0.0.27 shape
.. versionadded:: 0.0.28 transform
.. versionadded:: 0.5.0 nodata, kwargs
Parameters
----------
Expand All @@ -84,7 +87,17 @@ def reproject(
The destination transform.
resampling: rasterio.enums.Resampling, optional
See :func:`rasterio.warp.reproject` for more details.
nodata: float, optional
The nodata value used to initialize the destination;
it will remain in all areas not covered by the reprojected source.
Defaults to the nodata value of the source image if none provided
and exists or attempts to find an appropriate value by dtype.
**kwargs: dict
Additional keyword arguments to pass into :func:`rasterio.warp.reproject`.
To override:
- src_transform: `rio.write_transform`
- src_crs: `rio.write_crs`
- src_nodata: `rio.write_nodata`
Returns
--------
Expand All @@ -102,6 +115,8 @@ def reproject(
shape=shape,
transform=transform,
resampling=resampling,
nodata=nodata,
**kwargs,
)
)
return resampled_dataset
Expand Down
54 changes: 49 additions & 5 deletions test/integration/test_integration_rioxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from dask.delayed import Delayed
from numpy.testing import assert_almost_equal, assert_array_equal
from pyproj import CRS as pCRS
from rasterio.control import GroundControlPoint
from rasterio.crs import CRS
from rasterio.windows import Window

Expand Down Expand Up @@ -699,7 +700,8 @@ def test_reproject__no_transform(modis_reproject):
_assert_xarrays_equal(mds_repr, mdc)


def test_reproject__no_nodata(modis_reproject):
@pytest.mark.parametrize("nodata", [None, -9999])
def test_reproject__no_nodata(nodata, modis_reproject):
mask_args = (
dict(masked=False, mask_and_scale=False)
if "rasterio" in str(modis_reproject["open"])
Expand All @@ -712,19 +714,20 @@ def test_reproject__no_nodata(modis_reproject):
_del_attr(mda, "_FillValue")
_del_attr(mda, "nodata")
# reproject
mds_repr = mda.rio.reproject(modis_reproject["to_proj"])
mds_repr = mda.rio.reproject(modis_reproject["to_proj"], nodata=nodata)

# overwrite test dataset
# if isinstance(modis_reproject['open'], xarray.DataArray):
# mds_repr.to_netcdf(modis_reproject['compare'])

# replace -9999 with original _FillValue for testing
fill_nodata = -32768 if nodata is None else nodata
if hasattr(mds_repr, "variables"):
for var in mds_repr.rio.vars:
mds_repr[var].values[mds_repr[var].values == -32768] = orig_fill
mds_repr[var].values[mds_repr[var].values == fill_nodata] = orig_fill
else:
mds_repr.values[mds_repr.values == -32768] = orig_fill
_mod_attr(mdc, "_FillValue", val=-32768)
mds_repr.values[mds_repr.values == fill_nodata] = orig_fill
_mod_attr(mdc, "_FillValue", val=fill_nodata)
# test
_assert_xarrays_equal(mds_repr, mdc)

Expand All @@ -750,6 +753,47 @@ def test_reproject__no_nodata_masked(modis_reproject):
_assert_xarrays_equal(mds_repr, mdc)


def test_reproject__gcps_kwargs(tmp_path):
tiffname = tmp_path / "test.tif"
src_gcps = [
GroundControlPoint(row=0, col=0, x=156113, y=2818720, z=0),
GroundControlPoint(row=0, col=800, x=338353, y=2785790, z=0),
GroundControlPoint(row=800, col=800, x=297939, y=2618518, z=0),
GroundControlPoint(row=800, col=0, x=115698, y=2651448, z=0),
]
crs = CRS.from_epsg(32618)
with rasterio.open(
tiffname,
mode="w",
height=800,
width=800,
count=3,
dtype=numpy.uint8,
driver="GTiff",
) as source:
source.gcps = (src_gcps, crs)

rds = rioxarray.open_rasterio(tiffname)
rds.rio.write_crs(crs, inplace=True)
rds = rds.rio.reproject(
crs,
gcps=src_gcps,
)
assert rds.rio.height == 923
assert rds.rio.width == 1027
assert rds.rio.crs == crs
assert rds.rio.transform().almost_equals(
Affine(
216.8587081056465,
0.0,
115698.25,
0.0,
-216.8587081056465,
2818720.0,
)
)


def test_reproject_match(modis_reproject_match):
mask_args = (
dict(masked=False, mask_and_scale=False)
Expand Down

0 comments on commit 50ad9ad

Please sign in to comment.