diff --git a/src/depiction/image/xarray_helper.py b/src/depiction/image/xarray_helper.py index 96c37c1..ee0f01d 100644 --- a/src/depiction/image/xarray_helper.py +++ b/src/depiction/image/xarray_helper.py @@ -1,9 +1,6 @@ # TODO this is not the ideal place, but to avoid code duplication it's better to have a place for now from __future__ import annotations -from typing import Callable - -import xarray from xarray import DataArray @@ -25,49 +22,49 @@ def ensure_dense(cls, values: DataArray, copy: bool = False) -> DataArray: else: return values - @staticmethod - def apply_on_spatial_view(array: DataArray, fn: Callable[[DataArray], DataArray]) -> DataArray: - if set(array.dims) == {"y", "x", "c"}: - array = array.transpose("y", "x", "c") - result = fn(array) - return result.transpose("y", "x", "c") - elif set(array.dims) == {"i", "c"}: - # determine if i was indexing ["y", "x"] or ["x", "y"] - index_order = XarrayHelper.get_index_order(array) + # @staticmethod + # def apply_on_spatial_view(array: DataArray, fn: Callable[[DataArray], DataArray]) -> DataArray: + # if set(array.dims) == {"y", "x", "c"}: + # array = array.transpose("y", "x", "c") + # result = fn(array) + # return result.transpose("y", "x", "c") + # elif set(array.dims) == {"i", "c"}: + # # determine if i was indexing ["y", "x"] or ["x", "y"] + # index_order = XarrayHelper.get_index_order(array) - # reset index - original_coords = array.coords["i"] - array = array.reset_index("i") + # # reset index + # original_coords = array.coords["i"] + # array = array.reset_index("i") - # get 2d view - array_flat = array.set_xindex(index_order) - array_2d = array_flat.unstack("i").transpose("y", "x", "c") + # # get 2d view + # array_flat = array.set_xindex(index_order) + # array_2d = array_flat.unstack("i").transpose("y", "x", "c") - # call the function - result = fn(array_2d) + # # call the function + # result = fn(array_2d) - # we only want to drop the background, i.e. no dropping of the values that were nan before or where fn - # returned an all nan array - is_foreground = xarray.ones_like(array_flat.isel(c=[0])).unstack("i").transpose("y", "x", "c") + # # we only want to drop the background, i.e. no dropping of the values that were nan before or where fn + # # returned an all nan array + # is_foreground = xarray.ones_like(array_flat.isel(c=[0])).unstack("i").transpose("y", "x", "c") - # stack this into the result - if "c" in array_flat.coords: - is_foreground = is_foreground.assign_coords(c=["is_foreground"]) - result = xarray.concat([result, is_foreground], dim="c") + # # stack this into the result + # if "c" in array_flat.coords: + # is_foreground = is_foreground.assign_coords(c=["is_foreground"]) + # result = xarray.concat([result, is_foreground], dim="c") - # make flat again - result_flat = result.stack(i=index_order) - # remove nan - result = result_flat.dropna("i", how="all").drop_isel(c=-1) + # # make flat again + # result_flat = result.stack(i=index_order) + # # remove nan + # result = result_flat.dropna("i", how="all").drop_isel(c=-1) - # TODO assigning the coords will be broken in the future, when "i" is a multi-index, - # however since in general it is not, this will require a case distinction - return result.assign_coords(i=original_coords) - else: - raise ValueError(f"Unsupported dims={set(array.dims)}") + # # TODO assigning the coords will be broken in the future, when "i" is a multi-index, + # # however since in general it is not, this will require a case distinction + # return result.assign_coords(i=original_coords) + # else: + # raise ValueError(f"Unsupported dims={set(array.dims)}") - @staticmethod - def get_index_order(array: DataArray) -> tuple[str, str]: - index_order = tuple(array.coords["i"].coords) - assert index_order in (("i", "y", "x"), ("i", "x", "y")), f"Unexpected index_order={index_order}" - return index_order[1:] + # @staticmethod + # def get_index_order(array: DataArray) -> tuple[str, str]: + # index_order = tuple(array.coords["i"].coords) + # assert index_order in (("i", "y", "x"), ("i", "x", "y")), f"Unexpected index_order={index_order}" + # return index_order[1:] diff --git a/tests/unit/image/test_xarray_helper.py b/tests/unit/image/test_xarray_helper.py index b3c989d..8684dff 100644 --- a/tests/unit/image/test_xarray_helper.py +++ b/tests/unit/image/test_xarray_helper.py @@ -1,7 +1,6 @@ import numpy as np import pytest import sparse -import xarray.testing from xarray import DataArray from depiction.image.xarray_helper import XarrayHelper @@ -81,36 +80,36 @@ def array_flat_transposed(array_spatial) -> DataArray: return array_spatial.stack(i=("y", "x")).dropna("i", how="all") -def test_apply_on_spatial_view_array_spatial(array_spatial) -> None: - result = XarrayHelper.apply_on_spatial_view(array_spatial, dummy_function) - xarray.testing.assert_equal(result, array_spatial * 2) - - -def test_apply_on_spatial_view_array_flat_no_nan(array_flat) -> None: - array_flat = array_flat.fillna(0) - result = XarrayHelper.apply_on_spatial_view(array_flat, dummy_function) - xarray.testing.assert_equal(result, array_flat * 2) - - -def test_apply_on_spatial_view_array_flat_with_nan(array_flat) -> None: - result = XarrayHelper.apply_on_spatial_view(array_flat, dummy_function) - xarray.testing.assert_equal(result, array_flat * 2) - - -def test_apply_on_spatial_view_array_flat_with_nan_before(array_flat_with_nan_before) -> None: - result = XarrayHelper.apply_on_spatial_view(array_flat_with_nan_before, dummy_function) - xarray.testing.assert_equal(result, array_flat_with_nan_before * 2) - - -def test_apply_on_spatial_view_array_flat_no_nan_transposed(array_flat_transposed) -> None: - array_flat_transposed = array_flat_transposed.fillna(0) - result = XarrayHelper.apply_on_spatial_view(array_flat_transposed, dummy_function) - xarray.testing.assert_equal(result, array_flat_transposed * 2) - - -def test_apply_on_spatial_view_array_flat_with_nan_transposed(array_flat_transposed) -> None: - result = XarrayHelper.apply_on_spatial_view(array_flat_transposed, dummy_function) - xarray.testing.assert_equal(result, array_flat_transposed * 2) +# def test_apply_on_spatial_view_array_spatial(array_spatial) -> None: +# result = XarrayHelper.apply_on_spatial_view(array_spatial, dummy_function) +# xarray.testing.assert_equal(result, array_spatial * 2) +# +# +# def test_apply_on_spatial_view_array_flat_no_nan(array_flat) -> None: +# array_flat = array_flat.fillna(0) +# result = XarrayHelper.apply_on_spatial_view(array_flat, dummy_function) +# xarray.testing.assert_equal(result, array_flat * 2) +# +# +# def test_apply_on_spatial_view_array_flat_with_nan(array_flat) -> None: +# result = XarrayHelper.apply_on_spatial_view(array_flat, dummy_function) +# xarray.testing.assert_equal(result, array_flat * 2) +# +# +# def test_apply_on_spatial_view_array_flat_with_nan_before(array_flat_with_nan_before) -> None: +# result = XarrayHelper.apply_on_spatial_view(array_flat_with_nan_before, dummy_function) +# xarray.testing.assert_equal(result, array_flat_with_nan_before * 2) +# +# +# def test_apply_on_spatial_view_array_flat_no_nan_transposed(array_flat_transposed) -> None: +# array_flat_transposed = array_flat_transposed.fillna(0) +# result = XarrayHelper.apply_on_spatial_view(array_flat_transposed, dummy_function) +# xarray.testing.assert_equal(result, array_flat_transposed * 2) +# +# +# def test_apply_on_spatial_view_array_flat_with_nan_transposed(array_flat_transposed) -> None: +# result = XarrayHelper.apply_on_spatial_view(array_flat_transposed, dummy_function) +# xarray.testing.assert_equal(result, array_flat_transposed * 2) if __name__ == "__main__":