|
1 | 1 | import contextlib
|
2 | 2 | import itertools
|
3 | 3 | import time
|
4 |
| -import unittest |
5 | 4 | import unittest.mock
|
6 | 5 | from datetime import datetime
|
7 | 6 | from os import path
|
8 | 7 | from urllib.parse import urlparse
|
9 | 8 | from urllib.request import urlopen, Request
|
10 | 9 |
|
| 10 | +import pytest |
| 11 | + |
11 | 12 | from torchvision import datasets
|
12 | 13 | from torchvision.datasets.utils import download_url, check_integrity
|
13 | 14 |
|
@@ -43,89 +44,94 @@ def inner_wrapper(request, *args, **kwargs):
|
43 | 44 | urlopen = limit_requests_per_time()(urlopen)
|
44 | 45 |
|
45 | 46 |
|
46 |
| -class DownloadTester(unittest.TestCase): |
47 |
| - @staticmethod |
48 |
| - @contextlib.contextmanager |
49 |
| - def log_download_attempts(patch=True): |
50 |
| - urls_and_md5s = set() |
51 |
| - with unittest.mock.patch( |
52 |
| - "torchvision.datasets.utils.download_url", wraps=None if patch else download_url |
53 |
| - ) as mock: |
54 |
| - try: |
55 |
| - yield urls_and_md5s |
56 |
| - finally: |
57 |
| - for args, kwargs in mock.call_args_list: |
58 |
| - url = args[0] |
59 |
| - md5 = args[-1] if len(args) == 4 else kwargs.get("md5") |
60 |
| - urls_and_md5s.add((url, md5)) |
61 |
| - |
62 |
| - @staticmethod |
63 |
| - def retry(fn, times=1, wait=5.0): |
64 |
| - msgs = [] |
65 |
| - for _ in range(times + 1): |
66 |
| - try: |
67 |
| - return fn() |
68 |
| - except AssertionError as error: |
69 |
| - msgs.append(str(error)) |
70 |
| - time.sleep(wait) |
71 |
| - else: |
72 |
| - raise AssertionError( |
73 |
| - "\n".join( |
74 |
| - ( |
75 |
| - f"Assertion failed {times + 1} times with {wait:.1f} seconds intermediate wait time.\n", |
76 |
| - *(f"{idx}: {error}" for idx, error in enumerate(msgs, 1)), |
77 |
| - ) |
| 47 | +@contextlib.contextmanager |
| 48 | +def log_download_attempts(patch=True): |
| 49 | + urls_and_md5s = set() |
| 50 | + with unittest.mock.patch("torchvision.datasets.utils.download_url", wraps=None if patch else download_url) as mock: |
| 51 | + try: |
| 52 | + yield urls_and_md5s |
| 53 | + finally: |
| 54 | + for args, kwargs in mock.call_args_list: |
| 55 | + url = args[0] |
| 56 | + md5 = args[-1] if len(args) == 4 else kwargs.get("md5") |
| 57 | + urls_and_md5s.add((url, md5)) |
| 58 | + |
| 59 | + |
| 60 | +def retry(fn, times=1, wait=5.0): |
| 61 | + msgs = [] |
| 62 | + for _ in range(times + 1): |
| 63 | + try: |
| 64 | + return fn() |
| 65 | + except AssertionError as error: |
| 66 | + msgs.append(str(error)) |
| 67 | + time.sleep(wait) |
| 68 | + else: |
| 69 | + raise AssertionError( |
| 70 | + "\n".join( |
| 71 | + ( |
| 72 | + f"Assertion failed {times + 1} times with {wait:.1f} seconds intermediate wait time.\n", |
| 73 | + *(f"{idx}: {error}" for idx, error in enumerate(msgs, 1)), |
78 | 74 | )
|
79 | 75 | )
|
80 |
| - |
81 |
| - @staticmethod |
82 |
| - def assert_response_ok(response, url=None, ok=200): |
83 |
| - msg = f"The server returned status code {response.code}" |
84 |
| - if url is not None: |
85 |
| - msg += f"for the the URL {url}" |
86 |
| - assert response.code == ok, msg |
87 |
| - |
88 |
| - @staticmethod |
89 |
| - def assert_is_downloadable(url): |
90 |
| - request = Request(url, headers=dict(method="HEAD")) |
91 |
| - response = urlopen(request) |
92 |
| - DownloadTester.assert_response_ok(response, url) |
93 |
| - |
94 |
| - @staticmethod |
95 |
| - def assert_downloads_correctly(url, md5): |
96 |
| - with get_tmp_dir() as root: |
97 |
| - file = path.join(root, path.basename(url)) |
98 |
| - with urlopen(url) as response, open(file, "wb") as fh: |
99 |
| - DownloadTester.assert_response_ok(response, url) |
100 |
| - fh.write(response.read()) |
101 |
| - |
102 |
| - assert check_integrity(file, md5=md5), "The MD5 checksums mismatch" |
103 |
| - |
104 |
| - def test_download(self): |
105 |
| - assert_fn = ( |
106 |
| - lambda url, _: self.assert_is_downloadable(url) |
107 |
| - if self.only_test_downloadability |
108 |
| - else self.assert_downloads_correctly |
109 | 76 | )
|
110 |
| - for url, md5 in self.collect_urls_and_md5s(): |
111 |
| - with self.subTest(url=url, md5=md5): |
112 |
| - self.retry(lambda: assert_fn(url, md5)) |
113 | 77 |
|
114 |
| - def collect_urls_and_md5s(self): |
115 |
| - raise NotImplementedError |
116 | 78 |
|
117 |
| - @property |
118 |
| - def only_test_downloadability(self): |
119 |
| - return True |
| 79 | +def assert_server_response_ok(response, url=None): |
| 80 | + msg = f"The server returned status code {response.code}" |
| 81 | + if url is not None: |
| 82 | + msg += f"for the the URL {url}" |
| 83 | + assert 200 <= response.code < 300, msg |
| 84 | + |
| 85 | + |
| 86 | +def assert_url_is_accessible(url): |
| 87 | + request = Request(url, headers=dict(method="HEAD")) |
| 88 | + response = urlopen(request) |
| 89 | + assert_server_response_ok(response, url) |
| 90 | + |
| 91 | + |
| 92 | +def assert_file_downloads_correctly(url, md5): |
| 93 | + with get_tmp_dir() as root: |
| 94 | + file = path.join(root, path.basename(url)) |
| 95 | + with urlopen(url) as response, open(file, "wb") as fh: |
| 96 | + assert_server_response_ok(response, url) |
| 97 | + fh.write(response.read()) |
| 98 | + |
| 99 | + assert check_integrity(file, md5=md5), "The MD5 checksums mismatch" |
| 100 | + |
| 101 | + |
| 102 | +class DownloadConfig: |
| 103 | + def __init__(self, url, md5=None, id=None): |
| 104 | + self.url = url |
| 105 | + self.md5 = md5 |
| 106 | + self.id = id or url |
| 107 | + |
| 108 | + |
| 109 | +def make_parametrize_kwargs(download_configs): |
| 110 | + argvalues = [] |
| 111 | + ids = [] |
| 112 | + for config in download_configs: |
| 113 | + argvalues.append((config.url, config.md5)) |
| 114 | + ids.append(config.id) |
| 115 | + |
| 116 | + return dict(argnames="url, md5", argvalues=argvalues, ids=ids) |
| 117 | + |
| 118 | + |
| 119 | +def places365(): |
| 120 | + with log_download_attempts(patch=False) as urls_and_md5s: |
| 121 | + for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)): |
| 122 | + with places365_root(split=split, small=small) as places365: |
| 123 | + root, data = places365 |
| 124 | + |
| 125 | + datasets.Places365(root, split=split, small=small, download=True) |
| 126 | + |
| 127 | + return [DownloadConfig(url, md5=md5, id=f"Places365, {url}") for url, md5 in urls_and_md5s] |
120 | 128 |
|
121 | 129 |
|
122 |
| -class Places365Tester(DownloadTester): |
123 |
| - def collect_urls_and_md5s(self): |
124 |
| - with self.log_download_attempts(patch=False) as urls_and_md5s: |
125 |
| - for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)): |
126 |
| - with places365_root(split=split, small=small) as places365: |
127 |
| - root, data = places365 |
| 130 | +@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain(places365(),))) |
| 131 | +def test_url_is_accessible(url, md5): |
| 132 | + retry(lambda: assert_url_is_accessible(url)) |
128 | 133 |
|
129 |
| - datasets.Places365(root, split=split, small=small, download=True) |
130 | 134 |
|
131 |
| - return urls_and_md5s |
| 135 | +@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain())) |
| 136 | +def test_file_downloads_correctly(url, md5): |
| 137 | + retry(lambda: assert_file_downloads_correctly(url, md5)) |
0 commit comments