diff --git a/examples/dreambooth/train_dreambooth_flax.py b/examples/dreambooth/train_dreambooth_flax.py index f9383bad2671..1470df14cde5 100644 --- a/examples/dreambooth/train_dreambooth_flax.py +++ b/examples/dreambooth/train_dreambooth_flax.py @@ -477,6 +477,7 @@ def collate_fn(examples): noise_scheduler = FlaxDDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 ) + noise_scheduler_state = noise_scheduler.create_state() # Initialize our training train_rngs = jax.random.split(rng, jax.local_device_count()) @@ -513,7 +514,7 @@ def compute_loss(params): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps) # Get the text embedding for conditioning if args.train_text_encoder: diff --git a/examples/text_to_image/train_text_to_image_flax.py b/examples/text_to_image/train_text_to_image_flax.py index 056ff6bad0f4..b3379226f243 100644 --- a/examples/text_to_image/train_text_to_image_flax.py +++ b/examples/text_to_image/train_text_to_image_flax.py @@ -417,6 +417,7 @@ def collate_fn(examples): noise_scheduler = FlaxDDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 ) + noise_scheduler_state = noise_scheduler.create_state() # Initialize our training rng = jax.random.PRNGKey(args.seed) @@ -449,7 +450,7 @@ def compute_loss(params): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps) # Get the text embedding for conditioning encoder_hidden_states = text_encoder( diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index a9fa9e36931b..320b194a0d38 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -505,6 +505,7 @@ def update_fn(updates, state, params=None): noise_scheduler = FlaxDDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 ) + noise_scheduler_state = noise_scheduler.create_state() # Initialize our training train_rngs = jax.random.split(rng, jax.local_device_count()) @@ -531,7 +532,7 @@ def compute_loss(params): 0, noise_scheduler.config.num_train_timesteps, ) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps) encoder_hidden_states = state.apply_fn( batch["input_ids"], params=params, dropout_rng=dropout_rng, train=True )[0] diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index e26d71ffda23..eccb4a9b9d87 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -261,7 +261,8 @@ def loop_body(step, args): ) # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma + latents = latents * params["scheduler"].init_noise_sigma + if DEBUG: # run with python for loop for i in range(num_inference_steps): diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index a6f7eace2637..875da07adb59 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -15,7 +15,6 @@ # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion # and https://github.com/hojonathanho/diffusion -import math from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -26,51 +25,37 @@ from ..utils import deprecate from .scheduling_utils_flax import ( _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, + CommonSchedulerState, FlaxSchedulerMixin, FlaxSchedulerOutput, - broadcast_to_shape_from_left, + add_noise_common, ) -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - - Returns: - betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs - """ - - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return jnp.array(betas, dtype=jnp.float32) - - @flax.struct.dataclass class DDIMSchedulerState: + common: CommonSchedulerState + final_alpha_cumprod: jnp.ndarray + # setable values + init_noise_sigma: jnp.ndarray timesteps: jnp.ndarray - alphas_cumprod: jnp.ndarray num_inference_steps: Optional[int] = None @classmethod - def create(cls, num_train_timesteps: int, alphas_cumprod: jnp.ndarray): - return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1], alphas_cumprod=alphas_cumprod) + def create( + cls, + common: CommonSchedulerState, + final_alpha_cumprod: jnp.ndarray, + init_noise_sigma: jnp.ndarray, + timesteps: jnp.ndarray, + ): + return cls( + common=common, + final_alpha_cumprod=final_alpha_cumprod, + init_noise_sigma=init_noise_sigma, + timesteps=timesteps, + ) @dataclass @@ -112,12 +97,15 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): prediction_type (`str`, default `epsilon`): indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`. `v-prediction` is not supported for this scheduler. - + dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): + the `dtype` used for params and computation. """ _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() _deprecated_kwargs = ["predict_epsilon"] + dtype: jnp.dtype + @property def has_state(self): return True @@ -129,43 +117,46 @@ def __init__( beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", + trained_betas: Optional[jnp.ndarray] = None, set_alpha_to_one: bool = True, steps_offset: int = 0, prediction_type: str = "epsilon", + dtype: jnp.dtype = jnp.float32, **kwargs, ): message = ( "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" - " FlaxDDIMScheduler.from_pretrained(, prediction_type='epsilon')`." + f" {self.__class__.__name__}.from_pretrained(, prediction_type='epsilon')`." ) predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs) if predict_epsilon is not None: self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample") - if beta_schedule == "linear": - self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2 - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas + self.dtype = dtype - # HACK for now - clean up later (PVP) - self._alphas_cumprod = jnp.cumprod(self.alphas, axis=0) + def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState: + if common is None: + common = CommonSchedulerState.create(self) # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. - self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else float(self._alphas_cumprod[0]) + final_alpha_cumprod = ( + jnp.array(1.0, dtype=self.dtype) if self.config.set_alpha_to_one else common.alphas_cumprod[0] + ) # standard deviation of the initial noise distribution - self.init_noise_sigma = 1.0 + init_noise_sigma = jnp.array(1.0, dtype=self.dtype) + + timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] + + return DDIMSchedulerState.create( + common=common, + final_alpha_cumprod=final_alpha_cumprod, + init_noise_sigma=init_noise_sigma, + timesteps=timesteps, + ) def scale_model_input( self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None @@ -181,21 +172,6 @@ def scale_model_input( """ return sample - def create_state(self): - return DDIMSchedulerState.create( - num_train_timesteps=self.config.num_train_timesteps, alphas_cumprod=self._alphas_cumprod - ) - - def _get_variance(self, timestep, prev_timestep, alphas_cumprod): - alpha_prod_t = alphas_cumprod[timestep] - alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod) - beta_prod_t = 1 - alpha_prod_t - beta_prod_t_prev = 1 - alpha_prod_t_prev - - variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) - - return variance - def set_timesteps( self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple = () ) -> DDIMSchedulerState: @@ -208,15 +184,27 @@ def set_timesteps( num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ - offset = self.config.steps_offset - step_ratio = self.config.num_train_timesteps // num_inference_steps # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] - timesteps = timesteps + offset + # rounding to avoid issues when num_inference_step is power of 3 + timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] + self.config.steps_offset + + return state.replace( + num_inference_steps=num_inference_steps, + timesteps=timesteps, + ) + + def _get_variance(self, state: DDIMSchedulerState, timestep, prev_timestep): + alpha_prod_t = state.common.alphas_cumprod[timestep] + alpha_prod_t_prev = jnp.where( + prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod + ) + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) - return state.replace(num_inference_steps=num_inference_steps, timesteps=timesteps) + return variance def step( self, @@ -224,6 +212,7 @@ def step( model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, + eta: float = 0.0, return_dict: bool = True, ) -> Union[FlaxDDIMSchedulerOutput, Tuple]: """ @@ -259,17 +248,15 @@ def step( # - pred_sample_direction -> "direction pointing to x_t" # - pred_prev_sample -> "x_t-1" - # TODO(Patrick) - eta is always 0.0 for now, allow to be set in step function - eta = 0.0 - # 1. get previous step value (=t-1) prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps - alphas_cumprod = state.alphas_cumprod + alphas_cumprod = state.common.alphas_cumprod + final_alpha_cumprod = state.final_alpha_cumprod # 2. compute alphas, betas alpha_prod_t = alphas_cumprod[timestep] - alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod) + alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], final_alpha_cumprod) beta_prod_t = 1 - alpha_prod_t @@ -291,7 +278,7 @@ def step( # 4. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) - variance = self._get_variance(timestep, prev_timestep, alphas_cumprod) + variance = self._get_variance(state, timestep, prev_timestep) std_dev_t = eta * variance ** (0.5) # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf @@ -307,20 +294,12 @@ def step( def add_noise( self, + state: DDIMSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray, ) -> jnp.ndarray: - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape) - - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape) - - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise - return noisy_samples + return add_noise_common(state.common, original_samples, noise, timesteps) def __len__(self): return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py index 84186546ed9e..7cca0b963c38 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_flax.py +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -14,62 +14,36 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim -import math from dataclasses import dataclass from typing import Optional, Tuple, Union import flax +import jax import jax.numpy as jnp -from jax import random -from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config +from ..configuration_utils import ConfigMixin, register_to_config from ..utils import deprecate from .scheduling_utils_flax import ( _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, + CommonSchedulerState, FlaxSchedulerMixin, FlaxSchedulerOutput, - broadcast_to_shape_from_left, + add_noise_common, ) -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - - Returns: - betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs - """ - - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return jnp.array(betas, dtype=jnp.float32) - - @flax.struct.dataclass class DDPMSchedulerState: + common: CommonSchedulerState + # setable values + init_noise_sigma: jnp.ndarray timesteps: jnp.ndarray num_inference_steps: Optional[int] = None @classmethod - def create(cls, num_train_timesteps: int): - return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1]) + def create(cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray): + return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps) @dataclass @@ -106,11 +80,15 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): prediction_type (`str`, default `epsilon`): indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`. `v-prediction` is not supported for this scheduler. + dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): + the `dtype` used for params and computation. """ _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() _deprecated_kwargs = ["predict_epsilon"] + dtype: jnp.dtype + @property def has_state(self): return True @@ -126,35 +104,47 @@ def __init__( variance_type: str = "fixed_small", clip_sample: bool = True, prediction_type: str = "epsilon", + dtype: jnp.dtype = jnp.float32, **kwargs, ): message = ( "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" - " FlaxDDPMScheduler.from_pretrained(, prediction_type='epsilon')`." + f" {self.__class__.__name__}.from_pretrained(, prediction_type='epsilon')`." ) predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs) if predict_epsilon is not None: self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample") - if trained_betas is not None: - self.betas = jnp.asarray(trained_betas) - elif beta_schedule == "linear": - self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2 - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + self.dtype = dtype + + def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDPMSchedulerState: + if common is None: + common = CommonSchedulerState.create(self) + + # standard deviation of the initial noise distribution + init_noise_sigma = jnp.array(1.0, dtype=self.dtype) + + timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] + + return DDPMSchedulerState.create( + common=common, + init_noise_sigma=init_noise_sigma, + timesteps=timesteps, + ) - self.alphas = 1.0 - self.betas - self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) - self.one = jnp.array(1.0) + def scale_model_input( + self, state: DDPMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None + ) -> jnp.ndarray: + """ + Args: + state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. + sample (`jnp.ndarray`): input sample + timestep (`int`, optional): current timestep - def create_state(self): - return DDPMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps) + Returns: + `jnp.ndarray`: scaled input sample + """ + return sample def set_timesteps( self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple = () @@ -168,20 +158,25 @@ def set_timesteps( num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ - num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps) - timesteps = jnp.arange( - 0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps - )[::-1] - return state.replace(num_inference_steps=num_inference_steps, timesteps=timesteps) - def _get_variance(self, t, predicted_variance=None, variance_type=None): - alpha_prod_t = self.alphas_cumprod[t] - alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one + step_ratio = self.config.num_train_timesteps // num_inference_steps + # creates integer timesteps by multiplying by ratio + # rounding to avoid issues when num_inference_step is power of 3 + timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] + + return state.replace( + num_inference_steps=num_inference_steps, + timesteps=timesteps, + ) + + def _get_variance(self, state: DDPMSchedulerState, t, predicted_variance=None, variance_type=None): + alpha_prod_t = state.common.alphas_cumprod[t] + alpha_prod_t_prev = jnp.where(t > 0, state.common.alphas_cumprod[t - 1], jnp.array(1.0, dtype=self.dtype)) # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) # and sample from it to get previous sample # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample - variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t] + variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * state.common.betas[t] if variance_type is None: variance_type = self.config.variance_type @@ -193,15 +188,15 @@ def _get_variance(self, t, predicted_variance=None, variance_type=None): elif variance_type == "fixed_small_log": variance = jnp.log(jnp.clip(variance, a_min=1e-20)) elif variance_type == "fixed_large": - variance = self.betas[t] + variance = state.common.betas[t] elif variance_type == "fixed_large_log": # Glide max_log - variance = jnp.log(self.betas[t]) + variance = jnp.log(state.common.betas[t]) elif variance_type == "learned": return predicted_variance elif variance_type == "learned_range": min_log = variance - max_log = self.betas[t] + max_log = state.common.betas[t] frac = (predicted_variance + 1) / 2 variance = frac * max_log + (1 - frac) * min_log @@ -213,9 +208,8 @@ def step( model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, - key: random.KeyArray, + key: jax.random.KeyArray = jax.random.PRNGKey(0), return_dict: bool = True, - **kwargs, ) -> Union[FlaxDDPMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion @@ -227,7 +221,7 @@ def step( timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. - key (`random.KeyArray`): a PRNG key. + key (`jax.random.KeyArray`): a PRNG key. return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class Returns: @@ -235,16 +229,6 @@ def step( `tuple`. When returning a tuple, the first element is the sample tensor. """ - message = ( - "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" - " FlaxDDPMScheduler.from_pretrained(, prediction_type='epsilon')`." - ) - predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs) - if predict_epsilon is not None: - new_config = dict(self.config) - new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample" - self._internal_dict = FrozenDict(new_config) - t = timestep if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]: @@ -253,8 +237,8 @@ def step( predicted_variance = None # 1. compute alphas, betas - alpha_prod_t = self.alphas_cumprod[t] - alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one + alpha_prod_t = state.common.alphas_cumprod[t] + alpha_prod_t_prev = jnp.where(t > 0, state.common.alphas_cumprod[t - 1], jnp.array(1.0, dtype=self.dtype)) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev @@ -264,6 +248,8 @@ def step( pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) elif self.config.prediction_type == "sample": pred_original_sample = model_output + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` " @@ -276,19 +262,20 @@ def step( # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf - pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t - current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * state.common.betas[t]) / beta_prod_t + current_sample_coeff = state.common.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t # 5. Compute predicted previous sample µ_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample # 6. Add noise - variance = 0 - if t > 0: - key = random.split(key, num=1) - noise = random.normal(key=key, shape=model_output.shape) - variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise + def random_variance(): + split_key = jax.random.split(key, num=1) + noise = jax.random.normal(split_key, shape=model_output.shape, dtype=self.dtype) + return (self._get_variance(state, t, predicted_variance=predicted_variance) ** 0.5) * noise + + variance = jnp.where(t > 0, random_variance(), jnp.zeros(model_output.shape, dtype=self.dtype)) pred_prev_sample = pred_prev_sample + variance @@ -299,20 +286,12 @@ def step( def add_noise( self, + state: DDPMSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray, ) -> jnp.ndarray: - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape) - - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape) - - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise - return noisy_samples + return add_noise_common(state.common, original_samples, noise, timesteps) def __len__(self): return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py index 6bf389fa6797..0aa121b59dec 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py @@ -14,7 +14,6 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver -import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union @@ -26,57 +25,49 @@ from ..utils import deprecate from .scheduling_utils_flax import ( _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, + CommonSchedulerState, FlaxSchedulerMixin, FlaxSchedulerOutput, - broadcast_to_shape_from_left, + add_noise_common, ) -def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray: - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - - Returns: - betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs - """ - - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return jnp.array(betas, dtype=jnp.float32) - - @flax.struct.dataclass class DPMSolverMultistepSchedulerState: + common: CommonSchedulerState + alpha_t: jnp.ndarray + sigma_t: jnp.ndarray + lambda_t: jnp.ndarray + # setable values + init_noise_sigma: jnp.ndarray + timesteps: jnp.ndarray num_inference_steps: Optional[int] = None - timesteps: Optional[jnp.ndarray] = None # running values model_outputs: Optional[jnp.ndarray] = None - lower_order_nums: Optional[int] = None - step_index: Optional[int] = None - prev_timestep: Optional[int] = None + lower_order_nums: Optional[jnp.int32] = None + prev_timestep: Optional[jnp.int32] = None cur_sample: Optional[jnp.ndarray] = None @classmethod - def create(cls, num_train_timesteps: int): - return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1]) + def create( + cls, + common: CommonSchedulerState, + alpha_t: jnp.ndarray, + sigma_t: jnp.ndarray, + lambda_t: jnp.ndarray, + init_noise_sigma: jnp.ndarray, + timesteps: jnp.ndarray, + ): + return cls( + common=common, + alpha_t=alpha_t, + sigma_t=sigma_t, + lambda_t=lambda_t, + init_noise_sigma=init_noise_sigma, + timesteps=timesteps, + ) @dataclass @@ -145,12 +136,15 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): lower_order_final (`bool`, default `True`): whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. - + dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): + the `dtype` used for params and computation. """ _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() _deprecated_kwargs = ["predict_epsilon"] + dtype: jnp.dtype + @property def has_state(self): return True @@ -171,47 +165,47 @@ def __init__( algorithm_type: str = "dpmsolver++", solver_type: str = "midpoint", lower_order_final: bool = True, + dtype: jnp.dtype = jnp.float32, **kwargs, ): message = ( "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" - " FlaxDPMSolverMultistepScheduler.from_pretrained(, prediction_type='epsilon')`." + f" {self.__class__.__name__}.from_pretrained(, prediction_type='epsilon')`." ) predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs) if predict_epsilon is not None: self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample") - if trained_betas is not None: - self.betas = jnp.asarray(trained_betas) - elif beta_schedule == "linear": - self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2 - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + self.dtype = dtype + + def create_state(self, common: Optional[CommonSchedulerState] = None) -> DPMSolverMultistepSchedulerState: + if common is None: + common = CommonSchedulerState.create(self) - self.alphas = 1.0 - self.betas - self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) # Currently we only support VP-type noise schedule - self.alpha_t = jnp.sqrt(self.alphas_cumprod) - self.sigma_t = jnp.sqrt(1 - self.alphas_cumprod) - self.lambda_t = jnp.log(self.alpha_t) - jnp.log(self.sigma_t) + alpha_t = jnp.sqrt(common.alphas_cumprod) + sigma_t = jnp.sqrt(1 - common.alphas_cumprod) + lambda_t = jnp.log(alpha_t) - jnp.log(sigma_t) + + # settings for DPM-Solver + if self.config.algorithm_type not in ["dpmsolver", "dpmsolver++"]: + raise NotImplementedError(f"{self.config.algorithm_type} does is not implemented for {self.__class__}") + if self.config.solver_type not in ["midpoint", "heun"]: + raise NotImplementedError(f"{self.config.solver_type} does is not implemented for {self.__class__}") # standard deviation of the initial noise distribution - self.init_noise_sigma = 1.0 + init_noise_sigma = jnp.array(1.0, dtype=self.dtype) - # settings for DPM-Solver - if algorithm_type not in ["dpmsolver", "dpmsolver++"]: - raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") - if solver_type not in ["midpoint", "heun"]: - raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") + timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] - def create_state(self): - return DPMSolverMultistepSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps) + return DPMSolverMultistepSchedulerState.create( + common=common, + alpha_t=alpha_t, + sigma_t=sigma_t, + lambda_t=lambda_t, + init_noise_sigma=init_noise_sigma, + timesteps=timesteps, + ) def set_timesteps( self, state: DPMSolverMultistepSchedulerState, num_inference_steps: int, shape: Tuple @@ -227,24 +221,32 @@ def set_timesteps( shape (`Tuple`): the shape of the samples to be generated. """ + timesteps = ( jnp.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) .round()[::-1][:-1] .astype(jnp.int32) ) + # initial running values + + model_outputs = jnp.zeros((self.config.solver_order,) + shape, dtype=self.dtype) + lower_order_nums = jnp.int32(0) + prev_timestep = jnp.int32(-1) + cur_sample = jnp.zeros(shape, dtype=self.dtype) + return state.replace( num_inference_steps=num_inference_steps, timesteps=timesteps, - model_outputs=jnp.zeros((self.config.solver_order,) + shape), - lower_order_nums=0, - step_index=0, - prev_timestep=-1, - cur_sample=jnp.zeros(shape), + model_outputs=model_outputs, + lower_order_nums=lower_order_nums, + prev_timestep=prev_timestep, + cur_sample=cur_sample, ) def convert_model_output( self, + state: DPMSolverMultistepSchedulerState, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, @@ -271,12 +273,12 @@ def convert_model_output( # DPM-Solver++ needs to solve an integral of the data prediction model. if self.config.algorithm_type == "dpmsolver++": if self.config.prediction_type == "epsilon": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep] x0_pred = alpha_t * sample - sigma_t * model_output else: raise ValueError( @@ -299,11 +301,11 @@ def convert_model_output( if self.config.prediction_type == "epsilon": return model_output elif self.config.prediction_type == "sample": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep] epsilon = (sample - alpha_t * model_output) / sigma_t return epsilon elif self.config.prediction_type == "v_prediction": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep] epsilon = alpha_t * model_output + sigma_t * sample return epsilon else: @@ -313,7 +315,12 @@ def convert_model_output( ) def dpm_solver_first_order_update( - self, model_output: jnp.ndarray, timestep: int, prev_timestep: int, sample: jnp.ndarray + self, + state: DPMSolverMultistepSchedulerState, + model_output: jnp.ndarray, + timestep: int, + prev_timestep: int, + sample: jnp.ndarray, ) -> jnp.ndarray: """ One step for the first-order DPM-Solver (equivalent to DDIM). @@ -332,9 +339,9 @@ def dpm_solver_first_order_update( """ t, s0 = prev_timestep, timestep m0 = model_output - lambda_t, lambda_s = self.lambda_t[t], self.lambda_t[s0] - alpha_t, alpha_s = self.alpha_t[t], self.alpha_t[s0] - sigma_t, sigma_s = self.sigma_t[t], self.sigma_t[s0] + lambda_t, lambda_s = state.lambda_t[t], state.lambda_t[s0] + alpha_t, alpha_s = state.alpha_t[t], state.alpha_t[s0] + sigma_t, sigma_s = state.sigma_t[t], state.sigma_t[s0] h = lambda_t - lambda_s if self.config.algorithm_type == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * m0 @@ -344,6 +351,7 @@ def dpm_solver_first_order_update( def multistep_dpm_solver_second_order_update( self, + state: DPMSolverMultistepSchedulerState, model_output_list: jnp.ndarray, timestep_list: List[int], prev_timestep: int, @@ -365,9 +373,9 @@ def multistep_dpm_solver_second_order_update( """ t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] m0, m1 = model_output_list[-1], model_output_list[-2] - lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1] - alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] - sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + lambda_t, lambda_s0, lambda_s1 = state.lambda_t[t], state.lambda_t[s0], state.lambda_t[s1] + alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0] + sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0] h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 r0 = h_0 / h D0, D1 = m0, (1.0 / r0) * (m0 - m1) @@ -403,6 +411,7 @@ def multistep_dpm_solver_second_order_update( def multistep_dpm_solver_third_order_update( self, + state: DPMSolverMultistepSchedulerState, model_output_list: jnp.ndarray, timestep_list: List[int], prev_timestep: int, @@ -425,13 +434,13 @@ def multistep_dpm_solver_third_order_update( t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( - self.lambda_t[t], - self.lambda_t[s0], - self.lambda_t[s1], - self.lambda_t[s2], + state.lambda_t[t], + state.lambda_t[s0], + state.lambda_t[s1], + state.lambda_t[s2], ) - alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] - sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0] + sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0] h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 r0, r1 = h_0 / h, h_1 / h D0 = m0 @@ -482,14 +491,17 @@ def step( `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ - prev_timestep = jax.lax.cond( - state.step_index == len(state.timesteps) - 1, - lambda _: 0, - lambda _: state.timesteps[state.step_index + 1], - (), - ) + if state.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) - model_output = self.convert_model_output(model_output, timestep, sample) + (step_index,) = jnp.where(state.timesteps == timestep, size=1) + step_index = step_index[0] + + prev_timestep = jax.lax.select(step_index == len(state.timesteps) - 1, 0, state.timesteps[step_index + 1]) + + model_output = self.convert_model_output(state, model_output, timestep, sample) model_outputs_new = jnp.roll(state.model_outputs, -1, axis=0) model_outputs_new = model_outputs_new.at[-1].set(model_output) @@ -501,16 +513,18 @@ def step( def step_1(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: return self.dpm_solver_first_order_update( + state, state.model_outputs[-1], - state.timesteps[state.step_index], + state.timesteps[step_index], state.prev_timestep, state.cur_sample, ) def step_23(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: def step_2(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: - timestep_list = jnp.array([state.timesteps[state.step_index - 1], state.timesteps[state.step_index]]) + timestep_list = jnp.array([state.timesteps[step_index - 1], state.timesteps[step_index]]) return self.multistep_dpm_solver_second_order_update( + state, state.model_outputs, timestep_list, state.prev_timestep, @@ -520,65 +534,67 @@ def step_2(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: timestep_list = jnp.array( [ - state.timesteps[state.step_index - 2], - state.timesteps[state.step_index - 1], - state.timesteps[state.step_index], + state.timesteps[step_index - 2], + state.timesteps[step_index - 1], + state.timesteps[step_index], ] ) return self.multistep_dpm_solver_third_order_update( + state, state.model_outputs, timestep_list, state.prev_timestep, state.cur_sample, ) + step_2_output = step_2(state) + step_3_output = step_3(state) + if self.config.solver_order == 2: - return step_2(state) + return step_2_output elif self.config.lower_order_final and len(state.timesteps) < 15: - return jax.lax.cond( + return jax.lax.select( state.lower_order_nums < 2, - step_2, - lambda state: jax.lax.cond( - state.step_index == len(state.timesteps) - 2, - step_2, - step_3, - state, + step_2_output, + jax.lax.select( + step_index == len(state.timesteps) - 2, + step_2_output, + step_3_output, ), - state, ) else: - return jax.lax.cond( + return jax.lax.select( state.lower_order_nums < 2, - step_2, - step_3, - state, + step_2_output, + step_3_output, ) + step_1_output = step_1(state) + step_23_output = step_23(state) + if self.config.solver_order == 1: - prev_sample = step_1(state) + prev_sample = step_1_output + elif self.config.lower_order_final and len(state.timesteps) < 15: - prev_sample = jax.lax.cond( + prev_sample = jax.lax.select( state.lower_order_nums < 1, - step_1, - lambda state: jax.lax.cond( - state.step_index == len(state.timesteps) - 1, - step_1, - step_23, - state, + step_1_output, + jax.lax.select( + step_index == len(state.timesteps) - 1, + step_1_output, + step_23_output, ), - state, ) + else: - prev_sample = jax.lax.cond( + prev_sample = jax.lax.select( state.lower_order_nums < 1, - step_1, - step_23, - state, + step_1_output, + step_23_output, ) state = state.replace( lower_order_nums=jnp.minimum(state.lower_order_nums + 1, self.config.solver_order), - step_index=(state.step_index + 1), ) if not return_dict: @@ -606,20 +622,12 @@ def scale_model_input( def add_noise( self, + state: DPMSolverMultistepSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray, ) -> jnp.ndarray: - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape) - - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape) - - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise - return noisy_samples + return add_noise_common(state.common, original_samples, noise, timesteps) def __len__(self): return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_karras_ve_flax.py b/src/diffusers/schedulers/scheduling_karras_ve_flax.py index c4e612c3cc84..08d41d006ca4 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_karras_ve_flax.py @@ -233,5 +233,5 @@ def step_correct( return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state) - def add_noise(self, original_samples, noise, timesteps): + def add_noise(self, state: KarrasVeSchedulerState, original_samples, noise, timesteps): raise NotImplementedError() diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py index 5da43be2ada3..fde18f2653d6 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -22,6 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import ( _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, + CommonSchedulerState, FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left, @@ -30,15 +31,22 @@ @flax.struct.dataclass class LMSDiscreteSchedulerState: + common: CommonSchedulerState + # setable values + init_noise_sigma: jnp.ndarray + timesteps: jnp.ndarray + sigmas: jnp.ndarray num_inference_steps: Optional[int] = None - timesteps: Optional[jnp.ndarray] = None - sigmas: Optional[jnp.ndarray] = None - derivatives: jnp.ndarray = jnp.array([]) + + # running values + derivatives: Optional[jnp.ndarray] = None @classmethod - def create(cls, num_train_timesteps: int, sigmas: jnp.ndarray): - return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1], sigmas=sigmas) + def create( + cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray + ): + return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas) @dataclass @@ -66,10 +74,18 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): `linear` or `scaled_linear`. trained_betas (`jnp.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): + the `dtype` used for params and computation. """ _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + dtype: jnp.dtype + @property def has_state(self): return True @@ -82,24 +98,26 @@ def __init__( beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[jnp.ndarray] = None, + prediction_type: str = "epsilon", + dtype: jnp.dtype = jnp.float32, ): - if trained_betas is not None: - self.betas = jnp.asarray(trained_betas) - elif beta_schedule == "linear": - self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2 - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + self.dtype = dtype - self.alphas = 1.0 - self.betas - self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) + def create_state(self, common: Optional[CommonSchedulerState] = None) -> LMSDiscreteSchedulerState: + if common is None: + common = CommonSchedulerState.create(self) - def create_state(self): - self.state = LMSDiscreteSchedulerState.create( - num_train_timesteps=self.config.num_train_timesteps, - sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5, + timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] + sigmas = ((1 - common.alphas_cumprod) / common.alphas_cumprod) ** 0.5 + + # standard deviation of the initial noise distribution + init_noise_sigma = sigmas.max() + + return LMSDiscreteSchedulerState.create( + common=common, + init_noise_sigma=init_noise_sigma, + timesteps=timesteps, + sigmas=sigmas, ) def scale_model_input(self, state: LMSDiscreteSchedulerState, sample: jnp.ndarray, timestep: int) -> jnp.ndarray: @@ -118,11 +136,13 @@ def scale_model_input(self, state: LMSDiscreteSchedulerState, sample: jnp.ndarra `jnp.ndarray`: scaled input sample """ (step_index,) = jnp.where(state.timesteps == timestep, size=1) + step_index = step_index[0] + sigma = state.sigmas[step_index] sample = sample / ((sigma**2 + 1) ** 0.5) return sample - def get_lms_coefficient(self, state, order, t, current_order): + def get_lms_coefficient(self, state: LMSDiscreteSchedulerState, order, t, current_order): """ Compute a linear multistep coefficient. @@ -156,20 +176,28 @@ def set_timesteps( num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ - timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=jnp.float32) - low_idx = jnp.floor(timesteps).astype(int) - high_idx = jnp.ceil(timesteps).astype(int) + timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype) + + low_idx = jnp.floor(timesteps).astype(jnp.int32) + high_idx = jnp.ceil(timesteps).astype(jnp.int32) + frac = jnp.mod(timesteps, 1.0) - sigmas = jnp.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + + sigmas = ((1 - state.common.alphas_cumprod) / state.common.alphas_cumprod) ** 0.5 sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] - sigmas = jnp.concatenate([sigmas, jnp.array([0.0])]).astype(jnp.float32) + sigmas = jnp.concatenate([sigmas, jnp.array([0.0], dtype=self.dtype)]) + + timesteps = timesteps.astype(jnp.int32) + + # initial running values + derivatives = jnp.zeros((0,) + shape, dtype=self.dtype) return state.replace( - num_inference_steps=num_inference_steps, - timesteps=timesteps.astype(int), - derivatives=jnp.array([]), + timesteps=timesteps, sigmas=sigmas, + num_inference_steps=num_inference_steps, + derivatives=derivatives, ) def step( @@ -199,10 +227,23 @@ def step( `tuple`. When returning a tuple, the first element is the sample tensor. """ + if state.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + sigma = state.sigmas[timestep] # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise - pred_original_sample = sample - sigma * model_output + if self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) # 2. Convert to an ODE derivative derivative = (sample - pred_original_sample) / sigma diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 298e62de20d1..25c0db934617 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -14,7 +14,6 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim -import math from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -25,59 +24,45 @@ from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import ( _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, + CommonSchedulerState, FlaxSchedulerMixin, FlaxSchedulerOutput, - broadcast_to_shape_from_left, + add_noise_common, ) -def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray: - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - - Returns: - betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs - """ - - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return jnp.array(betas, dtype=jnp.float32) - - @flax.struct.dataclass class PNDMSchedulerState: + common: CommonSchedulerState + final_alpha_cumprod: jnp.ndarray + # setable values - _timesteps: jnp.ndarray + init_noise_sigma: jnp.ndarray + timesteps: jnp.ndarray num_inference_steps: Optional[int] = None prk_timesteps: Optional[jnp.ndarray] = None plms_timesteps: Optional[jnp.ndarray] = None - timesteps: Optional[jnp.ndarray] = None # running values cur_model_output: Optional[jnp.ndarray] = None - counter: int = 0 + counter: Optional[jnp.int32] = None cur_sample: Optional[jnp.ndarray] = None - ets: jnp.ndarray = jnp.array([]) + ets: Optional[jnp.ndarray] = None @classmethod - def create(cls, num_train_timesteps: int): - return cls(_timesteps=jnp.arange(0, num_train_timesteps)[::-1]) + def create( + cls, + common: CommonSchedulerState, + final_alpha_cumprod: jnp.ndarray, + init_noise_sigma: jnp.ndarray, + timesteps: jnp.ndarray, + ): + return cls( + common=common, + final_alpha_cumprod=final_alpha_cumprod, + init_noise_sigma=init_noise_sigma, + timesteps=timesteps, + ) @dataclass @@ -117,10 +102,19 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): + the `dtype` used for params and computation. """ _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + dtype: jnp.dtype + pndm_order: int + @property def has_state(self): return True @@ -136,35 +130,39 @@ def __init__( skip_prk_steps: bool = False, set_alpha_to_one: bool = False, steps_offset: int = 0, + prediction_type: str = "epsilon", + dtype: jnp.dtype = jnp.float32, ): - if trained_betas is not None: - self.betas = jnp.asarray(trained_betas) - elif beta_schedule == "linear": - self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2 - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) - - self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + self.dtype = dtype # For now we only support F-PNDM, i.e. the runge-kutta method # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf # mainly at formula (9), (12), (13) and the Algorithm 2. self.pndm_order = 4 + def create_state(self, common: Optional[CommonSchedulerState] = None) -> PNDMSchedulerState: + if common is None: + common = CommonSchedulerState.create(self) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + final_alpha_cumprod = ( + jnp.array(1.0, dtype=self.dtype) if self.config.set_alpha_to_one else common.alphas_cumprod[0] + ) + # standard deviation of the initial noise distribution - self.init_noise_sigma = 1.0 + init_noise_sigma = jnp.array(1.0, dtype=self.dtype) - def create_state(self): - return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps) + timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] + + return PNDMSchedulerState.create( + common=common, + final_alpha_cumprod=final_alpha_cumprod, + init_noise_sigma=init_noise_sigma, + timesteps=timesteps, + ) def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, shape: Tuple) -> PNDMSchedulerState: """ @@ -178,42 +176,47 @@ def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, sha shape (`Tuple`): the shape of the samples to be generated. """ - offset = self.config.steps_offset step_ratio = self.config.num_train_timesteps // num_inference_steps # creates integer timesteps by multiplying by ratio # rounding to avoid issues when num_inference_step is power of 3 - _timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round() + offset - - state = state.replace(num_inference_steps=num_inference_steps, _timesteps=_timesteps) + _timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round() + self.config.steps_offset if self.config.skip_prk_steps: # for some models like stable diffusion the prk steps can/should be skipped to # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51 - state = state.replace( - prk_timesteps=jnp.array([]), - plms_timesteps=jnp.concatenate( - [state._timesteps[:-1], state._timesteps[-2:-1], state._timesteps[-1:]] - )[::-1], - ) + + prk_timesteps = jnp.array([], dtype=jnp.int32) + plms_timesteps = jnp.concatenate([_timesteps[:-1], _timesteps[-2:-1], _timesteps[-1:]])[::-1] + else: - prk_timesteps = jnp.array(state._timesteps[-self.pndm_order :]).repeat(2) + jnp.tile( - jnp.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order + prk_timesteps = _timesteps[-self.pndm_order :].repeat(2) + jnp.tile( + jnp.array([0, self.config.num_train_timesteps // num_inference_steps // 2], dtype=jnp.int32), + self.pndm_order, ) - state = state.replace( - prk_timesteps=(prk_timesteps[:-1].repeat(2)[1:-1])[::-1], - plms_timesteps=state._timesteps[:-3][::-1], - ) + prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1] + plms_timesteps = _timesteps[:-3][::-1] + + timesteps = jnp.concatenate([prk_timesteps, plms_timesteps]) + + # initial running values + + cur_model_output = jnp.zeros(shape, dtype=self.dtype) + counter = jnp.int32(0) + cur_sample = jnp.zeros(shape, dtype=self.dtype) + ets = jnp.zeros((4,) + shape, dtype=self.dtype) return state.replace( - timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int32), - counter=0, - # Reserve space for the state variables - cur_model_output=jnp.zeros(shape), - cur_sample=jnp.zeros(shape), - ets=jnp.zeros((4,) + shape), + timesteps=timesteps, + num_inference_steps=num_inference_steps, + prk_timesteps=prk_timesteps, + plms_timesteps=plms_timesteps, + cur_model_output=cur_model_output, + counter=counter, + cur_sample=cur_sample, + ets=ets, ) def scale_model_input( @@ -260,19 +263,27 @@ def step( `tuple`. When returning a tuple, the first element is the sample tensor. """ - if self.config.skip_prk_steps: - prev_sample, state = self.step_plms( - state=state, model_output=model_output, timestep=timestep, sample=sample + + if state.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) + + if self.config.skip_prk_steps: + prev_sample, state = self.step_plms(state, model_output, timestep, sample) else: - prev_sample, state = jax.lax.switch( - jnp.where(state.counter < len(state.prk_timesteps), 0, 1), - (self.step_prk, self.step_plms), - # Args to either branch - state, - model_output, - timestep, - sample, + prk_prev_sample, prk_state = self.step_prk(state, model_output, timestep, sample) + plms_prev_sample, plms_state = self.step_plms(state, model_output, timestep, sample) + + cond = state.counter < len(state.prk_timesteps) + + prev_sample = jax.lax.select(cond, prk_prev_sample, plms_prev_sample) + + state = state.replace( + cur_model_output=jax.lax.select(cond, prk_state.cur_model_output, plms_state.cur_model_output), + ets=jax.lax.select(cond, prk_state.ets, plms_state.ets), + cur_sample=jax.lax.select(cond, prk_state.cur_sample, plms_state.cur_sample), + counter=jax.lax.select(cond, prk_state.counter, plms_state.counter), ) if not return_dict: @@ -304,6 +315,7 @@ def step_prk( `tuple`. When returning a tuple, the first element is the sample tensor. """ + if state.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" @@ -315,37 +327,34 @@ def step_prk( prev_timestep = timestep - diff_to_prev timestep = state.prk_timesteps[state.counter // 4 * 4] - def remainder_0(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int): - return ( - state.replace( - cur_model_output=state.cur_model_output + 1 / 6 * model_output, - ets=state.ets.at[ets_at].set(model_output), - cur_sample=sample, - ), - model_output, - ) - - def remainder_1(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int): - return state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output), model_output - - def remainder_2(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int): - return state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output), model_output - - def remainder_3(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int): - model_output = state.cur_model_output + 1 / 6 * model_output - return state.replace(cur_model_output=jnp.zeros_like(state.cur_model_output)), model_output + model_output = jax.lax.select( + (state.counter % 4) != 3, + model_output, # remainder 0, 1, 2 + state.cur_model_output + 1 / 6 * model_output, # remainder 3 + ) - state, model_output = jax.lax.switch( - state.counter % 4, - (remainder_0, remainder_1, remainder_2, remainder_3), - # Args to either branch - state, - model_output, - state.counter // 4, + state = state.replace( + cur_model_output=jax.lax.select_n( + state.counter % 4, + state.cur_model_output + 1 / 6 * model_output, # remainder 0 + state.cur_model_output + 1 / 3 * model_output, # remainder 1 + state.cur_model_output + 1 / 3 * model_output, # remainder 2 + jnp.zeros_like(state.cur_model_output), # remainder 3 + ), + ets=jax.lax.select( + (state.counter % 4) == 0, + state.ets.at[0:3].set(state.ets[1:4]).at[3].set(model_output), # remainder 0 + state.ets, # remainder 1, 2, 3 + ), + cur_sample=jax.lax.select( + (state.counter % 4) == 0, + sample, # remainder 0 + state.cur_sample, # remainder 1, 2, 3 + ), ) cur_sample = state.cur_sample - prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) + prev_sample = self._get_prev_sample(state, cur_sample, timestep, prev_timestep, model_output) state = state.replace(counter=state.counter + 1) return (prev_sample, state) @@ -374,18 +383,13 @@ def step_plms( `tuple`. When returning a tuple, the first element is the sample tensor. """ + if state.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) - if not self.config.skip_prk_steps and len(state.ets) < 3: - raise ValueError( - f"{self.__class__} can only be run AFTER scheduler has been run " - "in 'prk' mode for at least 12 iterations " - "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py " - "for more information." - ) + # NOTE: There is no way to check in the jitted runtime if the prk mode was ran before prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps prev_timestep = jnp.where(prev_timestep > 0, prev_timestep, 0) @@ -417,64 +421,39 @@ def step_plms( # else: # model_output = (1 / 24) * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]) - def counter_0(state: PNDMSchedulerState): - ets = state.ets.at[0].set(model_output) - return state.replace( - ets=ets, - cur_sample=sample, - cur_model_output=jnp.array(model_output, dtype=jnp.float32), - ) - - def counter_1(state: PNDMSchedulerState): - return state.replace( - cur_model_output=(model_output + state.ets[0]) / 2, - ) - - def counter_2(state: PNDMSchedulerState): - ets = state.ets.at[1].set(model_output) - return state.replace( - ets=ets, - cur_model_output=(3 * ets[1] - ets[0]) / 2, - cur_sample=sample, - ) - - def counter_3(state: PNDMSchedulerState): - ets = state.ets.at[2].set(model_output) - return state.replace( - ets=ets, - cur_model_output=(23 * ets[2] - 16 * ets[1] + 5 * ets[0]) / 12, - cur_sample=sample, - ) - - def counter_other(state: PNDMSchedulerState): - ets = state.ets.at[3].set(model_output) - next_model_output = (1 / 24) * (55 * ets[3] - 59 * ets[2] + 37 * ets[1] - 9 * ets[0]) - - ets = ets.at[0].set(ets[1]) - ets = ets.at[1].set(ets[2]) - ets = ets.at[2].set(ets[3]) - - return state.replace( - ets=ets, - cur_model_output=next_model_output, - cur_sample=sample, - ) + state = state.replace( + ets=jax.lax.select( + state.counter != 1, + state.ets.at[0:3].set(state.ets[1:4]).at[3].set(model_output), # counter != 1 + state.ets, # counter 1 + ), + cur_sample=jax.lax.select( + state.counter != 1, + sample, # counter != 1 + state.cur_sample, # counter 1 + ), + ) - counter = jnp.clip(state.counter, 0, 4) - state = jax.lax.switch( - counter, - [counter_0, counter_1, counter_2, counter_3, counter_other], - state, + state = state.replace( + cur_model_output=jax.lax.select_n( + jnp.clip(state.counter, 0, 4), + model_output, # counter 0 + (model_output + state.ets[-1]) / 2, # counter 1 + (3 * state.ets[-1] - state.ets[-2]) / 2, # counter 2 + (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12, # counter 3 + (1 / 24) + * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]), # counter >= 4 + ), ) sample = state.cur_sample model_output = state.cur_model_output - prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output) + prev_sample = self._get_prev_sample(state, sample, timestep, prev_timestep, model_output) state = state.replace(counter=state.counter + 1) return (prev_sample, state) - def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): + def _get_prev_sample(self, state: PNDMSchedulerState, sample, timestep, prev_timestep, model_output): # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf # this function computes x_(t−δ) using the formula of (9) # Note that x_t needs to be added to both sides of the equation @@ -487,11 +466,20 @@ def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): # sample -> x_t # model_output -> e_θ(x_t, t) # prev_sample -> x_(t−δ) - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = jnp.where(prev_timestep >= 0, self.alphas_cumprod[prev_timestep], self.final_alpha_cumprod) + alpha_prod_t = state.common.alphas_cumprod[timestep] + alpha_prod_t_prev = jnp.where( + prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod + ) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev + if self.config.prediction_type == "v_prediction": + model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + elif self.config.prediction_type != "epsilon": + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `v_prediction`" + ) + # corresponds to (α_(t−δ) - α_t) divided by # denominator of x_t in formula (9) and plus 1 # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = @@ -512,20 +500,12 @@ def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): def add_noise( self, + state: PNDMSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray, ) -> jnp.ndarray: - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape) - - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape) - - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise - return noisy_samples + return add_noise_common(state.common, original_samples, noise, timesteps) def __len__(self): return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_utils_flax.py b/src/diffusers/schedulers/scheduling_utils_flax.py index 5dc28c25d9d6..7ad8ad7286db 100644 --- a/src/diffusers/schedulers/scheduling_utils_flax.py +++ b/src/diffusers/schedulers/scheduling_utils_flax.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import importlib +import math import os from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Union +import flax import jax.numpy as jnp from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput @@ -50,6 +52,7 @@ class FlaxSchedulerMixin: """ config_name = SCHEDULER_CONFIG_NAME + ignore_for_config = ["dtype"] _compatibles = [] has_compatibles = True @@ -167,3 +170,90 @@ def _get_compatibles(cls): def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray: assert len(shape) >= x.ndim return jnp.broadcast_to(x.reshape(x.shape + (1,) * (len(shape) - x.ndim)), shape) + + +def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999, dtype=jnp.float32) -> jnp.ndarray: + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return jnp.array(betas, dtype=dtype) + + +@flax.struct.dataclass +class CommonSchedulerState: + alphas: jnp.ndarray + betas: jnp.ndarray + alphas_cumprod: jnp.ndarray + + @classmethod + def create(cls, scheduler): + config = scheduler.config + + if config.trained_betas is not None: + betas = jnp.asarray(config.trained_betas, dtype=scheduler.dtype) + elif config.beta_schedule == "linear": + betas = jnp.linspace(config.beta_start, config.beta_end, config.num_train_timesteps, dtype=scheduler.dtype) + elif config.beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + betas = ( + jnp.linspace( + config.beta_start**0.5, config.beta_end**0.5, config.num_train_timesteps, dtype=scheduler.dtype + ) + ** 2 + ) + elif config.beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + betas = betas_for_alpha_bar(config.num_train_timesteps, dtype=scheduler.dtype) + else: + raise NotImplementedError( + f"beta_schedule {config.beta_schedule} is not implemented for scheduler {scheduler.__class__.__name__}" + ) + + alphas = 1.0 - betas + + alphas_cumprod = jnp.cumprod(alphas, axis=0) + + return cls( + alphas=alphas, + betas=betas, + alphas_cumprod=alphas_cumprod, + ) + + +def add_noise_common( + state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray +): + alphas_cumprod = state.alphas_cumprod + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples diff --git a/tests/test_scheduler_flax.py b/tests/test_scheduler_flax.py index 40227c9ac44d..1da75f051d50 100644 --- a/tests/test_scheduler_flax.py +++ b/tests/test_scheduler_flax.py @@ -296,10 +296,11 @@ def test_variance(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() - assert jnp.sum(jnp.abs(scheduler._get_variance(0) - 0.0)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(487) - 0.00979)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(999) - 0.02)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0) - 0.0)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 487) - 0.00979)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 999) - 0.02)) < 1e-5 def test_full_loop_no_noise(self): scheduler_class = self.scheduler_classes[0] @@ -577,12 +578,12 @@ def test_variance(self): scheduler = scheduler_class(**scheduler_config) state = scheduler.create_state() - assert jnp.sum(jnp.abs(scheduler._get_variance(0, 0, state.alphas_cumprod) - 0.0)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(420, 400, state.alphas_cumprod) - 0.14771)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(980, 960, state.alphas_cumprod) - 0.32460)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(0, 0, state.alphas_cumprod) - 0.0)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(487, 486, state.alphas_cumprod) - 0.00979)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(999, 998, state.alphas_cumprod) - 0.02)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0, 0) - 0.0)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 420, 400) - 0.14771)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 980, 960) - 0.32460)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0, 0) - 0.0)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 487, 486) - 0.00979)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 999, 998) - 0.02)) < 1e-5 def test_full_loop_no_noise(self): sample = self.full_loop()