From 94ec950486e17e69ddb1d50a85dc001377d2d054 Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Wed, 23 Jan 2019 14:51:22 +0100 Subject: [PATCH 1/5] MAINT: remove variables not needed to store --- metric_learn/itml.py | 45 ++++++++++++++++++++++++++++------ metric_learn/lmnn.py | 23 ++++++++++-------- metric_learn/lsml.py | 58 ++++++++++++++++++++++++-------------------- metric_learn/mlkr.py | 3 +++ metric_learn/mmc.py | 6 +++++ metric_learn/nca.py | 3 +++ metric_learn/sdml.py | 10 ++++---- 7 files changed, 99 insertions(+), 49 deletions(-) diff --git a/metric_learn/itml.py b/metric_learn/itml.py index 158ec4d3..6e643882 100644 --- a/metric_learn/itml.py +++ b/metric_learn/itml.py @@ -73,9 +73,9 @@ def _fit(self, pairs, y, bounds=None): self.bounds_[self.bounds_==0] = 1e-9 # init metric if self.A0 is None: - self.A_ = np.identity(pairs.shape[2]) + A = np.identity(pairs.shape[2]) else: - self.A_ = check_array(self.A0) + A = check_array(self.A0, copy=True) gamma = self.gamma pos_pairs, neg_pairs = pairs[y == 1], pairs[y == -1] num_pos = len(pos_pairs) @@ -87,7 +87,6 @@ def _fit(self, pairs, y, bounds=None): neg_bhat = np.zeros(num_neg) + self.bounds_[1] pos_vv = pos_pairs[:, 0, :] - pos_pairs[:, 1, :] neg_vv = neg_pairs[:, 0, :] - neg_pairs[:, 1, :] - A = self.A_ for it in xrange(self.max_iter): # update positives @@ -125,7 +124,7 @@ def _fit(self, pairs, y, bounds=None): print('itml converged at iter: %d, conv = %f' % (it, conv)) self.n_iter_ = it - self.transformer_ = transformer_from_metric(self.A_) + self.transformer_ = transformer_from_metric(A) return self @@ -134,6 +133,17 @@ class ITML(_BaseITML, _PairsClassifierMixin): Attributes ---------- + bounds_ : `list` of two numbers + Bounds on similarity, aside slack variables, s.t. ``d(a, b) < pos`` for + all given pairs of similar points ``a`` and ``b``, and + ``d(c, d) > neg`` for all given pairs of dissimilar points ``c`` and + ``d``, with ``bounds=[ pos, neg]``, and ``d`` the learned distance. If + not provided at initialization, these are the ones derived at train + time. + + n_iter_ : `int` + The number of iterations the solver has ran. + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) @@ -151,8 +161,12 @@ def fit(self, pairs, y, bounds=None): preprocessor. y: array-like, of shape (n_constraints,) Labels of constraints. Should be -1 for dissimilar pair, 1 for similar. - bounds : list (pos,neg) pairs, optional - bounds on similarity, s.t. d(X[a],X[b]) < pos and d(X[c],X[d]) > neg + bounds : `list` of two numbers + Bounds on similarity, aside slack variables, s.t. ``d(a, b) < pos`` for + all given pairs of similar points ``a`` and ``b``, and + ``d(c, d) > neg`` for all given pairs of dissimilar points ``c`` and + ``d``, with ``bounds=[pos, neg]``, and ``d`` the learned distance. + If not provided at initialization, these will be derived at train time. Returns ------- @@ -167,6 +181,17 @@ class ITML_Supervised(_BaseITML, TransformerMixin): Attributes ---------- + bounds_ : `list` of two numbers + Bounds on similarity, aside slack variables, s.t. ``d(a, b) < pos`` for + all given pairs of similar points ``a`` and ``b``, and + ``d(c, d) > neg`` for all given pairs of dissimilar points ``c`` and + ``d``, with ``bounds=[pos, neg]``, and ``d`` the learned distance. If + not provided at initialization, these are the ones derived at train + time. + + n_iter_ : `int` + The number of iterations the solver has ran. + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) @@ -193,8 +218,12 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, be removed in 0.6.0. num_constraints: int, optional number of constraints to generate - bounds : list (pos,neg) pairs, optional - bounds on similarity, s.t. d(X[a],X[b]) < pos and d(X[c],X[d]) > neg + bounds : `list` of two numbers + Bounds on similarity, aside slack variables, s.t. ``d(a, b) < pos`` for + all given pairs of similar points ``a`` and ``b``, and ``d(c, d) > neg`` + for all given pairs of dissimilar points ``c`` and ``d``, with + ``bounds=[pos, neg]``, and ``d`` the learned distance. If not provided at + initialization, these will be derived at train time. A0 : (d x d) matrix, optional initial regularization matrix, defaults to identity verbose : bool, optional diff --git a/metric_learn/lmnn.py b/metric_learn/lmnn.py index 1d7ddf2a..c218ac48 100644 --- a/metric_learn/lmnn.py +++ b/metric_learn/lmnn.py @@ -60,20 +60,20 @@ def fit(self, X, y): X, y = self._prepare_inputs(X, y, dtype=float, ensure_min_samples=2) num_pts, num_dims = X.shape - unique_labels, self.label_inds_ = np.unique(y, return_inverse=True) - if len(self.label_inds_) != num_pts: + unique_labels, label_inds = np.unique(y, return_inverse=True) + if len(label_inds) != num_pts: raise ValueError('Must have one label per point.') self.labels_ = np.arange(len(unique_labels)) if self.use_pca: warnings.warn('use_pca does nothing for the python_LMNN implementation') self.transformer_ = np.eye(num_dims) - required_k = np.bincount(self.label_inds_).min() + required_k = np.bincount(label_inds).min() if self.k > required_k: raise ValueError('not enough class labels for specified k' ' (smallest class has %d)' % required_k) - target_neighbors = self._select_targets(X) - impostors = self._find_impostors(target_neighbors[:, -1], X) + target_neighbors = self._select_targets(X, label_inds) + impostors = self._find_impostors(target_neighbors[:, -1], X, label_inds) if len(impostors) == 0: # L has already been initialized to an identity matrix return @@ -196,23 +196,23 @@ def _loss_grad(self, X, L, dfG, impostors, it, k, reg, target_neighbors, df, objective += G.flatten().dot(L.T.dot(L).flatten()) return G, objective, total_active, df, a1, a2 - def _select_targets(self, X): + def _select_targets(self, X, label_inds): target_neighbors = np.empty((X.shape[0], self.k), dtype=int) for label in self.labels_: - inds, = np.nonzero(self.label_inds_ == label) + inds, = np.nonzero(label_inds == label) dd = euclidean_distances(X[inds], squared=True) np.fill_diagonal(dd, np.inf) nn = np.argsort(dd)[..., :self.k] target_neighbors[inds] = inds[nn] return target_neighbors - def _find_impostors(self, furthest_neighbors, X): + def _find_impostors(self, furthest_neighbors, X, label_inds): Lx = self.transform(X) margin_radii = 1 + _inplace_paired_L2(Lx[furthest_neighbors], Lx) impostors = [] for label in self.labels_[:-1]: - in_inds, = np.nonzero(self.label_inds_ == label) - out_inds, = np.nonzero(self.label_inds_ > label) + in_inds, = np.nonzero(label_inds == label) + out_inds, = np.nonzero(label_inds > label) dist = euclidean_distances(Lx[out_inds], Lx[in_inds], squared=True) i1,j1 = np.nonzero(dist < margin_radii[out_inds][:,None]) i2,j2 = np.nonzero(dist < margin_radii[in_inds]) @@ -265,6 +265,9 @@ class LMNN(_base_LMNN): Attributes ---------- + n_iter_ : `int` + The number of iterations the solver has ran. + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) The learned linear transformation ``L``. """ diff --git a/metric_learn/lsml.py b/metric_learn/lsml.py index 50fcfa3e..1699f359 100644 --- a/metric_learn/lsml.py +++ b/metric_learn/lsml.py @@ -50,32 +50,32 @@ def _fit(self, quadruplets, y=None, weights=None): type_of_inputs='tuples') # check to make sure that no two constrained vectors are identical - self.vab_ = quadruplets[:, 0, :] - quadruplets[:, 1, :] - self.vcd_ = quadruplets[:, 2, :] - quadruplets[:, 3, :] - if self.vab_.shape != self.vcd_.shape: + vab = quadruplets[:, 0, :] - quadruplets[:, 1, :] + vcd = quadruplets[:, 2, :] - quadruplets[:, 3, :] + if vab.shape != vcd.shape: raise ValueError('Constraints must have same length') if weights is None: - self.w_ = np.ones(self.vab_.shape[0]) + self.w_ = np.ones(vab.shape[0]) else: self.w_ = weights self.w_ /= self.w_.sum() # weights must sum to 1 if self.prior is None: X = np.vstack({tuple(row) for row in quadruplets.reshape(-1, quadruplets.shape[2])}) - self.prior_inv_ = np.atleast_2d(np.cov(X, rowvar=False)) - self.M_ = np.linalg.inv(self.prior_inv_) + prior_inv = np.atleast_2d(np.cov(X, rowvar=False)) + M = np.linalg.inv(prior_inv) else: - self.M_ = self.prior - self.prior_inv_ = np.linalg.inv(self.prior) + M = self.prior + prior_inv = np.linalg.inv(self.prior) step_sizes = np.logspace(-10, 0, 10) # Keep track of the best step size and the loss at that step. l_best = 0 - s_best = self._total_loss(self.M_) + s_best = self._total_loss(M, vab, vcd, prior_inv) if self.verbose: print('initial loss', s_best) for it in xrange(1, self.max_iter+1): - grad = self._gradient(self.M_) + grad = self._gradient(M, vab, vcd, prior_inv) grad_norm = scipy.linalg.norm(grad) if grad_norm < self.tol: break @@ -84,10 +84,10 @@ def _fit(self, quadruplets, y=None, weights=None): M_best = None for step_size in step_sizes: step_size /= grad_norm - new_metric = self.M_ - step_size * grad + new_metric = M - step_size * grad w, v = scipy.linalg.eigh(new_metric) new_metric = v.dot((np.maximum(w, 1e-8) * v).T) - cur_s = self._total_loss(new_metric) + cur_s = self._total_loss(new_metric, vab, vcd, prior_inv) if cur_s < s_best: l_best = step_size s_best = cur_s @@ -96,36 +96,36 @@ def _fit(self, quadruplets, y=None, weights=None): print('iter', it, 'cost', s_best, 'best step', l_best * grad_norm) if M_best is None: break - self.M_ = M_best + M = M_best else: if self.verbose: print("Didn't converge after", it, "iterations. Final loss:", s_best) self.n_iter_ = it - self.transformer_ = transformer_from_metric(self.M_) + self.transformer_ = transformer_from_metric(M) return self - def _comparison_loss(self, metric): - dab = np.sum(self.vab_.dot(metric) * self.vab_, axis=1) - dcd = np.sum(self.vcd_.dot(metric) * self.vcd_, axis=1) + def _comparison_loss(self, metric, vab, vcd): + dab = np.sum(vab.dot(metric) * vab, axis=1) + dcd = np.sum(vcd.dot(metric) * vcd, axis=1) violations = dab > dcd return self.w_[violations].dot((np.sqrt(dab[violations]) - np.sqrt(dcd[violations]))**2) - def _total_loss(self, metric): + def _total_loss(self, metric, vab, vcd, prior_inv): # Regularization loss sign, logdet = np.linalg.slogdet(metric) - reg_loss = np.sum(metric * self.prior_inv_) - sign * logdet - return self._comparison_loss(metric) + reg_loss + reg_loss = np.sum(metric * prior_inv) - sign * logdet + return self._comparison_loss(metric, vab, vcd) + reg_loss - def _gradient(self, metric): - dMetric = self.prior_inv_ - np.linalg.inv(metric) - dabs = np.sum(self.vab_.dot(metric) * self.vab_, axis=1) - dcds = np.sum(self.vcd_.dot(metric) * self.vcd_, axis=1) + def _gradient(self, metric, vab, vcd, prior_inv): + dMetric = prior_inv - np.linalg.inv(metric) + dabs = np.sum(vab.dot(metric) * vab, axis=1) + dcds = np.sum(vcd.dot(metric) * vcd, axis=1) violations = dabs > dcds # TODO: vectorize - for vab, dab, vcd, dcd in zip(self.vab_[violations], dabs[violations], - self.vcd_[violations], dcds[violations]): + for vab, dab, vcd, dcd in zip(vab[violations], dabs[violations], + vcd[violations], dcds[violations]): dMetric += ((1-np.sqrt(dcd/dab))*np.outer(vab, vab) + (1-np.sqrt(dab/dcd))*np.outer(vcd, vcd)) return dMetric @@ -136,6 +136,9 @@ class LSML(_BaseLSML, _QuadrupletsClassifierMixin): Attributes ---------- + n_iter_ : `int` + The number of iterations the solver has ran. + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) @@ -169,6 +172,9 @@ class LSML_Supervised(_BaseLSML, TransformerMixin): Attributes ---------- + n_iter_ : `int` + The number of iterations the solver has ran. + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) diff --git a/metric_learn/mlkr.py b/metric_learn/mlkr.py index 6b79638e..84d3b76a 100644 --- a/metric_learn/mlkr.py +++ b/metric_learn/mlkr.py @@ -30,6 +30,9 @@ class MLKR(MahalanobisMixin, TransformerMixin): Attributes ---------- + n_iter_ : `int` + The number of iterations the solver has ran. + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) The learned linear transformation ``L``. """ diff --git a/metric_learn/mmc.py b/metric_learn/mmc.py index b806a97e..1c874a89 100644 --- a/metric_learn/mmc.py +++ b/metric_learn/mmc.py @@ -353,6 +353,9 @@ class MMC(_BaseMMC, _PairsClassifierMixin): Attributes ---------- + n_iter_ : `int` + The number of iterations the solver has ran. + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) @@ -384,6 +387,9 @@ class MMC_Supervised(_BaseMMC, TransformerMixin): Attributes ---------- + n_iter_ : `int` + The number of iterations the solver has ran. + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) diff --git a/metric_learn/nca.py b/metric_learn/nca.py index 81045287..08888d18 100644 --- a/metric_learn/nca.py +++ b/metric_learn/nca.py @@ -24,6 +24,9 @@ class NCA(MahalanobisMixin, TransformerMixin): Attributes ---------- + n_iter_ : `int` + The number of iterations the solver has ran. + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) The learned linear transformation ``L``. """ diff --git a/metric_learn/sdml.py b/metric_learn/sdml.py index 40fd5727..82f715ac 100644 --- a/metric_learn/sdml.py +++ b/metric_learn/sdml.py @@ -58,18 +58,18 @@ def _fit(self, pairs, y): # set up prior M if self.use_cov: X = np.vstack({tuple(row) for row in pairs.reshape(-1, pairs.shape[2])}) - self.M_ = pinvh(np.cov(X, rowvar = False)) + M = pinvh(np.cov(X, rowvar = False)) else: - self.M_ = np.identity(pairs.shape[2]) + M = np.identity(pairs.shape[2]) diff = pairs[:, 0] - pairs[:, 1] loss_matrix = (diff.T * y).dot(diff) - P = self.M_ + self.balance_param * loss_matrix + P = M + self.balance_param * loss_matrix emp_cov = pinvh(P) # hack: ensure positive semidefinite emp_cov = emp_cov.T.dot(emp_cov) - _, self.M_ = graph_lasso(emp_cov, self.sparsity_param, verbose=self.verbose) + _, M = graph_lasso(emp_cov, self.sparsity_param, verbose=self.verbose) - self.transformer_ = transformer_from_metric(self.M_) + self.transformer_ = transformer_from_metric(M) return self From 2d276c89d879c081dde8fe30262b688fa78fa598 Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Thu, 24 Jan 2019 10:55:51 +0100 Subject: [PATCH 2/5] Address review https://github.com/metric-learn/metric-learn/pull/159#pullrequestreview-195570695 --- metric_learn/itml.py | 12 ++++++------ metric_learn/lmnn.py | 2 +- metric_learn/lsml.py | 4 ++-- metric_learn/mlkr.py | 2 +- metric_learn/mmc.py | 4 ++-- metric_learn/nca.py | 2 +- 6 files changed, 13 insertions(+), 13 deletions(-) diff --git a/metric_learn/itml.py b/metric_learn/itml.py index 6e643882..b4ea3079 100644 --- a/metric_learn/itml.py +++ b/metric_learn/itml.py @@ -133,16 +133,16 @@ class ITML(_BaseITML, _PairsClassifierMixin): Attributes ---------- - bounds_ : `list` of two numbers + bounds_ : array-like, shape=(2,) Bounds on similarity, aside slack variables, s.t. ``d(a, b) < pos`` for all given pairs of similar points ``a`` and ``b``, and ``d(c, d) > neg`` for all given pairs of dissimilar points ``c`` and ``d``, with ``bounds=[ pos, neg]``, and ``d`` the learned distance. If - not provided at initialization, these are the ones derived at train + not provided at initialization, these are derived at train time. n_iter_ : `int` - The number of iterations the solver has ran. + The number of iterations the solver has run. transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis @@ -181,16 +181,16 @@ class ITML_Supervised(_BaseITML, TransformerMixin): Attributes ---------- - bounds_ : `list` of two numbers + bounds_ : array-like, shape=(2,) Bounds on similarity, aside slack variables, s.t. ``d(a, b) < pos`` for all given pairs of similar points ``a`` and ``b``, and ``d(c, d) > neg`` for all given pairs of dissimilar points ``c`` and ``d``, with ``bounds=[pos, neg]``, and ``d`` the learned distance. If - not provided at initialization, these are the ones derived at train + not provided at initialization, these are derived at train time. n_iter_ : `int` - The number of iterations the solver has ran. + The number of iterations the solver has run. transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis diff --git a/metric_learn/lmnn.py b/metric_learn/lmnn.py index c218ac48..f9cd0e91 100644 --- a/metric_learn/lmnn.py +++ b/metric_learn/lmnn.py @@ -266,7 +266,7 @@ class LMNN(_base_LMNN): Attributes ---------- n_iter_ : `int` - The number of iterations the solver has ran. + The number of iterations the solver has run. transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) The learned linear transformation ``L``. diff --git a/metric_learn/lsml.py b/metric_learn/lsml.py index 1699f359..312990ab 100644 --- a/metric_learn/lsml.py +++ b/metric_learn/lsml.py @@ -137,7 +137,7 @@ class LSML(_BaseLSML, _QuadrupletsClassifierMixin): Attributes ---------- n_iter_ : `int` - The number of iterations the solver has ran. + The number of iterations the solver has run. transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis @@ -173,7 +173,7 @@ class LSML_Supervised(_BaseLSML, TransformerMixin): Attributes ---------- n_iter_ : `int` - The number of iterations the solver has ran. + The number of iterations the solver has run. transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis diff --git a/metric_learn/mlkr.py b/metric_learn/mlkr.py index 84d3b76a..74a21a82 100644 --- a/metric_learn/mlkr.py +++ b/metric_learn/mlkr.py @@ -31,7 +31,7 @@ class MLKR(MahalanobisMixin, TransformerMixin): Attributes ---------- n_iter_ : `int` - The number of iterations the solver has ran. + The number of iterations the solver has run. transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) The learned linear transformation ``L``. diff --git a/metric_learn/mmc.py b/metric_learn/mmc.py index 1c874a89..f9d3690b 100644 --- a/metric_learn/mmc.py +++ b/metric_learn/mmc.py @@ -354,7 +354,7 @@ class MMC(_BaseMMC, _PairsClassifierMixin): Attributes ---------- n_iter_ : `int` - The number of iterations the solver has ran. + The number of iterations the solver has run. transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis @@ -388,7 +388,7 @@ class MMC_Supervised(_BaseMMC, TransformerMixin): Attributes ---------- n_iter_ : `int` - The number of iterations the solver has ran. + The number of iterations the solver has run. transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis diff --git a/metric_learn/nca.py b/metric_learn/nca.py index 08888d18..5abe52e3 100644 --- a/metric_learn/nca.py +++ b/metric_learn/nca.py @@ -25,7 +25,7 @@ class NCA(MahalanobisMixin, TransformerMixin): Attributes ---------- n_iter_ : `int` - The number of iterations the solver has ran. + The number of iterations the solver has run. transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) The learned linear transformation ``L``. From 41227a60d43152bc85bc4cb505334eebac78a3bf Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Tue, 29 Jan 2019 11:46:18 +0100 Subject: [PATCH 3/5] DOC: add more precise docstring --- metric_learn/itml.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/metric_learn/itml.py b/metric_learn/itml.py index b4ea3079..fb7615a9 100644 --- a/metric_learn/itml.py +++ b/metric_learn/itml.py @@ -138,8 +138,9 @@ class ITML(_BaseITML, _PairsClassifierMixin): all given pairs of similar points ``a`` and ``b``, and ``d(c, d) > neg`` for all given pairs of dissimilar points ``c`` and ``d``, with ``bounds=[ pos, neg]``, and ``d`` the learned distance. If - not provided at initialization, these are derived at train - time. + not provided at initialization, bounds_[0] and bounds_[1] are set at + train time to the 5th and 95th percentile of the pairwise distances among + all points present in the input `pairs`. n_iter_ : `int` The number of iterations the solver has run. @@ -163,10 +164,12 @@ def fit(self, pairs, y, bounds=None): Labels of constraints. Should be -1 for dissimilar pair, 1 for similar. bounds : `list` of two numbers Bounds on similarity, aside slack variables, s.t. ``d(a, b) < pos`` for - all given pairs of similar points ``a`` and ``b``, and - ``d(c, d) > neg`` for all given pairs of dissimilar points ``c`` and - ``d``, with ``bounds=[pos, neg]``, and ``d`` the learned distance. - If not provided at initialization, these will be derived at train time. + all given pairs of similar points ``a`` and ``b``, and ``d(c, d) > + neg`` for all given pairs of dissimilar points ``c`` and ``d``, with + ``bounds=[pos, neg]``, and ``d`` the learned distance. If not provided + at initialization, bounds_[0] and bounds_[1] will be set to the 5th and + 95th percentile of the pairwise distances among all points present in + the input `pairs`. Returns ------- @@ -185,9 +188,10 @@ class ITML_Supervised(_BaseITML, TransformerMixin): Bounds on similarity, aside slack variables, s.t. ``d(a, b) < pos`` for all given pairs of similar points ``a`` and ``b``, and ``d(c, d) > neg`` for all given pairs of dissimilar points ``c`` and - ``d``, with ``bounds=[pos, neg]``, and ``d`` the learned distance. If - not provided at initialization, these are derived at train - time. + ``d``, with ``bounds=[pos, neg]``, and ``d`` the learned distance. If not + provided at initialization, bounds_[0] and bounds_[1] are set at train + time to the 5th and 95th percentile of the pairwise distances among all + points in the training data `X`. n_iter_ : `int` The number of iterations the solver has run. @@ -223,7 +227,9 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, all given pairs of similar points ``a`` and ``b``, and ``d(c, d) > neg`` for all given pairs of dissimilar points ``c`` and ``d``, with ``bounds=[pos, neg]``, and ``d`` the learned distance. If not provided at - initialization, these will be derived at train time. + initialization, bounds_[0] and bounds_[1] will be set to the 5th and 95th + percentile of the pairwise distances among all points in the training + data `X`. A0 : (d x d) matrix, optional initial regularization matrix, defaults to identity verbose : bool, optional From b93b4cd44ab402b7a4ba620fe4524e9fdece2a6b Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Tue, 29 Jan 2019 12:06:16 +0100 Subject: [PATCH 4/5] API: put parameter in fit, deprecate it in init, and also change previous deprecation tests names --- metric_learn/itml.py | 36 ++++++++++++++++++++++++------------ test/metric_learn_test.py | 37 +++++++++++++++++++++++++++---------- test/test_base_metric.py | 6 +++--- 3 files changed, 54 insertions(+), 25 deletions(-) diff --git a/metric_learn/itml.py b/metric_learn/itml.py index fb7615a9..a9019e4e 100644 --- a/metric_learn/itml.py +++ b/metric_learn/itml.py @@ -202,8 +202,8 @@ class ITML_Supervised(_BaseITML, TransformerMixin): """ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, - num_labeled='deprecated', num_constraints=None, bounds=None, - A0=None, verbose=False, preprocessor=None): + num_labeled='deprecated', num_constraints=None, + bounds='deprecated', A0=None, verbose=False, preprocessor=None): """Initialize the supervised version of `ITML`. `ITML_Supervised` creates pairs of similar sample by taking same class @@ -222,14 +222,11 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, be removed in 0.6.0. num_constraints: int, optional number of constraints to generate - bounds : `list` of two numbers - Bounds on similarity, aside slack variables, s.t. ``d(a, b) < pos`` for - all given pairs of similar points ``a`` and ``b``, and ``d(c, d) > neg`` - for all given pairs of dissimilar points ``c`` and ``d``, with - ``bounds=[pos, neg]``, and ``d`` the learned distance. If not provided at - initialization, bounds_[0] and bounds_[1] will be set to the 5th and 95th - percentile of the pairwise distances among all points in the training - data `X`. + bounds : Not used + .. deprecated:: 0.5.0 + `bounds` was deprecated in version 0.5.0 and will + be removed in 0.6.0. Set `bounds` at fit time instead : + `itml_supervised.fit(X, y, bounds=...)` A0 : (d x d) matrix, optional initial regularization matrix, defaults to identity verbose : bool, optional @@ -245,7 +242,7 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, self.num_constraints = num_constraints self.bounds = bounds - def fit(self, X, y, random_state=np.random): + def fit(self, X, y, random_state=np.random, bounds=None): """Create constraints from labels and learn the ITML model. @@ -259,11 +256,26 @@ def fit(self, X, y, random_state=np.random): random_state : numpy.random.RandomState, optional If provided, controls random number generation. + + bounds : `list` of two numbers + Bounds on similarity, aside slack variables, s.t. ``d(a, b) < pos`` for + all given pairs of similar points ``a`` and ``b``, and ``d(c, d) > neg`` + for all given pairs of dissimilar points ``c`` and ``d``, with + ``bounds=[pos, neg]``, and ``d`` the learned distance. If not provided at + initialization, bounds_[0] and bounds_[1] will be set to the 5th and 95th + percentile of the pairwise distances among all points in the training + data `X`. """ + # TODO: remove these in v0.6.0 if self.num_labeled != 'deprecated': warnings.warn('"num_labeled" parameter is not used.' ' It has been deprecated in version 0.5.0 and will be' 'removed in 0.6.0', DeprecationWarning) + if self.bounds != 'deprecated': + warnings.warn('"bounds" parameter from initialization is not used.' + ' It has been deprecated in version 0.5.0 and will be' + 'removed in 0.6.0. Use the "bounds" parameter of this ' + 'fit method instead.', DeprecationWarning) X, y = self._prepare_inputs(X, y, ensure_min_samples=2) num_constraints = self.num_constraints if num_constraints is None: @@ -274,4 +286,4 @@ def fit(self, X, y, random_state=np.random): pos_neg = c.positive_negative_pairs(num_constraints, random_state=random_state) pairs, y = wrap_pairs(X, pos_neg) - return _BaseITML._fit(self, pairs, y, bounds=self.bounds) + return _BaseITML._fit(self, pairs, y, bounds=bounds) diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index eebce1f9..b4b8c925 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -58,8 +58,9 @@ def test_iris(self): self.assertLess(csep, 0.8) # it's pretty terrible def test_deprecation(self): - # test that the right deprecation message is thrown. - # TODO: remove in v.0.5 + # test that a deprecation message is thrown if num_labeled is set at + # initialization + # TODO: remove in v.0.6 X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) y = np.array([1, 0, 1, 0]) lsml_supervised = LSML_Supervised(num_labeled=np.inf) @@ -78,8 +79,9 @@ def test_iris(self): self.assertLess(csep, 0.2) def test_deprecation(self): - # test that the right deprecation message is thrown. - # TODO: remove in v.0.5 + # test that a deprecation message is thrown if num_labeled is set at + # initialization + # TODO: remove in v.0.6 X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) y = np.array([1, 0, 1, 0]) itml_supervised = ITML_Supervised(num_labeled=np.inf) @@ -88,6 +90,19 @@ def test_deprecation(self): 'removed in 0.6.0') assert_warns_message(DeprecationWarning, msg, itml_supervised.fit, X, y) + def test_deprecation_bounds(self): + # test that a deprecation message is thrown if bounds is set at + # initialization + # TODO: remove in v.0.6 + X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) + y = np.array([1, 0, 1, 0]) + itml_supervised = ITML_Supervised(bounds=None) + msg = ('"bounds" parameter from initialization is not used.' + ' It has been deprecated in version 0.5.0 and will be' + 'removed in 0.6.0. Use the "bounds" parameter of this ' + 'fit method instead.') + assert_warns_message(DeprecationWarning, msg, itml_supervised.fit, X, y) + class TestLMNN(MetricTestCase): def test_iris(self): @@ -143,9 +158,10 @@ def test_iris(self): csep = class_separation(sdml.transform(self.iris_points), self.iris_labels) self.assertLess(csep, 0.25) - def test_deprecation(self): - # test that the right deprecation message is thrown. - # TODO: remove in v.0.5 + def test_deprecation_num_labeled(self): + # test that a deprecation message is thrown if num_labeled is set at + # initialization + # TODO: remove in v.0.6 X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) y = np.array([1, 0, 1, 0]) sdml_supervised = SDML_Supervised(num_labeled=np.inf) @@ -368,9 +384,10 @@ def test_iris(self): csep = class_separation(mmc.transform(self.iris_points), self.iris_labels) self.assertLess(csep, 0.2) - def test_deprecation(self): - # test that the right deprecation message is thrown. - # TODO: remove in v.0.5 + def test_deprecation_num_labeled(self): + # test that a deprecation message is thrown if num_labeled is set at + # initialization + # TODO: remove in v.0.6 X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) y = np.array([1, 0, 1, 0]) mmc_supervised = MMC_Supervised(num_labeled=np.inf) diff --git a/test/test_base_metric.py b/test/test_base_metric.py index fdea2949..58938985 100644 --- a/test/test_base_metric.py +++ b/test/test_base_metric.py @@ -31,9 +31,9 @@ def test_itml(self): preprocessor=None, verbose=False) """.strip('\n')) self.assertEqual(str(metric_learn.ITML_Supervised()), """ -ITML_Supervised(A0=None, bounds=None, convergence_threshold=0.001, gamma=1.0, - max_iter=1000, num_constraints=None, num_labeled='deprecated', - preprocessor=None, verbose=False) +ITML_Supervised(A0=None, bounds='deprecated', convergence_threshold=0.001, + gamma=1.0, max_iter=1000, num_constraints=None, + num_labeled='deprecated', preprocessor=None, verbose=False) """.strip('\n')) def test_lsml(self): From a2ed22b925913742be5413cf1bbc350166daea51 Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Tue, 29 Jan 2019 15:39:20 +0100 Subject: [PATCH 5/5] Change remaining test names --- test/metric_learn_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index edb83008..e1eace90 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -57,7 +57,7 @@ def test_iris(self): csep = class_separation(lsml.transform(self.iris_points), self.iris_labels) self.assertLess(csep, 0.8) # it's pretty terrible - def test_deprecation(self): + def test_deprecation_num_labeled(self): # test that a deprecation message is thrown if num_labeled is set at # initialization # TODO: remove in v.0.6 @@ -78,7 +78,7 @@ def test_iris(self): csep = class_separation(itml.transform(self.iris_points), self.iris_labels) self.assertLess(csep, 0.2) - def test_deprecation(self): + def test_deprecation_num_labeled(self): # test that a deprecation message is thrown if num_labeled is set at # initialization # TODO: remove in v.0.6