diff --git a/torchvision/prototype/features/_label.py b/torchvision/prototype/features/_label.py index c61419a61b6..b6b8588e478 100644 --- a/torchvision/prototype/features/_label.py +++ b/torchvision/prototype/features/_label.py @@ -1,9 +1,9 @@ from __future__ import annotations -from typing import Any, cast, Optional, Sequence, Union +from typing import Any, Optional, Sequence, Union import torch -from torchvision.prototype.utils._internal import apply_recursively +from torch.utils._pytree import tree_map from ._feature import _Feature @@ -43,10 +43,10 @@ def from_category( return cls(categories.index(category), categories=categories, **kwargs) def to_categories(self) -> Any: - if not self.categories: - raise RuntimeError() + if self.categories is None: + raise RuntimeError("Label does not have categories") - return apply_recursively(lambda idx: cast(Sequence[str], self.categories)[idx], self.tolist()) + return tree_map(lambda idx: self.categories[idx], self.tolist()) class OneHotLabel(_Feature):