Skip to content

Commit c17e823

Browse files
committed
Add note for skipping prk
1 parent da534b9 commit c17e823

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

src/diffusers/schedulers/scheduling_pndm_flax.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -388,13 +388,14 @@ def step_plms(
388388
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
389389
)
390390

391-
if not self.config.skip_prk_steps:
392-
raise ValueError(
393-
f"{self.__class__} can only be run AFTER scheduler has been run "
394-
"in 'prk' mode for at least 12 iterations "
395-
"See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py "
396-
"for more information."
397-
)
391+
# NOTE: There is no way to check in the jitted runtime if the prk mode was ran before
392+
# if not self.config.skip_prk_steps and state.counter < 3:
393+
# raise ValueError(
394+
# f"{self.__class__} can only be run AFTER scheduler has been run "
395+
# "in 'prk' mode for at least 12 iterations "
396+
# "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py "
397+
# "for more information."
398+
# )
398399

399400
prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
400401
prev_timestep = jnp.where(prev_timestep > 0, prev_timestep, 0)

0 commit comments

Comments
 (0)