Skip to content

Commit 8827e64

Browse files
podgorkipre-commit-ci[bot]SkafteNickiBordamergify[bot]
authored
Bugfix so multilabel confusion matrix can plot for 2 or more labels (#2858)
* added matplotlib to test requirements * added new test for plotting in multilabel classifier * added bugfix * fix errors * changelog --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Nicki Skafte Detlefsen <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent a7284e2 commit 8827e64

File tree

3 files changed

+14
-1
lines changed

3 files changed

+14
-1
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2727

2828
### Fixed
2929

30+
- Fixed plotting of multilabel confusion matrix ([#2858](https://github.com/PyTorchLightning/metrics/pull/2858))
31+
32+
3033
- Delete `Device2Host` caused by comm with device and host ([#2840](https://github.com/PyTorchLightning/metrics/pull/2840))
3134

3235

src/torchmetrics/utilities/plot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def plot_confusion_matrix(
270270
fig, axs = plt.subplots(nrows=rows, ncols=cols, constrained_layout=True) if ax is None else (ax.get_figure(), ax)
271271
axs = trim_axs(axs, nb)
272272
for i in range(nb):
273-
ax = axs[i] if rows != 1 and cols != 1 else axs
273+
ax = axs[i] if (rows != 1 or cols != 1) else axs
274274
if fig_label is not None:
275275
ax.set_title(f"Label {fig_label[i]}", fontsize=15)
276276
im = ax.imshow(confmat[i].cpu().detach() if confmat.ndim == 3 else confmat.cpu().detach(), cmap=cmap)

tests/unittests/classification/test_confusion_matrix.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,16 @@ def test_multilabel_confusion_matrix_dtype_gpu(self, inputs, dtype):
393393
dtype=dtype,
394394
)
395395

396+
@pytest.mark.parametrize("num_labels", [2, NUM_CLASSES])
397+
def test_multilabel_confusion_matrix_plot(self, num_labels, inputs):
398+
"""Test multilabel cm plots."""
399+
multi_label_confusion_matrix = MultilabelConfusionMatrix(num_labels=num_labels)
400+
preds = target = torch.ones(1, num_labels).int()
401+
multi_label_confusion_matrix.update(preds, target)
402+
fig, ax = multi_label_confusion_matrix.plot()
403+
assert fig is not None
404+
assert ax is not None
405+
396406

397407
def test_warning_on_nan():
398408
"""Test that a warning is given if division by zero happens during normalization of confusion matrix."""

0 commit comments

Comments
 (0)