diff --git a/src/depiction/tools/clustering.py b/src/depiction/tools/clustering.py index d79e010..78e4fe2 100644 --- a/src/depiction/tools/clustering.py +++ b/src/depiction/tools/clustering.py @@ -12,6 +12,7 @@ from depiction.clustering.metrics import cross_correlation from depiction.clustering.stratified_grid import StratifiedGrid from depiction.image.feature_selection import FeatureSelectionIQR, retain_features +from depiction.image.image_normalization import ImageNormalization, ImageNormalizationVariant from depiction.image.multi_channel_image import MultiChannelImage @@ -62,6 +63,9 @@ def clustering( image_full_combined = MultiChannelImage.read_hdf5(path=input_hdf5) assert "cluster" not in image_full_combined.channel_names image_full_features = image_full_combined.drop_channels(coords=["image_index"], allow_missing=True) + image_full_features = ImageNormalization().normalize_image( + image=image_full_features, variant=ImageNormalizationVariant.VEC_NORM + ) image_full_image_index = image_full_combined.retain_channels(coords=["image_index"]) # retain the most relevant features @@ -100,18 +104,11 @@ def clustering( # write the result of the operation output_image = MultiChannelImage( - xarray.concat([label_image.data_spatial, image_full_features.data_spatial], dim="c") + xarray.concat([label_image.data_spatial, image_full_combined.data_spatial], dim="c") ) output_image.write_hdf5(output_hdf5) -# def select_features_cv(image: MultiChannelImage, n_keep: int = 30) -> DataArray: -# image_data = image.data_flat -# cv_score = image_data.std("i") / image_data.mean("i") -# return image.retain_channels(coords=cv_score.sortby("c")[-n_keep:].c.values) -# - - def compute_labels(features: NDArray[float], method: MethodEnum, method_params: str) -> NDArray[int]: if method == MethodEnum.KMEANS: clu = KMeans(n_clusters=10).fit(features)