diff --git a/test/test_ops.py b/test/test_ops.py index 82f8b6b6eb2..0b2e9b8340d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -46,9 +46,11 @@ def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwar tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5 torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol) + @pytest.mark.parametrize("seed", range(10)) @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("contiguous", (True, False)) - def test_backward(self, device, contiguous): + def test_backward(self, seed, device, contiguous): + torch.random.manual_seed(seed) pool_size = 2 x = torch.rand(1, 2 * (pool_size ** 2), 5, 5, dtype=self.dtype, device=device, requires_grad=True) if not contiguous: