Skip to content

Commit

Permalink
deactivate some tricky functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
leoschwarz committed Oct 22, 2024
1 parent a5f4af8 commit 2b295fd
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 72 deletions.
79 changes: 38 additions & 41 deletions src/depiction/image/xarray_helper.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
# TODO this is not the ideal place, but to avoid code duplication it's better to have a place for now
from __future__ import annotations

from typing import Callable

import xarray
from xarray import DataArray


Expand All @@ -25,49 +22,49 @@ def ensure_dense(cls, values: DataArray, copy: bool = False) -> DataArray:
else:
return values

@staticmethod
def apply_on_spatial_view(array: DataArray, fn: Callable[[DataArray], DataArray]) -> DataArray:
if set(array.dims) == {"y", "x", "c"}:
array = array.transpose("y", "x", "c")
result = fn(array)
return result.transpose("y", "x", "c")
elif set(array.dims) == {"i", "c"}:
# determine if i was indexing ["y", "x"] or ["x", "y"]
index_order = XarrayHelper.get_index_order(array)
# @staticmethod
# def apply_on_spatial_view(array: DataArray, fn: Callable[[DataArray], DataArray]) -> DataArray:
# if set(array.dims) == {"y", "x", "c"}:
# array = array.transpose("y", "x", "c")
# result = fn(array)
# return result.transpose("y", "x", "c")
# elif set(array.dims) == {"i", "c"}:
# # determine if i was indexing ["y", "x"] or ["x", "y"]
# index_order = XarrayHelper.get_index_order(array)

# reset index
original_coords = array.coords["i"]
array = array.reset_index("i")
# # reset index
# original_coords = array.coords["i"]
# array = array.reset_index("i")

# get 2d view
array_flat = array.set_xindex(index_order)
array_2d = array_flat.unstack("i").transpose("y", "x", "c")
# # get 2d view
# array_flat = array.set_xindex(index_order)
# array_2d = array_flat.unstack("i").transpose("y", "x", "c")

# call the function
result = fn(array_2d)
# # call the function
# result = fn(array_2d)

# we only want to drop the background, i.e. no dropping of the values that were nan before or where fn
# returned an all nan array
is_foreground = xarray.ones_like(array_flat.isel(c=[0])).unstack("i").transpose("y", "x", "c")
# # we only want to drop the background, i.e. no dropping of the values that were nan before or where fn
# # returned an all nan array
# is_foreground = xarray.ones_like(array_flat.isel(c=[0])).unstack("i").transpose("y", "x", "c")

# stack this into the result
if "c" in array_flat.coords:
is_foreground = is_foreground.assign_coords(c=["is_foreground"])
result = xarray.concat([result, is_foreground], dim="c")
# # stack this into the result
# if "c" in array_flat.coords:
# is_foreground = is_foreground.assign_coords(c=["is_foreground"])
# result = xarray.concat([result, is_foreground], dim="c")

# make flat again
result_flat = result.stack(i=index_order)
# remove nan
result = result_flat.dropna("i", how="all").drop_isel(c=-1)
# # make flat again
# result_flat = result.stack(i=index_order)
# # remove nan
# result = result_flat.dropna("i", how="all").drop_isel(c=-1)

# TODO assigning the coords will be broken in the future, when "i" is a multi-index,
# however since in general it is not, this will require a case distinction
return result.assign_coords(i=original_coords)
else:
raise ValueError(f"Unsupported dims={set(array.dims)}")
# # TODO assigning the coords will be broken in the future, when "i" is a multi-index,
# # however since in general it is not, this will require a case distinction
# return result.assign_coords(i=original_coords)
# else:
# raise ValueError(f"Unsupported dims={set(array.dims)}")

@staticmethod
def get_index_order(array: DataArray) -> tuple[str, str]:
index_order = tuple(array.coords["i"].coords)
assert index_order in (("i", "y", "x"), ("i", "x", "y")), f"Unexpected index_order={index_order}"
return index_order[1:]
# @staticmethod
# def get_index_order(array: DataArray) -> tuple[str, str]:
# index_order = tuple(array.coords["i"].coords)
# assert index_order in (("i", "y", "x"), ("i", "x", "y")), f"Unexpected index_order={index_order}"
# return index_order[1:]
61 changes: 30 additions & 31 deletions tests/unit/image/test_xarray_helper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
import pytest
import sparse
import xarray.testing
from xarray import DataArray

from depiction.image.xarray_helper import XarrayHelper
Expand Down Expand Up @@ -81,36 +80,36 @@ def array_flat_transposed(array_spatial) -> DataArray:
return array_spatial.stack(i=("y", "x")).dropna("i", how="all")


def test_apply_on_spatial_view_array_spatial(array_spatial) -> None:
result = XarrayHelper.apply_on_spatial_view(array_spatial, dummy_function)
xarray.testing.assert_equal(result, array_spatial * 2)


def test_apply_on_spatial_view_array_flat_no_nan(array_flat) -> None:
array_flat = array_flat.fillna(0)
result = XarrayHelper.apply_on_spatial_view(array_flat, dummy_function)
xarray.testing.assert_equal(result, array_flat * 2)


def test_apply_on_spatial_view_array_flat_with_nan(array_flat) -> None:
result = XarrayHelper.apply_on_spatial_view(array_flat, dummy_function)
xarray.testing.assert_equal(result, array_flat * 2)


def test_apply_on_spatial_view_array_flat_with_nan_before(array_flat_with_nan_before) -> None:
result = XarrayHelper.apply_on_spatial_view(array_flat_with_nan_before, dummy_function)
xarray.testing.assert_equal(result, array_flat_with_nan_before * 2)


def test_apply_on_spatial_view_array_flat_no_nan_transposed(array_flat_transposed) -> None:
array_flat_transposed = array_flat_transposed.fillna(0)
result = XarrayHelper.apply_on_spatial_view(array_flat_transposed, dummy_function)
xarray.testing.assert_equal(result, array_flat_transposed * 2)


def test_apply_on_spatial_view_array_flat_with_nan_transposed(array_flat_transposed) -> None:
result = XarrayHelper.apply_on_spatial_view(array_flat_transposed, dummy_function)
xarray.testing.assert_equal(result, array_flat_transposed * 2)
# def test_apply_on_spatial_view_array_spatial(array_spatial) -> None:
# result = XarrayHelper.apply_on_spatial_view(array_spatial, dummy_function)
# xarray.testing.assert_equal(result, array_spatial * 2)
#
#
# def test_apply_on_spatial_view_array_flat_no_nan(array_flat) -> None:
# array_flat = array_flat.fillna(0)
# result = XarrayHelper.apply_on_spatial_view(array_flat, dummy_function)
# xarray.testing.assert_equal(result, array_flat * 2)
#
#
# def test_apply_on_spatial_view_array_flat_with_nan(array_flat) -> None:
# result = XarrayHelper.apply_on_spatial_view(array_flat, dummy_function)
# xarray.testing.assert_equal(result, array_flat * 2)
#
#
# def test_apply_on_spatial_view_array_flat_with_nan_before(array_flat_with_nan_before) -> None:
# result = XarrayHelper.apply_on_spatial_view(array_flat_with_nan_before, dummy_function)
# xarray.testing.assert_equal(result, array_flat_with_nan_before * 2)
#
#
# def test_apply_on_spatial_view_array_flat_no_nan_transposed(array_flat_transposed) -> None:
# array_flat_transposed = array_flat_transposed.fillna(0)
# result = XarrayHelper.apply_on_spatial_view(array_flat_transposed, dummy_function)
# xarray.testing.assert_equal(result, array_flat_transposed * 2)
#
#
# def test_apply_on_spatial_view_array_flat_with_nan_transposed(array_flat_transposed) -> None:
# result = XarrayHelper.apply_on_spatial_view(array_flat_transposed, dummy_function)
# xarray.testing.assert_equal(result, array_flat_transposed * 2)


if __name__ == "__main__":
Expand Down

0 comments on commit 2b295fd

Please sign in to comment.