Skip to content

Commit

Permalink
Add default conversion to NaN and nodata definition in to_xarray(), a…
Browse files Browse the repository at this point in the history
…dd convenience arguments to raster_equal()
  • Loading branch information
rhugonnet committed Mar 18, 2024
1 parent 9f7e346 commit 7385f94
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 18 deletions.
48 changes: 37 additions & 11 deletions geoutils/raster/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,9 +807,9 @@ def to_rio_dataset(self) -> rio.io.DatasetReader:
driver="GTiff",
) as ds:
if self.count == 1:
ds.write(self.data[np.newaxis, :, :])
ds.write(self.data.filled(self.nodata)[np.newaxis, :, :])
else:
ds.write(self.data)
ds.write(self.data.filled(self.nodata))

# Then open as a DatasetReader
return mfh.open()
Expand Down Expand Up @@ -969,13 +969,17 @@ def __setitem__(self, index: Mask | NDArrayBool | Any, assign: NDArrayNum | Numb
self._data[:, ind] = assign # type: ignore
return None

def raster_equal(self, other: RasterType) -> bool:
def raster_equal(self, other: RasterType, strict_masked: bool = True, warn_failure_reason: bool = False) -> bool:
"""
Check if two rasters are equal.
This means that are equal:
- The raster's masked array's data (including masked values), mask, fill_value and dtype,
- The raster's transform, crs and nodata values.
:param other: Other raster.
:param strict_masked: Whether to check if masked pixels (in .data.mask) have the same value (in .data.data).
:param warn_failure_reason: Whether to warn for the reason of failure if the check does not pass.
"""

# If the mask is just "False", it is equivalent to being equal to an array of False
Expand All @@ -991,17 +995,36 @@ def raster_equal(self, other: RasterType) -> bool:

if not isinstance(other, Raster): # TODO: Possibly add equals to SatelliteImage?
raise NotImplementedError("Equality with other object than Raster not supported by raster_equal.")
return all(
[
np.array_equal(self.data.data, other.data.data, equal_nan=True),
np.array_equal(self_mask, other_mask),

if strict_masked:
names = ["data.data", "data.mask", "data.fill_value", "dtype", "transform", "crs", "nodata"]
equalities = [
np.array_equal(self.data.data, other.data.data, equal_nan=True),
np.array_equal(self_mask, other_mask),
self.data.fill_value == other.data.fill_value,
self.data.dtype == other.data.dtype,
self.transform == other.transform,
self.crs == other.crs,
self.nodata == other.nodata,
]
else:
names = ["data", "data.fill_value", "dtype", "transform", "crs", "nodata"]
equalities = [
np.ma.allequal(self.data, other.data),
self.data.fill_value == other.data.fill_value,
self.data.dtype == other.data.dtype,
self.transform == other.transform,
self.crs == other.crs,
self.nodata == other.nodata,
]
)
self.nodata == other.nodata
]

complete_equality = all(equalities)

if not complete_equality and warn_failure_reason:
where_fail = np.nonzero(~np.array(equalities))[0]
warnings.warn(category=UserWarning, message=f"Equality failed for: {', '.join([names[w] for w in where_fail])}.")

return complete_equality

def _overloading_check(
self: RasterType, other: RasterType | NDArrayNum | Number
Expand Down Expand Up @@ -2681,7 +2704,10 @@ def to_xarray(self, name: str | None = None) -> xr.DataArray:
else:
updated_raster = self

ds = rioxarray.open_rasterio(updated_raster.to_rio_dataset())
ds = rioxarray.open_rasterio(updated_raster.to_rio_dataset(), masked=True)
# When reading as masked, the nodata is not written to the dataset so we do it manually
ds.rio.set_nodata(self.nodata)

if name is not None:
ds.name = name

Expand Down
14 changes: 7 additions & 7 deletions tests/test_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,32 +391,32 @@ def test_to_xarray(self, example: str):

# Check that the arrays are equal in NaN type
if rst.count > 1:
assert np.array_equal(rst.data.data, ds.data)
assert np.array_equal(rst.get_nanarray(), ds.data.squeeze(), equal_nan=True)
else:
assert np.array_equal(rst.data.data, ds.data.squeeze())
assert np.array_equal(rst.get_nanarray(), ds.data.squeeze(), equal_nan=True)

@pytest.mark.parametrize("example", [landsat_b4_path, aster_dem_path, landsat_rgb_path]) # type: ignore
def test_from_xarray(self, example: str):
"""Test raster creation from a xarray dataset, not fully reversible with to_xarray due to float conversion"""

warnings.filterwarnings("ignore")

# Open raster and export to xarray, then import to xarray dataset
rst = gu.Raster(example)
ds = rst.to_xarray()
rst2 = gu.Raster.from_xarray(ds=ds)

# Exporting to a Xarray dataset results in loss of information to float32
# Check that the output equals the input converted to float32 (not fully reversible)
assert rst.astype("float32", convert_nodata=False).raster_equal(rst2)
assert rst.astype("float32", convert_nodata=False).raster_equal(rst2, strict_masked=False)

# Test with the dtype argument to convert back to original raster even if integer-type
if np.issubdtype(rst.dtypes[0], np.integer):
# Set an existing nodata value, because all of our integer-type example datasets currently have "None"
rst.set_nodata(new_nodata=255)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="New nodata value cells already exist.*")
rst.set_nodata(new_nodata=255)
ds = rst.to_xarray()
rst3 = gu.Raster.from_xarray(ds=ds, dtype=rst.dtypes[0])
assert rst3.raster_equal(rst)
assert rst3.raster_equal(rst, strict_masked=False)

@pytest.mark.parametrize("nodata_init", [None, "type_default"]) # type: ignore
@pytest.mark.parametrize(
Expand Down

0 comments on commit 7385f94

Please sign in to comment.