Skip to content
35 changes: 12 additions & 23 deletions metric_learn/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import numpy as np
import warnings
from six.moves import xrange
from scipy.sparse import coo_matrix
from sklearn.utils import check_random_state

__all__ = ['Constraints']
Expand All @@ -20,21 +19,7 @@ class Constraints(object):
def __init__(self, partial_labels):
'''partial_labels : int arraylike, -1 indicating unknown label'''
partial_labels = np.asanyarray(partial_labels, dtype=int)
self.num_points, = partial_labels.shape
self.known_label_idx, = np.where(partial_labels >= 0)
self.known_labels = partial_labels[self.known_label_idx]

def adjacency_matrix(self, num_constraints, random_state=None):
random_state = check_random_state(random_state)
a, b, c, d = self.positive_negative_pairs(num_constraints,
random_state=random_state)
row = np.concatenate((a, c))
col = np.concatenate((b, d))
data = np.ones_like(row, dtype=int)
data[len(a):] = -1
adj = coo_matrix((data, (row, col)), shape=(self.num_points,) * 2)
# symmetrize
return adj + adj.T
self.partial_labels = partial_labels

def positive_negative_pairs(self, num_constraints, same_length=False,
random_state=None):
Expand All @@ -50,17 +35,19 @@ def positive_negative_pairs(self, num_constraints, same_length=False,

def _pairs(self, num_constraints, same_label=True, max_iter=10,
random_state=np.random):
num_labels = len(self.known_labels)
known_label_idx, = np.where(self.partial_labels >= 0)
known_labels = self.partial_labels[known_label_idx]
num_labels = len(known_labels)
ab = set()
it = 0
while it < max_iter and len(ab) < num_constraints:
nc = num_constraints - len(ab)
for aidx in random_state.randint(num_labels, size=nc):
if same_label:
mask = self.known_labels[aidx] == self.known_labels
mask = known_labels[aidx] == known_labels
mask[aidx] = False # avoid identity pairs
else:
mask = self.known_labels[aidx] != self.known_labels
mask = known_labels[aidx] != known_labels
b_choices, = np.where(mask)
if len(b_choices) > 0:
ab.add((aidx, random_state.choice(b_choices)))
Expand All @@ -69,16 +56,18 @@ def _pairs(self, num_constraints, same_label=True, max_iter=10,
warnings.warn("Only generated %d %s constraints (requested %d)" % (
len(ab), 'positive' if same_label else 'negative', num_constraints))
ab = np.array(list(ab)[:num_constraints], dtype=int)
return self.known_label_idx[ab.T]
return known_label_idx[ab.T]

def chunks(self, num_chunks=100, chunk_size=2, random_state=None):
"""
the random state object to be passed must be a numpy random seed
"""
random_state = check_random_state(random_state)
chunks = -np.ones_like(self.known_label_idx, dtype=int)
uniq, lookup = np.unique(self.known_labels, return_inverse=True)
all_inds = [set(np.where(lookup == c)[0]) for c in xrange(len(uniq))]
chunks = -np.ones_like(self.partial_labels, dtype=int)
uniq, lookup = np.unique(self.partial_labels, return_inverse=True)
unknown_uniq = np.where(uniq < 0)[0]
all_inds = [set(np.where(lookup == c)[0]) for c in xrange(len(uniq))
if c not in unknown_uniq]
max_chunks = int(np.sum([len(s) // chunk_size for s in all_inds]))
if max_chunks < num_chunks:
raise ValueError(('Not enough possible chunks of %d elements in each'
Expand Down
13 changes: 12 additions & 1 deletion metric_learn/rca.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,12 @@ def __init__(self, n_components=None, num_dims='deprecated',

def _check_dimension(self, rank, X):
d = X.shape[1]

if rank < d:
warnings.warn('The inner covariance matrix is not invertible, '
'so the transformation matrix may contain Nan values. '
'You should reduce the dimensionality of your input,'
'You should remove any linearly dependent features and/or '
'reduce the dimensionality of your input, '
'for instance using `sklearn.decomposition.PCA` as a '
'preprocessing step.')

Expand Down Expand Up @@ -241,4 +243,13 @@ def fit(self, X, y, random_state='deprecated'):
chunks = Constraints(y).chunks(num_chunks=self.num_chunks,
chunk_size=self.chunk_size,
random_state=self.random_state)

if self.num_chunks * (self.chunk_size - 1) < X.shape[1]:
warnings.warn('Due to the parameters of RCA_Supervised, '
'the inner covariance matrix is not invertible, '
'so the transformation matrix will contain Nan values. '
'Increase the number or size of the chunks to correct '
'this problem.'
)

return RCA.fit(self, X, chunks)
39 changes: 38 additions & 1 deletion test/metric_learn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,9 +1100,11 @@ def test_rank_deficient_returns_warning(self):
rca = RCA()
msg = ('The inner covariance matrix is not invertible, '
'so the transformation matrix may contain Nan values. '
'You should reduce the dimensionality of your input,'
'You should remove any linearly dependent features and/or '
'reduce the dimensionality of your input, '
'for instance using `sklearn.decomposition.PCA` as a '
'preprocessing step.')

with pytest.warns(None) as raised_warnings:
rca.fit(X, y)
assert any(str(w.message) == msg for w in raised_warnings)
Expand Down Expand Up @@ -1136,6 +1138,41 @@ def test_changed_behaviour_warning_random_state(self):
rca_supervised.fit(X, y)
assert any(msg == str(wrn.message) for wrn in raised_warning)

def test_unknown_labels(self):
n = 200
num_chunks = 50
X, y = make_classification(random_state=42, n_samples=2 * n,
n_features=6, n_informative=6, n_redundant=0)
y2 = np.concatenate((y[:n], -np.ones(n)))

rca = RCA_Supervised(num_chunks=num_chunks, random_state=42)
rca.fit(X[:n], y[:n])

rca2 = RCA_Supervised(num_chunks=num_chunks, random_state=42)
rca2.fit(X, y2)

assert not np.any(np.isnan(rca.components_))
assert not np.any(np.isnan(rca2.components_))

np.testing.assert_array_equal(rca.components_, rca2.components_)

def test_bad_parameters(self):
n = 200
num_chunks = 3
X, y = make_classification(random_state=42, n_samples=n,
n_features=6, n_informative=6, n_redundant=0)

rca = RCA_Supervised(num_chunks=num_chunks, random_state=42)
msg = ('Due to the parameters of RCA_Supervised, '
'the inner covariance matrix is not invertible, '
'so the transformation matrix will contain Nan values. '
'Increase the number or size of the chunks to correct '
'this problem.'
)
with pytest.warns(None) as raised_warning:
rca.fit(X, y)
assert any(str(w.message) == msg for w in raised_warning)


@pytest.mark.parametrize('num_dims', [None, 2])
def test_deprecation_num_dims_rca(num_dims):
Expand Down
16 changes: 12 additions & 4 deletions test/test_constraints.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import unittest
import pytest
import numpy as np
from sklearn.utils import shuffle
Expand Down Expand Up @@ -34,7 +33,8 @@ def test_exact_num_points_for_chunks(num_chunks, chunk_size):
chunks = constraints.chunks(num_chunks=num_chunks, chunk_size=chunk_size,
random_state=SEED)

chunk_no, size_each_chunk = np.unique(chunks, return_counts=True)
chunk_no, size_each_chunk = np.unique(chunks[chunks >= 0],
return_counts=True)

np.testing.assert_array_equal(size_each_chunk, chunk_size)
assert chunk_no.shape[0] == num_chunks
Expand All @@ -59,5 +59,13 @@ def test_chunk_case_one_miss_point(num_chunks, chunk_size):
assert str(e.value) == expected_message


if __name__ == '__main__':
unittest.main()
@pytest.mark.parametrize("num_chunks, chunk_size", [(5, 10), (10, 50)])
def test_unknown_labels_not_in_chunks(num_chunks, chunk_size):
"""Checks that unknown labels are not assigned to any chunk."""
labels = gen_labels_for_chunks(num_chunks, chunk_size)

constraints = Constraints(labels)
chunks = constraints.chunks(num_chunks=num_chunks, chunk_size=chunk_size,
random_state=SEED)

assert np.all(chunks[labels < 0] < 0)