Skip to content

Commit 0adf189

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] [proto] Clean-up Label.to_categories (#6419)
Summary: * [proto] Clean-up Label.to_categories * Fixed flake8 Reviewed By: datumbox Differential Revision: D38824230 fbshipit-source-id: 4bb30daccc927e7c515105a1b4b8ab9f341cad8b
1 parent 9c64c1a commit 0adf189

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

torchvision/prototype/features/_label.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from __future__ import annotations
22

3-
from typing import Any, cast, Optional, Sequence, Union
3+
from typing import Any, Optional, Sequence, Union
44

55
import torch
6-
from torchvision.prototype.utils._internal import apply_recursively
6+
from torch.utils._pytree import tree_map
77

88
from ._feature import _Feature
99

@@ -43,10 +43,10 @@ def from_category(
4343
return cls(categories.index(category), categories=categories, **kwargs)
4444

4545
def to_categories(self) -> Any:
46-
if not self.categories:
47-
raise RuntimeError()
46+
if self.categories is None:
47+
raise RuntimeError("Label does not have categories")
4848

49-
return apply_recursively(lambda idx: cast(Sequence[str], self.categories)[idx], self.tolist())
49+
return tree_map(lambda idx: self.categories[idx], self.tolist())
5050

5151

5252
class OneHotLabel(_Feature):

0 commit comments

Comments
 (0)