-
Notifications
You must be signed in to change notification settings - Fork 229
[MRG] Learning on Triplets #279
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
perimosocordiae
merged 19 commits into
scikit-learn-contrib:master
from
grudloff:add_triplets
Mar 4, 2020
Merged
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
965cc5c
add _TripletsClassifierMixin
grudloff ecd22f6
added doc
grudloff f0df4ad
remove redundant code
grudloff b8054f3
added tests
grudloff de7aa11
triplets added to doc autosumary
grudloff cb64e2f
rephrasing, added docstring and small changes
grudloff 102e120
small rephrasing
grudloff 3b421ab
small flake8 fix
grudloff 3e164e4
Handle low number of neighbors for knn triplets
grudloff 8e45abe
add tests for knn triplet generation
grudloff aaeb458
fixed typos and rephrasing
grudloff 43cfd6c
added more tests for knn triplet construction
grudloff 0d5134a
sorted triplet & fix test_generate_knntriplets_k
grudloff 758bf14
added over the edge knn triplets test
grudloff f41fea1
multiple small code refactoring
grudloff 14cf03d
more refactoring
grudloff 088b59a
Fix & test unlabeled handling triplet generation
grudloff 59d253a
closer unlabeled point
grudloff fb30673
small clarity enhancement & repmat replacement
grudloff File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -592,14 +592,122 @@ points, while constrains the sum of distances between dissimilar points: | |
-with-side-information.pdf>`_. NIPS 2002 | ||
.. [2] Adapted from Matlab code http://www.cs.cmu.edu/%7Eepxing/papers/Old_papers/code_Metric_online.tar.gz | ||
|
||
.. _learning_on_triplets: | ||
|
||
Learning on triplets | ||
==================== | ||
|
||
Some metric learning algorithms learn on triplets of samples. In this case, | ||
one should provide the algorithm with `n_samples` triplets of points. The | ||
semantic of each triplet is that the first point should be closer to the | ||
second point than to the third one. | ||
|
||
Fitting | ||
------- | ||
Here is an example for fitting on triplets (see :ref:`fit_ws` for more | ||
details on the input data format and how to fit, in the general case of | ||
learning on tuples). | ||
|
||
>>> from metric_learn import SCML | ||
>>> triplets = np.array([[[1.2, 3.2], [2.3, 5.5], [2.1, 0.6]], | ||
>>> [[4.5, 2.3], [2.1, 2.3], [7.3, 3.4]]]) | ||
>>> scml = SCML(random_state=42) | ||
>>> scml.fit(triplets) | ||
SCML(beta=1e-5, B=None, max_iter=100000, verbose=False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will have to be updated with the final API for SCML in #278 |
||
preprocessor=None, random_state=None) | ||
|
||
Or alternatively (using a preprocessor): | ||
|
||
>>> X = np.array([[[1.2, 3.2], | ||
>>> [2.3, 5.5], | ||
>>> [2.1, 0.6], | ||
>>> [4.5, 2.3], | ||
>>> [2.1, 2.3], | ||
>>> [7.3, 3.4]]) | ||
>>> triplets_indices = np.array([[0, 1, 2], [3, 4, 5]]) | ||
>>> scml = SCML(preprocessor=X, random_state=42) | ||
>>> scml.fit(triplets_indices) | ||
SCML(beta=1e-5, B=None, max_iter=100000, verbose=False, | ||
preprocessor=array([[1.2, 3.2], | ||
[2.3, 5.5], | ||
[2.4, 6.7], | ||
[2.1, 0.6], | ||
[4.5, 2.3], | ||
[2.1, 2.3], | ||
[0.6, 1.2], | ||
[7.3, 3.4]]), | ||
random_state=None) | ||
|
||
|
||
Here, we want to learn a metric that, for each of the two | ||
`triplets`, will make the first point closer to the | ||
second point than to the third one. | ||
|
||
.. _triplets_predicting: | ||
|
||
Prediction | ||
---------- | ||
|
||
When a triplets learner is fitted, it is also able to predict, for an | ||
upcoming triplet, whether the first point is closer to the second point | ||
than to the third one (+1), or not (-1). | ||
|
||
>>> triplets_test = np.array( | ||
... [[[5.6, 5.3], [2.2, 2.1], [1.2, 3.4]], | ||
... [[6.0, 4.2], [4.3, 1.2], [0.1, 7.8]]]) | ||
>>> scml.predict(triplets_test) | ||
array([-1., 1.]) | ||
|
||
.. _triplets_scoring: | ||
|
||
Scoring | ||
------- | ||
|
||
Triplet metric learners can also return a `decision_function` for a set of triplets, | ||
which corresponds to the distance between the first two points minus the distance | ||
between the first and last points of the triplet (the higher the value, the more | ||
similar the first point to the second point compared to the last one). This "score" | ||
can be interpreted as a measure of likeliness of having a +1 prediction for this | ||
triplet. | ||
|
||
>>> scml.decision_function(triplets_test) | ||
array([-1.75700306, 4.98982131]) | ||
|
||
In the above example, for the first triplet in `triplets_test`, the first | ||
point is predicted less similar to the second point than to the last point | ||
(they are further away in the transformed space). | ||
|
||
Unlike pairs learners, triplets learners do not allow to give a `y` when fitting: we | ||
assume that the ordering of points within triplets is such that the training triplets | ||
are all positive. Therefore, it is not possible to use scikit-learn scoring functions | ||
(such as 'f1_score') for triplets learners. | ||
|
||
However, triplets learners do have a default scoring function, which will | ||
basically return the accuracy score on a given test set, i.e. the proportion | ||
of triplets that have the right predicted ordering. | ||
|
||
>>> scml.score(triplets_test) | ||
0.5 | ||
|
||
.. note:: | ||
See :ref:`fit_ws` for more details on metric learners functions that are | ||
not specific to learning on pairs, like `transform`, `score_pairs`, | ||
`get_metric` and `get_mahalanobis_matrix`. | ||
|
||
|
||
|
||
|
||
Algorithms | ||
---------- | ||
|
||
|
||
.. _learning_on_quadruplets: | ||
|
||
Learning on quadruplets | ||
======================= | ||
|
||
Some metric learning algorithms learn on quadruplets of samples. In this case, | ||
one should provide the algorithm with `n_samples` quadruplets of points. Th | ||
one should provide the algorithm with `n_samples` quadruplets of points. The | ||
semantic of each quadruplet is that the first two points should be closer | ||
together than the last two points. | ||
|
||
|
@@ -666,14 +774,12 @@ array([-1., 1.]) | |
Scoring | ||
------- | ||
|
||
Quadruplet metric learners can also | ||
return a `decision_function` for a set of pairs. This is basically the "score" | ||
which sign will be taken to find the prediction for the pair, which | ||
corresponds to the difference between the distance between the two last points, | ||
and the distance between the two last points of the quadruplet (higher | ||
score means the two last points are more likely to be more dissimilar than | ||
the two first points (i.e. more likely to have a +1 prediction since it's | ||
the right ordering)). | ||
Quadruplet metric learners can also return a `decision_function` for a set of | ||
quadruplets, which corresponds to the distance between the first pair of points minus | ||
the distance between the second pair of points of the triplet (the higher the value, | ||
the more similar the first pair is than the last pair). | ||
This "score" can be interpreted as a measure of likeliness of having a +1 prediction | ||
for this quadruplet. | ||
|
||
>>> lsml.decision_function(quadruplets_test) | ||
array([-1.75700306, 4.98982131]) | ||
|
@@ -682,17 +788,10 @@ In the above example, for the first quadruplet in `quadruplets_test`, the | |
two first points are predicted less similar than the two last points (they | ||
are further away in the transformed space). | ||
|
||
Unlike for pairs learners, quadruplets learners don't allow to give a `y` | ||
when fitting, which does not allow to use scikit-learn scoring functions | ||
like: | ||
|
||
>>> from sklearn.model_selection import cross_val_score | ||
>>> cross_val_score(lsml, quadruplets, scoring='f1_score') # this won't work | ||
|
||
(This is actually intentional, for more details | ||
about that, see | ||
`this comment <https://github.com/scikit-learn-contrib/metric-learn/pull/168#pullrequestreview-203730742>`_ | ||
on github.) | ||
Like triplet learners, quadruplets learners do not allow to give a `y` when fitting: we | ||
assume that the ordering of points within triplets is such that the training triplets | ||
are all positive. Therefore, it is not possible to use scikit-learn scoring functions | ||
(such as 'f1_score') for triplets learners. | ||
|
||
However, quadruplets learners do have a default scoring function, which will | ||
basically return the accuracy score on a given test set, i.e. the proportion | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.