Skip to content

Commit 3a17e33

Browse files
authored
Remove p_value test for RandomHorizontalFlipVideo (#4765)
1 parent 0f770ac commit 3a17e33

File tree

1 file changed

+7
-25
lines changed

1 file changed

+7
-25
lines changed

test/test_transforms_video.py

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -160,34 +160,16 @@ def test_to_tensor_video(self):
160160

161161
trans.__repr__()
162162

163-
@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
164-
def test_random_horizontal_flip_video(self):
165-
random_state = random.getstate()
166-
random.seed(42)
163+
@pytest.mark.parametrize("p", (0, 1))
164+
def test_random_horizontal_flip_video(self, p):
167165
clip = torch.rand((3, 4, 112, 112), dtype=torch.float)
168166
hclip = clip.flip((-1))
169167

170-
num_samples = 250
171-
num_horizontal = 0
172-
for _ in range(num_samples):
173-
out = transforms.RandomHorizontalFlipVideo()(clip)
174-
if torch.all(torch.eq(out, hclip)):
175-
num_horizontal += 1
176-
177-
p_value = stats.binom_test(num_horizontal, num_samples, p=0.5)
178-
random.setstate(random_state)
179-
assert p_value > 0.0001
180-
181-
num_samples = 250
182-
num_horizontal = 0
183-
for _ in range(num_samples):
184-
out = transforms.RandomHorizontalFlipVideo(p=0.7)(clip)
185-
if torch.all(torch.eq(out, hclip)):
186-
num_horizontal += 1
187-
188-
p_value = stats.binom_test(num_horizontal, num_samples, p=0.7)
189-
random.setstate(random_state)
190-
assert p_value > 0.0001
168+
out = transforms.RandomHorizontalFlipVideo(p=p)(clip)
169+
if p == 0:
170+
torch.testing.assert_close(out, clip)
171+
elif p == 1:
172+
torch.testing.assert_close(out, hclip)
191173

192174
transforms.RandomHorizontalFlipVideo().__repr__()
193175

0 commit comments

Comments
 (0)