Skip to content

Commit 3dd8d65

Browse files
yiwen-songNicolasHug
authored andcommitted
[fbsync] Replace get_tmp_dir() with tmpdir fixture in tests (#4280)
Summary: * Replace in test_datasets* * Replace in test_image.py * Replace in test_transforms_tensor.py * Replace in test_internet.py and test_io.py * get_list_of_videos is util function still use get_tmp_dir * Fix get_list_of_videos siginiture * Add get_tmp_dir import * Modify test_datasets_video_utils.py for test to pass * Fix indentation * Replace get_tmp_dir in util functions in test_dataset_sampler.py * Replace get_tmp_dir in util functions in test_dataset_video_utils.py * Move get_tmp_dir() to datasets_utils.py and refactor * Fix pylint, indentation and imports * import shutil to common_util.py * Fix function signiture * Remove get_list_of_videos under context manager * Move get_list_of_videos to common_utils.py * Move get_tmp_dir() back to common_utils.py * Fix pylint and imports Reviewed By: NicolasHug Differential Revision: D30417192 fbshipit-source-id: fd5ae2ad7f21509dbe09f7df85f8d9006b9ed1ea Co-authored-by: Nicolas Hug <[email protected]>
1 parent 6d8f0cb commit 3dd8d65

9 files changed

+317
-374
lines changed

test/common_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from numbers import Number
1616
from torch._six import string_classes
1717
from collections import OrderedDict
18+
from torchvision import io
1819

1920
import numpy as np
2021
from PIL import Image
@@ -147,6 +148,25 @@ def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu
147148
assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)
148149

149150

151+
def get_list_of_videos(tmpdir, num_videos=5, sizes=None, fps=None):
152+
names = []
153+
for i in range(num_videos):
154+
if sizes is None:
155+
size = 5 * (i + 1)
156+
else:
157+
size = sizes[i]
158+
if fps is None:
159+
f = 5
160+
else:
161+
f = fps[i]
162+
data = torch.randint(0, 256, (size, 300, 400, 3), dtype=torch.uint8)
163+
name = os.path.join(tmpdir, "{}.mp4".format(i))
164+
names.append(name)
165+
io.write_video(name, data, fps=f)
166+
167+
return names
168+
169+
150170
def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None):
151171
np_pil_image = np.array(pil_image)
152172
if np_pil_image.ndim == 2:

test/test_datasets_download.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
USER_AGENT,
2323
)
2424

25-
from common_utils import get_tmp_dir
26-
2725

2826
def limit_requests_per_time(min_secs_between_requests=2.0):
2927
last_requests = {}
@@ -166,16 +164,15 @@ def assert_url_is_accessible(url, timeout=5.0):
166164
urlopen(request, timeout=timeout)
167165

168166

169-
def assert_file_downloads_correctly(url, md5, timeout=5.0):
170-
with get_tmp_dir() as root:
171-
file = path.join(root, path.basename(url))
172-
with assert_server_response_ok():
173-
with open(file, "wb") as fh:
174-
request = Request(url, headers={"User-Agent": USER_AGENT})
175-
response = urlopen(request, timeout=timeout)
176-
fh.write(response.read())
167+
def assert_file_downloads_correctly(url, md5, tmpdir, timeout=5.0):
168+
file = path.join(tmpdir, path.basename(url))
169+
with assert_server_response_ok():
170+
with open(file, "wb") as fh:
171+
request = Request(url, headers={"User-Agent": USER_AGENT})
172+
response = urlopen(request, timeout=timeout)
173+
fh.write(response.read())
177174

178-
assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"
175+
assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"
179176

180177

181178
class DownloadConfig:

test/test_datasets_samplers.py

Lines changed: 67 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -13,104 +13,83 @@
1313
from torchvision.datasets.video_utils import VideoClips, unfold
1414
from torchvision import get_video_backend
1515

16-
from common_utils import get_tmp_dir, assert_equal
17-
18-
19-
@contextlib.contextmanager
20-
def get_list_of_videos(num_videos=5, sizes=None, fps=None):
21-
with get_tmp_dir() as tmp_dir:
22-
names = []
23-
for i in range(num_videos):
24-
if sizes is None:
25-
size = 5 * (i + 1)
26-
else:
27-
size = sizes[i]
28-
if fps is None:
29-
f = 5
30-
else:
31-
f = fps[i]
32-
data = torch.randint(0, 256, (size, 300, 400, 3), dtype=torch.uint8)
33-
name = os.path.join(tmp_dir, "{}.mp4".format(i))
34-
names.append(name)
35-
io.write_video(name, data, fps=f)
36-
37-
yield names
16+
from common_utils import get_list_of_videos, assert_equal
3817

3918

4019
@pytest.mark.skipif(not io.video._av_available(), reason="this test requires av")
4120
class TestDatasetsSamplers:
42-
def test_random_clip_sampler(self):
43-
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
44-
video_clips = VideoClips(video_list, 5, 5)
45-
sampler = RandomClipSampler(video_clips, 3)
46-
assert len(sampler) == 3 * 3
47-
indices = torch.tensor(list(iter(sampler)))
48-
videos = torch.div(indices, 5, rounding_mode='floor')
49-
v_idxs, count = torch.unique(videos, return_counts=True)
50-
assert_equal(v_idxs, torch.tensor([0, 1, 2]))
51-
assert_equal(count, torch.tensor([3, 3, 3]))
21+
def test_random_clip_sampler(self, tmpdir):
22+
video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25])
23+
video_clips = VideoClips(video_list, 5, 5)
24+
sampler = RandomClipSampler(video_clips, 3)
25+
assert len(sampler) == 3 * 3
26+
indices = torch.tensor(list(iter(sampler)))
27+
videos = torch.div(indices, 5, rounding_mode='floor')
28+
v_idxs, count = torch.unique(videos, return_counts=True)
29+
assert_equal(v_idxs, torch.tensor([0, 1, 2]))
30+
assert_equal(count, torch.tensor([3, 3, 3]))
5231

53-
def test_random_clip_sampler_unequal(self):
54-
with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list:
55-
video_clips = VideoClips(video_list, 5, 5)
56-
sampler = RandomClipSampler(video_clips, 3)
57-
assert len(sampler) == 2 + 3 + 3
58-
indices = list(iter(sampler))
59-
assert 0 in indices
60-
assert 1 in indices
61-
# remove elements of the first video, to simplify testing
62-
indices.remove(0)
63-
indices.remove(1)
64-
indices = torch.tensor(indices) - 2
65-
videos = torch.div(indices, 5, rounding_mode='floor')
66-
v_idxs, count = torch.unique(videos, return_counts=True)
67-
assert_equal(v_idxs, torch.tensor([0, 1]))
68-
assert_equal(count, torch.tensor([3, 3]))
32+
def test_random_clip_sampler_unequal(self, tmpdir):
33+
video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[10, 25, 25])
34+
video_clips = VideoClips(video_list, 5, 5)
35+
sampler = RandomClipSampler(video_clips, 3)
36+
assert len(sampler) == 2 + 3 + 3
37+
indices = list(iter(sampler))
38+
assert 0 in indices
39+
assert 1 in indices
40+
# remove elements of the first video, to simplify testing
41+
indices.remove(0)
42+
indices.remove(1)
43+
indices = torch.tensor(indices) - 2
44+
videos = torch.div(indices, 5, rounding_mode='floor')
45+
v_idxs, count = torch.unique(videos, return_counts=True)
46+
assert_equal(v_idxs, torch.tensor([0, 1]))
47+
assert_equal(count, torch.tensor([3, 3]))
6948

70-
def test_uniform_clip_sampler(self):
71-
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
72-
video_clips = VideoClips(video_list, 5, 5)
73-
sampler = UniformClipSampler(video_clips, 3)
74-
assert len(sampler) == 3 * 3
75-
indices = torch.tensor(list(iter(sampler)))
76-
videos = torch.div(indices, 5, rounding_mode='floor')
77-
v_idxs, count = torch.unique(videos, return_counts=True)
78-
assert_equal(v_idxs, torch.tensor([0, 1, 2]))
79-
assert_equal(count, torch.tensor([3, 3, 3]))
80-
assert_equal(indices, torch.tensor([0, 2, 4, 5, 7, 9, 10, 12, 14]))
49+
def test_uniform_clip_sampler(self, tmpdir):
50+
video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25])
51+
video_clips = VideoClips(video_list, 5, 5)
52+
sampler = UniformClipSampler(video_clips, 3)
53+
assert len(sampler) == 3 * 3
54+
indices = torch.tensor(list(iter(sampler)))
55+
videos = torch.div(indices, 5, rounding_mode='floor')
56+
v_idxs, count = torch.unique(videos, return_counts=True)
57+
assert_equal(v_idxs, torch.tensor([0, 1, 2]))
58+
assert_equal(count, torch.tensor([3, 3, 3]))
59+
assert_equal(indices, torch.tensor([0, 2, 4, 5, 7, 9, 10, 12, 14]))
8160

82-
def test_uniform_clip_sampler_insufficient_clips(self):
83-
with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list:
84-
video_clips = VideoClips(video_list, 5, 5)
85-
sampler = UniformClipSampler(video_clips, 3)
86-
assert len(sampler) == 3 * 3
87-
indices = torch.tensor(list(iter(sampler)))
88-
assert_equal(indices, torch.tensor([0, 0, 1, 2, 4, 6, 7, 9, 11]))
61+
def test_uniform_clip_sampler_insufficient_clips(self, tmpdir):
62+
video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[10, 25, 25])
63+
video_clips = VideoClips(video_list, 5, 5)
64+
sampler = UniformClipSampler(video_clips, 3)
65+
assert len(sampler) == 3 * 3
66+
indices = torch.tensor(list(iter(sampler)))
67+
assert_equal(indices, torch.tensor([0, 0, 1, 2, 4, 6, 7, 9, 11]))
8968

90-
def test_distributed_sampler_and_uniform_clip_sampler(self):
91-
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
92-
video_clips = VideoClips(video_list, 5, 5)
93-
clip_sampler = UniformClipSampler(video_clips, 3)
69+
def test_distributed_sampler_and_uniform_clip_sampler(self, tmpdir):
70+
video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25])
71+
video_clips = VideoClips(video_list, 5, 5)
72+
clip_sampler = UniformClipSampler(video_clips, 3)
9473

95-
distributed_sampler_rank0 = DistributedSampler(
96-
clip_sampler,
97-
num_replicas=2,
98-
rank=0,
99-
group_size=3,
100-
)
101-
indices = torch.tensor(list(iter(distributed_sampler_rank0)))
102-
assert len(distributed_sampler_rank0) == 6
103-
assert_equal(indices, torch.tensor([0, 2, 4, 10, 12, 14]))
74+
distributed_sampler_rank0 = DistributedSampler(
75+
clip_sampler,
76+
num_replicas=2,
77+
rank=0,
78+
group_size=3,
79+
)
80+
indices = torch.tensor(list(iter(distributed_sampler_rank0)))
81+
assert len(distributed_sampler_rank0) == 6
82+
assert_equal(indices, torch.tensor([0, 2, 4, 10, 12, 14]))
10483

105-
distributed_sampler_rank1 = DistributedSampler(
106-
clip_sampler,
107-
num_replicas=2,
108-
rank=1,
109-
group_size=3,
110-
)
111-
indices = torch.tensor(list(iter(distributed_sampler_rank1)))
112-
assert len(distributed_sampler_rank1) == 6
113-
assert_equal(indices, torch.tensor([5, 7, 9, 0, 2, 4]))
84+
distributed_sampler_rank1 = DistributedSampler(
85+
clip_sampler,
86+
num_replicas=2,
87+
rank=1,
88+
group_size=3,
89+
)
90+
indices = torch.tensor(list(iter(distributed_sampler_rank1)))
91+
assert len(distributed_sampler_rank1) == 6
92+
assert_equal(indices, torch.tensor([5, 7, 9, 0, 2, 4]))
11493

11594

11695
if __name__ == '__main__':

test/test_datasets_utils.py

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import lzma
1313
import contextlib
1414

15-
from common_utils import get_tmp_dir
1615
from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS
1716

1817

@@ -113,7 +112,7 @@ def test_detect_file_type_incompatible(self, file):
113112
utils._detect_file_type(file)
114113

115114
@pytest.mark.parametrize('extension', [".bz2", ".gz", ".xz"])
116-
def test_decompress(self, extension):
115+
def test_decompress(self, extension, tmpdir):
117116
def create_compressed(root, content="this is the content"):
118117
file = os.path.join(root, "file")
119118
compressed = f"{file}{extension}"
@@ -124,21 +123,20 @@ def create_compressed(root, content="this is the content"):
124123

125124
return compressed, file, content
126125

127-
with get_tmp_dir() as temp_dir:
128-
compressed, file, content = create_compressed(temp_dir)
126+
compressed, file, content = create_compressed(tmpdir)
129127

130-
utils._decompress(compressed)
128+
utils._decompress(compressed)
131129

132-
assert os.path.exists(file)
130+
assert os.path.exists(file)
133131

134-
with open(file, "r") as fh:
135-
assert fh.read() == content
132+
with open(file, "r") as fh:
133+
assert fh.read() == content
136134

137135
def test_decompress_no_compression(self):
138136
with pytest.raises(RuntimeError):
139137
utils._decompress("foo.tar")
140138

141-
def test_decompress_remove_finished(self):
139+
def test_decompress_remove_finished(self, tmpdir):
142140
def create_compressed(root, content="this is the content"):
143141
file = os.path.join(root, "file")
144142
compressed = f"{file}.gz"
@@ -148,12 +146,11 @@ def create_compressed(root, content="this is the content"):
148146

149147
return compressed, file, content
150148

151-
with get_tmp_dir() as temp_dir:
152-
compressed, file, content = create_compressed(temp_dir)
149+
compressed, file, content = create_compressed(tmpdir)
153150

154-
utils.extract_archive(compressed, temp_dir, remove_finished=True)
151+
utils.extract_archive(compressed, tmpdir, remove_finished=True)
155152

156-
assert not os.path.exists(compressed)
153+
assert not os.path.exists(compressed)
157154

158155
@pytest.mark.parametrize('extension', [".gz", ".xz"])
159156
@pytest.mark.parametrize('remove_finished', [True, False])
@@ -166,7 +163,7 @@ def test_extract_archive_defer_to_decompress(self, extension, remove_finished, m
166163

167164
mocked.assert_called_once_with(file, filename, remove_finished=remove_finished)
168165

169-
def test_extract_zip(self):
166+
def test_extract_zip(self, tmpdir):
170167
def create_archive(root, content="this is the content"):
171168
file = os.path.join(root, "dst.txt")
172169
archive = os.path.join(root, "archive.zip")
@@ -176,19 +173,18 @@ def create_archive(root, content="this is the content"):
176173

177174
return archive, file, content
178175

179-
with get_tmp_dir() as temp_dir:
180-
archive, file, content = create_archive(temp_dir)
176+
archive, file, content = create_archive(tmpdir)
181177

182-
utils.extract_archive(archive, temp_dir)
178+
utils.extract_archive(archive, tmpdir)
183179

184-
assert os.path.exists(file)
180+
assert os.path.exists(file)
185181

186-
with open(file, "r") as fh:
187-
assert fh.read() == content
182+
with open(file, "r") as fh:
183+
assert fh.read() == content
188184

189185
@pytest.mark.parametrize('extension, mode', [
190186
('.tar', 'w'), ('.tar.gz', 'w:gz'), ('.tgz', 'w:gz'), ('.tar.xz', 'w:xz')])
191-
def test_extract_tar(self, extension, mode):
187+
def test_extract_tar(self, extension, mode, tmpdir):
192188
def create_archive(root, extension, mode, content="this is the content"):
193189
src = os.path.join(root, "src.txt")
194190
dst = os.path.join(root, "dst.txt")
@@ -202,15 +198,14 @@ def create_archive(root, extension, mode, content="this is the content"):
202198

203199
return archive, dst, content
204200

205-
with get_tmp_dir() as temp_dir:
206-
archive, file, content = create_archive(temp_dir, extension, mode)
201+
archive, file, content = create_archive(tmpdir, extension, mode)
207202

208-
utils.extract_archive(archive, temp_dir)
203+
utils.extract_archive(archive, tmpdir)
209204

210-
assert os.path.exists(file)
205+
assert os.path.exists(file)
211206

212-
with open(file, "r") as fh:
213-
assert fh.read() == content
207+
with open(file, "r") as fh:
208+
assert fh.read() == content
214209

215210
def test_verify_str_arg(self):
216211
assert "a" == utils.verify_str_arg("a", "arg", ("a",))

0 commit comments

Comments
 (0)