Skip to content

Commit abc6c77

Browse files
authored
fix and add test for sequence_to_str (#5213)
* fix and add test for sequence_to_str * remove manual ids
1 parent afdf126 commit abc6c77

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

test/test_prototype_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import pytest
2+
from torchvision.prototype.utils._internal import sequence_to_str
3+
4+
5+
@pytest.mark.parametrize(
6+
("seq", "separate_last", "expected"),
7+
[
8+
([], "", ""),
9+
(["foo"], "", "'foo'"),
10+
(["foo", "bar"], "", "'foo', 'bar'"),
11+
(["foo", "bar"], "and ", "'foo' and 'bar'"),
12+
(["foo", "bar", "baz"], "", "'foo', 'bar', 'baz'"),
13+
(["foo", "bar", "baz"], "and ", "'foo', 'bar', and 'baz'"),
14+
],
15+
)
16+
def test_sequence_to_str(seq, separate_last, expected):
17+
assert sequence_to_str(seq, separate_last=separate_last) == expected

torchvision/prototype/utils/_internal.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,15 @@ class StrEnum(enum.Enum, metaclass=StrEnumMeta):
3030

3131

3232
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
33+
if not seq:
34+
return ""
3335
if len(seq) == 1:
3436
return f"'{seq[0]}'"
3537

36-
return f"""'{"', '".join([str(item) for item in seq[:-1]])}', {separate_last}'{seq[-1]}'"""
38+
head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'"
39+
tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'"
40+
41+
return head + tail
3742

3843

3944
def add_suggestion(

0 commit comments

Comments
 (0)