diff --git a/src/depiction/image/image_normalization.py b/src/depiction/image/image_normalization.py index 77e2f8f..0d66664 100644 --- a/src/depiction/image/image_normalization.py +++ b/src/depiction/image/image_normalization.py @@ -28,7 +28,7 @@ def normalize_xarray(self, image: xarray.DataArray, variant: ImageNormalizationV else: raise NotImplementedError("Multiple index columns are not supported yet.") - def _normalize_single_xarray(self, image: xarray.DataArray, variant: ImageNormalizationVariant): + def _normalize_single_xarray(self, image: xarray.DataArray, variant: ImageNormalizationVariant) -> xarray.DataArray: if variant == ImageNormalizationVariant.VEC_NORM: return image / (((image**2).sum(["c"])) ** 0.5) elif variant == ImageNormalizationVariant.STD: @@ -36,7 +36,9 @@ def _normalize_single_xarray(self, image: xarray.DataArray, variant: ImageNormal else: raise NotImplementedError(f"Unknown variant: {variant}") - def _normalize_multiple_xarray(self, image: xarray.DataArray, index_dim: str, variant: ImageNormalizationVariant): + def _normalize_multiple_xarray( + self, image: xarray.DataArray, index_dim: str, variant: ImageNormalizationVariant + ) -> xarray.DataArray: dataset = image.to_dataset(dim=index_dim) normalized = dataset.map( lambda x: self._normalize_single_xarray(x, variant=variant), diff --git a/src/depiction/image/multi_channel_image.py b/src/depiction/image/multi_channel_image.py index 1e7084f..5299dd2 100644 --- a/src/depiction/image/multi_channel_image.py +++ b/src/depiction/image/multi_channel_image.py @@ -1,14 +1,17 @@ from __future__ import annotations from functools import cached_property -from pathlib import Path -from collections.abc import Sequence import numpy as np import xarray from xarray import DataArray from depiction.image.sparse_representation import SparseRepresentation +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Sequence + from pathlib import Path class MultiChannelImage: diff --git a/src/depiction/image/sparse_representation.py b/src/depiction/image/sparse_representation.py index 6f894eb..3e55c5e 100644 --- a/src/depiction/image/sparse_representation.py +++ b/src/depiction/image/sparse_representation.py @@ -40,6 +40,7 @@ def sparse_to_dense(cls, sparse_values: DataArray, coordinates: DataArray, bg_va for i_channel in range(n_channels): values_grid[tuple(coordinates_shifted.T) + (i_channel,)] = sparse_values[:, i_channel] + # TODO coordinates might come in the wrong order FIXME return DataArray(values_grid, dims=("y", "x", "c")) @classmethod diff --git a/src/depiction/image/xarray_helper.py b/src/depiction/image/xarray_helper.py index 319870a..96c37c1 100644 --- a/src/depiction/image/xarray_helper.py +++ b/src/depiction/image/xarray_helper.py @@ -60,8 +60,8 @@ def apply_on_spatial_view(array: DataArray, fn: Callable[[DataArray], DataArray] # remove nan result = result_flat.dropna("i", how="all").drop_isel(c=-1) - # TODO assigning the coords will be broken in the future, when "i" is a multi-index, however since in general - # it is not, this will require a case distinction + # TODO assigning the coords will be broken in the future, when "i" is a multi-index, + # however since in general it is not, this will require a case distinction return result.assign_coords(i=original_coords) else: raise ValueError(f"Unsupported dims={set(array.dims)}")