Skip to content

Commit d234307

Browse files
ahmadsharif1bmmtstbNicolasHug
authored
Add pathlib.Path support for download utils (#8196)
Co-authored-by: Ahmad Sharif <[email protected]> Co-authored-by: Brizar <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent 2afb7fa commit d234307

File tree

2 files changed

+29
-10
lines changed

2 files changed

+29
-10
lines changed

test/test_internet.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,35 +6,48 @@
66
"""
77

88
import os
9+
import pathlib
910
from urllib.error import URLError
1011

1112
import pytest
1213
import torchvision.datasets.utils as utils
1314

1415

1516
class TestDatasetUtils:
16-
def test_download_url(self, tmpdir):
17+
@pytest.mark.parametrize("use_pathlib", (True, False))
18+
def test_download_url(self, tmpdir, use_pathlib):
19+
if use_pathlib:
20+
tmpdir = pathlib.Path(tmpdir)
1721
url = "http://github.com/pytorch/vision/archive/master.zip"
1822
try:
1923
utils.download_url(url, tmpdir)
2024
assert len(os.listdir(tmpdir)) != 0
2125
except URLError:
2226
pytest.skip(f"could not download test file '{url}'")
2327

24-
def test_download_url_retry_http(self, tmpdir):
28+
@pytest.mark.parametrize("use_pathlib", (True, False))
29+
def test_download_url_retry_http(self, tmpdir, use_pathlib):
30+
if use_pathlib:
31+
tmpdir = pathlib.Path(tmpdir)
2532
url = "https://github.com/pytorch/vision/archive/master.zip"
2633
try:
2734
utils.download_url(url, tmpdir)
2835
assert len(os.listdir(tmpdir)) != 0
2936
except URLError:
3037
pytest.skip(f"could not download test file '{url}'")
3138

32-
def test_download_url_dont_exist(self, tmpdir):
39+
@pytest.mark.parametrize("use_pathlib", (True, False))
40+
def test_download_url_dont_exist(self, tmpdir, use_pathlib):
41+
if use_pathlib:
42+
tmpdir = pathlib.Path(tmpdir)
3343
url = "http://github.com/pytorch/vision/archive/this_doesnt_exist.zip"
3444
with pytest.raises(URLError):
3545
utils.download_url(url, tmpdir)
3646

37-
def test_download_url_dispatch_download_from_google_drive(self, mocker, tmpdir):
47+
@pytest.mark.parametrize("use_pathlib", (True, False))
48+
def test_download_url_dispatch_download_from_google_drive(self, mocker, tmpdir, use_pathlib):
49+
if use_pathlib:
50+
tmpdir = pathlib.Path(tmpdir)
3851
url = "https://drive.google.com/file/d/1GO-BHUYRuvzr1Gtp2_fqXRsr9TIeYbhV/view"
3952

4053
id = "1GO-BHUYRuvzr1Gtp2_fqXRsr9TIeYbhV"
@@ -44,7 +57,7 @@ def test_download_url_dispatch_download_from_google_drive(self, mocker, tmpdir):
4457
mocked = mocker.patch("torchvision.datasets.utils.download_file_from_google_drive")
4558
utils.download_url(url, tmpdir, filename, md5)
4659

47-
mocked.assert_called_once_with(id, tmpdir, filename, md5)
60+
mocked.assert_called_once_with(id, os.path.expanduser(tmpdir), filename, md5)
4861

4962

5063
if __name__ == "__main__":

torchvision/datasets/utils.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import urllib.request
1616
import warnings
1717
import zipfile
18-
from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar
18+
from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar, Union
1919
from urllib.parse import urlparse
2020

2121
import numpy as np
@@ -104,7 +104,11 @@ def _get_google_drive_file_id(url: str) -> Optional[str]:
104104

105105

106106
def download_url(
107-
url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3
107+
url: str,
108+
root: Union[str, pathlib.Path],
109+
filename: Optional[str] = None,
110+
md5: Optional[str] = None,
111+
max_redirect_hops: int = 3,
108112
) -> None:
109113
"""Download a file from a url and place it in root.
110114
@@ -118,7 +122,7 @@ def download_url(
118122
root = os.path.expanduser(root)
119123
if not filename:
120124
filename = os.path.basename(url)
121-
fpath = os.path.join(root, filename)
125+
fpath = os.fspath(os.path.join(root, filename))
122126

123127
os.makedirs(root, exist_ok=True)
124128

@@ -203,7 +207,9 @@ def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple
203207
return api_response, content
204208

205209

206-
def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None):
210+
def download_file_from_google_drive(
211+
file_id: str, root: Union[str, pathlib.Path], filename: Optional[str] = None, md5: Optional[str] = None
212+
):
207213
"""Download a Google Drive file from and place it in root.
208214
209215
Args:
@@ -217,7 +223,7 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
217223
root = os.path.expanduser(root)
218224
if not filename:
219225
filename = file_id
220-
fpath = os.path.join(root, filename)
226+
fpath = os.fspath(os.path.join(root, filename))
221227

222228
os.makedirs(root, exist_ok=True)
223229

0 commit comments

Comments
 (0)