6
6
"""
7
7
8
8
import os
9
+ import pathlib
9
10
from urllib .error import URLError
10
11
11
12
import pytest
12
13
import torchvision .datasets .utils as utils
13
14
14
15
15
16
class TestDatasetUtils :
16
- def test_download_url (self , tmpdir ):
17
+ @pytest .mark .parametrize ("use_pathlib" , (True , False ))
18
+ def test_download_url (self , tmpdir , use_pathlib ):
19
+ if use_pathlib :
20
+ tmpdir = pathlib .Path (tmpdir )
17
21
url = "http://github.com/pytorch/vision/archive/master.zip"
18
22
try :
19
23
utils .download_url (url , tmpdir )
20
24
assert len (os .listdir (tmpdir )) != 0
21
25
except URLError :
22
26
pytest .skip (f"could not download test file '{ url } '" )
23
27
24
- def test_download_url_retry_http (self , tmpdir ):
28
+ @pytest .mark .parametrize ("use_pathlib" , (True , False ))
29
+ def test_download_url_retry_http (self , tmpdir , use_pathlib ):
30
+ if use_pathlib :
31
+ tmpdir = pathlib .Path (tmpdir )
25
32
url = "https://github.com/pytorch/vision/archive/master.zip"
26
33
try :
27
34
utils .download_url (url , tmpdir )
28
35
assert len (os .listdir (tmpdir )) != 0
29
36
except URLError :
30
37
pytest .skip (f"could not download test file '{ url } '" )
31
38
32
- def test_download_url_dont_exist (self , tmpdir ):
39
+ @pytest .mark .parametrize ("use_pathlib" , (True , False ))
40
+ def test_download_url_dont_exist (self , tmpdir , use_pathlib ):
41
+ if use_pathlib :
42
+ tmpdir = pathlib .Path (tmpdir )
33
43
url = "http://github.com/pytorch/vision/archive/this_doesnt_exist.zip"
34
44
with pytest .raises (URLError ):
35
45
utils .download_url (url , tmpdir )
36
46
37
- def test_download_url_dispatch_download_from_google_drive (self , mocker , tmpdir ):
47
+ @pytest .mark .parametrize ("use_pathlib" , (True , False ))
48
+ def test_download_url_dispatch_download_from_google_drive (self , mocker , tmpdir , use_pathlib ):
49
+ if use_pathlib :
50
+ tmpdir = pathlib .Path (tmpdir )
38
51
url = "https://drive.google.com/file/d/1GO-BHUYRuvzr1Gtp2_fqXRsr9TIeYbhV/view"
39
52
40
53
id = "1GO-BHUYRuvzr1Gtp2_fqXRsr9TIeYbhV"
@@ -44,7 +57,7 @@ def test_download_url_dispatch_download_from_google_drive(self, mocker, tmpdir):
44
57
mocked = mocker .patch ("torchvision.datasets.utils.download_file_from_google_drive" )
45
58
utils .download_url (url , tmpdir , filename , md5 )
46
59
47
- mocked .assert_called_once_with (id , tmpdir , filename , md5 )
60
+ mocked .assert_called_once_with (id , os . path . expanduser ( tmpdir ) , filename , md5 )
48
61
49
62
50
63
if __name__ == "__main__" :
0 commit comments