diff --git a/metric_learn/lmnn.py b/metric_learn/lmnn.py index ed74ffcd..2035588f 100644 --- a/metric_learn/lmnn.py +++ b/metric_learn/lmnn.py @@ -210,7 +210,7 @@ def fit(self, X, y): init = self.init self.components_ = _initialize_components(output_dim, X, y, init, self.verbose, - self.random_state) + random_state=self.random_state) required_k = np.bincount(label_inds).min() if self.k > required_k: raise ValueError('not enough class labels for specified k' diff --git a/metric_learn/lsml.py b/metric_learn/lsml.py index b1b2fc7f..340e6bf2 100644 --- a/metric_learn/lsml.py +++ b/metric_learn/lsml.py @@ -56,9 +56,8 @@ def _fit(self, quadruplets, weights=None): else: prior = self.prior M, prior_inv = _initialize_metric_mahalanobis(quadruplets, prior, - return_inverse=True, - strict_pd=True, - matrix_name='prior') + return_inverse=True, strict_pd=True, matrix_name='prior', + random_state=self.random_state) step_sizes = np.logspace(-10, 0, 10) # Keep track of the best step size and the loss at that step. diff --git a/metric_learn/nca.py b/metric_learn/nca.py index 37fe0923..03abdc41 100644 --- a/metric_learn/nca.py +++ b/metric_learn/nca.py @@ -174,7 +174,8 @@ def fit(self, X, y): init = 'auto' else: init = self.init - A = _initialize_components(n_components, X, labels, init, self.verbose) + A = _initialize_components(n_components, X, labels, init, self.verbose, + self.random_state) # Run NCA mask = labels[:, np.newaxis] == labels[np.newaxis, :] diff --git a/metric_learn/sdml.py b/metric_learn/sdml.py index cfd37955..2d67e0b8 100644 --- a/metric_learn/sdml.py +++ b/metric_learn/sdml.py @@ -69,9 +69,8 @@ def _fit(self, pairs, y): else: prior = self.prior _, prior_inv = _initialize_metric_mahalanobis(pairs, prior, - return_inverse=True, - strict_pd=True, - matrix_name='prior') + return_inverse=True, strict_pd=True, matrix_name='prior', + random_state=self.random_state) diff = pairs[:, 0] - pairs[:, 1] loss_matrix = (diff.T * y).dot(diff) emp_cov = prior_inv + self.balance_param * loss_matrix diff --git a/test/test_mahalanobis_mixin.py b/test/test_mahalanobis_mixin.py index 737d2341..a812d185 100644 --- a/test/test_mahalanobis_mixin.py +++ b/test/test_mahalanobis_mixin.py @@ -652,3 +652,25 @@ def test_singular_array_init_or_prior(estimator, build_dataset, w0): with pytest.raises(LinAlgError) as raised_err: model.fit(input_data, labels) assert str(raised_err.value) == msg + + +@pytest.mark.integration +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) +def test_deterministic_initialization(estimator, build_dataset): + """Test that estimators that have a prior or an init are deterministic + when it is set to to random and when the random_state is fixed.""" + input_data, labels, _, X = build_dataset() + model = clone(estimator) + if hasattr(estimator, 'init'): + model.set_params(init='random') + if hasattr(estimator, 'prior'): + model.set_params(prior='random') + model1 = clone(model) + set_random_state(model1, 42) + model1 = model1.fit(input_data, labels) + model2 = clone(model) + set_random_state(model2, 42) + model2 = model2.fit(input_data, labels) + np.testing.assert_allclose(model1.get_mahalanobis_matrix(), + model2.get_mahalanobis_matrix())