Skip to content

Commit

Permalink
make more useful
Browse files Browse the repository at this point in the history
  • Loading branch information
leoschwarz committed Aug 7, 2024
1 parent ddd837e commit 2547194
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
15 changes: 10 additions & 5 deletions src/depiction/image/feature_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,21 @@ class FeatureSelectionIQR(BaseModel):
FeatureSelection = Annotated[Union[FeatureSelectionCV, FeatureSelectionIQR], Field(discriminator="method")]


def select_features(feature_selection: FeatureSelection, image: MultiChannelImage) -> MultiChannelImage:
"""Returns a new ``MultiChannelImage`` that is a copy of ``image`` with only the selected features remaining."""
def select_features(feature_selection: FeatureSelection, image: MultiChannelImage) -> list[str]:
"""Returns the selected features based on the provided feature selection method."""
match feature_selection:
case FeatureSelectionCV(n_features=n_features):
features = _select_features_cv(image=image, n_features=n_features)
return _select_features_cv(image=image, n_features=n_features)
case FeatureSelectionIQR(n_features=n_features):
features = _select_features_iqr(image=image, n_features=n_features)
return _select_features_iqr(image=image, n_features=n_features)
case _:
raise ValueError("Invalid feature selection method.")
return image.retain_channels(coords=features)


def retain_features(feature_selection: FeatureSelection, image: MultiChannelImage) -> MultiChannelImage:
"""Returns a new ``MultiChannelImage`` that is a copy of ``image`` with only the selected features remaining."""
selected_features = select_features(feature_selection=feature_selection, image=image)
return image.retain_channels(coords=selected_features)


def _select_features_cv(image: MultiChannelImage, n_features: int) -> list[str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from loguru import logger
from umap import UMAP

from depiction.image.feature_selection import FeatureSelectionIQR, select_features
from depiction.image.feature_selection import FeatureSelectionIQR, retain_features
from depiction.image.multi_channel_image import MultiChannelImage
from depiction.image.multi_channel_image_concatenation import MultiChannelImageConcatenation

Expand All @@ -31,7 +31,7 @@ def compute_image_umap_coefs(
input_image = input_image_conc.get_combined_image()
if enable_feature_selection:
logger.info(f"Feature selection requested: {feature_selection}")
input_image = select_features(feature_selection=feature_selection, image=input_image)
input_image = retain_features(feature_selection=feature_selection, image=input_image)

# compute the umap transformation into 2D
logger.info(f"Computing UMAP for input image with shape {input_image.dimensions}")
Expand Down

0 comments on commit 2547194

Please sign in to comment.