Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/builtin_dataset_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file
from torch.nn.functional import one_hot
from torch.testing import make_tensor as _make_tensor
from torchvision._utils import sequence_to_str
from torchvision.prototype.datasets._api import find
from torchvision.prototype.utils._internal import sequence_to_str

make_tensor = functools.partial(_make_tensor, device="cpu")
make_scalar = functools.partial(make_tensor, ())
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from torchvision.prototype.utils._internal import sequence_to_str
from torchvision._utils import sequence_to_str


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion test/test_prototype_builtin_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter
from torch.utils.data.graph import traverse
from torchdata.datapipes.iter import IterDataPipe, Shuffler
from torchvision._utils import sequence_to_str
from torchvision.prototype import transforms, datasets
from torchvision.prototype.utils._internal import sequence_to_str


assert_samples_equal = functools.partial(
Expand Down
14 changes: 13 additions & 1 deletion torchvision/_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import enum
from typing import TypeVar, Type
from typing import Sequence, TypeVar, Type

T = TypeVar("T", bound=enum.Enum)

Expand All @@ -18,3 +18,15 @@ def from_str(self: Type[T], member: str) -> T: # type: ignore[misc]

class StrEnum(enum.Enum, metaclass=StrEnumMeta):
pass


def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
if not seq:
return ""
if len(seq) == 1:
return f"'{seq[0]}'"

head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'"
tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'"

return head + tail
3 changes: 2 additions & 1 deletion torchvision/prototype/datasets/utils/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from typing import Any, Dict, List, Optional, Sequence, Union, Tuple, Collection

from torch.utils.data import IterDataPipe
from torchvision.prototype.utils._internal import FrozenBunch, make_repr, add_suggestion, sequence_to_str
from torchvision._utils import sequence_to_str
from torchvision.prototype.utils._internal import FrozenBunch, make_repr, add_suggestion

from .._home import use_sharded_dataset
from ._internal import BUILTIN_DIR, _make_sharded_datapipe
Expand Down
15 changes: 2 additions & 13 deletions torchvision/prototype/utils/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@

import numpy as np
import torch
from torchvision._utils import sequence_to_str


__all__ = [
"sequence_to_str",
"add_suggestion",
"FrozenMapping",
"make_repr",
Expand All @@ -43,18 +44,6 @@
]


def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
if not seq:
return ""
if len(seq) == 1:
return f"'{seq[0]}'"

head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'"
tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'"

return head + tail


def add_suggestion(
msg: str,
*,
Expand Down