Skip to content

Commit 0108066

Browse files
committed
Fix prk_timesteps and plms_timesteps dtype
1 parent 8722622 commit 0108066

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/diffusers/schedulers/scheduling_pndm_flax.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,19 +183,19 @@ def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, sha
183183
# produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation
184184
# is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
185185

186-
prk_timesteps = jnp.array([], dtype=self.config.dtype)
186+
prk_timesteps = jnp.array([], dtype=jnp.int32)
187187
plms_timesteps = jnp.concatenate([_timesteps[:-1], _timesteps[-2:-1], _timesteps[-1:]])[::-1]
188188

189189
else:
190190
prk_timesteps = jnp.array(_timesteps[-self.pndm_order :]).repeat(2) + jnp.tile(
191-
jnp.array([0, self.config.num_train_timesteps // num_inference_steps // 2], dtype=self.config.dtype),
191+
jnp.array([0, self.config.num_train_timesteps // num_inference_steps // 2], dtype=jnp.int32),
192192
self.pndm_order,
193193
)
194194

195195
prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1]
196196
plms_timesteps = _timesteps[:-3][::-1]
197197

198-
timesteps = jnp.concatenate([prk_timesteps, plms_timesteps]).astype(jnp.int32)
198+
timesteps = jnp.concatenate([prk_timesteps, plms_timesteps])
199199

200200
# initial running values
201201

0 commit comments

Comments
 (0)