Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ def loop_body(step, args):
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
timestep = jnp.broadcast_to(t, latents_input.shape[0])

latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t)

# predict the noise residual
noise_pred = self.unet.apply(
{"params": params["unet"]},
Expand All @@ -189,6 +191,9 @@ def loop_body(step, args):
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape
)

# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma

if debug:
# run with python for loop
for i in range(num_inference_steps):
Expand Down
17 changes: 17 additions & 0 deletions src/diffusers/schedulers/scheduling_ddim_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,23 @@ def __init__(
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else float(self._alphas_cumprod[0])

# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0

def scale_model_input(
self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
) -> jnp.ndarray:
"""
Args:
state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
sample (`jnp.ndarray`): input sample
timestep (`int`, optional): current timestep

Returns:
`jnp.ndarray`: scaled input sample
"""
return sample

def create_state(self):
return DDIMSchedulerState.create(
num_train_timesteps=self.config.num_train_timesteps, alphas_cumprod=self._alphas_cumprod
Expand Down
22 changes: 21 additions & 1 deletion src/diffusers/schedulers/scheduling_pndm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ def __init__(
# mainly at formula (9), (12), (13) and the Algorithm 2.
self.pndm_order = 4

# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0

def create_state(self):
return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)

Expand Down Expand Up @@ -196,14 +199,31 @@ def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, sha
)

return state.replace(
timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int64),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

double precision values are not available by default injax cf https://github.com/google/jax#current-gotchas

timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int32),
counter=0,
# Reserve space for the state variables
cur_model_output=jnp.zeros(shape),
cur_sample=jnp.zeros(shape),
ets=jnp.zeros((4,) + shape),
)

def scale_model_input(
self, state: PNDMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
) -> jnp.ndarray:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.

Args:
state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
sample (`jnp.ndarray`): input sample
timestep (`int`, optional): current timestep

Returns:
`jnp.ndarray`: scaled input sample
"""
return sample

def step(
self,
state: PNDMSchedulerState,
Expand Down