Skip to content

Commit 29f38f1

Browse files
authored
Download model weights in parallel for prototype CI (#4772)
* enable caching of model weights for prototype CI * syntax * syntax * make cache dir dynamic * increase verbosity * fix * use larget CI machine * revert debug output * [DEBUG] test env var usage in save_cache * retry * use checksum for caching * remove env vars because expansion is not working * syntax * cleanup * base caching on model-urls * relax regex * cleanup skips * cleanup * fix skipping logic * improve step name * benchmark without caching * benchmark with external download * debug * fix manual download location * debug again * download weights in the background * try parallel download * add missing import * use correct decoractor * up resource_class * fix wording * enable stdout passthrough to see download during test * remove linebreak * move checkout up * cleanup * debug failing test * temp fix * fix * cleanup * fix regex * remove explicit install of numpy
1 parent cca1699 commit 29f38f1

File tree

5 files changed

+76
-13
lines changed

5 files changed

+76
-13
lines changed

.circleci/config.yml

Lines changed: 12 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.circleci/config.yml.in

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,14 +263,23 @@ jobs:
263263
prototype_test:
264264
docker:
265265
- image: circleci/python:3.7
266+
resource_class: xlarge
266267
steps:
267268
- run:
268269
name: Install torch
269-
command: pip install --user --progress-bar=off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
270+
command: |
271+
pip install --user --progress-bar=off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
270272
- run:
271273
name: Install prototype dependencies
272274
command: pip install --user --progress-bar=off git+https://github.com/pytorch/data.git
273275
- checkout
276+
- run:
277+
name: Download model weights
278+
background: true
279+
command: |
280+
sudo apt update -qy && sudo apt install -qy parallel wget
281+
python scripts/collect_model_urls.py torchvision/prototype/models \
282+
| parallel -j0 wget --no-verbose -P ~/.cache/torch/hub/checkpoints {}
274283
- run:
275284
name: Install torchvision
276285
command: pip install --user --progress-bar off --no-build-isolation .
@@ -279,6 +288,8 @@ jobs:
279288
command: pip install --user --progress-bar=off pytest pytest-mock scipy iopath
280289
- run:
281290
name: Run tests
291+
environment:
292+
PYTORCH_TEST_WITH_PROTOTYPE: 1
282293
command: pytest --junitxml=test-results/junit.xml -v --durations 20 test/test_prototype_*.py
283294
- store_test_results:
284295
path: test-results

scripts/collect_model_urls.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import pathlib
2+
import re
3+
import sys
4+
5+
MODEL_URL_PATTERN = re.compile(r"https://download[.]pytorch[.]org/models/.*?[.]pth")
6+
7+
8+
def main(root):
9+
model_urls = set()
10+
for path in pathlib.Path(root).glob("**/*"):
11+
if path.name.startswith("_") or not path.suffix == ".py":
12+
continue
13+
14+
with open(path, "r") as file:
15+
for line in file:
16+
model_urls.update(MODEL_URL_PATTERN.findall(line))
17+
18+
print("\n".join(sorted(model_urls)))
19+
20+
21+
if __name__ == "__main__":
22+
main(sys.argv[1])

test/common_utils.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,29 @@
44
import random
55
import shutil
66
import tempfile
7+
from distutils.util import strtobool
78

89
import numpy as np
10+
import pytest
911
import torch
1012
from PIL import Image
1113
from torchvision import io
1214

1315
import __main__ # noqa: 401
1416

1517

16-
IN_CIRCLE_CI = os.getenv("CIRCLECI", False) == "true"
17-
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
18-
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
18+
def get_bool_env_var(name, *, exist_ok=False, default=False):
19+
value = os.getenv(name)
20+
if value is None:
21+
return default
22+
if exist_ok:
23+
return True
24+
return bool(strtobool(value))
25+
26+
27+
IN_CIRCLE_CI = get_bool_env_var("CIRCLECI")
28+
IN_RE_WORKER = get_bool_env_var("INSIDE_RE_WORKER", exist_ok=True)
29+
IN_FBCODE = get_bool_env_var("IN_FBCODE_TORCHVISION")
1930
CUDA_NOT_AVAILABLE_MSG = "CUDA device not available"
2031
CIRCLECI_GPU_NO_CUDA_MSG = "We're in a CircleCI GPU machine, and this test doesn't need cuda."
2132

@@ -202,3 +213,7 @@ def _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwargs):
202213
# scriptable function test
203214
s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs)
204215
torch.testing.assert_close(transformed_batch, s_transformed_batch, rtol=1e-5, atol=scripted_fn_atol)
216+
217+
218+
def run_on_env_var(name, *, skip_reason=None, exist_ok=False, default=False):
219+
return pytest.mark.skipif(not get_bool_env_var(name, exist_ok=exist_ok, default=default), reason=skip_reason)

test/test_prototype_models.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import importlib
2-
import os
32

43
import pytest
54
import test_models as TM
65
import torch
7-
from common_utils import cpu_and_gpu
6+
from common_utils import cpu_and_gpu, run_on_env_var
87
from torchvision.prototype import models
98

9+
run_if_test_with_prototype = run_on_env_var(
10+
"PYTORCH_TEST_WITH_PROTOTYPE",
11+
skip_reason="Prototype tests are disabled by default. Set PYTORCH_TEST_WITH_PROTOTYPE=1 to run them.",
12+
)
13+
1014

1115
def _get_original_model(model_fn):
1216
original_module_name = model_fn.__module__.replace(".prototype", "")
@@ -48,34 +52,34 @@ def test_get_weight(model_fn, name, weight):
4852

4953
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models))
5054
@pytest.mark.parametrize("dev", cpu_and_gpu())
51-
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
55+
@run_if_test_with_prototype
5256
def test_classification_model(model_fn, dev):
5357
TM.test_classification_model(model_fn, dev)
5458

5559

5660
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.detection))
5761
@pytest.mark.parametrize("dev", cpu_and_gpu())
58-
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
62+
@run_if_test_with_prototype
5963
def test_detection_model(model_fn, dev):
6064
TM.test_detection_model(model_fn, dev)
6165

6266

6367
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.quantization))
64-
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
68+
@run_if_test_with_prototype
6569
def test_quantized_classification_model(model_fn):
6670
TM.test_quantized_classification_model(model_fn)
6771

6872

6973
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.segmentation))
7074
@pytest.mark.parametrize("dev", cpu_and_gpu())
71-
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
75+
@run_if_test_with_prototype
7276
def test_segmentation_model(model_fn, dev):
7377
TM.test_segmentation_model(model_fn, dev)
7478

7579

7680
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.video))
7781
@pytest.mark.parametrize("dev", cpu_and_gpu())
78-
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
82+
@run_if_test_with_prototype
7983
def test_video_model(model_fn, dev):
8084
TM.test_video_model(model_fn, dev)
8185

@@ -89,7 +93,7 @@ def test_video_model(model_fn, dev):
8993
+ get_models_with_module_names(models.video),
9094
)
9195
@pytest.mark.parametrize("dev", cpu_and_gpu())
92-
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
96+
@run_if_test_with_prototype
9397
def test_old_vs_new_factory(model_fn, module_name, dev):
9498
defaults = {
9599
"models": {

0 commit comments

Comments
 (0)