Skip to content

Commit

Permalink
Remove combine_times kwarg from multiscene.stack and default to its d…
Browse files Browse the repository at this point in the history
…efault behaviour
  • Loading branch information
lahtinep committed Feb 1, 2024
1 parent 942be9c commit db2bf31
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 30 deletions.
28 changes: 8 additions & 20 deletions satpy/multiscene/_blend_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
def stack(
data_arrays: Sequence[xr.DataArray],
weights: Optional[Sequence[xr.DataArray]] = None,
combine_times: bool = True,
blend_type: str = "select_with_weights"
) -> xr.DataArray:
"""Combine a series of datasets in different ways.
Expand All @@ -39,19 +38,18 @@ def stack(
"""
if weights:
return _stack_with_weights(data_arrays, weights, combine_times, blend_type)
return _stack_no_weights(data_arrays, combine_times)
return _stack_with_weights(data_arrays, weights, blend_type)
return _stack_no_weights(data_arrays)


def _stack_with_weights(
datasets: Sequence[xr.DataArray],
weights: Sequence[xr.DataArray],
combine_times: bool,
blend_type: str
) -> xr.DataArray:
blend_func = _get_weighted_blending_func(blend_type)
filled_weights = list(_fill_weights_for_invalid_dataset_pixels(datasets, weights))
return blend_func(datasets, filled_weights, combine_times)
return blend_func(datasets, filled_weights)


def _get_weighted_blending_func(blend_type: str) -> Callable:
Expand Down Expand Up @@ -84,10 +82,9 @@ def _fill_weights_for_invalid_dataset_pixels(
def _stack_blend_by_weights(
datasets: Sequence[xr.DataArray],
weights: Sequence[xr.DataArray],
combine_times: bool
) -> xr.DataArray:
"""Stack datasets blending overlap using weights."""
attrs = _combine_stacked_attrs([data_arr.attrs for data_arr in datasets], combine_times)
attrs = _combine_stacked_attrs([data_arr.attrs for data_arr in datasets])

overlays = []
for weight, overlay in zip(weights, datasets):
Expand All @@ -109,14 +106,13 @@ def _stack_blend_by_weights(
def _stack_select_by_weights(
datasets: Sequence[xr.DataArray],
weights: Sequence[xr.DataArray],
combine_times: bool
) -> xr.DataArray:
"""Stack datasets selecting pixels using weights."""
indices = da.argmax(da.dstack(weights), axis=-1)
if "bands" in datasets[0].dims:
indices = [indices] * datasets[0].sizes["bands"]

attrs = _combine_stacked_attrs([data_arr.attrs for data_arr in datasets], combine_times)
attrs = _combine_stacked_attrs([data_arr.attrs for data_arr in datasets])
dims = datasets[0].dims
coords = datasets[0].coords
selected_array = xr.DataArray(da.choose(indices, datasets), dims=dims, coords=coords, attrs=attrs)
Expand All @@ -125,7 +121,6 @@ def _stack_select_by_weights(

def _stack_no_weights(
datasets: Sequence[xr.DataArray],
combine_times: bool
) -> xr.DataArray:
base = datasets[0].copy()
collected_attrs = [base.attrs]
Expand All @@ -136,20 +131,13 @@ def _stack_no_weights(
except KeyError:
base = base.where(data_arr.isnull(), data_arr)

attrs = _combine_stacked_attrs(collected_attrs, combine_times)
attrs = _combine_stacked_attrs(collected_attrs)
base.attrs = attrs
return base


def _combine_stacked_attrs(collected_attrs: Sequence[Mapping], combine_times: bool) -> dict:
attrs = combine_metadata(*collected_attrs)
if combine_times and ("start_time" in attrs or "end_time" in attrs):
new_start, new_end = _get_combined_start_end_times(collected_attrs)
if new_start:
attrs["start_time"] = new_start
if new_end:
attrs["end_time"] = new_end
return attrs
def _combine_stacked_attrs(collected_attrs: Sequence[Mapping]) -> dict:
return combine_metadata(*collected_attrs)


def _get_combined_start_end_times(metadata_objects: Iterable[Mapping]) -> tuple[datetime | None, datetime | None]:
Expand Down
15 changes: 5 additions & 10 deletions satpy/tests/multiscene_tests/test_blend.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,9 @@ def test_blend_two_scenes_bad_blend_type(self, multi_scene_and_weights, groups):
("select_with_weights", _get_expected_stack_select),
("blend_with_weights", _get_expected_stack_blend),
])
@pytest.mark.parametrize("combine_times", [False, True])
def test_blend_two_scenes_using_stack_weighted(self, multi_scene_and_weights, groups,
scene1_with_weights, scene2_with_weights,
combine_times, blend_func, exp_result_func):
blend_func, exp_result_func):
"""Test stacking two scenes using weights.
Here we test that the start and end times can be combined so that they
Expand All @@ -266,7 +265,7 @@ def test_blend_two_scenes_using_stack_weighted(self, multi_scene_and_weights, gr
multi_scene.group(simple_groups)

weights = [weights[0][0], weights[1][0]]
stack_func = partial(stack, weights=weights, blend_type=blend_func, combine_times=combine_times)
stack_func = partial(stack, weights=weights, blend_type=blend_func)
weighted_blend = multi_scene.blend(blend_function=stack_func)

expected = exp_result_func(scene1, scene2)
Expand All @@ -275,12 +274,8 @@ def test_blend_two_scenes_using_stack_weighted(self, multi_scene_and_weights, gr
np.testing.assert_allclose(result.data, expected.data)

_check_stacked_metadata(result, "CloudType")
if combine_times:
assert result.attrs["start_time"] == datetime(2023, 1, 16, 11, 9, 17)
assert result.attrs["end_time"] == datetime(2023, 1, 16, 11, 28, 1, 900000)
else:
assert result.attrs["start_time"] == datetime(2023, 1, 16, 11, 11, 7, 250000)
assert result.attrs["end_time"] == datetime(2023, 1, 16, 11, 20, 11, 950000)
assert result.attrs["start_time"] == datetime(2023, 1, 16, 11, 9, 17)
assert result.attrs["end_time"] == datetime(2023, 1, 16, 11, 28, 1, 900000)

@pytest.fixture()
def datasets_and_weights(self):
Expand Down Expand Up @@ -329,7 +324,7 @@ def test_blend_function_stack_weighted(self, datasets_and_weights, line, column)
input_data["weights"][1][line, :] = 2
input_data["weights"][2][:, column] = 2

stack_with_weights = partial(stack, weights=input_data["weights"], combine_times=False)
stack_with_weights = partial(stack, weights=input_data["weights"])
blend_result = stack_with_weights(input_data["datasets"][0:3])

ds1 = input_data["datasets"][0]
Expand Down

0 comments on commit db2bf31

Please sign in to comment.