-
Notifications
You must be signed in to change notification settings - Fork 66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Some robustifying fixes on Atol fit and tests for vectorizers #1096
Changes from all commits
a4576e5
c389eea
c72d4ee
4fabf81
632a55b
03c61bb
442f401
bd99872
7338eaf
9a94b4e
88a7a16
5e17323
f7b14d8
040fb54
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -756,7 +756,7 @@ def __init__( | |
self, | ||
quantiser=KMeans(n_clusters=2, n_init="auto"), | ||
weighting_method="cloud", | ||
contrast="gaussian" | ||
contrast="gaussian", | ||
): | ||
""" | ||
Constructor for the Atol measure vectorisation class. | ||
|
@@ -794,7 +794,8 @@ def get_weighting_method(self): | |
|
||
def fit(self, X, y=None, sample_weight=None): | ||
""" | ||
Calibration step: fit centers to the sample measures and derive inertias between centers. | ||
Calibration step: fit centers to the target sample measures and derive inertias between centers. If the target | ||
does not contain enough points for creating the intended number of centers, we fill in with bogus centers. | ||
|
||
Parameters: | ||
X (list N x d numpy arrays): input measures in R^d from which to learn center locations and inertias | ||
|
@@ -806,32 +807,48 @@ def fit(self, X, y=None, sample_weight=None): | |
Returns: | ||
self | ||
""" | ||
if not hasattr(self.quantiser, 'fit'): | ||
raise TypeError("quantiser %s has no `fit` attribute." % (self.quantiser)) | ||
|
||
# In fitting we remove infinite death time points so that every center is finite | ||
X = [dgm[~np.isinf(dgm).any(axis=1), :] for dgm in X] | ||
n_clusters = self.quantiser.n_clusters | ||
|
||
if not len(X): | ||
raise ValueError("Cannot fit Atol on empty target.") | ||
measures_concat = np.concatenate(X) | ||
if sample_weight is None: | ||
sample_weight = [self.get_weighting_method()(measure) for measure in X] | ||
|
||
measures_concat = np.concatenate(X) | ||
weights_concat = np.concatenate(sample_weight) | ||
|
||
self.quantiser.fit(X=measures_concat, sample_weight=weights_concat) | ||
# In fitting we remove infinite birth/death time points so that every center is finite. We do not care about duplicates. | ||
filtered_measures_concat = measures_concat[~np.isinf(measures_concat).any(axis=1), :] if len(measures_concat) else measures_concat | ||
filtered_weights_concat = weights_concat[~np.isinf(measures_concat).any(axis=1)] if len(measures_concat) else weights_concat | ||
|
||
n_points = len(filtered_measures_concat) | ||
if not n_points: | ||
raise ValueError("Cannot fit Atol on measure with infinite components only.") | ||
if n_points < n_clusters: | ||
self.quantiser.n_clusters = n_points | ||
|
||
self.quantiser.fit(X=filtered_measures_concat, sample_weight=filtered_weights_concat) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When the number of points equals the number of requested clusters, are clustering algorithms usually good at finding all of the points as centers? Another possibility, in the "too few points" case, would be to write to self.centers directly and skip the use of the quantiser. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. First of all thanks for all the great remarks.
Will work on that! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yeah, I don't think there is a function that begins sorting or hash-tabling and stops once it has found k distinct elements... I think the common case will be that the first Actually, do we really care? Having a duplicated feature is not much worse than having a feature that is always 0, and this is only for a degenerate case that really shouldn't happen anyway, so I don't think we should make too much effort beyond "warn and don't crash". I only mentioned the possibility of duplicate points in passing, but it does not bother me.
As you like. Unless someone else chimes in, I don't really care what points you use as filler. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that we do not really care, so I have dropped this! |
||
self.centers = self.quantiser.cluster_centers_ | ||
|
||
# Hack, but some people are unhappy if the order depends on the version of sklearn | ||
self.centers = self.centers[np.lexsort(self.centers.T)] | ||
if self.quantiser.n_clusters == 1: | ||
dist_centers = pairwise.pairwise_distances(measures_concat) | ||
dist_centers = pairwise.pairwise_distances(filtered_measures_concat) | ||
np.fill_diagonal(dist_centers, 0) | ||
best_inertia = np.max(dist_centers)/2 if np.max(dist_centers)/2 > 0 else 1 | ||
self.inertias = np.array([best_inertia]) | ||
else: | ||
dist_centers = pairwise.pairwise_distances(self.centers) | ||
dist_centers[dist_centers == 0] = np.inf | ||
self.inertias = np.min(dist_centers, axis=0)/2 | ||
|
||
if n_points < n_clusters: | ||
# There weren't enough points to fit n_clusters, so we arbitrarily put centers as [-np.inf]^measure_dim. | ||
print(f"[Atol] after filtering had only {n_points=} to fit {n_clusters=}, adding meaningless centers.") | ||
fill_center = np.repeat(np.inf, repeats=X[0].shape[1]) | ||
fill_inertia = 0 | ||
self.centers = np.concatenate([self.centers, np.repeat([fill_center], repeats=n_clusters-n_points, axis=0)]) | ||
self.inertias = np.concatenate([self.inertias, np.repeat(fill_inertia, repeats=n_clusters-n_points)]) | ||
self.quantiser.n_clusters = n_clusters | ||
return self | ||
|
||
def __call__(self, measure, sample_weight=None): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# The following tests only check that the program runs, not what it outputs | ||
|
||
import numpy as np | ||
|
||
from sklearn.base import clone | ||
from sklearn.cluster import KMeans | ||
|
||
from gudhi.representations import (Atol, Landscape, Silhouette, BettiCurve, ComplexPolynomial, \ | ||
TopologicalVector, PersistenceImage, Entropy) | ||
|
||
vectorizers = { | ||
"atol": Atol(quantiser=KMeans(n_clusters=2, random_state=202312, n_init="auto")), | ||
# "betti": BettiCurve(), | ||
} | ||
|
||
diag1 = [np.array([[0., np.inf], | ||
[0., 8.94427191], | ||
[0., 7.28010989], | ||
[0., 6.08276253], | ||
[0., 5.83095189], | ||
[0., 5.38516481], | ||
[0., 5.]]), | ||
np.array([[11., np.inf], | ||
[6.32455532, 6.70820393]]), | ||
np.empty(shape=[0, 2])] | ||
|
||
diag2 = [np.array([[0., np.inf], | ||
[0., 8.94427191], | ||
[0., 7.28010989], | ||
[0., 6.08276253], | ||
[0., 5.83095189], | ||
[0., 5.38516481], | ||
[0., 5.]]), | ||
np.array([[11., np.inf], | ||
[6.32455532, 6.70820393]]), | ||
np.array([[0., np.inf], | ||
[0., 1]])] | ||
|
||
diag3 = [np.empty(shape=[0, 2])] | ||
|
||
|
||
mglisse marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def test_fit(): | ||
print(f" > Testing `fit`.") | ||
for name, vectorizer in vectorizers.items(): | ||
print(f" >> Testing {name}") | ||
clone(vectorizer).fit(X=[diag1[0], diag2[0]]) | ||
|
||
|
||
def test_transform(): | ||
print(f" > Testing `transform`.") | ||
for name, vectorizer in vectorizers.items(): | ||
print(f" >> Testing {name}") | ||
clone(vectorizer).fit_transform(X=[diag1[0], diag2[0], diag3[0]]) | ||
|
||
|
||
def test_transform_empty(): | ||
print(f" > Testing `transform_empty`.") | ||
for name, vectorizer in vectorizers.items(): | ||
print(f" >> Testing {name}") | ||
copy_vec = clone(vectorizer).fit(X=[diag1[0], diag2[0]]) | ||
copy_vec.transform(X=[diag3[0], diag3[0]]) | ||
|
||
|
||
def test_set_output(): | ||
print(f" > Testing `set_output`.") | ||
try: | ||
import pandas | ||
for name, vectorizer in vectorizers.items(): | ||
print(f" >> Testing {name}") | ||
clone(vectorizer).set_output(transform="pandas") | ||
except ImportError: | ||
print("Missing pandas, skipping set_output test") | ||
|
||
|
||
def test_compose(): | ||
print(f" > Testing composition with `sklearn.compose.ColumnTransformer`.") | ||
from sklearn.compose import ColumnTransformer | ||
for name, vectorizer in vectorizers.items(): | ||
print(f" >> Testing {name}") | ||
ct = ColumnTransformer([ | ||
(f"{name}-0", clone(vectorizer), 0), | ||
(f"{name}-1", clone(vectorizer), 1), | ||
(f"{name}-2", clone(vectorizer), 2)] | ||
) | ||
ct.fit_transform(X=[diag1, diag2]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if we should ask the user to do so before calling Atol (with
DiagramSelector(use=True, point_type="finite")
for instance), or if all representations that requires no infinite birth/death time points shall do the same. To be discussed.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have a strong opinion on this, but a comment: Atol is slightly different from the other vectorizers in that it vectorizes$d$ -dimensional measures, not just diagrams, so the filtering done in the
Atol.fit
is specific and needed and hopefully guarantees to an extent that theAtol.fit
will work in that general case.Perhaps there is a case to be made for filtering-but-printing/warning if a vectorizer requiring no infinite death points receives one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No strong opinion here too... @mglisse or @MathieuCarriere ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Warning if the filtering removed any infinite points looks like a good idea (I think, not 100% sure). (and here the default of warnings.warn to print the warning only once looks good, since this is about explaining that users may want to filter themselves before atol, not something specific to one dataset)
It is true that it would be cleaner not to filter in Atol, but it should be cheap enough, so for convenience, it doesn't really bother me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Today's thought: I'm not sure if it's not simpler to just have vectorizers handle infinite time points though (for instance by filtering).
Since diagrams are the primary focus of vectorizers, I think it's OK that vectorizers handle diagrams in the more general case (that is including infinite points), and leave the special treatments (using
DiagramSelector
) to special cases.