Skip to content

[MRG] update impostors, closer to original implem #228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jul 4, 2019
2 changes: 1 addition & 1 deletion examples/plot_metric_learning_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def plot_tsne(X, y, colormap=plt.cm.Paired):
#

# setting up LMNN
lmnn = metric_learn.LMNN(k=5, learn_rate=1e-6, init='random')
lmnn = metric_learn.LMNN(k=5, learn_rate=1e-6)

# fit the data!
lmnn.fit(X, y)
Expand Down
77 changes: 26 additions & 51 deletions metric_learn/lmnn.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
"""
Large Margin Nearest Neighbor Metric learning (LMNN)
"""
# TODO: periodic recalculation of impostors, PCA initialization

from __future__ import print_function, absolute_import
import numpy as np
import warnings
Expand Down Expand Up @@ -208,31 +206,19 @@ def fit(self, X, y):
' (smallest class has %d)' % required_k)

target_neighbors = self._select_targets(X, label_inds)
impostors = self._find_impostors(target_neighbors[:, -1], X, label_inds)
if len(impostors) == 0:
# L has already been initialized to an identity matrix
return

# sum outer products
dfG = _sum_outer_products(X, target_neighbors.flatten(),
np.repeat(np.arange(X.shape[0]), k))
df = np.zeros_like(dfG)

# storage
a1 = [None]*k
a2 = [None]*k
for nn_idx in xrange(k):
a1[nn_idx] = np.array([])
a2[nn_idx] = np.array([])

# initialize L
L = self.transformer_

# first iteration: we compute variables (including objective and gradient)
# at initialization point
G, objective, total_active, df, a1, a2 = (
self._loss_grad(X, L, dfG, impostors, 1, k, reg, target_neighbors, df,
a1, a2))
G, objective, total_active = self._loss_grad(X, L, dfG, k,
reg, target_neighbors,
label_inds)

it = 1 # we already made one iteration

Expand All @@ -246,10 +232,9 @@ def fit(self, X, y):
# we compute the objective at next point
# we copy variables that can be modified by _loss_grad, because if we
# retry we don t want to modify them several times
(G_next, objective_next, total_active_next, df_next, a1_next,
a2_next) = (
self._loss_grad(X, L_next, dfG, impostors, it, k, reg,
target_neighbors, df.copy(), list(a1), list(a2)))
(G_next, objective_next, total_active_next) = (
self._loss_grad(X, L_next, dfG, k, reg, target_neighbors,
label_inds))
assert not np.isnan(objective)
delta_obj = objective_next - objective
if delta_obj > 0:
Expand All @@ -264,8 +249,7 @@ def fit(self, X, y):
# old variables to these new ones before next iteration and we
# slightly increase the learning rate
L = L_next
G, df, objective, total_active, a1, a2 = (
G_next, df_next, objective_next, total_active_next, a1_next, a2_next)
G, objective, total_active = G_next, objective_next, total_active_next
learn_rate *= 1.01

if self.verbose:
Expand All @@ -285,54 +269,45 @@ def fit(self, X, y):
self.n_iter_ = it
return self

def _loss_grad(self, X, L, dfG, impostors, it, k, reg, target_neighbors, df,
a1, a2):
def _loss_grad(self, X, L, dfG, k, reg, target_neighbors, label_inds):
# Compute pairwise distances under current metric
Lx = L.dot(X.T).T
g0 = _inplace_paired_L2(*Lx[impostors])

# we need to find the furthest neighbor:
Ni = 1 + _inplace_paired_L2(Lx[target_neighbors], Lx[:, None, :])
furthest_neighbors = np.take_along_axis(target_neighbors,
Ni.argmax(axis=1)[:, None], 1)
impostors = self._find_impostors(furthest_neighbors.ravel(), X,
label_inds, L)

g0 = _inplace_paired_L2(*Lx[impostors])

# we reorder the target neighbors
g1, g2 = Ni[impostors]
# compute the gradient
total_active = 0
for nn_idx in reversed(xrange(k)):
df = np.zeros((X.shape[1], X.shape[1]))
for nn_idx in reversed(xrange(k)): # note: reverse not useful here
act1 = g0 < g1[:, nn_idx]
act2 = g0 < g2[:, nn_idx]
total_active += act1.sum() + act2.sum()

if it > 1:
plus1 = act1 & ~a1[nn_idx]
minus1 = a1[nn_idx] & ~act1
plus2 = act2 & ~a2[nn_idx]
minus2 = a2[nn_idx] & ~act2
else:
plus1 = act1
plus2 = act2
minus1 = np.zeros(0, dtype=int)
minus2 = np.zeros(0, dtype=int)

targets = target_neighbors[:, nn_idx]
PLUS, pweight = _count_edges(plus1, plus2, impostors, targets)
PLUS, pweight = _count_edges(act1, act2, impostors, targets)
df += _sum_outer_products(X, PLUS[:, 0], PLUS[:, 1], pweight)
MINUS, mweight = _count_edges(minus1, minus2, impostors, targets)
df -= _sum_outer_products(X, MINUS[:, 0], MINUS[:, 1], mweight)

in_imp, out_imp = impostors
df += _sum_outer_products(X, in_imp[minus1], out_imp[minus1])
df += _sum_outer_products(X, in_imp[minus2], out_imp[minus2])

df -= _sum_outer_products(X, in_imp[plus1], out_imp[plus1])
df -= _sum_outer_products(X, in_imp[plus2], out_imp[plus2])
df -= _sum_outer_products(X, in_imp[act1], out_imp[act1])
df -= _sum_outer_products(X, in_imp[act2], out_imp[act2])

a1[nn_idx] = act1
a2[nn_idx] = act2
# do the gradient update
assert not np.isnan(df).any()
G = dfG * reg + df * (1 - reg)
G = L.dot(G)
# compute the objective function
objective = total_active * (1 - reg)
objective += G.flatten().dot(L.flatten())
return 2 * G, objective, total_active, df, a1, a2
return 2 * G, objective, total_active

def _select_targets(self, X, label_inds):
target_neighbors = np.empty((X.shape[0], self.k), dtype=int)
Expand All @@ -344,8 +319,8 @@ def _select_targets(self, X, label_inds):
target_neighbors[inds] = inds[nn]
return target_neighbors

def _find_impostors(self, furthest_neighbors, X, label_inds):
Lx = self.transform(X)
def _find_impostors(self, furthest_neighbors, X, label_inds, L):
Lx = X.dot(L.T)
margin_radii = 1 + _inplace_paired_L2(Lx[furthest_neighbors], Lx)
impostors = []
for label in self.labels_[:-1]:
Expand Down
159 changes: 142 additions & 17 deletions test/metric_learn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import re
import pytest
import numpy as np
import scipy
from scipy.optimize import check_grad, approx_fprime
from six.moves import xrange
from sklearn.metrics import pairwise_distances
from sklearn.metrics import pairwise_distances, euclidean_distances
from sklearn.datasets import (load_iris, make_classification, make_regression,
make_spd_matrix)
from numpy.testing import (assert_array_almost_equal, assert_array_equal,
Expand Down Expand Up @@ -242,25 +243,15 @@ def test_loss_grad_lbfgs(self):
lmnn.transformer_ = np.eye(n_components)

target_neighbors = lmnn._select_targets(X, label_inds)
impostors = lmnn._find_impostors(target_neighbors[:, -1], X, label_inds)

# sum outer products
dfG = _sum_outer_products(X, target_neighbors.flatten(),
np.repeat(np.arange(X.shape[0]), k))
df = np.zeros_like(dfG)

# storage
a1 = [None]*k
a2 = [None]*k
for nn_idx in xrange(k):
a1[nn_idx] = np.array([])
a2[nn_idx] = np.array([])

# initialize L
def loss_grad(flat_L):
return lmnn._loss_grad(X, flat_L.reshape(-1, X.shape[1]), dfG, impostors,
1, k, reg, target_neighbors, df.copy(),
list(a1), list(a2))
return lmnn._loss_grad(X, flat_L.reshape(-1, X.shape[1]), dfG,
k, reg, target_neighbors, label_inds)

def fun(x):
return loss_grad(x)[1]
Expand Down Expand Up @@ -292,6 +283,141 @@ def test_changed_behaviour_warning(self):
assert any(msg == str(wrn.message) for wrn in raised_warning)


def test_loss_func(capsys):
"""Test the loss function (and its gradient) on a simple example,
by comparing the results with the actual implementation of metric-learn,
with a very simple (but nonperformant) implementation"""

# toy dataset to use
X, y = make_classification(n_samples=10, n_classes=2,
n_features=6,
n_redundant=0, shuffle=True,
scale=[1, 1, 20, 20, 20, 20], random_state=42)

def hinge(a):
if a > 0:
return a, 1
else:
return 0, 0

def loss_fn(L, X, y, target_neighbors, reg):
L = L.reshape(-1, X.shape[1])
Lx = np.dot(X, L.T)
loss = 0
total_active = 0
grad = np.zeros_like(L)
for i in range(X.shape[0]):
for j in target_neighbors[i]:
loss += (1 - reg) * np.sum((Lx[i] - Lx[j]) ** 2)
grad += (1 - reg) * np.outer(Lx[i] - Lx[j], X[i] - X[j])
for l in range(X.shape[0]):
if y[i] != y[l]:
hin, active = hinge(1 + np.sum((Lx[i] - Lx[j])**2) -
np.sum((Lx[i] - Lx[l])**2))
total_active += active
if active:
loss += reg * hin
grad += (reg * (np.outer(Lx[i] - Lx[j], X[i] - X[j]) -
np.outer(Lx[i] - Lx[l], X[i] - X[l])))
grad = 2 * grad
return grad, loss, total_active

# we check that the gradient we have computed in the non-performant implem
# is indeed the true gradient on a toy example:

def _select_targets(X, y, k):
target_neighbors = np.empty((X.shape[0], k), dtype=int)
for label in np.unique(y):
inds, = np.nonzero(y == label)
dd = euclidean_distances(X[inds], squared=True)
np.fill_diagonal(dd, np.inf)
nn = np.argsort(dd)[..., :k]
target_neighbors[inds] = inds[nn]
return target_neighbors

target_neighbors = _select_targets(X, y, 2)
regularization = 0.5
n_features = X.shape[1]
x0 = np.random.randn(1, n_features)

def loss(x0):
return loss_fn(x0.reshape(-1, X.shape[1]), X, y, target_neighbors,
regularization)[1]

def grad(x0):
return loss_fn(x0.reshape(-1, X.shape[1]), X, y, target_neighbors,
regularization)[0].ravel()

scipy.optimize.check_grad(loss, grad, x0.ravel())

class LMNN_with_callback(LMNN):
""" We will use a callback to get the gradient (see later)
"""

def __init__(self, callback, *args, **kwargs):
self.callback = callback
super(LMNN_with_callback, self).__init__(*args, **kwargs)

def _loss_grad(self, *args, **kwargs):
grad, objective, total_active = (
super(LMNN_with_callback, self)._loss_grad(*args, **kwargs))
self.callback.append(grad)
return grad, objective, total_active

class LMNN_nonperformant(LMNN_with_callback):

def fit(self, X, y):
self.y = y
return super(LMNN_nonperformant, self).fit(X, y)

def _loss_grad(self, X, L, dfG, k, reg, target_neighbors, label_inds):
grad, loss, total_active = loss_fn(L.ravel(), X, self.y,
target_neighbors, self.regularization)
self.callback.append(grad)
return grad, loss, total_active

mem1, mem2 = [], []
lmnn_perf = LMNN_with_callback(verbose=True, random_state=42,
init='identity', max_iter=30, callback=mem1)
lmnn_nonperf = LMNN_nonperformant(verbose=True, random_state=42,
init='identity', max_iter=30,
callback=mem2)
objectives, obj_diffs, learn_rate, total_active = (dict(), dict(), dict(),
dict())
for algo, name in zip([lmnn_perf, lmnn_nonperf], ['perf', 'nonperf']):
algo.fit(X, y)
out, _ = capsys.readouterr()
lines = re.split("\n+", out)
# we get every variable that is printed from the algorithm in verbose
num = '(-?\d+.?\d*(e[+|-]\d+)?)'
strings = [re.search("\d+ (?:{}) (?:{}) (?:(\d+)) (?:{})"
.format(num, num, num), s) for s in lines]
objectives[name] = [float(match.group(1)) for match in strings if match is
not None]
obj_diffs[name] = [float(match.group(3)) for match in strings if match is
not None]
total_active[name] = [float(match.group(5)) for match in strings if
match is not
None]
learn_rate[name] = [float(match.group(6)) for match in strings if match is
not None]
assert len(strings) >= 10 # we ensure that we actually did more than 10
# iterations
assert total_active[name][0] >= 2 # we ensure that we have some active
# constraints (that's the case we want to test)
# we remove the last element because it can be equal to the penultimate
# if the last gradient update is null
for i in range(len(mem1)):
np.testing.assert_allclose(lmnn_perf.callback[i],
lmnn_nonperf.callback[i],
err_msg='Gradient different at position '
'{}'.format(i))
np.testing.assert_allclose(objectives['perf'], objectives['nonperf'])
np.testing.assert_allclose(obj_diffs['perf'], obj_diffs['nonperf'])
np.testing.assert_allclose(total_active['perf'], total_active['nonperf'])
np.testing.assert_allclose(learn_rate['perf'], learn_rate['nonperf'])


@pytest.mark.parametrize('X, y, loss', [(np.array([[0], [1], [2], [3]]),
[1, 1, 0, 0], 3.0),
(np.array([[0], [1], [2], [3]]),
Expand All @@ -312,7 +438,7 @@ def test_toy_ex_lmnn(X, y, loss):
lmnn.transformer_ = np.eye(n_components)

target_neighbors = lmnn._select_targets(X, label_inds)
impostors = lmnn._find_impostors(target_neighbors[:, -1], X, label_inds)
impostors = lmnn._find_impostors(target_neighbors[:, -1], X, label_inds, L)

# sum outer products
dfG = _sum_outer_products(X, target_neighbors.flatten(),
Expand All @@ -327,9 +453,8 @@ def test_toy_ex_lmnn(X, y, loss):
a2[nn_idx] = np.array([])

# assert that the loss equals the one computed by hand
assert lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, impostors, 1, k,
reg, target_neighbors, df, a1, a2)[1] == loss

assert lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, k,
reg, target_neighbors, label_inds)[1] == loss

def test_convergence_simple_example(capsys):
# LMNN should converge on this simple example, which it did not with
Expand Down