Skip to content

Commit da534b9

Browse files
committed
Fix and simplify PNDM by aligning with Pytorch
1 parent fe29f16 commit da534b9

File tree

1 file changed

+11
-15
lines changed

1 file changed

+11
-15
lines changed

src/diffusers/schedulers/scheduling_pndm_flax.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def step_prk(
342342
),
343343
ets=jax.lax.select(
344344
(state.counter % 4) == 0,
345-
state.ets.at[state.counter // 4].set(model_output), # remainder 0
345+
state.ets.at[0:3].set(state.ets[1:4]).at[3].set(model_output), # remainder 0
346346
state.ets, # remainder 1, 2, 3
347347
),
348348
cur_sample=jax.lax.select(
@@ -388,7 +388,7 @@ def step_plms(
388388
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
389389
)
390390

391-
if not self.config.skip_prk_steps and len(state.ets) < 3:
391+
if not self.config.skip_prk_steps:
392392
raise ValueError(
393393
f"{self.__class__} can only be run AFTER scheduler has been run "
394394
"in 'prk' mode for at least 12 iterations "
@@ -427,31 +427,27 @@ def step_plms(
427427
# model_output = (1 / 24) * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4])
428428

429429
state = state.replace(
430-
ets=jax.lax.select_n(
431-
jnp.clip(state.counter, 0, 5),
432-
state.ets.at[0].set(model_output), # counter 0
430+
ets=jax.lax.select(
431+
state.counter != 1,
432+
state.ets.at[0:3].set(state.ets[1:4]).at[3].set(model_output), # counter != 1
433433
state.ets, # counter 1
434-
state.ets.at[1].set(model_output), # counter 2
435-
state.ets.at[2].set(model_output), # counter 3
436-
state.ets.at[3].set(model_output), # counter 4
437-
state.ets.at[0:3].set(state.ets[1:4]).at[3].set(model_output), # counter >= 5
438434
),
439435
cur_sample=jax.lax.select(
440-
state.counter == 1,
441-
state.cur_sample, # counter 1
436+
state.counter != 1,
442437
sample, # counter != 1
438+
state.cur_sample, # counter 1
443439
),
444440
)
445441

446442
state = state.replace(
447443
cur_model_output=jax.lax.select_n(
448444
jnp.clip(state.counter, 0, 4),
449445
model_output, # counter 0
450-
(model_output + state.ets[0]) / 2, # counter 1
451-
(3 * state.ets[1] - state.ets[0]) / 2, # counter 2
452-
(23 * state.ets[2] - 16 * state.ets[1] + 5 * state.ets[0]) / 12, # counter 3
446+
(model_output + state.ets[-1]) / 2, # counter 1
447+
(3 * state.ets[-1] - state.ets[-2]) / 2, # counter 2
448+
(23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12, # counter 3
453449
(1 / 24)
454-
* (55 * state.ets[3] - 59 * state.ets[2] + 37 * state.ets[1] - 9 * state.ets[0]), # counter >= 4
450+
* (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]), # counter >= 4
455451
),
456452
)
457453

0 commit comments

Comments
 (0)