diff --git a/geoutils/raster/raster.py b/geoutils/raster/raster.py index c5506fe2..f217e415 100644 --- a/geoutils/raster/raster.py +++ b/geoutils/raster/raster.py @@ -807,9 +807,9 @@ def to_rio_dataset(self) -> rio.io.DatasetReader: driver="GTiff", ) as ds: if self.count == 1: - ds.write(self.data[np.newaxis, :, :]) + ds.write(self.data.filled(self.nodata)[np.newaxis, :, :]) else: - ds.write(self.data) + ds.write(self.data.filled(self.nodata)) # Then open as a DatasetReader return mfh.open() @@ -969,13 +969,17 @@ def __setitem__(self, index: Mask | NDArrayBool | Any, assign: NDArrayNum | Numb self._data[:, ind] = assign # type: ignore return None - def raster_equal(self, other: RasterType) -> bool: + def raster_equal(self, other: RasterType, strict_masked: bool = True, warn_failure_reason: bool = False) -> bool: """ Check if two rasters are equal. This means that are equal: - The raster's masked array's data (including masked values), mask, fill_value and dtype, - The raster's transform, crs and nodata values. + + :param other: Other raster. + :param strict_masked: Whether to check if masked pixels (in .data.mask) have the same value (in .data.data). + :param warn_failure_reason: Whether to warn for the reason of failure if the check does not pass. """ # If the mask is just "False", it is equivalent to being equal to an array of False @@ -991,17 +995,36 @@ def raster_equal(self, other: RasterType) -> bool: if not isinstance(other, Raster): # TODO: Possibly add equals to SatelliteImage? raise NotImplementedError("Equality with other object than Raster not supported by raster_equal.") - return all( - [ - np.array_equal(self.data.data, other.data.data, equal_nan=True), - np.array_equal(self_mask, other_mask), + + if strict_masked: + names = ["data.data", "data.mask", "data.fill_value", "dtype", "transform", "crs", "nodata"] + equalities = [ + np.array_equal(self.data.data, other.data.data, equal_nan=True), + np.array_equal(self_mask, other_mask), + self.data.fill_value == other.data.fill_value, + self.data.dtype == other.data.dtype, + self.transform == other.transform, + self.crs == other.crs, + self.nodata == other.nodata, + ] + else: + names = ["data", "data.fill_value", "dtype", "transform", "crs", "nodata"] + equalities = [ + np.ma.allequal(self.data, other.data), self.data.fill_value == other.data.fill_value, self.data.dtype == other.data.dtype, self.transform == other.transform, self.crs == other.crs, - self.nodata == other.nodata, - ] - ) + self.nodata == other.nodata + ] + + complete_equality = all(equalities) + + if not complete_equality and warn_failure_reason: + where_fail = np.nonzero(~np.array(equalities))[0] + warnings.warn(category=UserWarning, message=f"Equality failed for: {', '.join([names[w] for w in where_fail])}.") + + return complete_equality def _overloading_check( self: RasterType, other: RasterType | NDArrayNum | Number @@ -2681,7 +2704,10 @@ def to_xarray(self, name: str | None = None) -> xr.DataArray: else: updated_raster = self - ds = rioxarray.open_rasterio(updated_raster.to_rio_dataset()) + ds = rioxarray.open_rasterio(updated_raster.to_rio_dataset(), masked=True) + # When reading as masked, the nodata is not written to the dataset so we do it manually + ds.rio.set_nodata(self.nodata) + if name is not None: ds.name = name diff --git a/tests/test_raster.py b/tests/test_raster.py index ab6ad9aa..6902a9df 100644 --- a/tests/test_raster.py +++ b/tests/test_raster.py @@ -391,16 +391,14 @@ def test_to_xarray(self, example: str): # Check that the arrays are equal in NaN type if rst.count > 1: - assert np.array_equal(rst.data.data, ds.data) + assert np.array_equal(rst.get_nanarray(), ds.data.squeeze(), equal_nan=True) else: - assert np.array_equal(rst.data.data, ds.data.squeeze()) + assert np.array_equal(rst.get_nanarray(), ds.data.squeeze(), equal_nan=True) @pytest.mark.parametrize("example", [landsat_b4_path, aster_dem_path, landsat_rgb_path]) # type: ignore def test_from_xarray(self, example: str): """Test raster creation from a xarray dataset, not fully reversible with to_xarray due to float conversion""" - warnings.filterwarnings("ignore") - # Open raster and export to xarray, then import to xarray dataset rst = gu.Raster(example) ds = rst.to_xarray() @@ -408,15 +406,17 @@ def test_from_xarray(self, example: str): # Exporting to a Xarray dataset results in loss of information to float32 # Check that the output equals the input converted to float32 (not fully reversible) - assert rst.astype("float32", convert_nodata=False).raster_equal(rst2) + assert rst.astype("float32", convert_nodata=False).raster_equal(rst2, strict_masked=False) # Test with the dtype argument to convert back to original raster even if integer-type if np.issubdtype(rst.dtypes[0], np.integer): # Set an existing nodata value, because all of our integer-type example datasets currently have "None" - rst.set_nodata(new_nodata=255) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="New nodata value cells already exist.*") + rst.set_nodata(new_nodata=255) ds = rst.to_xarray() rst3 = gu.Raster.from_xarray(ds=ds, dtype=rst.dtypes[0]) - assert rst3.raster_equal(rst) + assert rst3.raster_equal(rst, strict_masked=False) @pytest.mark.parametrize("nodata_init", [None, "type_default"]) # type: ignore @pytest.mark.parametrize(