diff --git a/examples/plot_metric_learning_examples.py b/examples/plot_metric_learning_examples.py index b46d1adc..0d602cbb 100644 --- a/examples/plot_metric_learning_examples.py +++ b/examples/plot_metric_learning_examples.py @@ -139,7 +139,7 @@ def plot_tsne(X, y, colormap=plt.cm.Paired): # # setting up LMNN -lmnn = metric_learn.LMNN(k=5, learn_rate=1e-6, init='random') +lmnn = metric_learn.LMNN(k=5, learn_rate=1e-6) # fit the data! lmnn.fit(X, y) diff --git a/metric_learn/lmnn.py b/metric_learn/lmnn.py index 600d55c0..1feec167 100644 --- a/metric_learn/lmnn.py +++ b/metric_learn/lmnn.py @@ -1,8 +1,6 @@ """ Large Margin Nearest Neighbor Metric learning (LMNN) """ -# TODO: periodic recalculation of impostors, PCA initialization - from __future__ import print_function, absolute_import import numpy as np import warnings @@ -208,31 +206,19 @@ def fit(self, X, y): ' (smallest class has %d)' % required_k) target_neighbors = self._select_targets(X, label_inds) - impostors = self._find_impostors(target_neighbors[:, -1], X, label_inds) - if len(impostors) == 0: - # L has already been initialized to an identity matrix - return # sum outer products dfG = _sum_outer_products(X, target_neighbors.flatten(), np.repeat(np.arange(X.shape[0]), k)) - df = np.zeros_like(dfG) - - # storage - a1 = [None]*k - a2 = [None]*k - for nn_idx in xrange(k): - a1[nn_idx] = np.array([]) - a2[nn_idx] = np.array([]) # initialize L L = self.transformer_ # first iteration: we compute variables (including objective and gradient) # at initialization point - G, objective, total_active, df, a1, a2 = ( - self._loss_grad(X, L, dfG, impostors, 1, k, reg, target_neighbors, df, - a1, a2)) + G, objective, total_active = self._loss_grad(X, L, dfG, k, + reg, target_neighbors, + label_inds) it = 1 # we already made one iteration @@ -246,10 +232,9 @@ def fit(self, X, y): # we compute the objective at next point # we copy variables that can be modified by _loss_grad, because if we # retry we don t want to modify them several times - (G_next, objective_next, total_active_next, df_next, a1_next, - a2_next) = ( - self._loss_grad(X, L_next, dfG, impostors, it, k, reg, - target_neighbors, df.copy(), list(a1), list(a2))) + (G_next, objective_next, total_active_next) = ( + self._loss_grad(X, L_next, dfG, k, reg, target_neighbors, + label_inds)) assert not np.isnan(objective) delta_obj = objective_next - objective if delta_obj > 0: @@ -264,8 +249,7 @@ def fit(self, X, y): # old variables to these new ones before next iteration and we # slightly increase the learning rate L = L_next - G, df, objective, total_active, a1, a2 = ( - G_next, df_next, objective_next, total_active_next, a1_next, a2_next) + G, objective, total_active = G_next, objective_next, total_active_next learn_rate *= 1.01 if self.verbose: @@ -285,46 +269,37 @@ def fit(self, X, y): self.n_iter_ = it return self - def _loss_grad(self, X, L, dfG, impostors, it, k, reg, target_neighbors, df, - a1, a2): + def _loss_grad(self, X, L, dfG, k, reg, target_neighbors, label_inds): # Compute pairwise distances under current metric Lx = L.dot(X.T).T - g0 = _inplace_paired_L2(*Lx[impostors]) + + # we need to find the furthest neighbor: Ni = 1 + _inplace_paired_L2(Lx[target_neighbors], Lx[:, None, :]) + furthest_neighbors = np.take_along_axis(target_neighbors, + Ni.argmax(axis=1)[:, None], 1) + impostors = self._find_impostors(furthest_neighbors.ravel(), X, + label_inds, L) + + g0 = _inplace_paired_L2(*Lx[impostors]) + + # we reorder the target neighbors g1, g2 = Ni[impostors] # compute the gradient total_active = 0 - for nn_idx in reversed(xrange(k)): + df = np.zeros((X.shape[1], X.shape[1])) + for nn_idx in reversed(xrange(k)): # note: reverse not useful here act1 = g0 < g1[:, nn_idx] act2 = g0 < g2[:, nn_idx] total_active += act1.sum() + act2.sum() - if it > 1: - plus1 = act1 & ~a1[nn_idx] - minus1 = a1[nn_idx] & ~act1 - plus2 = act2 & ~a2[nn_idx] - minus2 = a2[nn_idx] & ~act2 - else: - plus1 = act1 - plus2 = act2 - minus1 = np.zeros(0, dtype=int) - minus2 = np.zeros(0, dtype=int) - targets = target_neighbors[:, nn_idx] - PLUS, pweight = _count_edges(plus1, plus2, impostors, targets) + PLUS, pweight = _count_edges(act1, act2, impostors, targets) df += _sum_outer_products(X, PLUS[:, 0], PLUS[:, 1], pweight) - MINUS, mweight = _count_edges(minus1, minus2, impostors, targets) - df -= _sum_outer_products(X, MINUS[:, 0], MINUS[:, 1], mweight) in_imp, out_imp = impostors - df += _sum_outer_products(X, in_imp[minus1], out_imp[minus1]) - df += _sum_outer_products(X, in_imp[minus2], out_imp[minus2]) - - df -= _sum_outer_products(X, in_imp[plus1], out_imp[plus1]) - df -= _sum_outer_products(X, in_imp[plus2], out_imp[plus2]) + df -= _sum_outer_products(X, in_imp[act1], out_imp[act1]) + df -= _sum_outer_products(X, in_imp[act2], out_imp[act2]) - a1[nn_idx] = act1 - a2[nn_idx] = act2 # do the gradient update assert not np.isnan(df).any() G = dfG * reg + df * (1 - reg) @@ -332,7 +307,7 @@ def _loss_grad(self, X, L, dfG, impostors, it, k, reg, target_neighbors, df, # compute the objective function objective = total_active * (1 - reg) objective += G.flatten().dot(L.flatten()) - return 2 * G, objective, total_active, df, a1, a2 + return 2 * G, objective, total_active def _select_targets(self, X, label_inds): target_neighbors = np.empty((X.shape[0], self.k), dtype=int) @@ -344,8 +319,8 @@ def _select_targets(self, X, label_inds): target_neighbors[inds] = inds[nn] return target_neighbors - def _find_impostors(self, furthest_neighbors, X, label_inds): - Lx = self.transform(X) + def _find_impostors(self, furthest_neighbors, X, label_inds, L): + Lx = X.dot(L.T) margin_radii = 1 + _inplace_paired_L2(Lx[furthest_neighbors], Lx) impostors = [] for label in self.labels_[:-1]: diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index c49c9ef5..0f5ee2fa 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -2,9 +2,10 @@ import re import pytest import numpy as np +import scipy from scipy.optimize import check_grad, approx_fprime from six.moves import xrange -from sklearn.metrics import pairwise_distances +from sklearn.metrics import pairwise_distances, euclidean_distances from sklearn.datasets import (load_iris, make_classification, make_regression, make_spd_matrix) from numpy.testing import (assert_array_almost_equal, assert_array_equal, @@ -242,25 +243,15 @@ def test_loss_grad_lbfgs(self): lmnn.transformer_ = np.eye(n_components) target_neighbors = lmnn._select_targets(X, label_inds) - impostors = lmnn._find_impostors(target_neighbors[:, -1], X, label_inds) # sum outer products dfG = _sum_outer_products(X, target_neighbors.flatten(), np.repeat(np.arange(X.shape[0]), k)) - df = np.zeros_like(dfG) - - # storage - a1 = [None]*k - a2 = [None]*k - for nn_idx in xrange(k): - a1[nn_idx] = np.array([]) - a2[nn_idx] = np.array([]) # initialize L def loss_grad(flat_L): - return lmnn._loss_grad(X, flat_L.reshape(-1, X.shape[1]), dfG, impostors, - 1, k, reg, target_neighbors, df.copy(), - list(a1), list(a2)) + return lmnn._loss_grad(X, flat_L.reshape(-1, X.shape[1]), dfG, + k, reg, target_neighbors, label_inds) def fun(x): return loss_grad(x)[1] @@ -292,6 +283,141 @@ def test_changed_behaviour_warning(self): assert any(msg == str(wrn.message) for wrn in raised_warning) +def test_loss_func(capsys): + """Test the loss function (and its gradient) on a simple example, + by comparing the results with the actual implementation of metric-learn, + with a very simple (but nonperformant) implementation""" + + # toy dataset to use + X, y = make_classification(n_samples=10, n_classes=2, + n_features=6, + n_redundant=0, shuffle=True, + scale=[1, 1, 20, 20, 20, 20], random_state=42) + + def hinge(a): + if a > 0: + return a, 1 + else: + return 0, 0 + + def loss_fn(L, X, y, target_neighbors, reg): + L = L.reshape(-1, X.shape[1]) + Lx = np.dot(X, L.T) + loss = 0 + total_active = 0 + grad = np.zeros_like(L) + for i in range(X.shape[0]): + for j in target_neighbors[i]: + loss += (1 - reg) * np.sum((Lx[i] - Lx[j]) ** 2) + grad += (1 - reg) * np.outer(Lx[i] - Lx[j], X[i] - X[j]) + for l in range(X.shape[0]): + if y[i] != y[l]: + hin, active = hinge(1 + np.sum((Lx[i] - Lx[j])**2) - + np.sum((Lx[i] - Lx[l])**2)) + total_active += active + if active: + loss += reg * hin + grad += (reg * (np.outer(Lx[i] - Lx[j], X[i] - X[j]) - + np.outer(Lx[i] - Lx[l], X[i] - X[l]))) + grad = 2 * grad + return grad, loss, total_active + + # we check that the gradient we have computed in the non-performant implem + # is indeed the true gradient on a toy example: + + def _select_targets(X, y, k): + target_neighbors = np.empty((X.shape[0], k), dtype=int) + for label in np.unique(y): + inds, = np.nonzero(y == label) + dd = euclidean_distances(X[inds], squared=True) + np.fill_diagonal(dd, np.inf) + nn = np.argsort(dd)[..., :k] + target_neighbors[inds] = inds[nn] + return target_neighbors + + target_neighbors = _select_targets(X, y, 2) + regularization = 0.5 + n_features = X.shape[1] + x0 = np.random.randn(1, n_features) + + def loss(x0): + return loss_fn(x0.reshape(-1, X.shape[1]), X, y, target_neighbors, + regularization)[1] + + def grad(x0): + return loss_fn(x0.reshape(-1, X.shape[1]), X, y, target_neighbors, + regularization)[0].ravel() + + scipy.optimize.check_grad(loss, grad, x0.ravel()) + + class LMNN_with_callback(LMNN): + """ We will use a callback to get the gradient (see later) + """ + + def __init__(self, callback, *args, **kwargs): + self.callback = callback + super(LMNN_with_callback, self).__init__(*args, **kwargs) + + def _loss_grad(self, *args, **kwargs): + grad, objective, total_active = ( + super(LMNN_with_callback, self)._loss_grad(*args, **kwargs)) + self.callback.append(grad) + return grad, objective, total_active + + class LMNN_nonperformant(LMNN_with_callback): + + def fit(self, X, y): + self.y = y + return super(LMNN_nonperformant, self).fit(X, y) + + def _loss_grad(self, X, L, dfG, k, reg, target_neighbors, label_inds): + grad, loss, total_active = loss_fn(L.ravel(), X, self.y, + target_neighbors, self.regularization) + self.callback.append(grad) + return grad, loss, total_active + + mem1, mem2 = [], [] + lmnn_perf = LMNN_with_callback(verbose=True, random_state=42, + init='identity', max_iter=30, callback=mem1) + lmnn_nonperf = LMNN_nonperformant(verbose=True, random_state=42, + init='identity', max_iter=30, + callback=mem2) + objectives, obj_diffs, learn_rate, total_active = (dict(), dict(), dict(), + dict()) + for algo, name in zip([lmnn_perf, lmnn_nonperf], ['perf', 'nonperf']): + algo.fit(X, y) + out, _ = capsys.readouterr() + lines = re.split("\n+", out) + # we get every variable that is printed from the algorithm in verbose + num = '(-?\d+.?\d*(e[+|-]\d+)?)' + strings = [re.search("\d+ (?:{}) (?:{}) (?:(\d+)) (?:{})" + .format(num, num, num), s) for s in lines] + objectives[name] = [float(match.group(1)) for match in strings if match is + not None] + obj_diffs[name] = [float(match.group(3)) for match in strings if match is + not None] + total_active[name] = [float(match.group(5)) for match in strings if + match is not + None] + learn_rate[name] = [float(match.group(6)) for match in strings if match is + not None] + assert len(strings) >= 10 # we ensure that we actually did more than 10 + # iterations + assert total_active[name][0] >= 2 # we ensure that we have some active + # constraints (that's the case we want to test) + # we remove the last element because it can be equal to the penultimate + # if the last gradient update is null + for i in range(len(mem1)): + np.testing.assert_allclose(lmnn_perf.callback[i], + lmnn_nonperf.callback[i], + err_msg='Gradient different at position ' + '{}'.format(i)) + np.testing.assert_allclose(objectives['perf'], objectives['nonperf']) + np.testing.assert_allclose(obj_diffs['perf'], obj_diffs['nonperf']) + np.testing.assert_allclose(total_active['perf'], total_active['nonperf']) + np.testing.assert_allclose(learn_rate['perf'], learn_rate['nonperf']) + + @pytest.mark.parametrize('X, y, loss', [(np.array([[0], [1], [2], [3]]), [1, 1, 0, 0], 3.0), (np.array([[0], [1], [2], [3]]), @@ -312,7 +438,7 @@ def test_toy_ex_lmnn(X, y, loss): lmnn.transformer_ = np.eye(n_components) target_neighbors = lmnn._select_targets(X, label_inds) - impostors = lmnn._find_impostors(target_neighbors[:, -1], X, label_inds) + impostors = lmnn._find_impostors(target_neighbors[:, -1], X, label_inds, L) # sum outer products dfG = _sum_outer_products(X, target_neighbors.flatten(), @@ -327,9 +453,8 @@ def test_toy_ex_lmnn(X, y, loss): a2[nn_idx] = np.array([]) # assert that the loss equals the one computed by hand - assert lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, impostors, 1, k, - reg, target_neighbors, df, a1, a2)[1] == loss - + assert lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, k, + reg, target_neighbors, label_inds)[1] == loss def test_convergence_simple_example(capsys): # LMNN should converge on this simple example, which it did not with