diff --git a/metric_learn/lmnn.py b/metric_learn/lmnn.py index dea12f0c..f58bc00a 100644 --- a/metric_learn/lmnn.py +++ b/metric_learn/lmnn.py @@ -90,83 +90,49 @@ def fit(self, X, y): a1[nn_idx] = np.array([]) a2[nn_idx] = np.array([]) - # initialize gradient and L - G = dfG * reg + df * (1-reg) + # initialize L L = self.L_ - objective = np.inf - - # main loop - for it in xrange(1, self.max_iter): - df_old = df.copy() - a1_old = [a.copy() for a in a1] - a2_old = [a.copy() for a in a2] - objective_old = objective - # Compute pairwise distances under current metric - Lx = L.dot(self.X_.T).T - g0 = _inplace_paired_L2(*Lx[impostors]) - Ni = 1 + _inplace_paired_L2(Lx[target_neighbors], Lx[:,None,:]) - g1,g2 = Ni[impostors] - - # compute the gradient - total_active = 0 - for nn_idx in reversed(xrange(k)): - act1 = g0 < g1[:,nn_idx] - act2 = g0 < g2[:,nn_idx] - total_active += act1.sum() + act2.sum() - - if it > 1: - plus1 = act1 & ~a1[nn_idx] - minus1 = a1[nn_idx] & ~act1 - plus2 = act2 & ~a2[nn_idx] - minus2 = a2[nn_idx] & ~act2 - else: - plus1 = act1 - plus2 = act2 - minus1 = np.zeros(0, dtype=int) - minus2 = np.zeros(0, dtype=int) - - targets = target_neighbors[:,nn_idx] - PLUS, pweight = _count_edges(plus1, plus2, impostors, targets) - df += _sum_outer_products(self.X_, PLUS[:,0], PLUS[:,1], pweight) - MINUS, mweight = _count_edges(minus1, minus2, impostors, targets) - df -= _sum_outer_products(self.X_, MINUS[:,0], MINUS[:,1], mweight) - - in_imp, out_imp = impostors - df += _sum_outer_products(self.X_, in_imp[minus1], out_imp[minus1]) - df += _sum_outer_products(self.X_, in_imp[minus2], out_imp[minus2]) - - df -= _sum_outer_products(self.X_, in_imp[plus1], out_imp[plus1]) - df -= _sum_outer_products(self.X_, in_imp[plus2], out_imp[plus2]) - - a1[nn_idx] = act1 - a2[nn_idx] = act2 - - # do the gradient update - assert not np.isnan(df).any() - G = dfG * reg + df * (1-reg) - # compute the objective function - objective = total_active * (1-reg) - objective += G.flatten().dot(L.T.dot(L).flatten()) - assert not np.isnan(objective) - delta_obj = objective - objective_old + # first iteration: we compute variables (including objective and gradient) + # at initialization point + G, objective, total_active, df, a1, a2 = ( + self._loss_grad(L, dfG, impostors, 1, k, reg, target_neighbors, df, a1, + a2)) + + for it in xrange(2, self.max_iter): + # then at each iteration, we try to find a value of L that has better + # objective than the previous L, following the gradient: + while True: + # the next point next_L to try out is found by a gradient step + L_next = L - 2 * learn_rate * G + # we compute the objective at next point + # we copy variables that can be modified by _loss_grad, because if we + # retry we don t want to modify them several times + (G_next, objective_next, total_active_next, df_next, a1_next, + a2_next) = ( + self._loss_grad(L_next, dfG, impostors, it, k, reg, + target_neighbors, df.copy(), list(a1), list(a2))) + assert not np.isnan(objective) + delta_obj = objective_next - objective + if delta_obj > 0: + # if we did not find a better objective, we retry with an L closer to + # the starting point, by decreasing the learning rate (making the + # gradient step smaller) + learn_rate /= 2 + else: + # otherwise, if we indeed found a better obj, we get out of the loop + break + # when the better L is found (and the related variables), we set the + # old variables to these new ones before next iteration and we + # slightly increase the learning rate + L = L_next + G, df, objective, total_active, a1, a2 = ( + G_next, df_next, objective_next, total_active_next, a1_next, a2_next) + learn_rate *= 1.01 if self.verbose: print(it, objective, delta_obj, total_active, learn_rate) - # update step size - if delta_obj > 0: - # we're getting worse... roll back! - learn_rate /= 2.0 - df = df_old - a1 = a1_old - a2 = a2_old - objective = objective_old - else: - # update L - L -= learn_rate * 2 * L.dot(G) - learn_rate *= 1.01 - # check for convergence if it > self.min_iter and abs(delta_obj) < self.convergence_tol: if self.verbose: @@ -181,6 +147,54 @@ def fit(self, X, y): self.n_iter_ = it return self + def _loss_grad(self, L, dfG, impostors, it, k, reg, target_neighbors, df, a1, + a2): + # Compute pairwise distances under current metric + Lx = L.dot(self.X_.T).T + g0 = _inplace_paired_L2(*Lx[impostors]) + Ni = 1 + _inplace_paired_L2(Lx[target_neighbors], Lx[:, None, :]) + g1, g2 = Ni[impostors] + # compute the gradient + total_active = 0 + for nn_idx in reversed(xrange(k)): + act1 = g0 < g1[:, nn_idx] + act2 = g0 < g2[:, nn_idx] + total_active += act1.sum() + act2.sum() + + if it > 1: + plus1 = act1 & ~a1[nn_idx] + minus1 = a1[nn_idx] & ~act1 + plus2 = act2 & ~a2[nn_idx] + minus2 = a2[nn_idx] & ~act2 + else: + plus1 = act1 + plus2 = act2 + minus1 = np.zeros(0, dtype=int) + minus2 = np.zeros(0, dtype=int) + + targets = target_neighbors[:, nn_idx] + PLUS, pweight = _count_edges(plus1, plus2, impostors, targets) + df += _sum_outer_products(self.X_, PLUS[:, 0], PLUS[:, 1], pweight) + MINUS, mweight = _count_edges(minus1, minus2, impostors, targets) + df -= _sum_outer_products(self.X_, MINUS[:, 0], MINUS[:, 1], mweight) + + in_imp, out_imp = impostors + df += _sum_outer_products(self.X_, in_imp[minus1], out_imp[minus1]) + df += _sum_outer_products(self.X_, in_imp[minus2], out_imp[minus2]) + + df -= _sum_outer_products(self.X_, in_imp[plus1], out_imp[plus1]) + df -= _sum_outer_products(self.X_, in_imp[plus2], out_imp[plus2]) + + a1[nn_idx] = act1 + a2[nn_idx] = act2 + # do the gradient update + assert not np.isnan(df).any() + G = dfG * reg + df * (1 - reg) + # compute the objective function + objective = total_active * (1 - reg) + objective += G.flatten().dot(L.T.dot(L).flatten()) + return G, objective, total_active, df, a1, a2 + def _select_targets(self): target_neighbors = np.empty((self.X_.shape[0], self.k), dtype=int) for label in self.labels_: diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index 1f2af2f7..1d0a5d02 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -1,5 +1,5 @@ -import re import unittest +import re import pytest import numpy as np from scipy.optimize import check_grad @@ -76,6 +76,37 @@ def test_iris(self): self.assertLess(csep, 0.25) +def test_convergence_simple_example(capsys): + # LMNN should converge on this simple example, which it did not with + # this issue: https://github.com/metric-learn/metric-learn/issues/88 + X, y = make_classification(random_state=0) + lmnn = python_LMNN(verbose=True) + lmnn.fit(X, y) + out, _ = capsys.readouterr() + assert "LMNN converged with objective" in out + + +def test_no_twice_same_objective(capsys): + # test that the objective function never has twice the same value + # see https://github.com/metric-learn/metric-learn/issues/88 + X, y = make_classification(random_state=0) + lmnn = python_LMNN(verbose=True) + lmnn.fit(X, y) + out, _ = capsys.readouterr() + lines = re.split("\n+", out) + # we get only objectives from each line: + # the regexp matches a float that follows an integer (the iteration + # number), and which is followed by a (signed) float (delta obj). It + # matches for instance: + # 3 **1113.7665747189938** -3.182774197440267 46431.0200999999999998e-06 + objectives = [re.search("\d* (?:(\d*.\d*))[ | -]\d*.\d*", s) + for s in lines] + objectives = [match.group(1) for match in objectives if match is not None] + # we remove the last element because it can be equal to the penultimate + # if the last gradient update is null + assert len(objectives[:-1]) == len(set(objectives[:-1])) + + class TestSDML(MetricTestCase): def test_iris(self): # Note: this is a flaky test, which fails for certain seeds.