Skip to content

Commit b322765

Browse files
authored
Merge branch 'main' into vit_h_14
2 parents cd46955 + abc6c77 commit b322765

File tree

4 files changed

+31
-23
lines changed

4 files changed

+31
-23
lines changed

test/common_utils.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,18 @@
44
import random
55
import shutil
66
import tempfile
7-
from distutils.util import strtobool
87

98
import numpy as np
10-
import pytest
119
import torch
1210
from PIL import Image
1311
from torchvision import io
1412

1513
import __main__ # noqa: 401
1614

1715

18-
def get_bool_env_var(name, *, exist_ok=False, default=False):
19-
value = os.getenv(name)
20-
if value is None:
21-
return default
22-
if exist_ok:
23-
return True
24-
return bool(strtobool(value))
25-
26-
27-
IN_CIRCLE_CI = get_bool_env_var("CIRCLECI")
28-
IN_RE_WORKER = get_bool_env_var("INSIDE_RE_WORKER", exist_ok=True)
29-
IN_FBCODE = get_bool_env_var("IN_FBCODE_TORCHVISION")
16+
IN_CIRCLE_CI = os.getenv("CIRCLECI", False) == "true"
17+
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
18+
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
3019
CUDA_NOT_AVAILABLE_MSG = "CUDA device not available"
3120
CIRCLECI_GPU_NO_CUDA_MSG = "We're in a CircleCI GPU machine, and this test doesn't need cuda."
3221

@@ -213,7 +202,3 @@ def _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwargs):
213202
# scriptable function test
214203
s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs)
215204
torch.testing.assert_close(transformed_batch, s_transformed_batch, rtol=1e-5, atol=scripted_fn_atol)
216-
217-
218-
def run_on_env_var(name, *, skip_reason=None, exist_ok=False, default=False):
219-
return pytest.mark.skipif(not get_bool_env_var(name, exist_ok=exist_ok, default=default), reason=skip_reason)

test/test_prototype_models.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
import importlib
2+
import os
23

34
import pytest
45
import test_models as TM
56
import torch
6-
from common_utils import cpu_and_gpu, run_on_env_var, needs_cuda
7+
from common_utils import cpu_and_gpu, needs_cuda
78
from torchvision.prototype import models
89
from torchvision.prototype.models._api import WeightsEnum, Weights
910
from torchvision.prototype.models._utils import handle_legacy_interface
1011

11-
run_if_test_with_prototype = run_on_env_var(
12-
"PYTORCH_TEST_WITH_PROTOTYPE",
13-
skip_reason="Prototype tests are disabled by default. Set PYTORCH_TEST_WITH_PROTOTYPE=1 to run them.",
12+
run_if_test_with_prototype = pytest.mark.skipif(
13+
os.getenv("PYTORCH_TEST_WITH_PROTOTYPE") != "1",
14+
reason="Prototype tests are disabled by default. Set PYTORCH_TEST_WITH_PROTOTYPE=1 to run them.",
1415
)
1516

1617

test/test_prototype_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import pytest
2+
from torchvision.prototype.utils._internal import sequence_to_str
3+
4+
5+
@pytest.mark.parametrize(
6+
("seq", "separate_last", "expected"),
7+
[
8+
([], "", ""),
9+
(["foo"], "", "'foo'"),
10+
(["foo", "bar"], "", "'foo', 'bar'"),
11+
(["foo", "bar"], "and ", "'foo' and 'bar'"),
12+
(["foo", "bar", "baz"], "", "'foo', 'bar', 'baz'"),
13+
(["foo", "bar", "baz"], "and ", "'foo', 'bar', and 'baz'"),
14+
],
15+
)
16+
def test_sequence_to_str(seq, separate_last, expected):
17+
assert sequence_to_str(seq, separate_last=separate_last) == expected

torchvision/prototype/utils/_internal.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,15 @@ class StrEnum(enum.Enum, metaclass=StrEnumMeta):
3030

3131

3232
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
33+
if not seq:
34+
return ""
3335
if len(seq) == 1:
3436
return f"'{seq[0]}'"
3537

36-
return f"""'{"', '".join([str(item) for item in seq[:-1]])}', {separate_last}'{seq[-1]}'"""
38+
head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'"
39+
tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'"
40+
41+
return head + tail
3742

3843

3944
def add_suggestion(

0 commit comments

Comments
 (0)