Skip to content

Commit 0985533

Browse files
slipnitskayadatumboxpmeier
authored
Make download_url() follow redirects (#3235) (#3236)
* Make download_url() follow redirects Fix bug related to the incorrect processing of redirects. Follow the redirect chain until the destination is reached or the number of redirects exceeds the max allowed value (by default 10). * Parametrize value of max allowed redirect number Make max number of hops a function argument and assign its default value to 10 * Propagate the max number of hops to download_url() Add the maximum number of redirect hops parameter to download_url() * check file existence before redirect * remove print * remove recursion * add tests * Reducing max_redirect_hops Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 3b19d6f commit 0985533

File tree

2 files changed

+52
-18
lines changed

2 files changed

+52
-18
lines changed

test/test_datasets_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,18 @@ def test_check_integrity(self):
3636
self.assertTrue(utils.check_integrity(existing_fpath))
3737
self.assertFalse(utils.check_integrity(nonexisting_fpath))
3838

39+
def test_get_redirect_url(self):
40+
url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz"
41+
expected = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view"
42+
43+
actual = utils._get_redirect_url(url)
44+
assert actual == expected
45+
46+
def test_get_redirect_url_max_hops_exceeded(self):
47+
url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz"
48+
with self.assertRaises(RecursionError):
49+
utils._get_redirect_url(url, max_hops=0)
50+
3951
def test_download_url(self):
4052
with get_tmp_dir() as temp_dir:
4153
url = "http://github.com/pytorch/vision/archive/master.zip"

torchvision/datasets/utils.py

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,31 @@ def check_integrity(fpath: str, md5: Optional[str] = None) -> bool:
4242
return check_md5(fpath, md5)
4343

4444

45-
def download_url(url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None) -> None:
45+
def _get_redirect_url(url: str, max_hops: int = 10) -> str:
46+
import requests
47+
48+
for hop in range(max_hops + 1):
49+
response = requests.get(url)
50+
51+
if response.url == url or response.url is None:
52+
return url
53+
54+
url = response.url
55+
else:
56+
raise RecursionError(f"Too many redirects: {max_hops + 1})")
57+
58+
59+
def download_url(
60+
url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3
61+
) -> None:
4662
"""Download a file from a url and place it in root.
4763
4864
Args:
4965
url (str): URL to download file from
5066
root (str): Directory to place downloaded file in
5167
filename (str, optional): Name to save the file under. If None, use the basename of the URL
5268
md5 (str, optional): MD5 checksum of the download. If None, do not check
69+
max_redirect_hops (int, optional): Maximum number of redirect hops allowed
5370
"""
5471
import urllib
5572

@@ -63,27 +80,32 @@ def download_url(url: str, root: str, filename: Optional[str] = None, md5: Optio
6380
# check if file is already present locally
6481
if check_integrity(fpath, md5):
6582
print('Using downloaded and verified file: ' + fpath)
66-
else: # download the file
67-
try:
68-
print('Downloading ' + url + ' to ' + fpath)
83+
return
84+
85+
# expand redirect chain if needed
86+
url = _get_redirect_url(url, max_hops=max_redirect_hops)
87+
88+
# download the file
89+
try:
90+
print('Downloading ' + url + ' to ' + fpath)
91+
urllib.request.urlretrieve(
92+
url, fpath,
93+
reporthook=gen_bar_updater()
94+
)
95+
except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
96+
if url[:5] == 'https':
97+
url = url.replace('https:', 'http:')
98+
print('Failed download. Trying https -> http instead.'
99+
' Downloading ' + url + ' to ' + fpath)
69100
urllib.request.urlretrieve(
70101
url, fpath,
71102
reporthook=gen_bar_updater()
72103
)
73-
except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
74-
if url[:5] == 'https':
75-
url = url.replace('https:', 'http:')
76-
print('Failed download. Trying https -> http instead.'
77-
' Downloading ' + url + ' to ' + fpath)
78-
urllib.request.urlretrieve(
79-
url, fpath,
80-
reporthook=gen_bar_updater()
81-
)
82-
else:
83-
raise e
84-
# check integrity of downloaded file
85-
if not check_integrity(fpath, md5):
86-
raise RuntimeError("File not found or corrupted.")
104+
else:
105+
raise e
106+
# check integrity of downloaded file
107+
if not check_integrity(fpath, md5):
108+
raise RuntimeError("File not found or corrupted.")
87109

88110

89111
def list_dir(root: str, prefix: bool = False) -> List[str]:

0 commit comments

Comments
 (0)