diff --git a/metric_learn/constraints.py b/metric_learn/constraints.py index b71c9b96..752ca6e0 100644 --- a/metric_learn/constraints.py +++ b/metric_learn/constraints.py @@ -5,7 +5,6 @@ import numpy as np import warnings from six.moves import xrange -from scipy.sparse import coo_matrix from sklearn.utils import check_random_state __all__ = ['Constraints'] @@ -20,21 +19,7 @@ class Constraints(object): def __init__(self, partial_labels): '''partial_labels : int arraylike, -1 indicating unknown label''' partial_labels = np.asanyarray(partial_labels, dtype=int) - self.num_points, = partial_labels.shape - self.known_label_idx, = np.where(partial_labels >= 0) - self.known_labels = partial_labels[self.known_label_idx] - - def adjacency_matrix(self, num_constraints, random_state=None): - random_state = check_random_state(random_state) - a, b, c, d = self.positive_negative_pairs(num_constraints, - random_state=random_state) - row = np.concatenate((a, c)) - col = np.concatenate((b, d)) - data = np.ones_like(row, dtype=int) - data[len(a):] = -1 - adj = coo_matrix((data, (row, col)), shape=(self.num_points,) * 2) - # symmetrize - return adj + adj.T + self.partial_labels = partial_labels def positive_negative_pairs(self, num_constraints, same_length=False, random_state=None): @@ -50,17 +35,19 @@ def positive_negative_pairs(self, num_constraints, same_length=False, def _pairs(self, num_constraints, same_label=True, max_iter=10, random_state=np.random): - num_labels = len(self.known_labels) + known_label_idx, = np.where(self.partial_labels >= 0) + known_labels = self.partial_labels[known_label_idx] + num_labels = len(known_labels) ab = set() it = 0 while it < max_iter and len(ab) < num_constraints: nc = num_constraints - len(ab) for aidx in random_state.randint(num_labels, size=nc): if same_label: - mask = self.known_labels[aidx] == self.known_labels + mask = known_labels[aidx] == known_labels mask[aidx] = False # avoid identity pairs else: - mask = self.known_labels[aidx] != self.known_labels + mask = known_labels[aidx] != known_labels b_choices, = np.where(mask) if len(b_choices) > 0: ab.add((aidx, random_state.choice(b_choices))) @@ -69,16 +56,18 @@ def _pairs(self, num_constraints, same_label=True, max_iter=10, warnings.warn("Only generated %d %s constraints (requested %d)" % ( len(ab), 'positive' if same_label else 'negative', num_constraints)) ab = np.array(list(ab)[:num_constraints], dtype=int) - return self.known_label_idx[ab.T] + return known_label_idx[ab.T] def chunks(self, num_chunks=100, chunk_size=2, random_state=None): """ the random state object to be passed must be a numpy random seed """ random_state = check_random_state(random_state) - chunks = -np.ones_like(self.known_label_idx, dtype=int) - uniq, lookup = np.unique(self.known_labels, return_inverse=True) - all_inds = [set(np.where(lookup == c)[0]) for c in xrange(len(uniq))] + chunks = -np.ones_like(self.partial_labels, dtype=int) + uniq, lookup = np.unique(self.partial_labels, return_inverse=True) + unknown_uniq = np.where(uniq < 0)[0] + all_inds = [set(np.where(lookup == c)[0]) for c in xrange(len(uniq)) + if c not in unknown_uniq] max_chunks = int(np.sum([len(s) // chunk_size for s in all_inds])) if max_chunks < num_chunks: raise ValueError(('Not enough possible chunks of %d elements in each' diff --git a/metric_learn/rca.py b/metric_learn/rca.py index 2a9ab1e8..204bd360 100644 --- a/metric_learn/rca.py +++ b/metric_learn/rca.py @@ -93,10 +93,12 @@ def __init__(self, n_components=None, num_dims='deprecated', def _check_dimension(self, rank, X): d = X.shape[1] + if rank < d: warnings.warn('The inner covariance matrix is not invertible, ' 'so the transformation matrix may contain Nan values. ' - 'You should reduce the dimensionality of your input,' + 'You should remove any linearly dependent features and/or ' + 'reduce the dimensionality of your input, ' 'for instance using `sklearn.decomposition.PCA` as a ' 'preprocessing step.') @@ -241,4 +243,13 @@ def fit(self, X, y, random_state='deprecated'): chunks = Constraints(y).chunks(num_chunks=self.num_chunks, chunk_size=self.chunk_size, random_state=self.random_state) + + if self.num_chunks * (self.chunk_size - 1) < X.shape[1]: + warnings.warn('Due to the parameters of RCA_Supervised, ' + 'the inner covariance matrix is not invertible, ' + 'so the transformation matrix will contain Nan values. ' + 'Increase the number or size of the chunks to correct ' + 'this problem.' + ) + return RCA.fit(self, X, chunks) diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index f713a059..5a271890 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -1100,9 +1100,11 @@ def test_rank_deficient_returns_warning(self): rca = RCA() msg = ('The inner covariance matrix is not invertible, ' 'so the transformation matrix may contain Nan values. ' - 'You should reduce the dimensionality of your input,' + 'You should remove any linearly dependent features and/or ' + 'reduce the dimensionality of your input, ' 'for instance using `sklearn.decomposition.PCA` as a ' 'preprocessing step.') + with pytest.warns(None) as raised_warnings: rca.fit(X, y) assert any(str(w.message) == msg for w in raised_warnings) @@ -1136,6 +1138,41 @@ def test_changed_behaviour_warning_random_state(self): rca_supervised.fit(X, y) assert any(msg == str(wrn.message) for wrn in raised_warning) + def test_unknown_labels(self): + n = 200 + num_chunks = 50 + X, y = make_classification(random_state=42, n_samples=2 * n, + n_features=6, n_informative=6, n_redundant=0) + y2 = np.concatenate((y[:n], -np.ones(n))) + + rca = RCA_Supervised(num_chunks=num_chunks, random_state=42) + rca.fit(X[:n], y[:n]) + + rca2 = RCA_Supervised(num_chunks=num_chunks, random_state=42) + rca2.fit(X, y2) + + assert not np.any(np.isnan(rca.components_)) + assert not np.any(np.isnan(rca2.components_)) + + np.testing.assert_array_equal(rca.components_, rca2.components_) + + def test_bad_parameters(self): + n = 200 + num_chunks = 3 + X, y = make_classification(random_state=42, n_samples=n, + n_features=6, n_informative=6, n_redundant=0) + + rca = RCA_Supervised(num_chunks=num_chunks, random_state=42) + msg = ('Due to the parameters of RCA_Supervised, ' + 'the inner covariance matrix is not invertible, ' + 'so the transformation matrix will contain Nan values. ' + 'Increase the number or size of the chunks to correct ' + 'this problem.' + ) + with pytest.warns(None) as raised_warning: + rca.fit(X, y) + assert any(str(w.message) == msg for w in raised_warning) + @pytest.mark.parametrize('num_dims', [None, 2]) def test_deprecation_num_dims_rca(num_dims): diff --git a/test/test_constraints.py b/test/test_constraints.py index a135985e..243028f6 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -1,4 +1,3 @@ -import unittest import pytest import numpy as np from sklearn.utils import shuffle @@ -34,7 +33,8 @@ def test_exact_num_points_for_chunks(num_chunks, chunk_size): chunks = constraints.chunks(num_chunks=num_chunks, chunk_size=chunk_size, random_state=SEED) - chunk_no, size_each_chunk = np.unique(chunks, return_counts=True) + chunk_no, size_each_chunk = np.unique(chunks[chunks >= 0], + return_counts=True) np.testing.assert_array_equal(size_each_chunk, chunk_size) assert chunk_no.shape[0] == num_chunks @@ -59,5 +59,13 @@ def test_chunk_case_one_miss_point(num_chunks, chunk_size): assert str(e.value) == expected_message -if __name__ == '__main__': - unittest.main() +@pytest.mark.parametrize("num_chunks, chunk_size", [(5, 10), (10, 50)]) +def test_unknown_labels_not_in_chunks(num_chunks, chunk_size): + """Checks that unknown labels are not assigned to any chunk.""" + labels = gen_labels_for_chunks(num_chunks, chunk_size) + + constraints = Constraints(labels) + chunks = constraints.chunks(num_chunks=num_chunks, chunk_size=chunk_size, + random_state=SEED) + + assert np.all(chunks[labels < 0] < 0)