From d1141bdd20c8133e3f95edbed8c116a2cbf6aa8e Mon Sep 17 00:00:00 2001 From: njindal Date: Thu, 27 Apr 2023 13:42:16 +0530 Subject: [PATCH] [Slow Test]: Cuda test fixes --- tests/schedulers/test_scheduler_dpm_sde.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/schedulers/test_scheduler_dpm_sde.py b/tests/schedulers/test_scheduler_dpm_sde.py index 010c4bdb1196..7906c8d5d4e9 100644 --- a/tests/schedulers/test_scheduler_dpm_sde.py +++ b/tests/schedulers/test_scheduler_dpm_sde.py @@ -65,6 +65,9 @@ def test_full_loop_no_noise(self): if torch_device in ["mps"]: assert abs(result_sum.item() - 167.47821044921875) < 1e-2 assert abs(result_mean.item() - 0.2178705964565277) < 1e-3 + elif torch_device in ["cuda"]: + assert abs(result_sum.item() - 171.59352111816406) < 1e-2 + assert abs(result_mean.item() - 0.22342906892299652) < 1e-3 else: assert abs(result_sum.item() - 162.52383422851562) < 1e-2 assert abs(result_mean.item() - 0.211619570851326) < 1e-3 @@ -94,6 +97,9 @@ def test_full_loop_with_v_prediction(self): if torch_device in ["mps"]: assert abs(result_sum.item() - 124.77149200439453) < 1e-2 assert abs(result_mean.item() - 0.16226289014816284) < 1e-3 + elif torch_device in ["cuda"]: + assert abs(result_sum.item() - 128.1663360595703) < 1e-2 + assert abs(result_mean.item() - 0.16688326001167297) < 1e-3 else: assert abs(result_sum.item() - 119.8487548828125) < 1e-2 assert abs(result_mean.item() - 0.1560530662536621) < 1e-3 @@ -122,6 +128,9 @@ def test_full_loop_device(self): if torch_device in ["mps"]: assert abs(result_sum.item() - 167.46957397460938) < 1e-2 assert abs(result_mean.item() - 0.21805934607982635) < 1e-3 + elif torch_device in ["cuda"]: + assert abs(result_sum.item() - 171.59353637695312) < 1e-2 + assert abs(result_mean.item() - 0.22342908382415771) < 1e-3 else: assert abs(result_sum.item() - 162.52383422851562) < 1e-2 assert abs(result_mean.item() - 0.211619570851326) < 1e-3 @@ -151,6 +160,9 @@ def test_full_loop_device_karras_sigmas(self): if torch_device in ["mps"]: assert abs(result_sum.item() - 176.66974135742188) < 1e-2 assert abs(result_mean.item() - 0.23003872730981811) < 1e-2 + elif torch_device in ["cuda"]: + assert abs(result_sum.item() - 177.63653564453125) < 1e-2 + assert abs(result_mean.item() - 0.23003872730981811) < 1e-2 else: assert abs(result_sum.item() - 170.3135223388672) < 1e-2 assert abs(result_mean.item() - 0.23003872730981811) < 1e-2