diff --git a/src/depiction/image/feature_selection.py b/src/depiction/image/feature_selection.py index 50d39b7..661190a 100644 --- a/src/depiction/image/feature_selection.py +++ b/src/depiction/image/feature_selection.py @@ -17,16 +17,21 @@ class FeatureSelectionIQR(BaseModel): FeatureSelection = Annotated[Union[FeatureSelectionCV, FeatureSelectionIQR], Field(discriminator="method")] -def select_features(feature_selection: FeatureSelection, image: MultiChannelImage) -> MultiChannelImage: - """Returns a new ``MultiChannelImage`` that is a copy of ``image`` with only the selected features remaining.""" +def select_features(feature_selection: FeatureSelection, image: MultiChannelImage) -> list[str]: + """Returns the selected features based on the provided feature selection method.""" match feature_selection: case FeatureSelectionCV(n_features=n_features): - features = _select_features_cv(image=image, n_features=n_features) + return _select_features_cv(image=image, n_features=n_features) case FeatureSelectionIQR(n_features=n_features): - features = _select_features_iqr(image=image, n_features=n_features) + return _select_features_iqr(image=image, n_features=n_features) case _: raise ValueError("Invalid feature selection method.") - return image.retain_channels(coords=features) + + +def retain_features(feature_selection: FeatureSelection, image: MultiChannelImage) -> MultiChannelImage: + """Returns a new ``MultiChannelImage`` that is a copy of ``image`` with only the selected features remaining.""" + selected_features = select_features(feature_selection=feature_selection, image=image) + return image.retain_channels(coords=selected_features) def _select_features_cv(image: MultiChannelImage, n_features: int) -> list[str]: diff --git a/src/depiction_cluster_sandbox/workflow/proc/compute_image_umap_coefs.py b/src/depiction_cluster_sandbox/workflow/proc/compute_image_umap_coefs.py index 9ff32e5..2df9c11 100644 --- a/src/depiction_cluster_sandbox/workflow/proc/compute_image_umap_coefs.py +++ b/src/depiction_cluster_sandbox/workflow/proc/compute_image_umap_coefs.py @@ -4,7 +4,7 @@ from loguru import logger from umap import UMAP -from depiction.image.feature_selection import FeatureSelectionIQR, select_features +from depiction.image.feature_selection import FeatureSelectionIQR, retain_features from depiction.image.multi_channel_image import MultiChannelImage from depiction.image.multi_channel_image_concatenation import MultiChannelImageConcatenation @@ -31,7 +31,7 @@ def compute_image_umap_coefs( input_image = input_image_conc.get_combined_image() if enable_feature_selection: logger.info(f"Feature selection requested: {feature_selection}") - input_image = select_features(feature_selection=feature_selection, image=input_image) + input_image = retain_features(feature_selection=feature_selection, image=input_image) # compute the umap transformation into 2D logger.info(f"Computing UMAP for input image with shape {input_image.dimensions}")