@@ -1281,10 +1281,11 @@ def test_full_loop_no_noise(self):
12811281
12821282 scheduler .set_timesteps (self .num_inference_steps )
12831283
1284- generator = torch .Generator ().manual_seed (0 )
1284+ generator = torch .Generator (torch_device ).manual_seed (0 )
12851285
12861286 model = self .dummy_model ()
12871287 sample = self .dummy_sample_deter * scheduler .init_noise_sigma
1288+ sample = sample .to (torch_device )
12881289
12891290 for i , t in enumerate (scheduler .timesteps ):
12901291 sample = scheduler .scale_model_input (sample , t )
@@ -1296,7 +1297,6 @@ def test_full_loop_no_noise(self):
12961297
12971298 result_sum = torch .sum (torch .abs (sample ))
12981299 result_mean = torch .mean (torch .abs (sample ))
1299- print (result_sum , result_mean )
13001300
13011301 assert abs (result_sum .item () - 10.0807 ) < 1e-2
13021302 assert abs (result_mean .item () - 0.0131 ) < 1e-3
@@ -1308,7 +1308,7 @@ def test_full_loop_device(self):
13081308
13091309 scheduler .set_timesteps (self .num_inference_steps , device = torch_device )
13101310
1311- generator = torch .Generator ().manual_seed (0 )
1311+ generator = torch .Generator (torch_device ).manual_seed (0 )
13121312
13131313 model = self .dummy_model ()
13141314 sample = self .dummy_sample_deter * scheduler .init_noise_sigma
@@ -1324,7 +1324,6 @@ def test_full_loop_device(self):
13241324
13251325 result_sum = torch .sum (torch .abs (sample ))
13261326 result_mean = torch .mean (torch .abs (sample ))
1327- print (result_sum , result_mean )
13281327
13291328 assert abs (result_sum .item () - 10.0807 ) < 1e-2
13301329 assert abs (result_mean .item () - 0.0131 ) < 1e-3
@@ -1365,10 +1364,11 @@ def test_full_loop_no_noise(self):
13651364
13661365 scheduler .set_timesteps (self .num_inference_steps )
13671366
1368- generator = torch .Generator ().manual_seed (0 )
1367+ generator = torch .Generator (device = torch_device ).manual_seed (0 )
13691368
13701369 model = self .dummy_model ()
13711370 sample = self .dummy_sample_deter * scheduler .init_noise_sigma
1371+ sample = sample .to (torch_device )
13721372
13731373 for i , t in enumerate (scheduler .timesteps ):
13741374 sample = scheduler .scale_model_input (sample , t )
@@ -1380,9 +1380,14 @@ def test_full_loop_no_noise(self):
13801380
13811381 result_sum = torch .sum (torch .abs (sample ))
13821382 result_mean = torch .mean (torch .abs (sample ))
1383- print (result_sum , result_mean )
1384- assert abs (result_sum .item () - 152.3192 ) < 1e-2
1385- assert abs (result_mean .item () - 0.1983 ) < 1e-3
1383+
1384+ if str (torch_device ).startswith ("cpu" ):
1385+ assert abs (result_sum .item () - 152.3192 ) < 1e-2
1386+ assert abs (result_mean .item () - 0.1983 ) < 1e-3
1387+ else :
1388+ # CUDA
1389+ assert abs (result_sum .item () - 144.8084 ) < 1e-2
1390+ assert abs (result_mean .item () - 0.18855 ) < 1e-3
13861391
13871392 def test_full_loop_device (self ):
13881393 scheduler_class = self .scheduler_classes [0 ]
@@ -1391,7 +1396,7 @@ def test_full_loop_device(self):
13911396
13921397 scheduler .set_timesteps (self .num_inference_steps , device = torch_device )
13931398
1394- generator = torch .Generator ().manual_seed (0 )
1399+ generator = torch .Generator (device = torch_device ).manual_seed (0 )
13951400
13961401 model = self .dummy_model ()
13971402 sample = self .dummy_sample_deter * scheduler .init_noise_sigma
@@ -1407,14 +1412,18 @@ def test_full_loop_device(self):
14071412
14081413 result_sum = torch .sum (torch .abs (sample ))
14091414 result_mean = torch .mean (torch .abs (sample ))
1410- print ( result_sum , result_mean )
1411- if not str (torch_device ).startswith ("mps " ):
1415+
1416+ if str (torch_device ).startswith ("cpu " ):
14121417 # The following sum varies between 148 and 156 on mps. Why?
14131418 assert abs (result_sum .item () - 152.3192 ) < 1e-2
14141419 assert abs (result_mean .item () - 0.1983 ) < 1e-3
1415- else :
1420+ elif str ( torch_device ). startswith ( "mps" ) :
14161421 # Larger tolerance on mps
14171422 assert abs (result_mean .item () - 0.1983 ) < 1e-2
1423+ else :
1424+ # CUDA
1425+ assert abs (result_sum .item () - 144.8084 ) < 1e-2
1426+ assert abs (result_mean .item () - 0.18855 ) < 1e-3
14181427
14191428
14201429class IPNDMSchedulerTest (SchedulerCommonTest ):
0 commit comments