Skip to content

Commit

Permalink
Converting constraints to a nice object
Browse files Browse the repository at this point in the history
A big refactor, but I think it makes things cleaner.
  • Loading branch information
perimosocordiae committed Jul 13, 2016
1 parent e31fb50 commit ce5f238
Show file tree
Hide file tree
Showing 12 changed files with 151 additions and 118 deletions.
24 changes: 11 additions & 13 deletions examples/sandwich.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from sklearn.metrics import pairwise_distances
from sklearn.neighbors import NearestNeighbors

import metric_learn.constraints as C
from metric_learn import ITML, LMNN, LSML, SDML
from metric_learn import LMNN, ITML_Supervised, LSML_Supervised, SDML_Supervised


def sandwich_demo():
Expand All @@ -22,22 +21,21 @@ def sandwich_demo():
ax.set_xticks([])
ax.set_yticks([])

num_constraints = 60
mls = [
(LMNN(), (x, y)),
(ITML(), (x, C.positive_negative_pairs(y, len(x), num_constraints))),
(SDML(), (x, C.adjacency_matrix(y, len(x), num_constraints))),
(LSML(), (x, C.relative_quadruplets(y, num_constraints)))
LMNN(),
ITML_Supervised(num_constraints=200),
SDML_Supervised(num_constraints=200),
LSML_Supervised(num_constraints=200),
]

for ax_num, (ml,args) in zip(range(3,7), mls):
ml.fit(*args)
for ax_num, ml in enumerate(mls, start=3):
ml.fit(x, y)
tx = ml.transform()
ml_knn = nearest_neighbors(tx, k=2)
ax = plt.subplot(3,2,ax_num)
plot_sandwich_data(tx, y, ax)
plot_neighborhood_graph(tx, ml_knn, y, ax)
ax.set_title('%s space' % ml.__class__.__name__)
ax = plt.subplot(3, 2, ax_num)
plot_sandwich_data(tx, y, axis=ax)
plot_neighborhood_graph(tx, ml_knn, y, axis=ax)
ax.set_title(ml.__class__.__name__)
ax.set_xticks([])
ax.set_yticks([])
plt.show()
Expand Down
1 change: 1 addition & 0 deletions metric_learn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import absolute_import

from .constraints import Constraints
from .covariance import Covariance
from .itml import ITML, ITML_Supervised
from .lmnn import LMNN
Expand Down
127 changes: 74 additions & 53 deletions metric_learn/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,65 +4,86 @@
"""
import numpy as np
import random
import warnings
from six.moves import xrange
from scipy.sparse import coo_matrix

# @TODO: consider creating a stateful class
# https://github.com/all-umass/metric-learn/pull/19#discussion_r67386226
__all__ = ['Constraints']


def adjacency_matrix(labels, num_points, num_constraints):
a, c = np.random.randint(len(labels), size=(2,num_constraints))
b, d = np.empty((2, num_constraints), dtype=int)
for i,(al,cl) in enumerate(zip(labels[a],labels[c])):
b[i] = random.choice(np.nonzero(labels == al)[0])
d[i] = random.choice(np.nonzero(labels != cl)[0])
W = np.zeros((num_points,num_points))
W[a,b] = 1
W[c,d] = -1
# make W symmetric
W[b,a] = 1
W[d,c] = -1
return W
class Constraints(object):
def __init__(self, partial_labels):
'''partial_labels : int arraylike, -1 indicating unknown label'''
partial_labels = np.asanyarray(partial_labels)
self.num_points, = partial_labels.shape
self.known_label_idx, = np.where(partial_labels >= 0)
self.known_labels = partial_labels[self.known_label_idx]

def adjacency_matrix(self, num_constraints):
a, b, c, d = self.positive_negative_pairs(num_constraints)
row = np.concatenate((a, c))
col = np.concatenate((b, d))
data = np.ones_like(row, dtype=int)
data[len(a):] = -1
adj = coo_matrix((data, (row, col)), shape=(self.num_points,)*2)
# symmetrize
return adj + adj.T

def positive_negative_pairs(labels, num_points, num_constraints):
ac,bd = np.random.randint(num_points, size=(2,num_constraints))
pos = labels[ac] == labels[bd]
a,c = ac[pos], ac[~pos]
b,d = bd[pos], bd[~pos]
return a,b,c,d
def positive_negative_pairs(self, num_constraints, same_length=False):
a, b = self._pairs(num_constraints, same_label=True)
c, d = self._pairs(num_constraints, same_label=False)
if same_length and len(a) != len(c):
n = min(len(a), len(c))
return a[:n], b[:n], c[:n], d[:n]
return a, b, c, d

def _pairs(self, num_constraints, same_label=True, max_iter=10):
num_labels = len(self.known_labels)
ab = set()
it = 0
while it < max_iter and len(ab) < num_constraints:
nc = num_constraints - len(ab)
for aidx in np.random.randint(num_labels, size=nc):
if same_label:
mask = self.known_labels[aidx] == self.known_labels
mask[aidx] = False # avoid identity pairs
else:
mask = self.known_labels[aidx] != self.known_labels
b_choices, = np.where(mask)
if len(b_choices) > 0:
ab.add((aidx, np.random.choice(b_choices)))
it += 1
if len(ab) < num_constraints:
warnings.warn("Only generated %d %s constraints (requested %d)" % (
len(ab), 'positive' if same_label else 'negative', num_constraints))
ab = np.array(list(ab)[:num_constraints], dtype=int)
return self.known_label_idx[ab.T]

def relative_quadruplets(labels, num_constraints):
C = np.empty((num_constraints,4), dtype=int)
a, c = np.random.randint(len(labels), size=(2,num_constraints))
for i,(al,cl) in enumerate(zip(labels[a],labels[c])):
C[i,1] = random.choice(np.nonzero(labels == al)[0])
C[i,3] = random.choice(np.nonzero(labels != cl)[0])
C[:,0] = a
C[:,2] = c
return C
def chunks(self, num_chunks=100, chunk_size=2):
chunks = -np.ones_like(self.known_label_idx, dtype=int)
uniq, lookup = np.unique(self.known_labels, return_inverse=True)
all_inds = [set(np.where(lookup==c)[0]) for c in xrange(len(uniq))]
idx = 0
while idx < num_chunks and all_inds:
c = random.randint(0, len(all_inds)-1)
inds = all_inds[c]
if len(inds) < chunk_size:
del all_inds[c]
continue
ii = random.sample(inds, chunk_size)
inds.difference_update(ii)
chunks[ii] = idx
idx += 1
if idx < num_chunks:
raise ValueError('Unable to make %d chunks of %d examples each' %
(num_chunks, chunk_size))
return chunks


def chunks(Y, num_chunks=100, chunk_size=2, seed=None):
# @TODO: remove seed from params and use numpy RandomState
# https://github.com/all-umass/metric-learn/pull/19#discussion_r67386666
random.seed(seed)
chunks = -np.ones_like(Y, dtype=int)
uniq, lookup = np.unique(Y, return_inverse=True)
all_inds = [set(np.where(lookup==c)[0]) for c in xrange(len(uniq))]
idx = 0
while idx < num_chunks and all_inds:
c = random.randint(0, len(all_inds)-1)
inds = all_inds[c]
if len(inds) < chunk_size:
del all_inds[c]
continue
ii = random.sample(inds, chunk_size)
inds.difference_update(ii)
chunks[ii] = idx
idx += 1
if idx < num_chunks:
raise ValueError('Unable to make %d chunks of %d examples each' %
(num_chunks, chunk_size))
return chunks
@staticmethod
def random_subset(all_labels, num_preserved=np.inf):
n = len(all_labels)
num_ignored = max(0, n - num_preserved)
idx = np.random.randint(n, size=num_ignored)
partial_labels = np.array(all_labels, copy=True)
partial_labels[idx] = -1
return Constraints(partial_labels)
1 change: 1 addition & 0 deletions metric_learn/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from __future__ import absolute_import
import numpy as np

from .base_metric import BaseMetricLearner


Expand Down
25 changes: 13 additions & 12 deletions metric_learn/itml.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from six.moves import xrange
from sklearn.metrics import pairwise_distances

from . import constraints
from .base_metric import BaseMetricLearner
from .constraints import Constraints


class ITML(BaseMetricLearner):
Expand Down Expand Up @@ -70,7 +70,7 @@ def fit(self, X, constraints, bounds=None, A0=None):
----------
X : (n x d) data matrix
each row corresponds to a single instance
constraints : tuple of arrays
constraints : 4-tuple of arrays
(a,b,c,d) indices into X, such that d(X[a],X[b]) < d(X[c],X[d])
bounds : list (pos,neg) pairs, optional
bounds on similarity, s.t. d(X[a],X[b]) < pos and d(X[c],X[d]) > neg
Expand Down Expand Up @@ -142,7 +142,8 @@ def _vector_norm(X):
class ITML_Supervised(ITML):
"""Information Theoretic Metric Learning (ITML)"""
def __init__(self, gamma=1., max_iters=1000, convergence_threshold=1e-3,
num_constraints=None, bounds=None, A0=None, verbose=False):
num_labeled=np.inf, num_constraints=None, bounds=None, A0=None,
verbose=False):
"""Initialize the learner.
Parameters
Expand All @@ -151,17 +152,17 @@ def __init__(self, gamma=1., max_iters=1000, convergence_threshold=1e-3,
value for slack variables
max_iters : int, optional
convergence_threshold : float, optional
num_constraints: int, needed for .fit()
num_labeled : int, optional
number of labels to preserve for training
num_constraints: int, optional
number of constraints to generate
verbose : bool, optional
if True, prints information while learning
"""
ITML.__init__(self, gamma=gamma, max_iters=max_iters,
convergence_threshold=convergence_threshold, verbose=verbose)
self.params.update({
'num_constraints': num_constraints,
'bounds': bounds,
'A0': A0,
})
self.params.update(num_labeled=num_labeled, num_constraints=num_constraints,
bounds=bounds, A0=A0)

def fit(self, X, labels):
"""Create constraints from labels and learn the ITML model.
Expand All @@ -178,6 +179,6 @@ def fit(self, X, labels):
num_classes = np.unique(labels)
num_constraints = 20*(len(num_classes))**2

C = constraints.positive_negative_pairs(labels, X.shape[0], num_constraints)
return ITML.fit(self, X, C, bounds=self.params['bounds'],
A0=self.params['A0'])
c = Constraints.random_subset(labels, self.params['num_labeled'])
return ITML.fit(self, X, c.positive_negative_pairs(num_constraints),
bounds=self.params['bounds'], A0=self.params['A0'])
1 change: 1 addition & 0 deletions metric_learn/lfda.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import scipy
from six.moves import xrange
from sklearn.metrics import pairwise_distances

from .base_metric import BaseMetricLearner


Expand Down
7 changes: 5 additions & 2 deletions metric_learn/lmnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from collections import Counter
from six.moves import xrange
from sklearn.metrics import pairwise_distances

from .base_metric import BaseMetricLearner


Expand Down Expand Up @@ -237,10 +238,12 @@ def _sum_outer_products(data, a_inds, b_inds, weights=None):

class LMNN(_base_LMNN):
def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7,
regularization=0.5, convergence_tol=0.001, use_pca=True, verbose=False):
regularization=0.5, convergence_tol=0.001, use_pca=True,
verbose=False):
_base_LMNN.__init__(self, k=k, min_iter=min_iter, max_iter=max_iter,
learn_rate=learn_rate, regularization=regularization,
convergence_tol=convergence_tol, use_pca=use_pca, verbose=verbose)
convergence_tol=convergence_tol, use_pca=use_pca,
verbose=verbose)

def fit(self, X, labels):
self.X = X
Expand Down
33 changes: 18 additions & 15 deletions metric_learn/lsml.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import scipy.linalg
from six.moves import xrange

from . import constraints
from .base_metric import BaseMetricLearner
from .constraints import Constraints


class LSML(BaseMetricLearner):
Expand All @@ -35,10 +35,12 @@ def __init__(self, tol=1e-3, max_iter=1000, verbose=False):

def _prepare_inputs(self, X, constraints, weights, prior):
self.X = X
self.vab = np.diff(X[constraints[:,:2]], axis=1)[:,0]
self.vcd = np.diff(X[constraints[:,2:]], axis=1)[:,0]
a,b,c,d = constraints
self.vab = X[a] - X[b]
self.vcd = X[c] - X[d]
assert self.vab.shape == self.vcd.shape, 'Constraints must have same length'
if weights is None:
self.w = np.ones(constraints.shape[0])
self.w = np.ones(self.vab.shape[0])
else:
self.w = weights
self.w /= self.w.sum() # weights must sum to 1
Expand All @@ -57,7 +59,7 @@ def fit(self, X, constraints, weights=None, prior=None):
----------
X : (n x d) data matrix
each row corresponds to a single instance
constraints : (m x 4) matrix of ints
constraints : 4-tuple of arrays
(a,b,c,d) indices into X, such that d(X[a],X[b]) < d(X[c],X[d])
weights : (m,) array of floats, optional
scale factor for each constraint
Expand Down Expand Up @@ -130,8 +132,8 @@ def _regularization_loss(metric, prior_inv):


class LSML_Supervised(LSML):
def __init__(self, tol=1e-3, max_iter=1000, prior=None, num_constraints=None,
weights=None, verbose=False):
def __init__(self, tol=1e-3, max_iter=1000, prior=None, num_labeled=np.inf,
num_constraints=None, weights=None, verbose=False):
"""Initialize the learner.
Parameters
Expand All @@ -140,18 +142,18 @@ def __init__(self, tol=1e-3, max_iter=1000, prior=None, num_constraints=None,
max_iter : int, optional
prior : (d x d) matrix, optional
guess at a metric [default: covariance(X)]
num_constraints: int, needed for .fit()
num_labeled : int, optional
number of labels to preserve for training
num_constraints: int, optional
number of constraints to generate
weights : (m,) array of floats, optional
scale factor for each constraint
verbose : bool, optional
if True, prints information while learning
"""
LSML.__init__(self, tol=tol, max_iter=max_iter, verbose=verbose)
self.params.update({
'prior': prior,
'num_constraints': num_constraints,
'weights': weights,
})
self.params.update(prior=prior, num_labeled=num_labeled,
num_constraints=num_constraints, weights=weights)

def fit(self, X, labels):
"""Create constraints from labels and learn the LSML model.
Expand All @@ -168,6 +170,7 @@ def fit(self, X, labels):
num_classes = np.unique(labels)
num_constraints = 20*(len(num_classes))**2

C = constraints.relative_quadruplets(labels, num_constraints)
return LSML.fit(self, X, C, weights=self.params['weights'],
c = Constraints.random_subset(labels, self.params['num_labeled'])
pairs = c.positive_negative_pairs(num_constraints, same_length=True)
return LSML.fit(self, X, pairs, weights=self.params['weights'],
prior=self.params['prior'])
1 change: 1 addition & 0 deletions metric_learn/nca.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from __future__ import absolute_import
import numpy as np
from six.moves import xrange

from .base_metric import BaseMetricLearner


Expand Down
Loading

0 comments on commit ce5f238

Please sign in to comment.