Skip to content

Inconsistent default values for average argument in classification metrics #2320

@StefanoWoerner

Description

@StefanoWoerner

🐛 Bug

When instantiating the multiclass (or multilabel) accuracy metric through the Accuracy wrapper class (legacy), the default value for average is micro. When instantiating directly through MulticlassAccuracy (new way since 0.11 I believe), the default value is macro. This is inconsistent, which can lead to very unexpected results.

The same is true for all metrics that are subclasses of MulticlassStatScores, BinaryStatScores or MultilabelStatScores as well as their respective functional interfaces.

To Reproduce

  1. Instantiate the metrics directly as well as through the wrapper.
  2. Compare results.
Code sample
classes = {0: "A", 1: "B", 2: "C"}
num_classes = len(classes)
num_samples = 10
multiclass_preds = torch.randn(num_samples, num_classes)
multiclass_targets = torch.randint(0, num_classes, (num_samples,))

legacy_mc_acc = Accuracy("multiclass", num_classes)
new_mc_acc = MulticlassAccuracy(num_classes)

legacy_result = legacy_mc_acc(multiclass_preds, multiclass_targets)
new_result = new_mc_acc(multiclass_preds, multiclass_targets)

assert  new_result == legacy_result

Expected behavior

Consistency between the different interfaces.

Environment

  • TorchMetrics version (and how you installed TM, e.g. conda, pip, build from source): >=0.11 (1.3 in my case)
  • Python & PyTorch Version (e.g., 1.0): irrelevant
  • Any other relevant information such as OS (e.g., Linux): irrelevant

Additional context

I would argue that in the case of accuracy the default being macro in the task-specific classes is not only inconsistent with legacy but actually wrong. The common deinition of accuracy is

$$ \mathrm{Acc}(\text{preds},\text{targets}) = \frac{1}{N}\sum_{i = 1}^{N} \left[ \begin{cases} 1 & \text{if preds}_i = \text{targets}_i \\ 0 & \text{otherwise} \end{cases} \right] $$

which is how accuracy is computed when setting average="micro".

Setting average="macro" can still be useful, as it is less prone to class imbalance. However, I think TorchMetrics should adhere to common definitions with the default settings, and would therefore argue for making micro the default.

The same is kind of true for precision and recall, which are also commonly defined as micro averages, if they are defined globally at all. Usually we encounter recall and precision as class-wise metrics.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions