@@ -318,13 +318,14 @@ def test_full_loop_no_noise(self):
318318
319319 model = self .dummy_model ()
320320 sample = self .dummy_sample_deter
321+ generator = torch .manual_seed (0 )
321322
322323 for t in reversed (range (num_trained_timesteps )):
323324 # 1. predict noise residual
324325 residual = model (sample , t )
325326
326327 # 2. predict previous mean of sample x_t-1
327- pred_prev_sample = scheduler .step (residual , t , sample ).prev_sample
328+ pred_prev_sample = scheduler .step (residual , t , sample , generator = generator ).prev_sample
328329
329330 # if t > 0:
330331 # noise = self.dummy_sample_deter
@@ -336,7 +337,7 @@ def test_full_loop_no_noise(self):
336337 result_sum = torch .sum (torch .abs (sample ))
337338 result_mean = torch .mean (torch .abs (sample ))
338339
339- assert abs (result_sum .item () - 259.0883 ) < 1e-2
340+ assert abs (result_sum .item () - 258.9070 ) < 1e-2
340341 assert abs (result_mean .item () - 0.3374 ) < 1e-3
341342
342343
@@ -657,7 +658,7 @@ def test_full_loop_no_noise(self):
657658class ScoreSdeVeSchedulerTest (unittest .TestCase ):
658659 # TODO adapt with class SchedulerCommonTest (scheduler needs Numpy Integration)
659660 scheduler_classes = (ScoreSdeVeScheduler ,)
660- forward_default_kwargs = (( "seed" , 0 ), )
661+ forward_default_kwargs = ()
661662
662663 @property
663664 def dummy_sample (self ):
@@ -718,13 +719,19 @@ def check_over_configs(self, time_step=0, **config):
718719 scheduler .save_config (tmpdirname )
719720 new_scheduler = scheduler_class .from_config (tmpdirname )
720721
721- output = scheduler .step_pred (residual , time_step , sample , ** kwargs ).prev_sample
722- new_output = new_scheduler .step_pred (residual , time_step , sample , ** kwargs ).prev_sample
722+ output = scheduler .step_pred (
723+ residual , time_step , sample , generator = torch .manual_seed (0 ), ** kwargs
724+ ).prev_sample
725+ new_output = new_scheduler .step_pred (
726+ residual , time_step , sample , generator = torch .manual_seed (0 ), ** kwargs
727+ ).prev_sample
723728
724729 assert torch .sum (torch .abs (output - new_output )) < 1e-5 , "Scheduler outputs are not identical"
725730
726- output = scheduler .step_correct (residual , sample , ** kwargs ).prev_sample
727- new_output = new_scheduler .step_correct (residual , sample , ** kwargs ).prev_sample
731+ output = scheduler .step_correct (residual , sample , generator = torch .manual_seed (0 ), ** kwargs ).prev_sample
732+ new_output = new_scheduler .step_correct (
733+ residual , sample , generator = torch .manual_seed (0 ), ** kwargs
734+ ).prev_sample
728735
729736 assert torch .sum (torch .abs (output - new_output )) < 1e-5 , "Scheduler correction are not identical"
730737
@@ -743,13 +750,19 @@ def check_over_forward(self, time_step=0, **forward_kwargs):
743750 scheduler .save_config (tmpdirname )
744751 new_scheduler = scheduler_class .from_config (tmpdirname )
745752
746- output = scheduler .step_pred (residual , time_step , sample , ** kwargs ).prev_sample
747- new_output = new_scheduler .step_pred (residual , time_step , sample , ** kwargs ).prev_sample
753+ output = scheduler .step_pred (
754+ residual , time_step , sample , generator = torch .manual_seed (0 ), ** kwargs
755+ ).prev_sample
756+ new_output = new_scheduler .step_pred (
757+ residual , time_step , sample , generator = torch .manual_seed (0 ), ** kwargs
758+ ).prev_sample
748759
749760 assert torch .sum (torch .abs (output - new_output )) < 1e-5 , "Scheduler outputs are not identical"
750761
751- output = scheduler .step_correct (residual , sample , ** kwargs ).prev_sample
752- new_output = new_scheduler .step_correct (residual , sample , ** kwargs ).prev_sample
762+ output = scheduler .step_correct (residual , sample , generator = torch .manual_seed (0 ), ** kwargs ).prev_sample
763+ new_output = new_scheduler .step_correct (
764+ residual , sample , generator = torch .manual_seed (0 ), ** kwargs
765+ ).prev_sample
753766
754767 assert torch .sum (torch .abs (output - new_output )) < 1e-5 , "Scheduler correction are not identical"
755768
@@ -779,26 +792,27 @@ def test_full_loop_no_noise(self):
779792
780793 scheduler .set_sigmas (num_inference_steps )
781794 scheduler .set_timesteps (num_inference_steps )
795+ generator = torch .manual_seed (0 )
782796
783797 for i , t in enumerate (scheduler .timesteps ):
784798 sigma_t = scheduler .sigmas [i ]
785799
786800 for _ in range (scheduler .correct_steps ):
787801 with torch .no_grad ():
788802 model_output = model (sample , sigma_t )
789- sample = scheduler .step_correct (model_output , sample , ** kwargs ).prev_sample
803+ sample = scheduler .step_correct (model_output , sample , generator = generator , ** kwargs ).prev_sample
790804
791805 with torch .no_grad ():
792806 model_output = model (sample , sigma_t )
793807
794- output = scheduler .step_pred (model_output , t , sample , ** kwargs )
808+ output = scheduler .step_pred (model_output , t , sample , generator = generator , ** kwargs )
795809 sample , _ = output .prev_sample , output .prev_sample_mean
796810
797811 result_sum = torch .sum (torch .abs (sample ))
798812 result_mean = torch .mean (torch .abs (sample ))
799813
800- assert abs (result_sum .item () - 14379591680 .0 ) < 1e-2
801- assert abs (result_mean .item () - 18723426 .0 ) < 1e-3
814+ assert np . isclose (result_sum .item (), 14372758528 .0 )
815+ assert np . isclose (result_mean .item (), 18714530 .0 )
802816
803817 def test_step_shape (self ):
804818 kwargs = dict (self .forward_default_kwargs )
@@ -817,8 +831,8 @@ def test_step_shape(self):
817831 elif num_inference_steps is not None and not hasattr (scheduler , "set_timesteps" ):
818832 kwargs ["num_inference_steps" ] = num_inference_steps
819833
820- output_0 = scheduler .step_pred (residual , 0 , sample , ** kwargs ).prev_sample
821- output_1 = scheduler .step_pred (residual , 1 , sample , ** kwargs ).prev_sample
834+ output_0 = scheduler .step_pred (residual , 0 , sample , generator = torch . manual_seed ( 0 ), ** kwargs ).prev_sample
835+ output_1 = scheduler .step_pred (residual , 1 , sample , generator = torch . manual_seed ( 0 ), ** kwargs ).prev_sample
822836
823837 self .assertEqual (output_0 .shape , sample .shape )
824838 self .assertEqual (output_0 .shape , output_1 .shape )
0 commit comments