Skip to content

Commit 43a60c9

Browse files
SCML : Sparse Compositional Metric Learning (#278)
* scml first commit * add scml to __init__.py * fix in components calculation * remove triplet generator, added in triplets PR * change init&fit interface, faster compute & others * added coments & docstrings, small code changes * typos and added choice of gamma & output_iter * some small improvements * lda tail handling rollback * performance improvement by precomputing rand_ints * small fix in components computation * flake8 fix * SCML_global fit fix & other small changes * Proper use of init vars and unsup bases generation * triplet dataset format & remove_y for triplets * adaptation with dataset format * remove labels for triplets and quadruplets * remove labels * remove labels & old fit random_state asignation * compliant with older numpy versions * small typo and fix order * fix n_basis check * initialize_basis_supervised and some refactoring * proper n_basis handling * scml specific tests * remove small mistake * test user input basis * Changed names and messages and some refactoring * triplets in features form passed to _fit * change indeces handlig and edge case fix * name change and typos * improve test_components_is_2D * Replace triplet_diffs option by better aproach * some comments, docstring and refactoring * fix bad triplet set * flake8 fix * SCML doc first draft * find neighbors for every class only once * improve some docstring and warnings * add sklearn compat test * changes to doc * fix and improve tests * use components_from_metric * change TestSCML to object and parametrize tests * fix test_iris * use model._authorized_basis and other fixes * verbose test * revert sum_where * small n_basis warning instead of error * add test iris on triplet_diffs * test lda & triplet_diffs * improved messages * remove quadruplets and triplets from pipeline test * test big n_features * Correct output iters * output_iter on supervised and improved verbose * flake8 fix * bases generation test comments * change big_n_basis_lda error msg * test generated n_basis and basis shape * add mini batch optimization * correct iter convention * eliminate n_samples = 1000 * batch grad refactored * adagrad adaptive learning * int input checks and tests * flake8 fix * no double division and smaller triplets arrays * minor grammar fixes * minor formatting tweaks Co-authored-by: CJ Carey <[email protected]>
1 parent c15f1c3 commit 43a60c9

10 files changed

+1080
-174
lines changed

doc/metric_learn.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ Supervised Learning Algorithms
3333
metric_learn.MMC_Supervised
3434
metric_learn.SDML_Supervised
3535
metric_learn.RCA_Supervised
36+
metric_learn.SCML_Supervised
3637

3738
Weakly Supervised Learning Algorithms
3839
-------------------------------------
@@ -45,6 +46,7 @@ Weakly Supervised Learning Algorithms
4546
metric_learn.LSML
4647
metric_learn.MMC
4748
metric_learn.SDML
49+
metric_learn.SCML
4850

4951
Unsupervised Learning Algorithms
5052
--------------------------------

doc/weakly_supervised.rst

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,63 @@ of triplets that have the right predicted ordering.
700700
Algorithms
701701
----------
702702

703+
.. _scml:
704+
705+
:py:class:`SCML <metric_learn.SCML>`
706+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
707+
708+
Sparse Compositional Metric Learning
709+
(:py:class:`SCML <metric_learn.SCML>`)
710+
711+
`SCML` learns a squared Mahalanobis distance from triplet constraints by
712+
optimizing sparse positive weights assigned to a set of :math:`K` rank-one
713+
PSD bases. This can be formulated as an optimization problem with only
714+
:math:`K` parameters, that can be solved with an efficient stochastic
715+
composite scheme.
716+
717+
The Mahalanobis matrix :math:`M` is built from a basis set :math:`B = \{b_i\}_{i=\{1,...,K\}}`
718+
weighted by a :math:`K` dimensional vector :math:`w = \{w_i\}_{i=\{1,...,K\}}` as:
719+
720+
.. math::
721+
722+
M = \sum_{i=1}^K w_i b_i b_i^T = B \cdot diag(w) \cdot B^T \quad w_i \geq 0
723+
724+
Learning :math:`M` in this form makes it PSD by design, as it is a
725+
nonnegative sum of PSD matrices. The basis set :math:`B` is fixed in advance
726+
and it is possible to construct it from the data. The optimization problem
727+
over :math:`w` is formulated as a classic margin-based hinge loss function
728+
involving the set :math:`C` of triplets. A regularization :math:`\ell_1`
729+
is added to yield a sparse combination. The formulation is the following:
730+
731+
.. math::
732+
733+
\min_{w\geq 0} \sum_{(x_i,x_j,x_k)\in C} [1 + d_w(x_i,x_j)-d_w(x_i,x_k)]_+ + \beta||w||_1
734+
735+
where :math:`[\cdot]_+` is the hinge loss.
736+
737+
.. topic:: Example Code:
738+
739+
::
740+
741+
from metric_learn import SCML
742+
743+
triplets = [[[1.2, 7.5], [1.3, 1.5], [6.2, 9.7]],
744+
[[1.3, 4.5], [3.2, 4.6], [5.4, 5.4]],
745+
[[3.2, 7.5], [3.3, 1.5], [8.2, 9.7]],
746+
[[3.3, 4.5], [5.2, 4.6], [7.4, 5.4]]]
747+
748+
scml = SCML()
749+
scml.fit(triplets)
750+
751+
.. topic:: References:
752+
753+
.. [1] Y. Shi, A. Bellet and F. Sha. `Sparse Compositional Metric Learning.
754+
<http://researchers.lille.inria.fr/abellet/papers/aaai14.pdf>`_. \
755+
(AAAI), 2014.
756+
757+
.. [2] Adapted from original \
758+
`Matlab implementation.<https://github.com/bellet/SCML>`_.
759+
703760
704761
.. _learning_on_quadruplets:
705762

@@ -829,13 +886,13 @@ extension leads to more stable estimation when the dimension is high and
829886
only a small amount of constraints is given.
830887

831888
The loss function of each constraint
832-
:math:`d(\mathbf{x}_a, \mathbf{x}_b) < d(\mathbf{x}_c, \mathbf{x}_d)` is
889+
:math:`d(\mathbf{x}_i, \mathbf{x}_j) < d(\mathbf{x}_k, \mathbf{x}_l)` is
833890
denoted as:
834891

835892
.. math::
836893
837-
H(d_\mathbf{M}(\mathbf{x}_a, \mathbf{x}_b)
838-
- d_\mathbf{M}(\mathbf{x}_c, \mathbf{x}_d))
894+
H(d_\mathbf{M}(\mathbf{x}_i, \mathbf{x}_j)
895+
- d_\mathbf{M}(\mathbf{x}_k, \mathbf{x}_l))
839896
840897
where :math:`H(\cdot)` is the squared Hinge loss function defined as:
841898

@@ -845,8 +902,8 @@ where :math:`H(\cdot)` is the squared Hinge loss function defined as:
845902
\,\,x^2 \qquad x>0\end{aligned}\right.\\
846903
847904
The summed loss function :math:`L(C)` is the simple sum over all constraints
848-
:math:`C = \{(\mathbf{x}_a , \mathbf{x}_b , \mathbf{x}_c , \mathbf{x}_d)
849-
: d(\mathbf{x}_a , \mathbf{x}_b) < d(\mathbf{x}_c , \mathbf{x}_d)\}`. The
905+
:math:`C = \{(\mathbf{x}_i , \mathbf{x}_j , \mathbf{x}_k , \mathbf{x}_l)
906+
: d(\mathbf{x}_i , \mathbf{x}_j) < d(\mathbf{x}_k , \mathbf{x}_l)\}`. The
850907
original paper suggested here should be a weighted sum since the confidence
851908
or probability of each constraint might differ. However, for the sake of
852909
simplicity and assumption of no extra knowledge provided, we just deploy
@@ -858,9 +915,9 @@ knowledge:
858915

859916
.. math::
860917
861-
\min_\mathbf{M}(D_{ld}(\mathbf{M, M_0}) + \sum_{(\mathbf{x}_a,
862-
\mathbf{x}_b, \mathbf{x}_c, \mathbf{x}_d)\in C}H(d_\mathbf{M}(
863-
\mathbf{x}_a, \mathbf{x}_b) - d_\mathbf{M}(\mathbf{x}_c, \mathbf{x}_c))\\
918+
\min_\mathbf{M}(D_{ld}(\mathbf{M, M_0}) + \sum_{(\mathbf{x}_i,
919+
\mathbf{x}_j, \mathbf{x}_k, \mathbf{x}_l)\in C}H(d_\mathbf{M}(
920+
\mathbf{x}_i, \mathbf{x}_j) - d_\mathbf{M}(\mathbf{x}_k, \mathbf{x}_l))\\
864921
865922
where :math:`\mathbf{M}_0` is the prior metric matrix, set as identity
866923
by default, :math:`D_{ld}(\mathbf{\cdot, \cdot})` is the LogDet divergence:

metric_learn/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
from .rca import RCA, RCA_Supervised
1010
from .mlkr import MLKR
1111
from .mmc import MMC, MMC_Supervised
12+
from .scml import SCML, SCML_Supervised
1213

1314
from ._version import __version__
1415

1516
__all__ = ['Constraints', 'Covariance', 'ITML', 'ITML_Supervised',
1617
'LMNN', 'LSML', 'LSML_Supervised', 'SDML',
1718
'SDML_Supervised', 'NCA', 'LFDA', 'RCA', 'RCA_Supervised',
18-
'MLKR', 'MMC', 'MMC_Supervised', '__version__']
19+
'MLKR', 'MMC', 'MMC_Supervised', 'SCML',
20+
'SCML_Supervised', '__version__']

0 commit comments

Comments
 (0)