-
Notifications
You must be signed in to change notification settings - Fork 461
Description
🐛 Bug
When instantiating the multiclass (or multilabel) accuracy metric through the Accuracy
wrapper class (legacy), the default value for average
is micro
. When instantiating directly through MulticlassAccuracy
(new way since 0.11 I believe), the default value is macro
. This is inconsistent, which can lead to very unexpected results.
The same is true for all metrics that are subclasses of MulticlassStatScores
, BinaryStatScores
or MultilabelStatScores
as well as their respective functional interfaces.
To Reproduce
- Instantiate the metrics directly as well as through the wrapper.
- Compare results.
Code sample
classes = {0: "A", 1: "B", 2: "C"}
num_classes = len(classes)
num_samples = 10
multiclass_preds = torch.randn(num_samples, num_classes)
multiclass_targets = torch.randint(0, num_classes, (num_samples,))
legacy_mc_acc = Accuracy("multiclass", num_classes)
new_mc_acc = MulticlassAccuracy(num_classes)
legacy_result = legacy_mc_acc(multiclass_preds, multiclass_targets)
new_result = new_mc_acc(multiclass_preds, multiclass_targets)
assert new_result == legacy_result
Expected behavior
Consistency between the different interfaces.
Environment
- TorchMetrics version (and how you installed TM, e.g.
conda
,pip
, build from source): >=0.11 (1.3 in my case) - Python & PyTorch Version (e.g., 1.0): irrelevant
- Any other relevant information such as OS (e.g., Linux): irrelevant
Additional context
I would argue that in the case of accuracy the default being macro
in the task-specific classes is not only inconsistent with legacy but actually wrong. The common deinition of accuracy is
which is how accuracy is computed when setting average="micro"
.
Setting average="macro"
can still be useful, as it is less prone to class imbalance. However, I think TorchMetrics should adhere to common definitions with the default settings, and would therefore argue for making micro
the default.
The same is kind of true for precision and recall, which are also commonly defined as micro averages, if they are defined globally at all. Usually we encounter recall and precision as class-wise metrics.