Skip to content

Commit

Permalink
fix more indexing problems
Browse files Browse the repository at this point in the history
  • Loading branch information
leoschwarz committed Jun 20, 2024
1 parent fd86ee0 commit b1ef78d
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 15 deletions.
25 changes: 10 additions & 15 deletions src/depiction/image/xarray_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from typing import Callable

import numpy as np
import xarray
from xarray import DataArray

Expand Down Expand Up @@ -48,26 +47,22 @@ def apply_on_spatial_view(array: DataArray, fn: Callable[[DataArray], DataArray]
# call the function
result = fn(array_2d)

# keep only the elements that were present before
array_present_before = (
array_flat.notnull()
.reduce(np.all, "c", keepdims=True)
.unstack("i")
.transpose("y", "x", "c")
.astype(float)
.where(lambda x: x)
)
# we only want to drop the background, i.e. no dropping of the values that were nan before or where fn
# returned an all nan array
is_foreground = xarray.ones_like(array_flat.isel(c=[0])).unstack("i").transpose("y", "x", "c")

# stack this into the result
if "c" in array_flat.coords:
# TODO test case distinction
array_present_before = array_present_before.assign_coords(c=["nan_before"])
result = xarray.concat([result, array_present_before], dim="c")
is_foreground = is_foreground.assign_coords(c=["is_foreground"])
result = xarray.concat([result, is_foreground], dim="c")

# make flat again
result_flat = result.stack(i=index_order)
# remove nan
result = result_flat.drop_isel(c=-1).dropna("i", how="all")
result = result_flat.dropna("i", how="all").drop_isel(c=-1)

# TODO assigning the coords will be broken in the future, when "i" is a multi-index, however since in general
# it is not, this will require a case distinciton
# it is not, this will require a case distinction
return result.assign_coords(i=original_coords)
else:
raise ValueError(f"Unsupported dims={set(array.dims)}")
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/image/test_xarray_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ def array_flat(array_spatial) -> DataArray:
return array_spatial.stack(i=("x", "y")).dropna("i", how="all")


@pytest.fixture
def array_flat_with_nan_before(array_spatial) -> DataArray:
return array_spatial.stack(i=("x", "y"))


@pytest.fixture
def array_flat_transposed(array_spatial) -> DataArray:
return array_spatial.stack(i=("y", "x")).dropna("i", how="all")
Expand All @@ -92,6 +97,11 @@ def test_apply_on_spatial_view_array_flat_with_nan(array_flat) -> None:
xarray.testing.assert_equal(result, array_flat * 2)


def test_apply_on_spatial_view_array_flat_with_nan_before(array_flat_with_nan_before) -> None:
result = XarrayHelper.apply_on_spatial_view(array_flat_with_nan_before, dummy_function)
xarray.testing.assert_equal(result, array_flat_with_nan_before * 2)


def test_apply_on_spatial_view_array_flat_no_nan_transposed(array_flat_transposed) -> None:
array_flat_transposed = array_flat_transposed.fillna(0)
result = XarrayHelper.apply_on_spatial_view(array_flat_transposed, dummy_function)
Expand Down

0 comments on commit b1ef78d

Please sign in to comment.