@@ -160,34 +160,16 @@ def test_to_tensor_video(self):
160
160
161
161
trans .__repr__ ()
162
162
163
- @pytest .mark .skipif (stats is None , reason = "scipy.stats not available" )
164
- def test_random_horizontal_flip_video (self ):
165
- random_state = random .getstate ()
166
- random .seed (42 )
163
+ @pytest .mark .parametrize ("p" , (0 , 1 ))
164
+ def test_random_horizontal_flip_video (self , p ):
167
165
clip = torch .rand ((3 , 4 , 112 , 112 ), dtype = torch .float )
168
166
hclip = clip .flip ((- 1 ))
169
167
170
- num_samples = 250
171
- num_horizontal = 0
172
- for _ in range (num_samples ):
173
- out = transforms .RandomHorizontalFlipVideo ()(clip )
174
- if torch .all (torch .eq (out , hclip )):
175
- num_horizontal += 1
176
-
177
- p_value = stats .binom_test (num_horizontal , num_samples , p = 0.5 )
178
- random .setstate (random_state )
179
- assert p_value > 0.0001
180
-
181
- num_samples = 250
182
- num_horizontal = 0
183
- for _ in range (num_samples ):
184
- out = transforms .RandomHorizontalFlipVideo (p = 0.7 )(clip )
185
- if torch .all (torch .eq (out , hclip )):
186
- num_horizontal += 1
187
-
188
- p_value = stats .binom_test (num_horizontal , num_samples , p = 0.7 )
189
- random .setstate (random_state )
190
- assert p_value > 0.0001
168
+ out = transforms .RandomHorizontalFlipVideo (p = p )(clip )
169
+ if p == 0 :
170
+ torch .testing .assert_close (out , clip )
171
+ elif p == 1 :
172
+ torch .testing .assert_close (out , hclip )
191
173
192
174
transforms .RandomHorizontalFlipVideo ().__repr__ ()
193
175
0 commit comments