Skip to content

Commit

Permalink
code style: image module
Browse files Browse the repository at this point in the history
  • Loading branch information
leoschwarz committed Jul 1, 2024
1 parent 86c1b82 commit 1755073
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 6 deletions.
6 changes: 4 additions & 2 deletions src/depiction/image/image_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,17 @@ def normalize_xarray(self, image: xarray.DataArray, variant: ImageNormalizationV
else:
raise NotImplementedError("Multiple index columns are not supported yet.")

def _normalize_single_xarray(self, image: xarray.DataArray, variant: ImageNormalizationVariant):
def _normalize_single_xarray(self, image: xarray.DataArray, variant: ImageNormalizationVariant) -> xarray.DataArray:
if variant == ImageNormalizationVariant.VEC_NORM:
return image / (((image**2).sum(["c"])) ** 0.5)
elif variant == ImageNormalizationVariant.STD:
return (image - image.mean("c")) / image.std("c")
else:
raise NotImplementedError(f"Unknown variant: {variant}")

def _normalize_multiple_xarray(self, image: xarray.DataArray, index_dim: str, variant: ImageNormalizationVariant):
def _normalize_multiple_xarray(
self, image: xarray.DataArray, index_dim: str, variant: ImageNormalizationVariant
) -> xarray.DataArray:
dataset = image.to_dataset(dim=index_dim)
normalized = dataset.map(
lambda x: self._normalize_single_xarray(x, variant=variant),
Expand Down
7 changes: 5 additions & 2 deletions src/depiction/image/multi_channel_image.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from __future__ import annotations

from functools import cached_property
from pathlib import Path
from collections.abc import Sequence

import numpy as np
import xarray
from xarray import DataArray

from depiction.image.sparse_representation import SparseRepresentation
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from collections.abc import Sequence
from pathlib import Path


class MultiChannelImage:
Expand Down
1 change: 1 addition & 0 deletions src/depiction/image/sparse_representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def sparse_to_dense(cls, sparse_values: DataArray, coordinates: DataArray, bg_va
for i_channel in range(n_channels):
values_grid[tuple(coordinates_shifted.T) + (i_channel,)] = sparse_values[:, i_channel]

# TODO coordinates might come in the wrong order FIXME
return DataArray(values_grid, dims=("y", "x", "c"))

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions src/depiction/image/xarray_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def apply_on_spatial_view(array: DataArray, fn: Callable[[DataArray], DataArray]
# remove nan
result = result_flat.dropna("i", how="all").drop_isel(c=-1)

# TODO assigning the coords will be broken in the future, when "i" is a multi-index, however since in general
# it is not, this will require a case distinction
# TODO assigning the coords will be broken in the future, when "i" is a multi-index,
# however since in general it is not, this will require a case distinction
return result.assign_coords(i=original_coords)
else:
raise ValueError(f"Unsupported dims={set(array.dims)}")
Expand Down

0 comments on commit 1755073

Please sign in to comment.