Skip to content

Commit 675dfcf

Browse files
author
mvargas33
committed
First draft of refactoring BaseMetricLearner and Mahalanobis Learner
1 parent 319bc5d commit 675dfcf

File tree

1 file changed

+127
-2
lines changed

1 file changed

+127
-2
lines changed

metric_learn/base_metric.py

Lines changed: 127 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import numpy as np
1010
from abc import ABCMeta, abstractmethod
1111
from ._util import ArrayIndexer, check_input, validate_vector
12+
import warnings
1213

1314

1415
class BaseMetricLearner(BaseEstimator, metaclass=ABCMeta):
@@ -27,7 +28,8 @@ def __init__(self, preprocessor=None):
2728

2829
@abstractmethod
2930
def score_pairs(self, pairs):
30-
"""Returns the score between pairs
31+
"""Deprecated.
32+
Returns the score between pairs
3133
(can be a similarity, or a distance/metric depending on the algorithm)
3234
3335
Parameters
@@ -49,6 +51,57 @@ def score_pairs(self, pairs):
4951
learner is.
5052
"""
5153

54+
@abstractmethod
55+
def pair_similarity(self, pairs):
56+
"""Returns the similarity score between pairs. Depending on the algorithm,
57+
this function can return the direct learned similarity score between pairs,
58+
or it can return the inverse of the distance learned between two pairs.
59+
60+
Parameters
61+
----------
62+
pairs : `numpy.ndarray`, shape=(n_samples, 2, n_features)
63+
3D array of pairs.
64+
65+
Returns
66+
-------
67+
scores : `numpy.ndarray` of shape=(n_pairs,)
68+
The score of every pair.
69+
70+
See Also
71+
--------
72+
get_metric : a method that returns a function to compute the metric between
73+
two points. The difference with `score_pairs` is that it works on two 1D
74+
arrays and cannot use a preprocessor. Besides, the returned function is
75+
independent of the metric learner and hence is not modified if the metric
76+
learner is.
77+
"""
78+
79+
@abstractmethod
80+
def pair_distance(self, pairs):
81+
"""Returns the distance score between pairs. Depending on the algorithm,
82+
this function can return the direct learned distance (or pseudo-distance)
83+
score between pairs, or it can return the inverse score of the similarity
84+
learned between two pairs.
85+
86+
Parameters
87+
----------
88+
pairs : `numpy.ndarray`, shape=(n_samples, 2, n_features)
89+
3D array of pairs.
90+
91+
Returns
92+
-------
93+
scores : `numpy.ndarray` of shape=(n_pairs,)
94+
The score of every pair.
95+
96+
See Also
97+
--------
98+
get_metric : a method that returns a function to compute the metric between
99+
two points. The difference with `score_pairs` is that it works on two 1D
100+
arrays and cannot use a preprocessor. Besides, the returned function is
101+
independent of the metric learner and hence is not modified if the metric
102+
learner is.
103+
"""
104+
52105
def _check_preprocessor(self):
53106
"""Initializes the preprocessor"""
54107
if _is_arraylike(self.preprocessor):
@@ -182,7 +235,79 @@ class MahalanobisMixin(BaseMetricLearner, MetricTransformer,
182235
"""
183236

184237
def score_pairs(self, pairs):
185-
r"""Returns the learned Mahalanobis distance between pairs.
238+
r"""Deprecated.
239+
240+
Returns the learned Mahalanobis distance between pairs.
241+
242+
This distance is defined as: :math:`d_M(x, x') = \sqrt{(x-x')^T M (x-x')}`
243+
where ``M`` is the learned Mahalanobis matrix, for every pair of points
244+
``x`` and ``x'``. This corresponds to the euclidean distance between
245+
embeddings of the points in a new space, obtained through a linear
246+
transformation. Indeed, we have also: :math:`d_M(x, x') = \sqrt{(x_e -
247+
x_e')^T (x_e- x_e')}`, with :math:`x_e = L x` (See
248+
:class:`MahalanobisMixin`).
249+
250+
Parameters
251+
----------
252+
pairs : array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2)
253+
3D Array of pairs to score, with each row corresponding to two points,
254+
for 2D array of indices of pairs if the metric learner uses a
255+
preprocessor.
256+
257+
Returns
258+
-------
259+
scores : `numpy.ndarray` of shape=(n_pairs,)
260+
The learned Mahalanobis distance for every pair.
261+
262+
See Also
263+
--------
264+
get_metric : a method that returns a function to compute the metric between
265+
two points. The difference with `score_pairs` is that it works on two 1D
266+
arrays and cannot use a preprocessor. Besides, the returned function is
267+
independent of the metric learner and hence is not modified if the metric
268+
learner is.
269+
270+
:ref:`mahalanobis_distances` : The section of the project documentation
271+
that describes Mahalanobis Distances.
272+
"""
273+
dpr_msg = ("score_pairs will be deprecated in the next release. "
274+
"Use pair_similarity to compute similarities, or "
275+
"pair_distances to compute distances.")
276+
warnings.warn(dpr_msg, category=FutureWarning)
277+
return self.pair_distance(pairs)
278+
279+
def pair_similarity(self, pairs):
280+
"""
281+
Returns the inverse of the learned Mahalanobis distance between pairs.
282+
283+
Parameters
284+
----------
285+
pairs : array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2)
286+
3D Array of pairs to score, with each row corresponding to two points,
287+
for 2D array of indices of pairs if the metric learner uses a
288+
preprocessor.
289+
290+
Returns
291+
-------
292+
scores : `numpy.ndarray` of shape=(n_pairs,)
293+
The inverse of the learned Mahalanobis distance for every pair.
294+
295+
See Also
296+
--------
297+
get_metric : a method that returns a function to compute the metric between
298+
two points. The difference with `score_pairs` is that it works on two 1D
299+
arrays and cannot use a preprocessor. Besides, the returned function is
300+
independent of the metric learner and hence is not modified if the metric
301+
learner is.
302+
303+
:ref:`mahalanobis_distances` : The section of the project documentation
304+
that describes Mahalanobis Distances.
305+
"""
306+
return -1 * self.pair_distance(pairs)
307+
308+
def pair_distance(self, pairs):
309+
"""
310+
Returns the learned Mahalanobis distance between pairs.
186311
187312
This distance is defined as: :math:`d_M(x, x') = \sqrt{(x-x')^T M (x-x')}`
188313
where ``M`` is the learned Mahalanobis matrix, for every pair of points

0 commit comments

Comments
 (0)