Skip to content

Commit 898802f

Browse files
pmeierfmassa
andauthored
fix scheduled download tests (#2706)
* fix triggers for scheduled workflow * more fix * add missing repository checkout * try fix label in template * rewrite test infrastructure * trigger issue generation * try fix issue template * try remove quotes * remove buggy label * try fix title * cleanup * add more test details * reenable issue creation Co-authored-by: Francisco Massa <[email protected]>
1 parent 9e7a4b1 commit 898802f

File tree

3 files changed

+98
-86
lines changed

3 files changed

+98
-86
lines changed

.github/failed_schedule_issue_template.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
---
2-
title: Scheduled workflow {{ env.WORKFLOW }}/{{ env.JOB }} failed
3-
labels: bug, module: datasets
2+
title: Scheduled workflow failed
3+
labels:
4+
- bug
5+
- "module: datasets"
46
---
57

68
Oh no, something went wrong in the scheduled workflow {{ env.WORKFLOW }}/{{ env.JOB }}.

.github/workflows/tests-schedule.yml

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ name: tests
22

33
on:
44
pull_request:
5-
- "test/test_datasets_download.py"
6-
- ".github/failed_schedule_issue_template.md"
7-
- ".github/workflows/tests-schedule.yml"
5+
paths:
6+
- "test/test_datasets_download.py"
7+
- ".github/failed_schedule_issue_template.md"
8+
- ".github/workflows/tests-schedule.yml"
89

910
schedule:
1011
- cron: "0 9 * * *"
@@ -22,20 +23,23 @@ jobs:
2223
- name: Upgrade pip
2324
run: python -m pip install --upgrade pip
2425

26+
- name: Checkout repository
27+
uses: actions/checkout@v2
28+
2529
- name: Install PyTorch from the nightlies
2630
run: |
2731
pip install numpy
2832
pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
2933
3034
- name: Install tests requirements
31-
run: pip install pytest pytest-subtests
35+
run: pip install pytest
3236

3337
- name: Run tests
34-
run: pytest test/test_datasets_download.py
38+
run: pytest --durations=20 -ra test/test_datasets_download.py
3539

3640
- uses: JasonEtco/[email protected]
3741
name: Create issue if download tests failed
38-
if: failure()
42+
if: failure() && github.event_name == 'schedule'
3943
env:
4044
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
4145
REPO: ${{ github.repository }}

test/test_datasets_download.py

Lines changed: 84 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import contextlib
22
import itertools
33
import time
4-
import unittest
54
import unittest.mock
65
from datetime import datetime
76
from os import path
87
from urllib.parse import urlparse
98
from urllib.request import urlopen, Request
109

10+
import pytest
11+
1112
from torchvision import datasets
1213
from torchvision.datasets.utils import download_url, check_integrity
1314

@@ -43,89 +44,94 @@ def inner_wrapper(request, *args, **kwargs):
4344
urlopen = limit_requests_per_time()(urlopen)
4445

4546

46-
class DownloadTester(unittest.TestCase):
47-
@staticmethod
48-
@contextlib.contextmanager
49-
def log_download_attempts(patch=True):
50-
urls_and_md5s = set()
51-
with unittest.mock.patch(
52-
"torchvision.datasets.utils.download_url", wraps=None if patch else download_url
53-
) as mock:
54-
try:
55-
yield urls_and_md5s
56-
finally:
57-
for args, kwargs in mock.call_args_list:
58-
url = args[0]
59-
md5 = args[-1] if len(args) == 4 else kwargs.get("md5")
60-
urls_and_md5s.add((url, md5))
61-
62-
@staticmethod
63-
def retry(fn, times=1, wait=5.0):
64-
msgs = []
65-
for _ in range(times + 1):
66-
try:
67-
return fn()
68-
except AssertionError as error:
69-
msgs.append(str(error))
70-
time.sleep(wait)
71-
else:
72-
raise AssertionError(
73-
"\n".join(
74-
(
75-
f"Assertion failed {times + 1} times with {wait:.1f} seconds intermediate wait time.\n",
76-
*(f"{idx}: {error}" for idx, error in enumerate(msgs, 1)),
77-
)
47+
@contextlib.contextmanager
48+
def log_download_attempts(patch=True):
49+
urls_and_md5s = set()
50+
with unittest.mock.patch("torchvision.datasets.utils.download_url", wraps=None if patch else download_url) as mock:
51+
try:
52+
yield urls_and_md5s
53+
finally:
54+
for args, kwargs in mock.call_args_list:
55+
url = args[0]
56+
md5 = args[-1] if len(args) == 4 else kwargs.get("md5")
57+
urls_and_md5s.add((url, md5))
58+
59+
60+
def retry(fn, times=1, wait=5.0):
61+
msgs = []
62+
for _ in range(times + 1):
63+
try:
64+
return fn()
65+
except AssertionError as error:
66+
msgs.append(str(error))
67+
time.sleep(wait)
68+
else:
69+
raise AssertionError(
70+
"\n".join(
71+
(
72+
f"Assertion failed {times + 1} times with {wait:.1f} seconds intermediate wait time.\n",
73+
*(f"{idx}: {error}" for idx, error in enumerate(msgs, 1)),
7874
)
7975
)
80-
81-
@staticmethod
82-
def assert_response_ok(response, url=None, ok=200):
83-
msg = f"The server returned status code {response.code}"
84-
if url is not None:
85-
msg += f"for the the URL {url}"
86-
assert response.code == ok, msg
87-
88-
@staticmethod
89-
def assert_is_downloadable(url):
90-
request = Request(url, headers=dict(method="HEAD"))
91-
response = urlopen(request)
92-
DownloadTester.assert_response_ok(response, url)
93-
94-
@staticmethod
95-
def assert_downloads_correctly(url, md5):
96-
with get_tmp_dir() as root:
97-
file = path.join(root, path.basename(url))
98-
with urlopen(url) as response, open(file, "wb") as fh:
99-
DownloadTester.assert_response_ok(response, url)
100-
fh.write(response.read())
101-
102-
assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"
103-
104-
def test_download(self):
105-
assert_fn = (
106-
lambda url, _: self.assert_is_downloadable(url)
107-
if self.only_test_downloadability
108-
else self.assert_downloads_correctly
10976
)
110-
for url, md5 in self.collect_urls_and_md5s():
111-
with self.subTest(url=url, md5=md5):
112-
self.retry(lambda: assert_fn(url, md5))
11377

114-
def collect_urls_and_md5s(self):
115-
raise NotImplementedError
11678

117-
@property
118-
def only_test_downloadability(self):
119-
return True
79+
def assert_server_response_ok(response, url=None):
80+
msg = f"The server returned status code {response.code}"
81+
if url is not None:
82+
msg += f"for the the URL {url}"
83+
assert 200 <= response.code < 300, msg
84+
85+
86+
def assert_url_is_accessible(url):
87+
request = Request(url, headers=dict(method="HEAD"))
88+
response = urlopen(request)
89+
assert_server_response_ok(response, url)
90+
91+
92+
def assert_file_downloads_correctly(url, md5):
93+
with get_tmp_dir() as root:
94+
file = path.join(root, path.basename(url))
95+
with urlopen(url) as response, open(file, "wb") as fh:
96+
assert_server_response_ok(response, url)
97+
fh.write(response.read())
98+
99+
assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"
100+
101+
102+
class DownloadConfig:
103+
def __init__(self, url, md5=None, id=None):
104+
self.url = url
105+
self.md5 = md5
106+
self.id = id or url
107+
108+
109+
def make_parametrize_kwargs(download_configs):
110+
argvalues = []
111+
ids = []
112+
for config in download_configs:
113+
argvalues.append((config.url, config.md5))
114+
ids.append(config.id)
115+
116+
return dict(argnames="url, md5", argvalues=argvalues, ids=ids)
117+
118+
119+
def places365():
120+
with log_download_attempts(patch=False) as urls_and_md5s:
121+
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
122+
with places365_root(split=split, small=small) as places365:
123+
root, data = places365
124+
125+
datasets.Places365(root, split=split, small=small, download=True)
126+
127+
return [DownloadConfig(url, md5=md5, id=f"Places365, {url}") for url, md5 in urls_and_md5s]
120128

121129

122-
class Places365Tester(DownloadTester):
123-
def collect_urls_and_md5s(self):
124-
with self.log_download_attempts(patch=False) as urls_and_md5s:
125-
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
126-
with places365_root(split=split, small=small) as places365:
127-
root, data = places365
130+
@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain(places365(),)))
131+
def test_url_is_accessible(url, md5):
132+
retry(lambda: assert_url_is_accessible(url))
128133

129-
datasets.Places365(root, split=split, small=small, download=True)
130134

131-
return urls_and_md5s
135+
@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain()))
136+
def test_file_downloads_correctly(url, md5):
137+
retry(lambda: assert_file_downloads_correctly(url, md5))

0 commit comments

Comments
 (0)