diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py index 6bdfa5eb5e0d..d06a171159ee 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_flax.py +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -222,9 +222,13 @@ def step( t = timestep if key is None: - key = jax.random.PRNGKey(0) + key = jax.random.key(0) - if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]: + if ( + len(model_output.shape) > 1 + and model_output.shape[1] == sample.shape[1] * 2 + and self.config.variance_type in ["learned", "learned_range"] + ): model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1) else: predicted_variance = None @@ -264,7 +268,7 @@ def step( # 6. Add noise def random_variance(): - split_key = jax.random.split(key, num=1) + split_key = jax.random.split(key, num=1)[0] noise = jax.random.normal(split_key, shape=model_output.shape, dtype=self.dtype) return (self._get_variance(state, t, predicted_variance=predicted_variance) ** 0.5) * noise