Skip to content

Make download_url() follow redirects (#3235) #3236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jan 15, 2021
12 changes: 12 additions & 0 deletions test/test_datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,18 @@ def test_check_integrity(self):
self.assertTrue(utils.check_integrity(existing_fpath))
self.assertFalse(utils.check_integrity(nonexisting_fpath))

def test_get_redirect_url(self):
url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz"
expected = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view"

actual = utils._get_redirect_url(url)
assert actual == expected

def test_get_redirect_url_max_hops_exceeded(self):
url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz"
with self.assertRaises(RecursionError):
utils._get_redirect_url(url, max_hops=0)

def test_download_url(self):
with get_tmp_dir() as temp_dir:
url = "http://github.com/pytorch/vision/archive/master.zip"
Expand Down
58 changes: 40 additions & 18 deletions torchvision/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,31 @@ def check_integrity(fpath: str, md5: Optional[str] = None) -> bool:
return check_md5(fpath, md5)


def download_url(url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None) -> None:
def _get_redirect_url(url: str, max_hops: int = 10) -> str:
import requests

for hop in range(max_hops + 1):
response = requests.get(url)

if response.url == url or response.url is None:
return url

url = response.url
else:
raise RecursionError(f"Too many redirects: {max_hops + 1})")


def download_url(
url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3
) -> None:
"""Download a file from a url and place it in root.

Args:
url (str): URL to download file from
root (str): Directory to place downloaded file in
filename (str, optional): Name to save the file under. If None, use the basename of the URL
md5 (str, optional): MD5 checksum of the download. If None, do not check
max_redirect_hops (int, optional): Maximum number of redirect hops allowed
"""
import urllib

Expand All @@ -63,27 +80,32 @@ def download_url(url: str, root: str, filename: Optional[str] = None, md5: Optio
# check if file is already present locally
if check_integrity(fpath, md5):
print('Using downloaded and verified file: ' + fpath)
else: # download the file
try:
print('Downloading ' + url + ' to ' + fpath)
return

# expand redirect chain if needed
url = _get_redirect_url(url, max_hops=max_redirect_hops)

# download the file
try:
print('Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve(
url, fpath,
reporthook=gen_bar_updater()
)
except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
if url[:5] == 'https':
url = url.replace('https:', 'http:')
print('Failed download. Trying https -> http instead.'
' Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve(
url, fpath,
reporthook=gen_bar_updater()
)
except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
if url[:5] == 'https':
url = url.replace('https:', 'http:')
print('Failed download. Trying https -> http instead.'
' Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve(
url, fpath,
reporthook=gen_bar_updater()
)
else:
raise e
# check integrity of downloaded file
if not check_integrity(fpath, md5):
raise RuntimeError("File not found or corrupted.")
else:
raise e
# check integrity of downloaded file
if not check_integrity(fpath, md5):
raise RuntimeError("File not found or corrupted.")


def list_dir(root: str, prefix: bool = False) -> List[str]:
Expand Down