Skip to content

Commit e1679ae

Browse files
simonreisepre-commit-ci[bot]SkafteNickiCopilot
authored
Add "mixed" input format to segmentation metrics (#3176)
* Add `"mixed"` input format for segmentation metrics * update changelog * better naming of inputs * Fixed mean_iou test * Fixed another test --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Nicki Skafte <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 81b34de commit e1679ae

File tree

15 files changed

+217
-81
lines changed

15 files changed

+217
-81
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3232
- Added `Lip Vertex Error (LVE)` in multimodal domain ([3090](https://github.com/Lightning-AI/torchmetrics/pull/3090))
3333

3434

35+
- Added `mixed` input format to segmentation metrics ([3176](https://github.com/Lightning-AI/torchmetrics/pull/3176))
36+
3537
### Changed
3638

3739
-

src/torchmetrics/functional/segmentation/dice.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from torch import Tensor
1818
from typing_extensions import Literal
1919

20-
from torchmetrics.functional.segmentation.utils import _ignore_background
20+
from torchmetrics.functional.segmentation.utils import _check_mixed_shape, _ignore_background
2121
from torchmetrics.utilities import rank_zero_warn
2222
from torchmetrics.utilities.checks import _check_same_shape
2323
from torchmetrics.utilities.compute import _safe_divide
@@ -27,7 +27,7 @@ def _dice_score_validate_args(
2727
num_classes: int,
2828
include_background: bool,
2929
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
30-
input_format: Literal["one-hot", "index"] = "one-hot",
30+
input_format: Literal["one-hot", "index", "mixed"] = "one-hot",
3131
aggregation_level: Optional[Literal["samplewise", "global"]] = "samplewise",
3232
) -> None:
3333
"""Validate the arguments of the metric."""
@@ -38,8 +38,10 @@ def _dice_score_validate_args(
3838
allowed_average = ["micro", "macro", "weighted", "none"]
3939
if average is not None and average not in allowed_average:
4040
raise ValueError(f"Expected argument `average` to be one of {allowed_average} or None, but got {average}.")
41-
if input_format not in ["one-hot", "index"]:
42-
raise ValueError(f"Expected argument `input_format` to be one of 'one-hot', 'index', but got {input_format}.")
41+
if input_format not in ["one-hot", "index", "mixed"]:
42+
raise ValueError(
43+
f"Expected argument `input_format` to be one of 'one-hot', 'index', 'mixed', but got {input_format}."
44+
)
4345
if aggregation_level not in ("samplewise", "global"):
4446
raise ValueError(
4547
f"Expected argument `aggregation_level` to be one of `samplewise`, `global`, but got {aggregation_level}"
@@ -51,14 +53,22 @@ def _dice_score_update(
5153
target: Tensor,
5254
num_classes: int,
5355
include_background: bool,
54-
input_format: Literal["one-hot", "index"] = "one-hot",
56+
input_format: Literal["one-hot", "index", "mixed"] = "one-hot",
5557
) -> tuple[Tensor, Tensor, Tensor]:
5658
"""Update the state with the current prediction and target."""
57-
_check_same_shape(preds, target)
59+
if input_format == "mixed":
60+
_check_mixed_shape(preds, target)
61+
else:
62+
_check_same_shape(preds, target)
5863

5964
if input_format == "index":
6065
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
6166
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)
67+
elif input_format == "mixed":
68+
if preds.dim() == (target.dim() + 1):
69+
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)
70+
elif (preds.dim() + 1) == target.dim():
71+
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
6272

6373
if preds.ndim < 3:
6474
raise ValueError(f"Expected both `preds` and `target` to have at least 3 dimensions, but got {preds.ndim}.")
@@ -117,7 +127,7 @@ def dice_score(
117127
num_classes: int,
118128
include_background: bool = True,
119129
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
120-
input_format: Literal["one-hot", "index"] = "one-hot",
130+
input_format: Literal["one-hot", "index", "mixed"] = "one-hot",
121131
aggregation_level: Optional[Literal["samplewise", "global"]] = "samplewise",
122132
) -> Tensor:
123133
"""Compute the Dice score for semantic segmentation.
@@ -128,9 +138,10 @@ def dice_score(
128138
num_classes: Number of classes
129139
include_background: Whether to include the background class in the computation
130140
average: The method to average the dice score. Options are ``"micro"``, ``"macro"``, ``"weighted"``, ``"none"``
131-
or ``None``. This determines how to average the dice score across different classes.
132-
input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors
133-
or ``"index"`` for index tensors
141+
or ``None``. This determines how to average the dice score across different classes.
142+
input_format: What kind of input the function receives.
143+
Choose between ``"one-hot"`` for one-hot encoded tensors, ``"index"`` for index tensors
144+
or ``"mixed"`` for one one-hot encoded and one index tensor
134145
aggregation_level: The level at which to aggregate the dice score. Options are ``"samplewise"`` or ``"global"``.
135146
For ``"samplewise"`` the dice score is computed for each sample and then averaged. For ``"global"`` the dice
136147
score is computed globally over all samples.

src/torchmetrics/functional/segmentation/generalized_dice.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from torch import Tensor
1818
from typing_extensions import Literal
1919

20-
from torchmetrics.functional.segmentation.utils import _ignore_background
20+
from torchmetrics.functional.segmentation.utils import _check_mixed_shape, _ignore_background
2121
from torchmetrics.utilities.checks import _check_same_shape
2222
from torchmetrics.utilities.compute import _safe_divide
2323

@@ -27,7 +27,7 @@ def _generalized_dice_validate_args(
2727
include_background: bool,
2828
per_class: bool,
2929
weight_type: Literal["square", "simple", "linear"],
30-
input_format: Literal["one-hot", "index"],
30+
input_format: Literal["one-hot", "index", "mixed"],
3131
) -> None:
3232
"""Validate the arguments of the metric."""
3333
if not isinstance(num_classes, int) or num_classes <= 0:
@@ -40,8 +40,10 @@ def _generalized_dice_validate_args(
4040
raise ValueError(
4141
f"Expected argument `weight_type` to be one of 'square', 'simple', 'linear', but got {weight_type}."
4242
)
43-
if input_format not in ["one-hot", "index"]:
44-
raise ValueError(f"Expected argument `input_format` to be one of 'one-hot', 'index', but got {input_format}.")
43+
if input_format not in ["one-hot", "index", "mixed"]:
44+
raise ValueError(
45+
f"Expected argument `input_format` to be one of 'one-hot', 'index', 'mixed', but got {input_format}."
46+
)
4547

4648

4749
def _generalized_dice_update(
@@ -50,14 +52,22 @@ def _generalized_dice_update(
5052
num_classes: int,
5153
include_background: bool,
5254
weight_type: Literal["square", "simple", "linear"] = "square",
53-
input_format: Literal["one-hot", "index"] = "one-hot",
55+
input_format: Literal["one-hot", "index", "mixed"] = "one-hot",
5456
) -> Tuple[Tensor, Tensor]:
5557
"""Update the state with the current prediction and target."""
56-
_check_same_shape(preds, target)
58+
if input_format == "mixed":
59+
_check_mixed_shape(preds, target)
60+
else:
61+
_check_same_shape(preds, target)
5762

5863
if input_format == "index":
5964
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
6065
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)
66+
elif input_format == "mixed":
67+
if preds.dim() == (target.dim() + 1):
68+
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)
69+
elif (preds.dim() + 1) == target.dim():
70+
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
6171

6272
if preds.ndim < 3:
6373
raise ValueError(f"Expected both `preds` and `target` to have at least 3 dimensions, but got {preds.ndim}.")
@@ -109,7 +119,7 @@ def generalized_dice_score(
109119
include_background: bool = True,
110120
per_class: bool = False,
111121
weight_type: Literal["square", "simple", "linear"] = "square",
112-
input_format: Literal["one-hot", "index"] = "one-hot",
122+
input_format: Literal["one-hot", "index", "mixed"] = "one-hot",
113123
) -> Tensor:
114124
"""Compute the Generalized Dice Score for semantic segmentation.
115125
@@ -120,8 +130,9 @@ def generalized_dice_score(
120130
include_background: Whether to include the background class in the computation
121131
per_class: Whether to compute the score for each class separately, else average over all classes
122132
weight_type: Type of weight factor to apply to the classes. One of ``"square"``, ``"simple"``, or ``"linear"``
123-
input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors
124-
or ``"index"`` for index tensors
133+
input_format: What kind of input the function receives.
134+
Choose between ``"one-hot"`` for one-hot encoded tensors, ``"index"`` for index tensors
135+
or ``"mixed"`` for one one-hot encoded and one index tensor
125136
126137
Returns:
127138
The Generalized Dice Score

src/torchmetrics/functional/segmentation/hausdorff_distance.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torch import Tensor
1919

2020
from torchmetrics.functional.segmentation.utils import (
21+
_check_mixed_shape,
2122
_ignore_background,
2223
edge_surface_distance,
2324
)
@@ -30,7 +31,7 @@ def _hausdorff_distance_validate_args(
3031
distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean",
3132
spacing: Optional[Union[Tensor, list[float]]] = None,
3233
directed: bool = False,
33-
input_format: Literal["one-hot", "index"] = "one-hot",
34+
input_format: Literal["one-hot", "index", "mixed"] = "one-hot",
3435
) -> None:
3536
"""Validate the arguments of `hausdorff_distance` function."""
3637
if num_classes <= 0:
@@ -45,8 +46,10 @@ def _hausdorff_distance_validate_args(
4546
raise ValueError(f"Arg `spacing` must be a list or tensor, but got {type(spacing)}.")
4647
if not isinstance(directed, bool):
4748
raise ValueError(f"Expected argument `directed` must be a boolean, but got {directed}.")
48-
if input_format not in ["one-hot", "index"]:
49-
raise ValueError(f"Expected argument `input_format` to be one of 'one-hot', 'index', but got {input_format}.")
49+
if input_format not in ["one-hot", "index", "mixed"]:
50+
raise ValueError(
51+
f"Expected argument `input_format` to be one of 'one-hot', 'index', 'mixed', but got {input_format}."
52+
)
5053

5154

5255
def hausdorff_distance(
@@ -57,7 +60,7 @@ def hausdorff_distance(
5760
distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean",
5861
spacing: Optional[Union[Tensor, list[float]]] = None,
5962
directed: bool = False,
60-
input_format: Literal["one-hot", "index"] = "one-hot",
63+
input_format: Literal["one-hot", "index", "mixed"] = "one-hot",
6164
) -> Tensor:
6265
"""Calculate `Hausdorff Distance`_ for semantic segmentation.
6366
@@ -70,8 +73,9 @@ def hausdorff_distance(
7073
`"chessboard"` or `"taxicab"`
7174
spacing: spacing between pixels along each spatial dimension. If not provided the spacing is assumed to be 1
7275
directed: whether to calculate directed or undirected Hausdorff distance
73-
input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors
74-
or ``"index"`` for index tensors
76+
input_format: What kind of input the function receives.
77+
Choose between ``"one-hot"`` for one-hot encoded tensors, ``"index"`` for index tensors
78+
or ``"mixed"`` for one one-hot encoded and one index tensor
7579
7680
Returns:
7781
Hausdorff Distance for each class and batch element
@@ -89,11 +93,19 @@ def hausdorff_distance(
8993
9094
"""
9195
_hausdorff_distance_validate_args(num_classes, include_background, distance_metric, spacing, directed, input_format)
92-
_check_same_shape(preds, target)
96+
if input_format == "mixed":
97+
_check_mixed_shape(preds, target)
98+
else:
99+
_check_same_shape(preds, target)
93100

94101
if input_format == "index":
95102
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
96103
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)
104+
elif input_format == "mixed":
105+
if preds.dim() == (target.dim() + 1):
106+
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)
107+
elif (preds.dim() + 1) == target.dim():
108+
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
97109

98110
if not include_background:
99111
preds, target = _ignore_background(preds, target)

src/torchmetrics/functional/segmentation/mean_iou.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@
1818
from torch import Tensor
1919
from typing_extensions import Literal
2020

21-
from torchmetrics.functional.segmentation.utils import _ignore_background
21+
from torchmetrics.functional.segmentation.utils import _check_mixed_shape, _ignore_background
2222
from torchmetrics.utilities.checks import _check_same_shape
2323
from torchmetrics.utilities.compute import _safe_divide
2424

2525

2626
def _mean_iou_reshape_args(
2727
preds: Tensor,
2828
targets: Tensor,
29-
input_format: Literal["one-hot", "index"] = "one-hot",
29+
input_format: Literal["one-hot", "index", "mixed"] = "one-hot",
3030
) -> Tuple[Tensor, Tensor]:
3131
"""Reshape tensors to 3D if needed."""
3232
if input_format == "one-hot":
@@ -48,11 +48,11 @@ def _mean_iou_validate_args(
4848
num_classes: Optional[int],
4949
include_background: bool,
5050
per_class: bool,
51-
input_format: Literal["one-hot", "index"] = "one-hot",
51+
input_format: Literal["one-hot", "index", "mixed"] = "one-hot",
5252
) -> None:
5353
"""Validate the arguments of the metric."""
54-
if input_format == "index" and num_classes is None:
55-
raise ValueError("Argument `num_classes` must be provided when `input_format='index'`.")
54+
if input_format in ["index", "mixed"] and num_classes is None:
55+
raise ValueError("Argument `num_classes` must be provided when `input_format` is 'index' or 'mixed'.")
5656
if num_classes is not None and num_classes <= 0:
5757
raise ValueError(
5858
f"Expected argument `num_classes` must be `None` or a positive integer, but got {num_classes}."
@@ -61,20 +61,25 @@ def _mean_iou_validate_args(
6161
raise ValueError(f"Expected argument `include_background` must be a boolean, but got {include_background}.")
6262
if not isinstance(per_class, bool):
6363
raise ValueError(f"Expected argument `per_class` must be a boolean, but got {per_class}.")
64-
if input_format not in ["one-hot", "index"]:
65-
raise ValueError(f"Expected argument `input_format` to be one of 'one-hot', 'index', but got {input_format}.")
64+
if input_format not in ["one-hot", "index", "mixed"]:
65+
raise ValueError(
66+
f"Expected argument `input_format` to be one of 'one-hot', 'index', 'mixed', but got {input_format}."
67+
)
6668

6769

6870
def _mean_iou_update(
6971
preds: Tensor,
7072
target: Tensor,
7173
num_classes: Optional[int] = None,
7274
include_background: bool = False,
73-
input_format: Literal["one-hot", "index"] = "one-hot",
75+
input_format: Literal["one-hot", "index", "mixed"] = "one-hot",
7476
) -> tuple[Tensor, Tensor]:
7577
"""Update the intersection and union counts for the mean IoU computation."""
7678
preds, target = _mean_iou_reshape_args(preds, target, input_format)
77-
_check_same_shape(preds, target)
79+
if input_format == "mixed":
80+
_check_mixed_shape(preds, target)
81+
else:
82+
_check_same_shape(preds, target)
7883

7984
if input_format == "index":
8085
if num_classes is None:
@@ -88,6 +93,13 @@ def _mean_iou_update(
8893
raise IndexError(f"Cannot determine `num_classes` from `preds` tensor: {preds}.") from err
8994
if num_classes == 0:
9095
raise ValueError(f"Expected argument `num_classes` to be a positive integer, but got {num_classes}.")
96+
elif input_format == "mixed":
97+
if num_classes is None:
98+
raise ValueError("Argument `num_classes` must be provided when `input_format='mixed'`.")
99+
if preds.dim() == (target.dim() + 1):
100+
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)
101+
elif (preds.dim() + 1) == target.dim():
102+
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
91103

92104
if not include_background:
93105
preds, target = _ignore_background(preds, target)
@@ -115,7 +127,7 @@ def mean_iou(
115127
num_classes: Optional[int] = None,
116128
include_background: bool = True,
117129
per_class: bool = False,
118-
input_format: Literal["one-hot", "index"] = "one-hot",
130+
input_format: Literal["one-hot", "index", "mixed"] = "one-hot",
119131
) -> Tensor:
120132
"""Calculates the mean Intersection over Union (mIoU) for semantic segmentation.
121133
@@ -127,8 +139,9 @@ def mean_iou(
127139
num_classes: Number of classes (required when input_format="index", optional when input_format="one-hot")
128140
include_background: Whether to include the background class in the computation
129141
per_class: Whether to compute the IoU for each class separately, else average over all classes
130-
input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors
131-
or ``"index"`` for index tensors
142+
input_format: What kind of input the function receives.
143+
Choose between ``"one-hot"`` for one-hot encoded tensors, ``"index"`` for index tensors
144+
or ``"mixed"`` for one one-hot encoded and one index tensor
132145
133146
Returns:
134147
The mean IoU score

src/torchmetrics/functional/segmentation/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,24 @@ def _ignore_background(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]:
3131
return preds, target
3232

3333

34+
def _check_mixed_shape(preds: Tensor, target: Tensor) -> None:
35+
"""Check that predictions and target have the same shape, else raise error."""
36+
if preds.dim() == (target.dim() + 1):
37+
if preds.shape[0] != target.shape[0] or preds.shape[2:] != target.shape[1:]:
38+
raise RuntimeError(
39+
f"Predictions and targets are expected to have the same shape, got {preds.shape} and {target.shape}."
40+
)
41+
elif (preds.dim() + 1) == target.dim():
42+
if preds.shape[0] != target.shape[0] or preds.shape[1:] != target.shape[2:]:
43+
raise RuntimeError(
44+
f"Predictions and targets are expected to have the same shape, got {preds.shape} and {target.shape}."
45+
)
46+
else:
47+
raise RuntimeError(
48+
f"Predictions and targets are expected to have the same shape, got {preds.shape} and {target.shape}."
49+
)
50+
51+
3452
def check_if_binarized(x: Tensor) -> None:
3553
"""Check if tensor is binarized.
3654

src/torchmetrics/segmentation/dice.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,9 @@ class DiceScore(Metric):
7070
aggregation_level: The level at which to aggregate the dice score. Options are ``"samplewise"`` or ``"global"``.
7171
For ``"samplewise"`` the dice score is computed for each sample and then averaged. For ``"global"`` the dice
7272
score is computed globally over all samples.
73-
input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors
74-
or ``"index"`` for index tensors.
73+
input_format: What kind of input the function receives.
74+
Choose between ``"one-hot"`` for one-hot encoded tensors, ``"index"`` for index tensors
75+
or ``"mixed"`` for one one-hot encoded and one index tensor
7576
zero_division: The value to return when there is a division by zero. Options are 1.0, 0.0, "warn" or "nan".
7677
Setting it to "warn" behaves like 0.0 but will also create a warning.
7778
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
@@ -84,7 +85,7 @@ class DiceScore(Metric):
8485
ValueError:
8586
If ``average`` is not one of ``"micro"``, ``"macro"``, ``"weighted"``, ``"none"`` or ``None``
8687
ValueError:
87-
If ``input_format`` is not one of ``"one-hot"`` or ``"index"``
88+
If ``input_format`` is not one of ``"one-hot"``, ``"index"`` or ``"mixed"``
8889
8990
Example:
9091
>>> from torch import randint
@@ -116,7 +117,7 @@ def __init__(
116117
include_background: bool = True,
117118
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
118119
aggregation_level: Optional[Literal["samplewise", "global"]] = "samplewise",
119-
input_format: Literal["one-hot", "index"] = "one-hot",
120+
input_format: Literal["one-hot", "index", "mixed"] = "one-hot",
120121
**kwargs: Any,
121122
) -> None:
122123
super().__init__(**kwargs)

0 commit comments

Comments
 (0)