Skip to content

Commit a44dfdc

Browse files
author
mvargas33
committed
Simplified prediction as suggested
1 parent abe014e commit a44dfdc

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

metric_learn/base_metric.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -602,9 +602,7 @@ def predict(self, triplets):
602602
prediction : `numpy.ndarray` of floats, shape=(n_constraints,)
603603
Predictions of the ordering of pairs, for each triplet.
604604
"""
605-
return np.array([-1 if (t <= 0) else 1 for t in
606-
self.decision_function(triplets)])
607-
# return np.sign(self.decision_function(triplets))
605+
return 2 * (self.decision_function(triplets) > 0) - 1
608606

609607
def decision_function(self, triplets):
610608
"""Predicts differences between sample distances in input triplets.

0 commit comments

Comments
 (0)