Skip to content

Commit d7490d1

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] Moving sequence_to_str to torchvision._utils (#5604)
Summary: * Moving `sequence_to_str` to `torchvision._utils` * Fix linter * Rename test_prototype_utils test to test_internal_utils Reviewed By: vmoens Differential Revision: D34878983 fbshipit-source-id: 02a5425a9d25056035d28307cf307e0c5353c94f
1 parent 7baeda8 commit d7490d1

File tree

6 files changed

+20
-18
lines changed

6 files changed

+20
-18
lines changed

test/builtin_dataset_mocks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file
2020
from torch.nn.functional import one_hot
2121
from torch.testing import make_tensor as _make_tensor
22+
from torchvision._utils import sequence_to_str
2223
from torchvision.prototype.datasets._api import find
23-
from torchvision.prototype.utils._internal import sequence_to_str
2424

2525
make_tensor = functools.partial(_make_tensor, device="cpu")
2626
make_scalar = functools.partial(make_tensor, ())

test/test_prototype_utils.py renamed to test/test_internal_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pytest
2-
from torchvision.prototype.utils._internal import sequence_to_str
2+
from torchvision._utils import sequence_to_str
33

44

55
@pytest.mark.parametrize(

test/test_prototype_builtin_datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter
1111
from torch.utils.data.graph import traverse
1212
from torchdata.datapipes.iter import IterDataPipe, Shuffler
13+
from torchvision._utils import sequence_to_str
1314
from torchvision.prototype import transforms, datasets
14-
from torchvision.prototype.utils._internal import sequence_to_str
1515

1616

1717
assert_samples_equal = functools.partial(

torchvision/_utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import enum
2-
from typing import TypeVar, Type
2+
from typing import Sequence, TypeVar, Type
33

44
T = TypeVar("T", bound=enum.Enum)
55

@@ -18,3 +18,15 @@ def from_str(self: Type[T], member: str) -> T: # type: ignore[misc]
1818

1919
class StrEnum(enum.Enum, metaclass=StrEnumMeta):
2020
pass
21+
22+
23+
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
24+
if not seq:
25+
return ""
26+
if len(seq) == 1:
27+
return f"'{seq[0]}'"
28+
29+
head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'"
30+
tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'"
31+
32+
return head + tail

torchvision/prototype/datasets/utils/_dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from typing import Any, Dict, List, Optional, Sequence, Union, Tuple, Collection
88

99
from torch.utils.data import IterDataPipe
10-
from torchvision.prototype.utils._internal import FrozenBunch, make_repr, add_suggestion, sequence_to_str
10+
from torchvision._utils import sequence_to_str
11+
from torchvision.prototype.utils._internal import FrozenBunch, make_repr, add_suggestion
1112

1213
from .._home import use_sharded_dataset
1314
from ._internal import BUILTIN_DIR, _make_sharded_datapipe

torchvision/prototype/utils/_internal.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@
2828

2929
import numpy as np
3030
import torch
31+
from torchvision._utils import sequence_to_str
32+
3133

3234
__all__ = [
33-
"sequence_to_str",
3435
"add_suggestion",
3536
"FrozenMapping",
3637
"make_repr",
@@ -43,18 +44,6 @@
4344
]
4445

4546

46-
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
47-
if not seq:
48-
return ""
49-
if len(seq) == 1:
50-
return f"'{seq[0]}'"
51-
52-
head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'"
53-
tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'"
54-
55-
return head + tail
56-
57-
5847
def add_suggestion(
5948
msg: str,
6049
*,

0 commit comments

Comments
 (0)