@@ -342,7 +342,7 @@ def step_prk(
342342 ),
343343 ets = jax .lax .select (
344344 (state .counter % 4 ) == 0 ,
345- state .ets .at [state .counter // 4 ].set (model_output ), # remainder 0
345+ state .ets .at [0 : 3 ]. set ( state .ets [ 1 : 4 ]). at [ 3 ].set (model_output ), # remainder 0
346346 state .ets , # remainder 1, 2, 3
347347 ),
348348 cur_sample = jax .lax .select (
@@ -388,7 +388,7 @@ def step_plms(
388388 "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
389389 )
390390
391- if not self .config .skip_prk_steps and len ( state . ets ) < 3 :
391+ if not self .config .skip_prk_steps :
392392 raise ValueError (
393393 f"{ self .__class__ } can only be run AFTER scheduler has been run "
394394 "in 'prk' mode for at least 12 iterations "
@@ -427,31 +427,27 @@ def step_plms(
427427 # model_output = (1 / 24) * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4])
428428
429429 state = state .replace (
430- ets = jax .lax .select_n (
431- jnp . clip ( state .counter , 0 , 5 ) ,
432- state .ets .at [0 ].set (model_output ), # counter 0
430+ ets = jax .lax .select (
431+ state .counter != 1 ,
432+ state .ets .at [0 : 3 ].set (state . ets [ 1 : 4 ]). at [ 3 ]. set ( model_output ), # counter != 1
433433 state .ets , # counter 1
434- state .ets .at [1 ].set (model_output ), # counter 2
435- state .ets .at [2 ].set (model_output ), # counter 3
436- state .ets .at [3 ].set (model_output ), # counter 4
437- state .ets .at [0 :3 ].set (state .ets [1 :4 ]).at [3 ].set (model_output ), # counter >= 5
438434 ),
439435 cur_sample = jax .lax .select (
440- state .counter == 1 ,
441- state .cur_sample , # counter 1
436+ state .counter != 1 ,
442437 sample , # counter != 1
438+ state .cur_sample , # counter 1
443439 ),
444440 )
445441
446442 state = state .replace (
447443 cur_model_output = jax .lax .select_n (
448444 jnp .clip (state .counter , 0 , 4 ),
449445 model_output , # counter 0
450- (model_output + state .ets [0 ]) / 2 , # counter 1
451- (3 * state .ets [1 ] - state .ets [0 ]) / 2 , # counter 2
452- (23 * state .ets [2 ] - 16 * state .ets [1 ] + 5 * state .ets [0 ]) / 12 , # counter 3
446+ (model_output + state .ets [- 1 ]) / 2 , # counter 1
447+ (3 * state .ets [- 1 ] - state .ets [- 2 ]) / 2 , # counter 2
448+ (23 * state .ets [- 1 ] - 16 * state .ets [- 2 ] + 5 * state .ets [- 3 ]) / 12 , # counter 3
453449 (1 / 24 )
454- * (55 * state .ets [3 ] - 59 * state .ets [2 ] + 37 * state .ets [1 ] - 9 * state .ets [0 ]), # counter >= 4
450+ * (55 * state .ets [- 1 ] - 59 * state .ets [- 2 ] + 37 * state .ets [- 3 ] - 9 * state .ets [- 4 ]), # counter >= 4
455451 ),
456452 )
457453
0 commit comments