Skip to content

Commit

Permalink
normalize features before clustering
Browse files Browse the repository at this point in the history
  • Loading branch information
leoschwarz committed Aug 7, 2024
1 parent f83ba1c commit 991b35d
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions src/depiction/tools/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from depiction.clustering.metrics import cross_correlation
from depiction.clustering.stratified_grid import StratifiedGrid
from depiction.image.feature_selection import FeatureSelectionIQR, retain_features
from depiction.image.image_normalization import ImageNormalization, ImageNormalizationVariant
from depiction.image.multi_channel_image import MultiChannelImage


Expand Down Expand Up @@ -62,6 +63,9 @@ def clustering(
image_full_combined = MultiChannelImage.read_hdf5(path=input_hdf5)
assert "cluster" not in image_full_combined.channel_names
image_full_features = image_full_combined.drop_channels(coords=["image_index"], allow_missing=True)
image_full_features = ImageNormalization().normalize_image(
image=image_full_features, variant=ImageNormalizationVariant.VEC_NORM
)
image_full_image_index = image_full_combined.retain_channels(coords=["image_index"])

# retain the most relevant features
Expand Down Expand Up @@ -100,18 +104,11 @@ def clustering(

# write the result of the operation
output_image = MultiChannelImage(
xarray.concat([label_image.data_spatial, image_full_features.data_spatial], dim="c")
xarray.concat([label_image.data_spatial, image_full_combined.data_spatial], dim="c")
)
output_image.write_hdf5(output_hdf5)


# def select_features_cv(image: MultiChannelImage, n_keep: int = 30) -> DataArray:
# image_data = image.data_flat
# cv_score = image_data.std("i") / image_data.mean("i")
# return image.retain_channels(coords=cv_score.sortby("c")[-n_keep:].c.values)
#


def compute_labels(features: NDArray[float], method: MethodEnum, method_params: str) -> NDArray[int]:
if method == MethodEnum.KMEANS:
clu = KMeans(n_clusters=10).fit(features)
Expand Down

0 comments on commit 991b35d

Please sign in to comment.