diff --git a/src/depiction/clustering/maxmin_sampling.py b/src/depiction/clustering/maxmin_sampling.py index e4d2227..9180265 100644 --- a/src/depiction/clustering/maxmin_sampling.py +++ b/src/depiction/clustering/maxmin_sampling.py @@ -4,7 +4,7 @@ from numpy.random import Generator -def maxmin_sampling(vectors: np.ndarray, k: int, rng: Generator) -> np.ndarray: +def maxmin_sampling(vectors: np.ndarray, k: int, rng: Generator, metric="euclidean") -> np.ndarray: """ Sample k diverse vectors from the given set of vectors using the MaxMin algorithm. @@ -35,7 +35,17 @@ def maxmin_sampling(vectors: np.ndarray, k: int, rng: Generator) -> np.ndarray: selected = [rng.integers(n)] # Compute distances to the selected point - distances = np.sum((vectors - vectors[selected[0]]) ** 2, axis=1) + # TODO this can be done more nicely with a scipy method later + if metric == "euclidean": + distances = np.sum((vectors - vectors[selected[0]]) ** 2, axis=1) + elif metric == "cosine": + distances = 1 - np.dot(vectors, vectors[selected[0]]) / ( + np.linalg.norm(vectors, axis=1) * np.linalg.norm(vectors[selected[0]]) + ) + elif metric == "correlation": + distances = 1 - np.dot(vectors, vectors[selected[0]]) / n + else: + raise ValueError(f"Unsupported metric: {metric}") for _ in range(1, k): # Select the point with maximum distance to the already selected points diff --git a/src/depiction/tools/clustering.py b/src/depiction/tools/clustering.py index 8be4291..93cdbad 100644 --- a/src/depiction/tools/clustering.py +++ b/src/depiction/tools/clustering.py @@ -26,7 +26,11 @@ class MethodEnum(Enum): def get_landmark_indices( - image_features: MultiChannelImage, image_index: MultiChannelImage, n_landmarks: int, rng: np.random.Generator + image_features: MultiChannelImage, + image_index: MultiChannelImage, + n_landmarks: int, + rng: np.random.Generator, + metric: str, ) -> NDArray[int]: image_joined = image_features.append_channels(image_index) image_joined_flat = image_joined.data_flat @@ -41,7 +45,7 @@ def get_landmark_indices( # determine the landmark indices features_image = image_joined_flat.drop_sel(c="image_index").isel(i=indices_image) - indices_image_landmarks = maxmin_sampling(features_image.values.T, k=n_samples, rng=rng) + indices_image_landmarks = maxmin_sampling(features_image.values.T, k=n_samples, rng=rng, metric=metric) # revert these indices into the original space indices.extend(indices_image[indices_image_landmarks]) @@ -57,6 +61,7 @@ def clustering( n_best_features: int = 30, n_samples_cluster: int = 10000, n_landmarks: int = 200, + landmark_metric: str = "correlation", ) -> None: rng = np.random.default_rng(42) @@ -65,7 +70,7 @@ def clustering( assert "cluster" not in image_full_combined.channel_names image_full_features = image_full_combined.drop_channels(coords=["image_index"], allow_missing=True) image_full_features = ImageNormalization().normalize_image( - image=image_full_features, variant=ImageNormalizationVariant.VEC_NORM + image=image_full_features, variant=ImageNormalizationVariant.STD ) image_full_image_index = image_full_combined.retain_channels(coords=["image_index"]) @@ -77,7 +82,11 @@ def clustering( # sample a number of landmark features which will be used for correlation-based clustering # since we might have more than one image, we want to make sure that we sample a bit of each landmark_indices = get_landmark_indices( - image_features=image_features, image_index=image_full_image_index, n_landmarks=n_landmarks, rng=rng + image_features=image_features, + image_index=image_full_image_index, + n_landmarks=n_landmarks, + rng=rng, + metric=landmark_metric, ) landmark_features = image_features.data_flat.values.T[landmark_indices]