Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/diffusers/schedulers/scheduling_ddpm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,13 @@ def step(
t = timestep

if key is None:
key = jax.random.PRNGKey(0)
key = jax.random.key(0)

if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]:
if (
len(model_output.shape) > 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we adding this len() > 1 check here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because input could have shape like x.shape == (10,). Accessing index 1 in the next line is an error in that case.

and model_output.shape[1] == sample.shape[1] * 2
and self.config.variance_type in ["learned", "learned_range"]
):
model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1)
else:
predicted_variance = None
Expand Down Expand Up @@ -264,7 +268,7 @@ def step(

# 6. Add noise
def random_variance():
split_key = jax.random.split(key, num=1)
split_key = jax.random.split(key, num=1)[0]
noise = jax.random.normal(split_key, shape=model_output.shape, dtype=self.dtype)
return (self._get_variance(state, t, predicted_variance=predicted_variance) ** 0.5) * noise

Expand Down