-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add initial clustering.extrapolate functionality
- Loading branch information
1 parent
1bd7c54
commit 1508721
Showing
2 changed files
with
189 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import numpy as np | ||
from numpy.typing import NDArray | ||
|
||
|
||
def extrapolate_labels( | ||
sampled_features: NDArray[float], sampled_labels: NDArray[int], full_features: NDArray[float] | ||
) -> NDArray[int]: | ||
"""Extrapolates cluster labels for a number of sampled features to the full set of features.""" | ||
if sampled_features.shape[1] != full_features.shape[1]: | ||
raise ValueError( | ||
f"Number of features must be the same in sampled_features ({sampled_features.shape[1]}) and full_features ({full_features.shape[1]})" | ||
) | ||
|
||
n_full_samples = full_features.shape[0] | ||
cluster_centers = get_cluster_centers(features=sampled_features, labels=sampled_labels) | ||
full_labels = np.zeros(n_full_samples, dtype=int) | ||
for i_sample in range(n_full_samples): | ||
distances = np.linalg.norm(cluster_centers - full_features[i_sample], axis=1) | ||
full_labels[i_sample] = np.argmin(distances) | ||
return full_labels | ||
|
||
|
||
def get_cluster_centers(features: NDArray[float], labels: NDArray[int]) -> NDArray[float]: | ||
"""Returns the cluster centers for the given features and labels. | ||
:param features: The features for each sample (n_samples, n_features). | ||
:param labels: The cluster labels for each sample (n_samples,). | ||
""" | ||
# normalize | ||
n_clusters = labels.max() - labels.min() + 1 | ||
labels = labels - labels.min() | ||
|
||
# compute cluster centers | ||
n_dim = features.shape[1] | ||
cluster_centers = np.zeros((n_clusters, n_dim)) | ||
for i_cluster in range(n_clusters): | ||
cluster_centers[i_cluster] = np.mean(features[labels == i_cluster], axis=0) | ||
return cluster_centers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
import numpy as np | ||
import pytest | ||
from depiction.clustering.extrapolate import get_cluster_centers, extrapolate_labels | ||
from numpy.testing import assert_array_almost_equal, assert_array_equal | ||
|
||
|
||
@pytest.fixture | ||
def basic_features(): | ||
return np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) | ||
|
||
|
||
@pytest.fixture | ||
def basic_labels(): | ||
return np.array([0, 0, 1, 1]) | ||
|
||
|
||
@pytest.fixture | ||
def basic_expected_centers(): | ||
return np.array([[2, 3], [6, 7]]) | ||
|
||
|
||
@pytest.fixture | ||
def high_dim_features(): | ||
return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]) | ||
|
||
|
||
@pytest.fixture | ||
def full_features(basic_features): | ||
return np.vstack([basic_features, [[9, 10]]]) | ||
|
||
|
||
def test_extrapolate_labels_basic(mocker, basic_features, basic_labels, full_features, basic_expected_centers): | ||
mocker.patch("depiction.clustering.extrapolate.get_cluster_centers", return_value=basic_expected_centers) | ||
|
||
expected_labels = np.array([0, 0, 1, 1, 1]) | ||
result = extrapolate_labels(basic_features, basic_labels, full_features) | ||
|
||
assert_array_equal(result, expected_labels) | ||
|
||
|
||
def test_extrapolate_labels_single_cluster(mocker, basic_features, full_features): | ||
mocker.patch("depiction.clustering.extrapolate.get_cluster_centers", return_value=np.array([[4, 5]])) | ||
|
||
sample_labels = np.array([0, 0, 0, 0]) | ||
expected_labels = np.array([0, 0, 0, 0, 0]) | ||
|
||
result = extrapolate_labels(basic_features, sample_labels, full_features) | ||
assert_array_equal(result, expected_labels) | ||
|
||
|
||
def test_extrapolate_labels_high_dimensionality(mocker, high_dim_features): | ||
sample_features = high_dim_features[:3] | ||
sample_labels = np.array([0, 1, 1]) | ||
full_features = high_dim_features | ||
mock_centers = np.array([[1, 2, 3], [5.5, 6.5, 7.5]]) | ||
|
||
mocker.patch("depiction.clustering.extrapolate.get_cluster_centers", return_value=mock_centers) | ||
|
||
expected_labels = np.array([0, 1, 1, 1]) | ||
result = extrapolate_labels(sample_features, sample_labels, full_features) | ||
|
||
assert_array_equal(result, expected_labels) | ||
|
||
|
||
def test_extrapolate_labels_empty_full_features(mocker, basic_features, basic_labels): | ||
mocker.patch("depiction.clustering.extrapolate.get_cluster_centers", return_value=np.array([[2, 3], [6, 7]])) | ||
|
||
full_features = np.empty((0, 2)) | ||
expected_labels = np.array([], dtype=int) | ||
|
||
result = extrapolate_labels(basic_features, basic_labels, full_features) | ||
assert_array_equal(result, expected_labels) | ||
|
||
|
||
def test_extrapolate_labels_different_feature_count(basic_features, basic_labels): | ||
with pytest.raises(ValueError, match="Number of features must be the same"): | ||
full_features = np.array([[1, 2, 3], [4, 5, 6]]) | ||
extrapolate_labels(basic_features, basic_labels, full_features) | ||
|
||
|
||
def test_extrapolate_labels_mock_calls(mocker, basic_features, basic_labels, full_features): | ||
mock_get_centers = mocker.patch("depiction.clustering.extrapolate.get_cluster_centers") | ||
mock_norm = mocker.patch("numpy.linalg.norm") | ||
mock_argmin = mocker.patch("numpy.argmin") | ||
|
||
mock_get_centers.return_value = np.array([[2, 3], [6, 7]]) | ||
mock_norm.return_value = np.array([1, 2]) | ||
mock_argmin.return_value = 0 | ||
|
||
extrapolate_labels(basic_features, basic_labels, full_features) | ||
|
||
mock_get_centers.assert_called_once_with(features=basic_features, labels=basic_labels) | ||
assert mock_norm.call_count == full_features.shape[0] | ||
assert mock_argmin.call_count == full_features.shape[0] | ||
|
||
|
||
def test_get_cluster_centers_basic(basic_features, basic_labels, basic_expected_centers): | ||
result = get_cluster_centers(basic_features, basic_labels) | ||
assert_array_almost_equal(result, basic_expected_centers) | ||
|
||
|
||
def test_get_cluster_centers_single_cluster(): | ||
features = np.array([[1, 2], [3, 4], [5, 6]]) | ||
labels = np.array([0, 0, 0]) | ||
expected_centers = np.array([[3, 4]]) | ||
|
||
result = get_cluster_centers(features, labels) | ||
assert_array_almost_equal(result, expected_centers) | ||
|
||
|
||
def test_get_cluster_centers_non_zero_min_label(basic_features, basic_expected_centers): | ||
labels = np.array([1, 1, 2, 2]) | ||
|
||
result = get_cluster_centers(basic_features, labels) | ||
assert_array_almost_equal(result, basic_expected_centers) | ||
|
||
|
||
def test_get_cluster_centers_high_dimensionality(high_dim_features): | ||
labels = np.array([0, 0, 1, 1]) | ||
expected_centers = np.array([[2.5, 3.5, 4.5], [8.5, 9.5, 10.5]]) | ||
|
||
result = get_cluster_centers(high_dim_features, labels) | ||
assert_array_almost_equal(result, expected_centers) | ||
|
||
|
||
def test_get_cluster_centers_empty_cluster(): | ||
features = np.array([[1, 2], [3, 4], [5, 6]]) | ||
labels = np.array([0, 2, 2]) | ||
expected_centers = np.array([[1, 2], [np.nan, np.nan], [4, 5]]) | ||
|
||
result = get_cluster_centers(features, labels) | ||
assert_array_almost_equal(result, expected_centers) | ||
|
||
|
||
def test_get_cluster_centers_input_validation(basic_features): | ||
with pytest.raises(IndexError): | ||
get_cluster_centers(basic_features, np.array([0, 1, 2])) | ||
|
||
|
||
def test_get_cluster_centers_mock_mean(mocker, basic_features, basic_labels): | ||
mock_mean = mocker.patch("numpy.mean") | ||
mock_mean.return_value = np.array([10, 20]) | ||
|
||
expected_centers = np.array([[10, 20], [10, 20]]) | ||
|
||
result = get_cluster_centers(basic_features, basic_labels) | ||
assert_array_almost_equal(result, expected_centers) | ||
assert mock_mean.call_count == 2 | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main() |