diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 82bbea5494b..62259a604a0 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -19,8 +19,8 @@ from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file from torch.nn.functional import one_hot from torch.testing import make_tensor as _make_tensor +from torchvision._utils import sequence_to_str from torchvision.prototype.datasets._api import find -from torchvision.prototype.utils._internal import sequence_to_str make_tensor = functools.partial(_make_tensor, device="cpu") make_scalar = functools.partial(make_tensor, ()) diff --git a/test/test_prototype_utils.py b/test/test_internal_utils.py similarity index 88% rename from test/test_prototype_utils.py rename to test/test_internal_utils.py index 712debb607a..f5f8a040db9 100644 --- a/test/test_prototype_utils.py +++ b/test/test_internal_utils.py @@ -1,5 +1,5 @@ import pytest -from torchvision.prototype.utils._internal import sequence_to_str +from torchvision._utils import sequence_to_str @pytest.mark.parametrize( diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 673158b00cd..f7c40d432a4 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -10,8 +10,8 @@ from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter from torch.utils.data.graph import traverse from torchdata.datapipes.iter import IterDataPipe, Shuffler +from torchvision._utils import sequence_to_str from torchvision.prototype import transforms, datasets -from torchvision.prototype.utils._internal import sequence_to_str assert_samples_equal = functools.partial( diff --git a/torchvision/_utils.py b/torchvision/_utils.py index da0eb923f75..8e8fe1b8a83 100644 --- a/torchvision/_utils.py +++ b/torchvision/_utils.py @@ -1,5 +1,5 @@ import enum -from typing import TypeVar, Type +from typing import Sequence, TypeVar, Type T = TypeVar("T", bound=enum.Enum) @@ -18,3 +18,15 @@ def from_str(self: Type[T], member: str) -> T: # type: ignore[misc] class StrEnum(enum.Enum, metaclass=StrEnumMeta): pass + + +def sequence_to_str(seq: Sequence, separate_last: str = "") -> str: + if not seq: + return "" + if len(seq) == 1: + return f"'{seq[0]}'" + + head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'" + tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'" + + return head + tail diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index 5ee7c5ccc60..b5c6d7acb60 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -7,7 +7,8 @@ from typing import Any, Dict, List, Optional, Sequence, Union, Tuple, Collection from torch.utils.data import IterDataPipe -from torchvision.prototype.utils._internal import FrozenBunch, make_repr, add_suggestion, sequence_to_str +from torchvision._utils import sequence_to_str +from torchvision.prototype.utils._internal import FrozenBunch, make_repr, add_suggestion from .._home import use_sharded_dataset from ._internal import BUILTIN_DIR, _make_sharded_datapipe diff --git a/torchvision/prototype/utils/_internal.py b/torchvision/prototype/utils/_internal.py index 864bff9ce02..147a7f0ff4c 100644 --- a/torchvision/prototype/utils/_internal.py +++ b/torchvision/prototype/utils/_internal.py @@ -28,9 +28,10 @@ import numpy as np import torch +from torchvision._utils import sequence_to_str + __all__ = [ - "sequence_to_str", "add_suggestion", "FrozenMapping", "make_repr", @@ -43,18 +44,6 @@ ] -def sequence_to_str(seq: Sequence, separate_last: str = "") -> str: - if not seq: - return "" - if len(seq) == 1: - return f"'{seq[0]}'" - - head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'" - tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'" - - return head + tail - - def add_suggestion( msg: str, *,