diff --git a/test/inductor/test_cooperative_reductions.py b/test/inductor/test_cooperative_reductions.py index 469ceec2e1b2..a913ffb25bf3 100644 --- a/test/inductor/test_cooperative_reductions.py +++ b/test/inductor/test_cooperative_reductions.py @@ -58,7 +58,8 @@ def setUp(self): torch._dynamo.reset() def run_and_check(self, fn, args, *, expect_kernel_count=1): - expected = fn(*args) + args_cpu = [tensor.cpu().to(torch.float32) for tensor in args] + expected = fn(*args_cpu).to(torch.float16) fn = torch.compile(fn, fullgraph=True) result, (source_code,) = run_and_get_code(fn, *args) self.assertEqual(result, expected)