diff --git a/metric_learn/scml.py b/metric_learn/scml.py index 199dfc40..db2fdf64 100644 --- a/metric_learn/scml.py +++ b/metric_learn/scml.py @@ -240,6 +240,12 @@ def _generate_bases_dist_diff(self, triplets, X): raise ValueError("n_basis should be an integer, instead it is of type %s" % type(self.n_basis)) + if n_features > n_triplets: + raise ValueError( + "Number of features (%s) is greater than the number of triplets(%s).\n" + "Consider using dimensionality reduction or using another basis " + "generation scheme." % (n_features, n_triplets)) + basis = np.zeros((n_basis, n_features)) # get all positive and negative pairs with lowest index first diff --git a/test/test_triplets_classifiers.py b/test/test_triplets_classifiers.py index f2d5c015..515a0a33 100644 --- a/test/test_triplets_classifiers.py +++ b/test/test_triplets_classifiers.py @@ -2,7 +2,12 @@ from sklearn.exceptions import NotFittedError from sklearn.model_selection import train_test_split -from test.test_utils import triplets_learners, ids_triplets_learners +from metric_learn import SCML +from test.test_utils import ( + triplets_learners, + ids_triplets_learners, + build_triplets +) from metric_learn.sklearn_shims import set_random_state from sklearn import clone import numpy as np @@ -107,3 +112,16 @@ def test_accuracy_toy_example(estimator, build_dataset): # we force the transformation to be identity so that we control what it does estimator.components_ = np.eye(X.shape[1]) assert estimator.score(triplets_test) == 0.25 + + +def test_raise_big_number_of_features(): + triplets, _, _, X = build_triplets(with_preprocessor=False) + triplets = triplets[:3, :, :] + estimator = SCML(n_basis=320) + set_random_state(estimator) + with pytest.raises(ValueError) as exc_info: + estimator.fit(triplets) + assert exc_info.value.args[0] == \ + "Number of features (4) is greater than the number of triplets(3)." \ + "\nConsider using dimensionality reduction or using another basis " \ + "generation scheme."