Skip to content

Commit 6df637e

Browse files
vincentvaroquauxadsBorda
authored andcommitted
fix ConfusionMatrix and StatScores for num_classes > 16 (#1521)
* fix: ConfusionMatrix&StatScores for num_classes > 16 e.g. if preds or target is uint8 and num_classes > 16, unique_mapping overflows * unittest #1521, NUM_CLASSES=17, add multiclass case "single dim int8-logits" * revert tests * add byte testing * changelog --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Nicki Skafte Detlefsen <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> (cherry picked from commit 6bc249d)
1 parent 6cb74aa commit 6df637e

File tree

5 files changed

+36
-3
lines changed

5 files changed

+36
-3
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2323
- Fixed `multilabel` in `ExactMatch` ([#1474](https://github.com/Lightning-AI/metrics/pull/1474))
2424

2525

26+
- Fixed classification metrics for `byte` input ([#1521](https://github.com/Lightning-AI/metrics/pull/1474))
27+
28+
2629
## [0.11.1] - 2023-01-30
2730

2831
### Fixed

src/torchmetrics/functional/classification/confusion_matrix.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,8 +325,8 @@ def _multiclass_confusion_matrix_format(
325325

326326

327327
def _multiclass_confusion_matrix_update(preds: Tensor, target: Tensor, num_classes: int) -> Tensor:
328-
"""Computes the bins to update the confusion matrix with."""
329-
unique_mapping = (target * num_classes + preds).to(torch.long)
328+
"""Compute the bins to update the confusion matrix with."""
329+
unique_mapping = target.to(torch.long) * num_classes + preds.to(torch.long)
330330
bins = _bincount(unique_mapping, minlength=num_classes**2)
331331
return bins.reshape(num_classes, num_classes)
332332

src/torchmetrics/functional/classification/stat_scores.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ def _multiclass_stat_scores_update(
408408
idx = target != ignore_index
409409
preds = preds[idx]
410410
target = target[idx]
411-
unique_mapping = (target * num_classes + preds).to(torch.long)
411+
unique_mapping = target.to(torch.long) * num_classes + preds.to(torch.long)
412412
bins = _bincount(unique_mapping, minlength=num_classes**2)
413413
confmat = bins.reshape(num_classes, num_classes)
414414
tp = confmat.diag()

tests/unittests/classification/test_confusion_matrix.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,18 @@ def test_multiclass_confusion_matrix_dtype_gpu(self, input, dtype):
214214
)
215215

216216

217+
def test_multiclass_overflow():
218+
"""Test that multiclass computations does not overflow even on byte input."""
219+
preds = torch.randint(20, (100,)).byte()
220+
target = torch.randint(20, (100,)).byte()
221+
222+
m = MulticlassConfusionMatrix(num_classes=20)
223+
res = m(preds, target)
224+
225+
compare = sk_confusion_matrix(target, preds)
226+
assert torch.allclose(res, torch.tensor(compare))
227+
228+
217229
def _sk_confusion_matrix_multilabel(preds, target, normalize=None, ignore_index=None):
218230
preds = preds.numpy()
219231
target = target.numpy()

tests/unittests/classification/test_stat_scores.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,24 @@ def test_top_k_multiclass(k, preds, target, average, expected):
323323
)
324324

325325

326+
def test_multiclass_overflow():
327+
"""Test that multiclass computations does not overflow even on byte input."""
328+
preds = torch.randint(20, (100,)).byte()
329+
target = torch.randint(20, (100,)).byte()
330+
331+
m = MulticlassStatScores(num_classes=20, average=None)
332+
res = m(preds, target)
333+
334+
confmat = sk_confusion_matrix(target, preds)
335+
fp = confmat.sum(axis=0) - np.diag(confmat)
336+
fn = confmat.sum(axis=1) - np.diag(confmat)
337+
tp = np.diag(confmat)
338+
tn = confmat.sum() - (fp + fn + tp)
339+
compare = np.stack([tp, fp, tn, fn, tp + fn]).T
340+
341+
assert torch.allclose(res, torch.tensor(compare))
342+
343+
326344
def _sk_stat_scores_multilabel(preds, target, ignore_index, multidim_average, average):
327345
preds = preds.numpy()
328346
target = target.numpy()

0 commit comments

Comments
 (0)