Skip to content

Commit 3a64fbf

Browse files
committed
Add missing set_timesteps in tests
1 parent 0108066 commit 3a64fbf

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

tests/test_scheduler_flax.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,9 @@ def test_inference_plms_no_past_residuals(self):
878878
scheduler = scheduler_class(**scheduler_config)
879879
state = scheduler.create_state()
880880

881+
num_inference_steps = 10
882+
state = scheduler.set_timesteps(state, num_inference_steps, shape=self.dummy_sample.shape)
883+
881884
scheduler.step_plms(state, self.dummy_sample, 1, self.dummy_sample).prev_sample
882885

883886
def test_full_loop_no_noise(self):

0 commit comments

Comments
 (0)