Skip to content

Commit

Permalink
Merge pull request pytroll#2696 from yukaribbba/add_features_to_backg…
Browse files Browse the repository at this point in the history
…round_compositor

Add double alpha channel support and improve metadata behaviours for BackgroundCompositor
  • Loading branch information
mraspaud authored Apr 22, 2024
2 parents b73b5a7 + 05fdcbf commit 211cda2
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 43 deletions.
133 changes: 109 additions & 24 deletions satpy/composites/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1670,25 +1670,76 @@ def __call__(self, *args, **kwargs):


class BackgroundCompositor(GenericCompositor):
"""A compositor that overlays one composite on top of another."""
"""A compositor that overlays one composite on top of another.
The output image mode will be determined by both foreground and background. Generally, when the background has
an alpha band, the output image will also have one.
============ ============ ========
Foreground Background Result
============ ============ ========
L L L
------------ ------------ --------
L LA LA
------------ ------------ --------
L RGB RGB
------------ ------------ --------
L RGBA RGBA
------------ ------------ --------
LA L L
------------ ------------ --------
LA LA LA
------------ ------------ --------
LA RGB RGB
------------ ------------ --------
LA RGBA RGBA
------------ ------------ --------
RGB L RGB
------------ ------------ --------
RGB LA RGBA
------------ ------------ --------
RGB RGB RGB
------------ ------------ --------
RGB RGBA RGBA
------------ ------------ --------
RGBA L RGB
------------ ------------ --------
RGBA LA RGBA
------------ ------------ --------
RGBA RGB RGB
------------ ------------ --------
RGBA RGBA RGBA
============ ============ ========
"""

def __call__(self, projectables, *args, **kwargs):
"""Call the compositor."""
projectables = self.match_data_arrays(projectables)

# Get enhanced datasets
foreground = enhance2dataset(projectables[0], convert_p=True)
background = enhance2dataset(projectables[1], convert_p=True)
# Adjust bands so that they match
# L/RGB -> RGB/RGB
# LA/RGB -> RGBA/RGBA
# RGB/RGBA -> RGBA/RGBA
before_bg_mode = background.attrs["mode"]

# Adjust bands so that they have the same mode
foreground = add_bands(foreground, background["bands"])
background = add_bands(background, foreground["bands"])

# It's important whether the alpha band of background is initially generated, e.g. by CloudCompositor
# The result will be used to determine the output image mode
initial_bg_alpha = "A" in before_bg_mode

attrs = self._combine_metadata_with_mode_and_sensor(foreground, background)
data = self._get_merged_image_data(foreground, background)
if "A" not in foreground.attrs["mode"] and "A" not in background.attrs["mode"]:
data = self._simple_overlay(foreground, background)
else:
data = self._get_merged_image_data(foreground, background, initial_bg_alpha=initial_bg_alpha)
for data_arr in data:
data_arr.attrs = attrs
res = super(BackgroundCompositor, self).__call__(data, **kwargs)
res.attrs.update(attrs)
attrs.update(res.attrs)
res.attrs = attrs
return res

def _combine_metadata_with_mode_and_sensor(self,
Expand All @@ -1707,27 +1758,61 @@ def _combine_metadata_with_mode_and_sensor(self,

@staticmethod
def _get_merged_image_data(foreground: xr.DataArray,
background: xr.DataArray
background: xr.DataArray,
initial_bg_alpha: bool,
) -> list[xr.DataArray]:
if "A" in foreground.attrs["mode"]:
# Use alpha channel as weight and blend the two composites
alpha = foreground.sel(bands="A")
data = []
# NOTE: there's no alpha band in the output image, it will
# be added by the data writer
for band in foreground.mode[:-1]:
fg_band = foreground.sel(bands=band)
bg_band = background.sel(bands=band)
chan = (fg_band * alpha + bg_band * (1 - alpha))
chan = xr.where(chan.isnull(), bg_band, chan)
data.append(chan)
else:
data_arr = xr.where(foreground.isnull(), background, foreground)
# Split to separate bands so the mode is correct
data = [data_arr.sel(bands=b) for b in data_arr["bands"]]
# For more info about alpha compositing please review https://en.wikipedia.org/wiki/Alpha_compositing
alpha_fore = _get_alpha(foreground)
alpha_back = _get_alpha(background)
new_alpha = alpha_fore + alpha_back * (1 - alpha_fore)

data = []

# Pass the image data (alpha band will be dropped temporally) to the writer
output_mode = background.attrs["mode"].replace("A", "")

for band in output_mode:
fg_band = foreground.sel(bands=band)
bg_band = background.sel(bands=band)
# Do the alpha compositing
chan = (fg_band * alpha_fore + bg_band * alpha_back * (1 - alpha_fore)) / new_alpha
# Fill the NaN area with background
chan = xr.where(chan.isnull(), bg_band * alpha_back, chan)
chan["bands"] = band
data.append(chan)

# If background has an initial alpha band, it will also be passed to the writer
if initial_bg_alpha:
new_alpha["bands"] = "A"
data.append(new_alpha)

return data

@staticmethod
def _simple_overlay(foreground: xr.DataArray,
background: xr.DataArray,) -> list[xr.DataArray]:
# This is for the case when no alpha bands are involved
# Just simply lay the foreground upon background
data_arr = xr.where(foreground.isnull(), background, foreground)
# Split to separate bands so the mode is correct
data = [data_arr.sel(bands=b) for b in data_arr["bands"]]

return data


def _get_alpha(dataset: xr.DataArray):
# 1. This function is only used by _get_merged_image_data
# 2. Both foreground and background have been through add_bands, so they have the same mode
# 3. If none of them has alpha band, they will be passed to _simple_overlay not _get_merged_image_data
# So any dataset(whether foreground or background) passed to this function has an alpha band for certain
# We will use it directly
alpha = dataset.sel(bands="A")
# There could be NaNs in the alpha
# Replace them with 0 to prevent cases like 1 + nan = nan, so they won't affect new_alpha
alpha = xr.where(alpha.isnull(), 0, alpha)

return alpha


class MaskingCompositor(GenericCompositor):
"""A compositor that masks e.g. IR 10.8 channel data using cloud products from NWC SAF."""
Expand Down
16 changes: 8 additions & 8 deletions satpy/tests/reader_tests/test_olci_nc.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,10 +335,10 @@ def test_bitflags_with_dataarray_without_flags(self):
"CLOUD_MARGIN", "CLOUD_AMBIGUOUS", "LOWRW", "LAND"]

mask = reduce(np.logical_or, [bflags[item] for item in items])
expected = np.array([True, False, True, True, True, True, False,
False, True, True, False, False, False, False,
False, False, False, True, False, True, False,
False, False, True, True, False, False, True,
expected = np.array([True, False, True, True, True, True, False,
False, True, True, False, False, False, False,
False, False, False, True, False, True, False,
False, False, True, True, False, False, True,
False])
assert all(mask == expected)

Expand Down Expand Up @@ -367,9 +367,9 @@ def test_bitflags_with_custom_flag_list(self):
"CLOUD_MARGIN", "CLOUD_AMBIGUOUS", "LOWRW", "LAND"]

mask = reduce(np.logical_or, [bflags[item] for item in items])
expected = np.array([True, False, True, True, True, True, False,
False, True, True, False, False, False, False,
False, False, False, True, False, True, False,
False, False, True, True, False, False, True,
expected = np.array([True, False, True, True, True, True, False,
False, True, True, False, False, False, False,
False, False, False, True, False, True, False,
False, False, True, True, False, False, True,
False])
assert all(mask == expected)
46 changes: 35 additions & 11 deletions satpy/tests/test_composites.py
Original file line number Diff line number Diff line change
Expand Up @@ -1484,31 +1484,52 @@ def setup_class(cls):
[[1., 0.5], [0., np.nan]],
[[1., 0.5], [0., np.nan]]]),
"RGBA": np.array([
[[1.0, 0.5], [0.0, np.nan]],
[[1.0, 0.5], [0.0, np.nan]],
[[1.0, 0.5], [0.0, np.nan]],
[[0.5, 0.5], [0.5, 0.5]]]),
[[1., 0.5], [0., np.nan]],
[[1., 0.5], [0., np.nan]],
[[1., 0.5], [0., np.nan]],
[[0.5, 0.5], [0., 0.5]]]),
}
cls.foreground_data = foreground_data

@mock.patch("satpy.composites.enhance2dataset", _enhance2dataset)
@pytest.mark.parametrize(
("foreground_bands", "background_bands", "exp_bands", "exp_result"),
[
("L", "L", "L", np.array([[1.0, 0.5], [0.0, 1.0]])),
("LA", "LA", "L", np.array([[1.0, 0.75], [0.5, 1.0]])),
("RGB", "RGB", "RGB", np.array([
("L", "L", "L", np.array([[1., 0.5], [0., 1.]])),
("L", "RGB", "RGB", np.array([
[[1., 0.5], [0., 1.]],
[[1., 0.5], [0., 1.]],
[[1., 0.5], [0., 1.]]])),
("RGBA", "RGBA", "RGB", np.array([
("LA", "LA", "LA", np.array([
[[1., 0.75], [0.5, 1.]],
[[1., 0.75], [0.5, 1.]],
[[1., 0.75], [0.5, 1.]]])),
("RGBA", "RGB", "RGB", np.array([
[[1., 1.], [1., 1.]]])),
("LA", "RGB", "RGB", np.array([
[[1., 0.75], [0.5, 1.]],
[[1., 0.75], [0.5, 1.]],
[[1., 0.75], [0.5, 1.]]])),
("RGB", "RGB", "RGB", np.array([
[[1., 0.5], [0., 1.]],
[[1., 0.5], [0., 1.]],
[[1., 0.5], [0., 1.]]])),
("RGB", "LA", "RGBA", np.array([
[[1., 0.5], [0., 1.]],
[[1., 0.5], [0., 1.]],
[[1., 0.5], [0., 1.]],
[[1., 1.], [1., 1.]]])),
("RGB", "RGBA", "RGBA", np.array([
[[1., 0.5], [0., 1.]],
[[1., 0.5], [0., 1.]],
[[1., 0.5], [0., 1.]],
[[1., 1.], [1., 1.]]])),
("RGBA", "RGBA", "RGBA", np.array([
[[1., 0.75], [1., 1.]],
[[1., 0.75], [1., 1.]],
[[1., 0.75], [1., 1.]],
[[1., 1.], [1., 1.]]])),
("RGBA", "RGB", "RGB", np.array([
[[1., 0.75], [1., 1.]],
[[1., 0.75], [1., 1.]],
[[1., 0.75], [1., 1.]]])),
]
)
def test_call(self, foreground_bands, background_bands, exp_bands, exp_result):
Expand All @@ -1518,6 +1539,7 @@ def test_call(self, foreground_bands, background_bands, exp_bands, exp_result):

# L mode images
foreground_data = self.foreground_data[foreground_bands]

attrs = {"mode": foreground_bands, "area": "foo"}
foreground = xr.DataArray(da.from_array(foreground_data),
dims=("bands", "y", "x"),
Expand All @@ -1527,7 +1549,9 @@ def test_call(self, foreground_bands, background_bands, exp_bands, exp_result):
background = xr.DataArray(da.ones((len(background_bands), 2, 2)), dims=("bands", "y", "x"),
coords={"bands": [c for c in attrs["mode"]]},
attrs=attrs)

res = comp([foreground, background])

assert res.attrs["area"] == "foo"
np.testing.assert_allclose(res, exp_result)
assert res.attrs["mode"] == exp_bands
Expand Down

0 comments on commit 211cda2

Please sign in to comment.