Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Issue237 nodata mm #305

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 16 additions & 80 deletions src/geowombat/backends/xarray_.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,97 +455,33 @@ def mosaic(
for fidx, fn in enumerate(warped_objects[1:]):
with open_rasterio(
fn, nodata=ref_kwargs['nodata'], **kwargs
) as darrayb:
) as darray_b:
with open_rasterio(
filenames[fidx + 1], nodata=ref_kwargs['nodata'], **kwargs
) as src_:
geometries.append(src_.gw.geometry)
src_ = None

if overlap == 'min':
if isinstance(ref_kwargs['nodata'], float) or isinstance(
ref_kwargs['nodata'], int
):
darray = xr.where(
(darray.mean(dim='band') == ref_kwargs['nodata'])
& (
darrayb.mean(dim='band')
!= ref_kwargs['nodata']
),
darrayb,
xr.where(
(
darray.mean(dim='band')
!= ref_kwargs['nodata']
)
& (
darrayb.mean(dim='band')
== ref_kwargs['nodata']
),
darray,
np.minimum(darray, darrayb),
),
)
# Stack the bands
nodataval = darray.gw.nodataval
stack = xr.concat((darray, darray_b), dim='band')
# Ensure 'no data' values are nans and ignored
stack = stack.gw.mask_nodata()

else:
darray = np.minimum(darray, darrayb)
if overlap == 'min':
darray = stack.min(dim='band', skipna=True, keepdims=True)

elif overlap == 'max':
if isinstance(ref_kwargs['nodata'], float) or isinstance(
ref_kwargs['nodata'], int
):
darray = xr.where(
(darray.mean(dim='band') == ref_kwargs['nodata'])
& (
darrayb.mean(dim='band')
!= ref_kwargs['nodata']
),
darrayb,
xr.where(
(
darray.mean(dim='band')
!= ref_kwargs['nodata']
)
& (
darrayb.mean(dim='band')
== ref_kwargs['nodata']
),
darray,
np.maximum(darray, darrayb),
),
)

else:
darray = np.maximum(darray, darrayb)
darray = stack.max(dim='band', skipna=True, keepdims=True)

elif overlap == 'mean':
if isinstance(ref_kwargs['nodata'], float) or isinstance(
ref_kwargs['nodata'], int
):

darray = xr.where(
(darray.mean(dim='band') == ref_kwargs['nodata'])
& (
darrayb.mean(dim='band')
!= ref_kwargs['nodata']
),
darrayb,
xr.where(
(
darray.mean(dim='band')
!= ref_kwargs['nodata']
)
& (
darrayb.mean(dim='band')
== ref_kwargs['nodata']
),
darray,
(darray + darrayb) / 2.0,
),
)
darray = stack.mean(dim='band', skipna=True, keepdims=True)

else:
darray = (darray + darrayb) / 2.0
# Reset the 'no data' values
darray = darray.gw.set_nodata(
src_nodata=np.nan,
dst_nodata=nodataval,
)

darray = darray.assign_attrs(**attrs)

Expand Down Expand Up @@ -600,7 +536,7 @@ def mosaic(
attrs.update(tags)
darray = darray.assign_attrs(**attrs)

if dtype:
if dtype is not None:
attrs = darray.attrs.copy()

return darray.astype(dtype).assign_attrs(**attrs)
Expand Down
57 changes: 55 additions & 2 deletions tests/test_open.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import unittest
from pathlib import Path

import tempfile
import dask
import numpy as np
import rasterio as rio
import xarray as xr
from pyproj import CRS

import os
import geowombat as gw
from geowombat.data import (
l3b_s2b_00390821jxn0l2a_20210319_20220730_c01,
l8_224077_20200518_B2,
l8_224078_20200518_B2,
l8_224077_20200518_B2_60m,
l8_224078_20200518,
)
Expand Down Expand Up @@ -113,6 +114,58 @@ def test_open_multiple_same_mean(self):
) as src:
self.assertEqual(src.gw.ntime, 1)

def test_union_values(self):
filenames = [l8_224077_20200518_B2, l8_224078_20200518_B2]
with gw.open(
filenames,
band_names=['blue'],
mosaic=True,
bounds_by='union'
) as src:
vals = src.values[0,src.shape[1]//2, src.shape[1]//2:src.shape[1]//2 +10]
self.assertTrue(all(vals==[8678, 8958, 8970, 8966, 8912, 8749, 8131, 7598, 7590, 7606]))

def test_mosaic_save(self):
# Using a context manager for the temporary directory
with tempfile.TemporaryDirectory() as temp_dir:
test_file_path = os.path.join(temp_dir, 'test.tif')
filenames = [l8_224077_20200518_B2, l8_224078_20200518_B2] # Assuming these are correct file paths
try:
with gw.open(
filenames,
band_names=['blue'],
mosaic=True,
bounds_by='union',
nodata=0
) as src:
src.gw.save(test_file_path, overwrite=True)
except Exception as e:
# If any exception is raised, fail the test with a message
self.fail(f"An error occurred during saving: {e}")


def test_bounds_union(self):
filenames = [l8_224077_20200518_B2, l8_224078_20200518_B2]
with gw.open(
filenames,
band_names=['blue'],
mosaic=True,
bounds_by='union'
) as src:
bounds = src.gw.bounds
self.assertEqual(bounds, (693990.0, -2832810.0, 778590.0, -2766600.0))

def test_bounds_intersection(self):
filenames = [l8_224077_20200518_B2, l8_224078_20200518_B2]
with gw.open(
filenames,
band_names=['blue'],
mosaic=True,
bounds_by='intersection'
) as src:
bounds = src.gw.bounds
self.assertEqual(bounds, (717330.0, -2812080.0, 754200.0, -2776980.0))

def test_has_time_dim(self):
with gw.open(
[l8_224078_20200518, l8_224078_20200518], stack_dim='time'
Expand Down
Loading