Skip to content

Commit 0b64c2c

Browse files
nipunjindalnjindal
andauthored
[Stochastic Sampler][Slow Test]: Cuda test fixes (#3257)
[Slow Test]: Cuda test fixes Co-authored-by: njindal <[email protected]>
1 parent fd512d7 commit 0b64c2c

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

tests/schedulers/test_scheduler_dpm_sde.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ def test_full_loop_no_noise(self):
6565
if torch_device in ["mps"]:
6666
assert abs(result_sum.item() - 167.47821044921875) < 1e-2
6767
assert abs(result_mean.item() - 0.2178705964565277) < 1e-3
68+
elif torch_device in ["cuda"]:
69+
assert abs(result_sum.item() - 171.59352111816406) < 1e-2
70+
assert abs(result_mean.item() - 0.22342906892299652) < 1e-3
6871
else:
6972
assert abs(result_sum.item() - 162.52383422851562) < 1e-2
7073
assert abs(result_mean.item() - 0.211619570851326) < 1e-3
@@ -94,6 +97,9 @@ def test_full_loop_with_v_prediction(self):
9497
if torch_device in ["mps"]:
9598
assert abs(result_sum.item() - 124.77149200439453) < 1e-2
9699
assert abs(result_mean.item() - 0.16226289014816284) < 1e-3
100+
elif torch_device in ["cuda"]:
101+
assert abs(result_sum.item() - 128.1663360595703) < 1e-2
102+
assert abs(result_mean.item() - 0.16688326001167297) < 1e-3
97103
else:
98104
assert abs(result_sum.item() - 119.8487548828125) < 1e-2
99105
assert abs(result_mean.item() - 0.1560530662536621) < 1e-3
@@ -122,6 +128,9 @@ def test_full_loop_device(self):
122128
if torch_device in ["mps"]:
123129
assert abs(result_sum.item() - 167.46957397460938) < 1e-2
124130
assert abs(result_mean.item() - 0.21805934607982635) < 1e-3
131+
elif torch_device in ["cuda"]:
132+
assert abs(result_sum.item() - 171.59353637695312) < 1e-2
133+
assert abs(result_mean.item() - 0.22342908382415771) < 1e-3
125134
else:
126135
assert abs(result_sum.item() - 162.52383422851562) < 1e-2
127136
assert abs(result_mean.item() - 0.211619570851326) < 1e-3
@@ -151,6 +160,9 @@ def test_full_loop_device_karras_sigmas(self):
151160
if torch_device in ["mps"]:
152161
assert abs(result_sum.item() - 176.66974135742188) < 1e-2
153162
assert abs(result_mean.item() - 0.23003872730981811) < 1e-2
163+
elif torch_device in ["cuda"]:
164+
assert abs(result_sum.item() - 177.63653564453125) < 1e-2
165+
assert abs(result_mean.item() - 0.23003872730981811) < 1e-2
154166
else:
155167
assert abs(result_sum.item() - 170.3135223388672) < 1e-2
156168
assert abs(result_mean.item() - 0.23003872730981811) < 1e-2

0 commit comments

Comments
 (0)