@@ -65,6 +65,9 @@ def test_full_loop_no_noise(self):
65
65
if torch_device in ["mps" ]:
66
66
assert abs (result_sum .item () - 167.47821044921875 ) < 1e-2
67
67
assert abs (result_mean .item () - 0.2178705964565277 ) < 1e-3
68
+ elif torch_device in ["cuda" ]:
69
+ assert abs (result_sum .item () - 171.59352111816406 ) < 1e-2
70
+ assert abs (result_mean .item () - 0.22342906892299652 ) < 1e-3
68
71
else :
69
72
assert abs (result_sum .item () - 162.52383422851562 ) < 1e-2
70
73
assert abs (result_mean .item () - 0.211619570851326 ) < 1e-3
@@ -94,6 +97,9 @@ def test_full_loop_with_v_prediction(self):
94
97
if torch_device in ["mps" ]:
95
98
assert abs (result_sum .item () - 124.77149200439453 ) < 1e-2
96
99
assert abs (result_mean .item () - 0.16226289014816284 ) < 1e-3
100
+ elif torch_device in ["cuda" ]:
101
+ assert abs (result_sum .item () - 128.1663360595703 ) < 1e-2
102
+ assert abs (result_mean .item () - 0.16688326001167297 ) < 1e-3
97
103
else :
98
104
assert abs (result_sum .item () - 119.8487548828125 ) < 1e-2
99
105
assert abs (result_mean .item () - 0.1560530662536621 ) < 1e-3
@@ -122,6 +128,9 @@ def test_full_loop_device(self):
122
128
if torch_device in ["mps" ]:
123
129
assert abs (result_sum .item () - 167.46957397460938 ) < 1e-2
124
130
assert abs (result_mean .item () - 0.21805934607982635 ) < 1e-3
131
+ elif torch_device in ["cuda" ]:
132
+ assert abs (result_sum .item () - 171.59353637695312 ) < 1e-2
133
+ assert abs (result_mean .item () - 0.22342908382415771 ) < 1e-3
125
134
else :
126
135
assert abs (result_sum .item () - 162.52383422851562 ) < 1e-2
127
136
assert abs (result_mean .item () - 0.211619570851326 ) < 1e-3
@@ -151,6 +160,9 @@ def test_full_loop_device_karras_sigmas(self):
151
160
if torch_device in ["mps" ]:
152
161
assert abs (result_sum .item () - 176.66974135742188 ) < 1e-2
153
162
assert abs (result_mean .item () - 0.23003872730981811 ) < 1e-2
163
+ elif torch_device in ["cuda" ]:
164
+ assert abs (result_sum .item () - 177.63653564453125 ) < 1e-2
165
+ assert abs (result_mean .item () - 0.23003872730981811 ) < 1e-2
154
166
else :
155
167
assert abs (result_sum .item () - 170.3135223388672 ) < 1e-2
156
168
assert abs (result_mean .item () - 0.23003872730981811 ) < 1e-2
0 commit comments