Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: kwargs & gcps support for rio.reproject #370

Merged
merged 3 commits into from
Jul 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ jobs:
- name: Install Env
shell: bash
run: |
conda create -n test python=${{ matrix.python-version }} rasterio=${{ matrix.rasterio-version }} xarray=${{ matrix.xarray-version }} scipy pyproj netcdf4 dask pandoc
conda create -n test python=${{ matrix.python-version }} rasterio=${{ matrix.rasterio-version }} xarray=${{ matrix.xarray-version }} 'libgdal<3.3' scipy pyproj netcdf4 dask pandoc
source activate test
python -m pip install -e .[all]

Expand Down
5 changes: 4 additions & 1 deletion docs/history.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ History

Latest
------
- BUG: Remove duplicate acquire in open_rasterio
- ENH: Allow passing in kwargs to `rio.reproject` (issue #369; pull #370)
- ENH: Allow nodata override and provide default nodata based on dtype in `rio.reproject` (pull #370)
- ENH: Add support for passing in gcps to rio.reproject (issue #339; pull #370)
- BUG: Remove duplicate acquire in open_rasterio (pull #364)

0.4.3
------
Expand Down
76 changes: 55 additions & 21 deletions rioxarray/raster_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import rasterio.mask
import rasterio.warp
import xarray
from rasterio.dtypes import dtype_rev
from rasterio.enums import Resampling
from rasterio.features import geometry_mask
from scipy.interpolate import griddata
Expand All @@ -38,6 +39,24 @@
)
from rioxarray.rioxarray import XRasterBase, _get_data_var_message, _make_coords

# DTYPE TO NODATA MAP
# Based on: https://github.com/OSGeo/gdal/blob/
# cde27dc7641964a872efdc6bbcf5e3d3f7ab9cfd/gdal/
# swig/python/gdal-utils/osgeo_utils/gdal_calc.py#L62
_NODATA_DTYPE_MAP = {
1: 255, # GDT_Byte
2: 65535, # GDT_UInt16
3: -32768, # GDT_Int16
4: 4294967293, # GDT_UInt32
5: -2147483647, # GDT_Int32
6: 3.402823466e38, # GDT_Float32
7: 1.7976931348623158e308, # GDT_Float64
8: -32768, # GDT_CInt16
9: -2147483647, # GDT_CInt32
10: 3.402823466e38, # GDT_CFloat32
11: 1.7976931348623158e308, # GDT_CFloat64
}


def _generate_attrs(src_data_array, dst_nodata):
# add original attributes
Expand Down Expand Up @@ -87,33 +106,32 @@ def _add_attrs_proj(new_data_array, src_data_array):


def _make_dst_affine(
src_data_array, src_crs, dst_crs, dst_resolution=None, dst_shape=None
src_data_array, src_crs, dst_crs, dst_resolution=None, dst_shape=None, **kwargs
):
"""Determine the affine of the new projected `xarray.DataArray`"""
src_bounds = src_data_array.rio.bounds()
src_bounds = () if "gcps" in kwargs else src_data_array.rio.bounds()
src_height, src_width = src_data_array.rio.shape
dst_height, dst_width = dst_shape if dst_shape is not None else (None, None)
# pylint: disable=isinstance-second-argument-not-valid-type
if isinstance(dst_resolution, Iterable):
dst_resolution = tuple(abs(res_val) for res_val in dst_resolution)
elif dst_resolution is not None:
dst_resolution = abs(dst_resolution)
resolution_or_width_height = {
k: v
for k, v in [
("resolution", dst_resolution),
("dst_height", dst_height),
("dst_width", dst_width),
]
if v is not None
}

for key, value in (
("resolution", dst_resolution),
("dst_height", dst_height),
("dst_width", dst_width),
):
if value is not None:
kwargs[key] = value
dst_affine, dst_width, dst_height = rasterio.warp.calculate_default_transform(
src_crs,
dst_crs,
src_width,
src_height,
*src_bounds,
**resolution_or_width_height,
**kwargs,
)
return dst_affine, dst_width, dst_height

Expand Down Expand Up @@ -319,6 +337,8 @@ def reproject(
shape=None,
transform=None,
resampling=Resampling.nearest,
nodata=None,
**kwargs,
):
"""
Reproject :obj:`xarray.DataArray` objects
Expand All @@ -332,6 +352,7 @@ def reproject(

.. versionadded:: 0.0.27 shape
.. versionadded:: 0.0.28 transform
.. versionadded:: 0.5.0 nodata, kwargs

Parameters
----------
Expand All @@ -343,10 +364,21 @@ def reproject(
shape: tuple(int, int), optional
Shape of the destination in pixels (dst_height, dst_width). Cannot be used
together with resolution.
transform: optional
transform: Affine, optional
The destination transform.
resampling: rasterio.enums.Resampling, optional
See :func:`rasterio.warp.reproject` for more details.
nodata: float, optional
The nodata value used to initialize the destination;
it will remain in all areas not covered by the reprojected source.
Defaults to the nodata value of the source image if none provided
and exists or attempts to find an appropriate value by dtype.
**kwargs: dict
Additional keyword arguments to pass into :func:`rasterio.warp.reproject`.
To override:
- src_transform: `rio.write_transform`
- src_crs: `rio.write_crs`
- src_nodata: `rio.write_nodata`


Returns
Expand All @@ -361,10 +393,10 @@ def reproject(
"CRS not found. Please set the CRS with 'rio.write_crs()'."
f"{_get_data_var_message(self._obj)}"
)
src_affine = self.transform(recalc=True)
src_affine = None if "gcps" in kwargs else self.transform(recalc=True)
if transform is None:
dst_affine, dst_width, dst_height = _make_dst_affine(
self._obj, self.crs, dst_crs, resolution, shape
self._obj, self.crs, dst_crs, resolution, shape, **kwargs
)
else:
dst_affine = transform
Expand All @@ -382,22 +414,24 @@ def reproject(
else:
dst_data = np.zeros((dst_height, dst_width), dtype=self._obj.dtype.type)

dst_nodata = self._obj.dtype.type(
self.nodata if self.nodata is not None else -9999
)
src_nodata = self._obj.dtype.type(
self.nodata if self.nodata is not None else dst_nodata
default_nodata = (
_NODATA_DTYPE_MAP[dtype_rev[self._obj.dtype.name]]
if self.nodata is None
else self.nodata
)
dst_nodata = default_nodata if nodata is None else nodata

rasterio.warp.reproject(
source=self._obj.values,
destination=dst_data,
src_transform=src_affine,
src_crs=self.crs,
src_nodata=src_nodata,
src_nodata=self.nodata,
dst_transform=dst_affine,
dst_crs=dst_crs,
dst_nodata=dst_nodata,
resampling=resampling,
**kwargs,
)
# add necessary attributes
new_attrs = _generate_attrs(self._obj, dst_nodata)
Expand Down
17 changes: 16 additions & 1 deletion rioxarray/raster_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def reproject(
shape=None,
transform=None,
resampling=Resampling.nearest,
nodata=None,
**kwargs,
):
"""
Reproject :class:`xarray.Dataset` objects
Expand All @@ -69,6 +71,7 @@ def reproject(

.. versionadded:: 0.0.27 shape
.. versionadded:: 0.0.28 transform
.. versionadded:: 0.5.0 nodata, kwargs

Parameters
----------
Expand All @@ -84,7 +87,17 @@ def reproject(
The destination transform.
resampling: rasterio.enums.Resampling, optional
See :func:`rasterio.warp.reproject` for more details.

nodata: float, optional
The nodata value used to initialize the destination;
it will remain in all areas not covered by the reprojected source.
Defaults to the nodata value of the source image if none provided
and exists or attempts to find an appropriate value by dtype.
**kwargs: dict
Additional keyword arguments to pass into :func:`rasterio.warp.reproject`.
To override:
- src_transform: `rio.write_transform`
- src_crs: `rio.write_crs`
- src_nodata: `rio.write_nodata`

Returns
--------
Expand All @@ -102,6 +115,8 @@ def reproject(
shape=shape,
transform=transform,
resampling=resampling,
nodata=nodata,
**kwargs,
)
)
return resampled_dataset
Expand Down
55 changes: 50 additions & 5 deletions test/integration/test_integration_rioxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from dask.delayed import Delayed
from numpy.testing import assert_almost_equal, assert_array_equal
from pyproj import CRS as pCRS
from rasterio.control import GroundControlPoint
from rasterio.crs import CRS
from rasterio.windows import Window

Expand Down Expand Up @@ -699,7 +700,8 @@ def test_reproject__no_transform(modis_reproject):
_assert_xarrays_equal(mds_repr, mdc)


def test_reproject__no_nodata(modis_reproject):
@pytest.mark.parametrize("nodata", [None, -9999])
def test_reproject__no_nodata(nodata, modis_reproject):
mask_args = (
dict(masked=False, mask_and_scale=False)
if "rasterio" in str(modis_reproject["open"])
Expand All @@ -712,19 +714,20 @@ def test_reproject__no_nodata(modis_reproject):
_del_attr(mda, "_FillValue")
_del_attr(mda, "nodata")
# reproject
mds_repr = mda.rio.reproject(modis_reproject["to_proj"])
mds_repr = mda.rio.reproject(modis_reproject["to_proj"], nodata=nodata)

# overwrite test dataset
# if isinstance(modis_reproject['open'], xarray.DataArray):
# mds_repr.to_netcdf(modis_reproject['compare'])

# replace -9999 with original _FillValue for testing
fill_nodata = -32768 if nodata is None else nodata
if hasattr(mds_repr, "variables"):
for var in mds_repr.rio.vars:
mds_repr[var].values[mds_repr[var].values == -9999] = orig_fill
mds_repr[var].values[mds_repr[var].values == fill_nodata] = orig_fill
else:
mds_repr.values[mds_repr.values == -9999] = orig_fill
_mod_attr(mdc, "_FillValue", val=-9999)
mds_repr.values[mds_repr.values == fill_nodata] = orig_fill
_mod_attr(mdc, "_FillValue", val=fill_nodata)
# test
_assert_xarrays_equal(mds_repr, mdc)

Expand All @@ -750,6 +753,47 @@ def test_reproject__no_nodata_masked(modis_reproject):
_assert_xarrays_equal(mds_repr, mdc)


def test_reproject__gcps_kwargs(tmp_path):
tiffname = tmp_path / "test.tif"
src_gcps = [
GroundControlPoint(row=0, col=0, x=156113, y=2818720, z=0),
GroundControlPoint(row=0, col=800, x=338353, y=2785790, z=0),
GroundControlPoint(row=800, col=800, x=297939, y=2618518, z=0),
GroundControlPoint(row=800, col=0, x=115698, y=2651448, z=0),
]
crs = CRS.from_epsg(32618)
with rasterio.open(
tiffname,
mode="w",
height=800,
width=800,
count=3,
dtype=numpy.uint8,
driver="GTiff",
) as source:
source.gcps = (src_gcps, crs)

rds = rioxarray.open_rasterio(tiffname)
rds.rio.write_crs(crs, inplace=True)
rds = rds.rio.reproject(
crs,
gcps=src_gcps,
)
assert rds.rio.height == 923
assert rds.rio.width == 1027
assert rds.rio.crs == crs
assert rds.rio.transform().almost_equals(
Affine(
216.8587081056465,
0.0,
115698.25,
0.0,
-216.8587081056465,
2818720.0,
)
)


def test_reproject_match(modis_reproject_match):
mask_args = (
dict(masked=False, mask_and_scale=False)
Expand Down Expand Up @@ -1069,6 +1113,7 @@ def test_geographic_reproject__missing_nodata():
mds_repr = mda.rio.reproject("epsg:32721")
# mds_repr.to_netcdf(sentinel_2_utm)
# test
_mod_attr(mdc, "_FillValue", val=65535)
_assert_xarrays_equal(mds_repr, mdc, precision=4)


Expand Down