1212# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313# See the License for the specific language governing permissions and
1414# limitations under the License.
15- import pdb
1615import tempfile
1716import unittest
1817
@@ -383,21 +382,22 @@ def get_scheduler_config(self, **kwargs):
383382
384383 def check_over_configs (self , time_step = 0 , ** config ):
385384 kwargs = dict (self .forward_default_kwargs )
385+ num_inference_steps = kwargs .pop ("num_inference_steps" , None )
386386 sample = self .dummy_sample
387387 residual = 0.1 * sample
388388 dummy_past_residuals = [residual + 0.2 , residual + 0.15 , residual + 0.1 , residual + 0.05 ]
389389
390390 for scheduler_class in self .scheduler_classes :
391391 scheduler_config = self .get_scheduler_config (** config )
392392 scheduler = scheduler_class (** scheduler_config )
393- scheduler .set_timesteps (kwargs [ " num_inference_steps" ] )
393+ scheduler .set_timesteps (num_inference_steps )
394394 # copy over dummy past residuals
395395 scheduler .ets = dummy_past_residuals [:]
396396
397397 with tempfile .TemporaryDirectory () as tmpdirname :
398398 scheduler .save_config (tmpdirname )
399399 new_scheduler = scheduler_class .from_config (tmpdirname )
400- new_scheduler .set_timesteps (kwargs [ " num_inference_steps" ] )
400+ new_scheduler .set_timesteps (num_inference_steps )
401401 # copy over dummy past residuals
402402 new_scheduler .ets = dummy_past_residuals [:]
403403
@@ -416,15 +416,15 @@ def test_from_pretrained_save_pretrained(self):
416416
417417 def check_over_forward (self , time_step = 0 , ** forward_kwargs ):
418418 kwargs = dict (self .forward_default_kwargs )
419- kwargs .update ( forward_kwargs )
419+ num_inference_steps = kwargs .pop ( "num_inference_steps" , None )
420420 sample = self .dummy_sample
421421 residual = 0.1 * sample
422422 dummy_past_residuals = [residual + 0.2 , residual + 0.15 , residual + 0.1 , residual + 0.05 ]
423423
424424 for scheduler_class in self .scheduler_classes :
425425 scheduler_config = self .get_scheduler_config ()
426426 scheduler = scheduler_class (** scheduler_config )
427- scheduler .set_timesteps (kwargs [ " num_inference_steps" ] )
427+ scheduler .set_timesteps (num_inference_steps )
428428
429429 # copy over dummy past residuals
430430 scheduler .ets = dummy_past_residuals [:]
@@ -434,7 +434,7 @@ def check_over_forward(self, time_step=0, **forward_kwargs):
434434 new_scheduler = scheduler_class .from_config (tmpdirname )
435435 # copy over dummy past residuals
436436 new_scheduler .ets = dummy_past_residuals [:]
437- new_scheduler .set_timesteps (kwargs [ " num_inference_steps" ] )
437+ new_scheduler .set_timesteps (num_inference_steps )
438438
439439 output = scheduler .step_prk (residual , time_step , sample , ** kwargs )["prev_sample" ]
440440 new_output = new_scheduler .step_prk (residual , time_step , sample , ** kwargs )["prev_sample" ]
@@ -474,12 +474,12 @@ def test_pytorch_equal_numpy(self):
474474 elif num_inference_steps is not None and not hasattr (scheduler , "set_timesteps" ):
475475 kwargs ["num_inference_steps" ] = num_inference_steps
476476
477- output = scheduler .step_prk (residual , 1 , sample , num_inference_steps , ** kwargs )["prev_sample" ]
478- output_pt = scheduler_pt .step_prk (residual_pt , 1 , sample_pt , num_inference_steps , ** kwargs )["prev_sample" ]
477+ output = scheduler .step_prk (residual , 1 , sample , ** kwargs )["prev_sample" ]
478+ output_pt = scheduler_pt .step_prk (residual_pt , 1 , sample_pt , ** kwargs )["prev_sample" ]
479479 assert np .sum (np .abs (output - output_pt .numpy ())) < 1e-4 , "Scheduler outputs are not identical"
480480
481- output = scheduler .step_plms (residual , 1 , sample , num_inference_steps , ** kwargs )["prev_sample" ]
482- output_pt = scheduler_pt .step_plms (residual_pt , 1 , sample_pt , num_inference_steps , ** kwargs )["prev_sample" ]
481+ output = scheduler .step_plms (residual , 1 , sample , ** kwargs )["prev_sample" ]
482+ output_pt = scheduler_pt .step_plms (residual_pt , 1 , sample_pt , ** kwargs )["prev_sample" ]
483483
484484 assert np .sum (np .abs (output - output_pt .numpy ())) < 1e-4 , "Scheduler outputs are not identical"
485485
@@ -503,14 +503,14 @@ def test_step_shape(self):
503503 elif num_inference_steps is not None and not hasattr (scheduler , "set_timesteps" ):
504504 kwargs ["num_inference_steps" ] = num_inference_steps
505505
506- output_0 = scheduler .step_prk (residual , 0 , sample , num_inference_steps , ** kwargs )["prev_sample" ]
507- output_1 = scheduler .step_prk (residual , 1 , sample , num_inference_steps , ** kwargs )["prev_sample" ]
506+ output_0 = scheduler .step_prk (residual , 0 , sample , ** kwargs )["prev_sample" ]
507+ output_1 = scheduler .step_prk (residual , 1 , sample , ** kwargs )["prev_sample" ]
508508
509509 self .assertEqual (output_0 .shape , sample .shape )
510510 self .assertEqual (output_0 .shape , output_1 .shape )
511511
512- output_0 = scheduler .step_plms (residual , 0 , sample , num_inference_steps , ** kwargs )["prev_sample" ]
513- output_1 = scheduler .step_plms (residual , 1 , sample , num_inference_steps , ** kwargs )["prev_sample" ]
512+ output_0 = scheduler .step_plms (residual , 0 , sample , ** kwargs )["prev_sample" ]
513+ output_1 = scheduler .step_plms (residual , 1 , sample , ** kwargs )["prev_sample" ]
514514
515515 self .assertEqual (output_0 .shape , sample .shape )
516516 self .assertEqual (output_0 .shape , output_1 .shape )
@@ -541,7 +541,7 @@ def test_inference_plms_no_past_residuals(self):
541541 scheduler_config = self .get_scheduler_config ()
542542 scheduler = scheduler_class (** scheduler_config )
543543
544- scheduler .step_plms (self .dummy_sample , 1 , self .dummy_sample , 50 )["prev_sample" ]
544+ scheduler .step_plms (self .dummy_sample , 1 , self .dummy_sample )["prev_sample" ]
545545
546546 def test_full_loop_no_noise (self ):
547547 scheduler_class = self .scheduler_classes [0 ]
@@ -555,11 +555,11 @@ def test_full_loop_no_noise(self):
555555
556556 for i , t in enumerate (scheduler .prk_timesteps ):
557557 residual = model (sample , t )
558- sample = scheduler .step_prk (residual , i , sample , num_inference_steps )["prev_sample" ]
558+ sample = scheduler .step_prk (residual , i , sample )["prev_sample" ]
559559
560560 for i , t in enumerate (scheduler .plms_timesteps ):
561561 residual = model (sample , t )
562- sample = scheduler .step_plms (residual , i , sample , num_inference_steps )["prev_sample" ]
562+ sample = scheduler .step_plms (residual , i , sample )["prev_sample" ]
563563
564564 result_sum = torch .sum (torch .abs (sample ))
565565 result_mean = torch .mean (torch .abs (sample ))
@@ -706,7 +706,7 @@ def test_full_loop_no_noise(self):
706706 model_output = model (sample , sigma_t )
707707
708708 output = scheduler .step_pred (model_output , t , sample , ** kwargs )
709- sample , sample_mean = output ["prev_sample" ], output ["prev_sample_mean" ]
709+ sample , _ = output ["prev_sample" ], output ["prev_sample_mean" ]
710710
711711 result_sum = torch .sum (torch .abs (sample ))
712712 result_mean = torch .mean (torch .abs (sample ))
0 commit comments