Skip to content

Commit

Permalink
new sampling strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
leoschwarz committed Aug 7, 2024
1 parent 2547194 commit 5d9447f
Showing 1 changed file with 44 additions and 20 deletions.
64 changes: 44 additions & 20 deletions src/depiction/tools/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
import cyclopts
import numpy as np
import xarray
from loguru import logger
from numpy.typing import NDArray
from sklearn.cluster import KMeans, BisectingKMeans

from depiction.clustering.extrapolate import extrapolate_labels
from depiction.clustering.maxmin_sampling import maxmin_sampling
from depiction.clustering.metrics import cross_correlation
from depiction.clustering.stratified_grid import StratifiedGrid
from depiction.image.feature_selection import FeatureSelectionIQR, retain_features
from depiction.image.multi_channel_image import MultiChannelImage
from numpy.typing import NDArray
from sklearn.cluster import KMeans, BisectingKMeans


class MethodEnum(Enum):
Expand All @@ -23,13 +23,27 @@ class MethodEnum(Enum):
app = cyclopts.App()


def retain_n_best_features(image_full: MultiChannelImage, n_best_features: int) -> MultiChannelImage:
strongest_channels = image_full.channel_stats.interquartile_range.drop_nulls().sort("iqr").tail(n_best_features)
logger.info(f"Retaining {n_best_features} best features: {strongest_channels['c'].to_numpy()}")
return image_full.retain_channels(coords=strongest_channels["c"].to_numpy())
def get_landmark_indices(
image_features: MultiChannelImage, image_index: MultiChannelImage, n_landmarks: int, rng: np.random.Generator
) -> NDArray[int]:
image_joined = image_features.append_channels(image_index)
image_joined_flat = image_joined.data_flat
n_images = int(image_index.data_flat.values.max() + 1)
indices = []
for i_image in range(n_images):
# determine the indices corresponding to i_image in the flat representation
indices_image = np.where(image_joined_flat.sel(c="image_index").values == i_image)[0]

# number of landmarks to retrieve for this particular image
n_samples = n_landmarks // n_images if i_image != 0 else n_landmarks // n_images + n_landmarks % n_images

# determine the landmark indices
features_image = image_joined_flat.drop_sel(c="image_index").isel(i=indices_image)
indices_image_landmarks = maxmin_sampling(features_image.values.T, k=n_samples, rng=rng)

# TODO the feature selection should also be studied on a per-image basis to identify potential relevant differences
# revert these indices into the original space
indices.extend(indices_image[indices_image_landmarks])
return np.asarray(indices)


@app.default()
Expand All @@ -42,23 +56,31 @@ def clustering(
n_samples_cluster: int = 10000,
n_landmarks: int = 50,
) -> None:
image_full = MultiChannelImage.read_hdf5(path=input_hdf5)
image_full = image_full.drop_channels(coords=["image_index", "cluster"], allow_missing=True)
rng = np.random.default_rng(42)

# retain only the most informative features for the clustering
image = retain_n_best_features(image_full, n_best_features=n_best_features)
# read the input image
image_full_combined = MultiChannelImage.read_hdf5(path=input_hdf5)
assert "cluster" not in image_full_combined.channel_names
image_full_features = image_full_combined.drop_channels(coords=["image_index"], allow_missing=True)
image_full_image_index = image_full_combined.retain_channels(coords=["image_index"])

# retain the most relevant features
image_features = retain_features(
feature_selection=FeatureSelectionIQR.validate({"n_features": n_best_features}), image=image_full_features
)

# sample a number of landmark features which will be used for correlation-based clustering
# TODO it is possible that the landmarks are all from the same image in the concatenated case,
# which needs to be addressed somehow...
landmark_indices = maxmin_sampling(image.data_flat.values.T, k=n_landmarks, rng=rng)
landmark_features = image.data_flat.values.T[landmark_indices]
# since we might have more than one image, we want to make sure that we sample a bit of each
landmark_indices = get_landmark_indices(
image_features=image_features, image_index=image_full_image_index, n_landmarks=n_landmarks, rng=rng
)

landmark_features = image_features.data_flat.values.T[landmark_indices]

# sample a large number of samples to cluster against the full image
# TODO this could be improved a bit by making sure that the landmarks are never sampled here again
grid = StratifiedGrid(cells_x=20, cells_y=20)
sampled_features = grid.sample_points(array=image.data_flat, n_samples=n_samples_cluster, rng=rng).values.T
sampled_features = grid.sample_points(array=image_features.data_flat, n_samples=n_samples_cluster, rng=rng).values.T

# compute pairwise correlation between landmark features and sampled features
correlation_features = cross_correlation(sampled_features, landmark_features)
Expand All @@ -70,14 +92,16 @@ def clustering(
full_labels = extrapolate_labels(
sampled_features=sampled_features,
sampled_labels=sampled_labels,
full_features=image.data_flat.values.T,
full_features=image_features.data_flat.values.T,
)
label_image = MultiChannelImage.from_sparse(
values=full_labels[:, np.newaxis], coordinates=image_full.coordinates_flat, channel_names=["cluster"]
values=full_labels[:, np.newaxis], coordinates=image_full_features.coordinates_flat, channel_names=["cluster"]
)

# write the result of the operation
output_image = MultiChannelImage(xarray.concat([label_image.data_spatial, image_full.data_spatial], dim="c"))
output_image = MultiChannelImage(
xarray.concat([label_image.data_spatial, image_full_features.data_spatial], dim="c")
)
output_image.write_hdf5(output_hdf5)


Expand Down

0 comments on commit 5d9447f

Please sign in to comment.