Skip to content

Commit 6ba395d

Browse files
Isalia20Borda
andauthored
fix slow calculations of classification metrics (#2876)
* fix slow calculations of classification metrics * chlog --------- Co-authored-by: Jirka B <[email protected]>
1 parent 4d9c843 commit 6ba395d

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3939
- Fixed `top_k` for `multiclassf1score` with one-hot encoding ([#2839](https://github.com/Lightning-AI/torchmetrics/issues/2839))
4040

4141

42+
- Fixed slow calculations of classification metrics with MPS ([#2876](https://github.com/Lightning-AI/torchmetrics/issues/2876))
43+
4244
---
4345

4446
## [1.6.0] - 2024-11-12

src/torchmetrics/utilities/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor:
199199
if minlength is None:
200200
minlength = len(torch.unique(x))
201201

202-
if torch.are_deterministic_algorithms_enabled() or _XLA_AVAILABLE and x.is_mps:
202+
if torch.are_deterministic_algorithms_enabled() or _XLA_AVAILABLE or x.is_mps:
203203
mesh = torch.arange(minlength, device=x.device).repeat(len(x), 1)
204204
return torch.eq(x.reshape(-1, 1), mesh).sum(dim=0)
205205

0 commit comments

Comments
 (0)