diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 3ca0fd913e93..07767189e91b 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -504,7 +504,9 @@ def main(): noise = torch.randn(latents.shape).to(latents.device) bsz = latents.shape[0] # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device).long() + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device + ).long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index fe4e9b0d2716..f6affe8a1400 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -130,7 +130,7 @@ def transforms(examples): bsz = clean_images.shape[0] # Sample a random timestep for each image timesteps = torch.randint( - 0, noise_scheduler.num_train_timesteps, (bsz,), device=clean_images.device + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=clean_images.device ).long() # Add noise to the clean images according to the noise magnitude at each timestep diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 3804d2e53f5f..a45cb2af5ae1 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -64,6 +64,13 @@ if is_flax_available(): from .modeling_flax_utils import FlaxModelMixin - from .schedulers import FlaxPNDMScheduler + from .schedulers import ( + FlaxDDIMScheduler, + FlaxDDPMScheduler, + FlaxKarrasVeScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, + FlaxScoreSdeVeScheduler, + ) else: from .utils.dummy_flax_objects import * # noqa F403 diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 4f2d25dfb168..bc79a52a0924 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -386,7 +386,7 @@ def from_pretrained( raise ValueError from e except (UnicodeDecodeError, ValueError): raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ") - # make sure all arrays are stored as jnp.arrays + # make sure all arrays are stored as jnp.ndarray # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: # https://github.com/google/flax/issues/1261 state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state) diff --git a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py index 604e2b54cc17..6ff8e03057a8 100644 --- a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +++ b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py @@ -80,7 +80,7 @@ def __call__( sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device) # correction step - for _ in range(self.scheduler.correct_steps): + for _ in range(self.scheduler.config.correct_steps): model_output = self.unet(sample, sigma_t).sample sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index c53101a9bb4a..495f30d9fabd 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -28,7 +28,12 @@ from ..utils.dummy_pt_objects import * # noqa F403 if is_flax_available(): + from .scheduling_ddim_flax import FlaxDDIMScheduler + from .scheduling_ddpm_flax import FlaxDDPMScheduler + from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler + from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler from .scheduling_pndm_flax import FlaxPNDMScheduler + from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler else: from ..utils.dummy_flax_objects import * # noqa F403 diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 36717a8cfec5..85ba67523172 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -113,7 +113,7 @@ def __init__( # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 - # `set_alpha_to_one` decides whether we set this paratemer simply to one or + # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0] @@ -195,7 +195,7 @@ def step( # - pred_original_sample -> f_theta(x_t, t) or x_0 # - std_dev_t -> sigma_t # - eta -> η - # - pred_sample_direction -> "direction pointingc to x_t" + # - pred_sample_direction -> "direction pointing to x_t" # - pred_prev_sample -> "x_t-1" # 1. get previous step value (=t-1) diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py new file mode 100644 index 000000000000..dd3c2ac85d2c --- /dev/null +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -0,0 +1,274 @@ +# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import flax +import jax.numpy as jnp +from jax import random + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return jnp.array(betas, dtype=jnp.float32) + + +@flax.struct.dataclass +class DDIMSchedulerState: + # setable values + timesteps: jnp.ndarray + num_inference_steps: Optional[int] = None + + @classmethod + def create(cls, num_train_timesteps: int): + return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1]) + + +@dataclass +class FlaxSchedulerOutput(SchedulerOutput): + state: DDIMSchedulerState + + +class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): + """ + Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising + diffusion probabilistic models (DDPMs) with non-Markovian guidance. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2010.02502 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`jnp.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + clip_sample (`bool`, default `True`): + option to clip predicted sample between -1 and 1 for numerical stability. + set_alpha_to_one (`bool`, default `True`): + if alpha for final step is 1 or the final alpha of the "non-previous" one. + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[jnp.ndarray] = None, + clip_sample: bool = True, + set_alpha_to_one: bool = True, + ): + if trained_betas is not None: + self.betas = jnp.asarray(trained_betas) + if beta_schedule == "linear": + self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + self.state = DDIMSchedulerState.create(num_train_timesteps=num_train_timesteps) + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def set_timesteps( + self, state: DDIMSchedulerState, num_inference_steps: int, offset: int = 0 + ) -> DDIMSchedulerState: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + state (`DDIMSchedulerState`): + the `FlaxDDIMScheduler` state data class instance. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + offset (`int`): + optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference. + """ + step_ratio = self.config.num_train_timesteps // num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] + timesteps = timesteps + offset + + return state.replace(num_inference_steps=num_inference_steps, timesteps=timesteps) + + def step( + self, + state: DDIMSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + key: random.KeyArray, + eta: float = 0.0, + use_clipped_model_output: bool = False, + return_dict: bool = True, + ) -> Union[FlaxSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + state (`DDIMSchedulerState`): the `FlaxDDIMScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + key (`random.KeyArray`): a PRNG key. + eta (`float`): weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`): TODO + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. + + """ + if state.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + + # 4. Clip "predicted x_0" + if self.config.clip_sample: + pred_original_sample = jnp.clip(pred_original_sample, -1, 1) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + if use_clipped_model_output: + # the model_output is always re-derived from the clipped x_0 in Glide + model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output + + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + if eta > 0: + key = random.split(key, num=1) + noise = random.normal(key=key, shape=model_output.shape) + variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise + + prev_sample = prev_sample + variance + + if not return_dict: + return (prev_sample, state) + + return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) + + def add_noise( + self, + original_samples: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> jnp.ndarray: + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 43a73b031691..fac75bc43eff 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -148,7 +148,7 @@ def _get_variance(self, t, predicted_variance=None, variance_type=None): if variance_type is None: variance_type = self.config.variance_type - # hacks - were probs added for training stability + # hacks - were probably added for training stability if variance_type == "fixed_small": variance = self.clip(variance, min_value=1e-20) # for rl-diffuser https://arxiv.org/abs/2205.09991 @@ -187,7 +187,6 @@ def step( timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`): current instance of sample being created by diffusion process. - eta (`float`): weight of noise for added noise in diffusion step. predict_epsilon (`bool`): optional flag to use when model predicts the samples directly instead of the noise, epsilon. generator: random number generator. diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py new file mode 100644 index 000000000000..f686a2a32234 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -0,0 +1,277 @@ +# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import flax +import jax.numpy as jnp +from jax import random + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return jnp.array(betas, dtype=jnp.float32) + + +@flax.struct.dataclass +class DDPMSchedulerState: + # setable values + timesteps: jnp.ndarray + num_inference_steps: Optional[int] = None + + @classmethod + def create(cls, num_train_timesteps: int): + return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1]) + + +@dataclass +class FlaxSchedulerOutput(SchedulerOutput): + state: DDPMSchedulerState + + +class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin): + """ + Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and + Langevin dynamics sampling. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2006.11239 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + variance_type (`str`): + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + clip_sample (`bool`, default `True`): + option to clip predicted sample between -1 and 1 for numerical stability. + tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. + + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[jnp.ndarray] = None, + variance_type: str = "fixed_small", + clip_sample: bool = True, + ): + if trained_betas is not None: + self.betas = jnp.asarray(trained_betas) + elif beta_schedule == "linear": + self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) + self.one = jnp.array(1.0) + + self.state = DDPMSchedulerState.create(num_train_timesteps=num_train_timesteps) + + self.variance_type = variance_type + + def set_timesteps(self, state: DDPMSchedulerState, num_inference_steps: int) -> DDPMSchedulerState: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + state (`DDIMSchedulerState`): + the `FlaxDDPMScheduler` state data class instance. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps) + timesteps = jnp.arange( + 0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps + )[::-1] + return state.replace(num_inference_steps=num_inference_steps, timesteps=timesteps) + + def _get_variance(self, t, predicted_variance=None, variance_type=None): + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one + + # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) + # and sample from it to get previous sample + # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample + variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t] + + if variance_type is None: + variance_type = self.config.variance_type + + # hacks - were probably added for training stability + if variance_type == "fixed_small": + variance = jnp.clip(variance, a_min=1e-20) + # for rl-diffuser https://arxiv.org/abs/2205.09991 + elif variance_type == "fixed_small_log": + variance = jnp.log(jnp.clip(variance, a_min=1e-20)) + elif variance_type == "fixed_large": + variance = self.betas[t] + elif variance_type == "fixed_large_log": + # Glide max_log + variance = jnp.log(self.betas[t]) + elif variance_type == "learned": + return predicted_variance + elif variance_type == "learned_range": + min_log = variance + max_log = self.betas[t] + frac = (predicted_variance + 1) / 2 + variance = frac * max_log + (1 - frac) * min_log + + return variance + + def step( + self, + state: DDPMSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + key: random.KeyArray, + predict_epsilon: bool = True, + return_dict: bool = True, + ) -> Union[FlaxSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + state (`DDPMSchedulerState`): the `FlaxDDPMScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + key (`random.KeyArray`): a PRNG key. + predict_epsilon (`bool`): + optional flag to use when model predicts the samples directly instead of the noise, epsilon. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. + + """ + t = timestep + + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if predict_epsilon: + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + else: + pred_original_sample = model_output + + # 3. Clip "predicted x_0" + if self.config.clip_sample: + pred_original_sample = jnp.clip(pred_original_sample, -1, 1) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t + current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample + + # 6. Add noise + variance = 0 + if t > 0: + key = random.split(key, num=1) + noise = random.normal(key=key, shape=model_output.shape) + variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise + + pred_prev_sample = pred_prev_sample + variance + + if not return_dict: + return (pred_prev_sample, state) + + return FlaxSchedulerOutput(prev_sample=pred_prev_sample, state=state) + + def add_noise( + self, + original_samples: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> jnp.ndarray: + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py index 4b4c90d07c5d..9898bb7eda7e 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve.py +++ b/src/diffusers/schedulers/scheduling_karras_ve.py @@ -105,7 +105,10 @@ def set_timesteps(self, num_inference_steps: int): self.num_inference_steps = num_inference_steps self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() self.schedule = [ - (self.sigma_max * (self.sigma_min**2 / self.sigma_max**2) ** (i / (num_inference_steps - 1))) + ( + self.config.sigma_max + * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1)) + ) for i in self.timesteps ] self.schedule = np.array(self.schedule, dtype=np.float32) @@ -121,13 +124,13 @@ def add_noise_to_input( TODO Args: """ - if self.s_min <= sigma <= self.s_max: - gamma = min(self.s_churn / self.num_inference_steps, 2**0.5 - 1) + if self.config.s_min <= sigma <= self.config.s_max: + gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1) else: gamma = 0 # sample eps ~ N(0, S_noise^2 * I) - eps = self.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device) + eps = self.config.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device) sigma_hat = sigma + gamma * sigma sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) diff --git a/src/diffusers/schedulers/scheduling_karras_ve_flax.py b/src/diffusers/schedulers/scheduling_karras_ve_flax.py new file mode 100644 index 000000000000..fe71b3fd0eab --- /dev/null +++ b/src/diffusers/schedulers/scheduling_karras_ve_flax.py @@ -0,0 +1,228 @@ +# Copyright 2022 NVIDIA and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import flax +import jax.numpy as jnp +from jax import random + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@flax.struct.dataclass +class KarrasVeSchedulerState: + # setable values + num_inference_steps: Optional[int] = None + timesteps: Optional[jnp.ndarray] = None + schedule: Optional[jnp.ndarray] = None # sigma(t_i) + + @classmethod + def create(cls): + return cls() + + +@dataclass +class FlaxKarrasVeOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + derivative (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): + Derivate of predicted original image sample (x_0). + state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. + """ + + prev_sample: jnp.ndarray + derivative: jnp.ndarray + state: KarrasVeSchedulerState + + +class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin): + """ + Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and + the VE column of Table 1 from [1] for reference. + + [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." + https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic + differential equations." https://arxiv.org/abs/2011.13456 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + + For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of + Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the + optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. + + Args: + sigma_min (`float`): minimum noise magnitude + sigma_max (`float`): maximum noise magnitude + s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. + A reasonable range is [1.000, 1.011]. + s_churn (`float`): the parameter controlling the overall amount of stochasticity. + A reasonable range is [0, 100]. + s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity). + A reasonable range is [0, 10]. + s_max (`float`): the end value of the sigma range where we add noise. + A reasonable range is [0.2, 80]. + """ + + @register_to_config + def __init__( + self, + sigma_min: float = 0.02, + sigma_max: float = 100, + s_noise: float = 1.007, + s_churn: float = 80, + s_min: float = 0.05, + s_max: float = 50, + ): + self.state = KarrasVeSchedulerState.create() + + def set_timesteps(self, state: KarrasVeSchedulerState, num_inference_steps: int) -> KarrasVeSchedulerState: + """ + Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + state (`KarrasVeSchedulerState`): + the `FlaxKarrasVeScheduler` state data class. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + + """ + timesteps = jnp.arange(0, num_inference_steps)[::-1].copy() + schedule = [ + ( + self.config.sigma_max + * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1)) + ) + for i in timesteps + ] + + return state.replace( + num_inference_steps=num_inference_steps, + schedule=jnp.array(schedule, dtype=jnp.float32), + timesteps=timesteps, + ) + + def add_noise_to_input( + self, + state: KarrasVeSchedulerState, + sample: jnp.ndarray, + sigma: float, + key: random.KeyArray, + ) -> Tuple[jnp.ndarray, float]: + """ + Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a + higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. + + TODO Args: + """ + if self.config.s_min <= sigma <= self.config.s_max: + gamma = min(self.config.s_churn / state.num_inference_steps, 2**0.5 - 1) + else: + gamma = 0 + + # sample eps ~ N(0, S_noise^2 * I) + key = random.split(key, num=1) + eps = self.config.s_noise * random.normal(key=key, shape=sample.shape) + sigma_hat = sigma + gamma * sigma + sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) + + return sample_hat, sigma_hat + + def step( + self, + state: KarrasVeSchedulerState, + model_output: jnp.ndarray, + sigma_hat: float, + sigma_prev: float, + sample_hat: jnp.ndarray, + return_dict: bool = True, + ) -> Union[FlaxKarrasVeOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + sigma_hat (`float`): TODO + sigma_prev (`float`): TODO + sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] or `tuple`: Updated sample in the diffusion + chain and derivative. [`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + """ + + pred_original_sample = sample_hat + sigma_hat * model_output + derivative = (sample_hat - pred_original_sample) / sigma_hat + sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative + + if not return_dict: + return (sample_prev, derivative, state) + + return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state) + + def step_correct( + self, + state: KarrasVeSchedulerState, + model_output: jnp.ndarray, + sigma_hat: float, + sigma_prev: float, + sample_hat: jnp.ndarray, + sample_prev: jnp.ndarray, + derivative: jnp.ndarray, + return_dict: bool = True, + ) -> Union[FlaxKarrasVeOutput, Tuple]: + """ + Correct the predicted sample based on the output model_output of the network. TODO complete description + + Args: + state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + sigma_hat (`float`): TODO + sigma_prev (`float`): TODO + sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO + sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO + derivative (`torch.FloatTensor` or `np.ndarray`): TODO + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO + + """ + pred_original_sample = sample_prev + sigma_prev * model_output + derivative_corr = (sample_prev - pred_original_sample) / sigma_prev + sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) + + if not return_dict: + return (sample_prev, derivative, state) + + return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state) + + def add_noise(self, original_samples, noise, timesteps): + raise NotImplementedError() diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 19272d627145..5857ae70a856 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -113,7 +113,7 @@ def set_timesteps(self, num_inference_steps: int): the number of diffusion steps used when generating samples with a pre-trained model. """ self.num_inference_steps = num_inference_steps - self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) + self.timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) low_idx = np.floor(self.timesteps).astype(int) high_idx = np.ceil(self.timesteps).astype(int) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py new file mode 100644 index 000000000000..1431bdacf54c --- /dev/null +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -0,0 +1,207 @@ +# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import flax +import jax.numpy as jnp +from scipy import integrate + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +@flax.struct.dataclass +class LMSDiscreteSchedulerState: + # setable values + num_inference_steps: Optional[int] = None + timesteps: Optional[jnp.ndarray] = None + sigmas: Optional[jnp.ndarray] = None + derivatives: jnp.ndarray = jnp.array([]) + + @classmethod + def create(cls, num_train_timesteps: int, sigmas: jnp.ndarray): + return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1], sigmas=sigmas) + + +@dataclass +class FlaxSchedulerOutput(SchedulerOutput): + state: LMSDiscreteSchedulerState + + +class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by + Katherine Crowson: + https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`jnp.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[jnp.ndarray] = None, + ): + if trained_betas is not None: + self.betas = jnp.asarray(trained_betas) + if beta_schedule == "linear": + self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2 + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) + + self.state = LMSDiscreteSchedulerState.create( + num_train_timesteps=num_train_timesteps, sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + ) + + def get_lms_coefficient(self, state, order, t, current_order): + """ + Compute a linear multistep coefficient. + + Args: + order (TODO): + t (TODO): + current_order (TODO): + """ + + def lms_derivative(tau): + prod = 1.0 + for k in range(order): + if current_order == k: + continue + prod *= (tau - state.sigmas[t - k]) / (state.sigmas[t - current_order] - state.sigmas[t - k]) + return prod + + integrated_coeff = integrate.quad(lms_derivative, state.sigmas[t], state.sigmas[t + 1], epsrel=1e-4)[0] + + return integrated_coeff + + def set_timesteps(self, state: LMSDiscreteSchedulerState, num_inference_steps: int) -> LMSDiscreteSchedulerState: + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + state (`LMSDiscreteSchedulerState`): + the `FlaxLMSDiscreteScheduler` state data class instance. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=jnp.float32) + + low_idx = jnp.floor(timesteps).astype(int) + high_idx = jnp.ceil(timesteps).astype(int) + frac = jnp.mod(timesteps, 1.0) + sigmas = jnp.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] + sigmas = jnp.concatenate([sigmas, jnp.array([0.0])]).astype(jnp.float32) + + return state.replace( + num_inference_steps=num_inference_steps, + timesteps=timesteps, + derivatives=jnp.array([]), + sigmas=sigmas, + ) + + def step( + self, + state: LMSDiscreteSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + order: int = 4, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + state (`LMSDiscreteSchedulerState`): the `FlaxLMSDiscreteScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + order: coefficient for multi-step inference. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. + + """ + sigma = state.sigmas[timestep] + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + pred_original_sample = sample - sigma * model_output + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + state = state.replace(derivatives=state.derivatives.append(derivative)) + if len(state.derivatives) > order: + state = state.replace(derivatives=state.derivatives.pop(0)) + + # 3. Compute linear multistep coefficients + order = min(timestep + 1, order) + lms_coeffs = [self.get_lms_coefficient(state, order, timestep, curr_order) for curr_order in range(order)] + + # 4. Compute previous sample based on the derivatives path + prev_sample = sample + sum( + coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(state.derivatives)) + ) + + if not return_dict: + return (prev_sample, state) + + return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) + + def add_noise( + self, + state: LMSDiscreteSchedulerState, + original_samples: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> jnp.ndarray: + sigmas = self.match_shape(state.sigmas[timesteps], noise) + noisy_samples = original_samples + noise * sigmas + + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index bfe31aca1715..6eeb15b93923 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -108,8 +108,6 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = np.cumprod(self.alphas, axis=0) - self.one = np.array(1.0) - # For now we only support F-PNDM, i.e. the runge-kutta method # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf # mainly at formula (9), (12), (13) and the Algorithm 2. diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 53b0126c51c0..8444d6680401 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -25,7 +25,7 @@ from .scheduling_utils import SchedulerMixin, SchedulerOutput -def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999): +def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -40,7 +40,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999): prevent singularities. Returns: - betas (`jnp.array`): the betas used by the scheduler to step the model outputs + betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs """ def alpha_bar(time_step): @@ -56,36 +56,23 @@ def alpha_bar(time_step): @flax.struct.dataclass class PNDMSchedulerState: - betas: jnp.array - # setable values - _timesteps: jnp.array + _timesteps: jnp.ndarray num_inference_steps: Optional[int] = None _offset: int = 0 - prk_timesteps: Optional[jnp.array] = None - plms_timesteps: Optional[jnp.array] = None - timesteps: Optional[jnp.array] = None + prk_timesteps: Optional[jnp.ndarray] = None + plms_timesteps: Optional[jnp.ndarray] = None + timesteps: Optional[jnp.ndarray] = None # running values cur_model_output: Optional[jnp.ndarray] = None counter: int = 0 cur_sample: Optional[jnp.ndarray] = None - ets: jnp.array = jnp.array([]) - - @property - def alphas(self) -> jnp.array: - return 1.0 - self.betas - - @property - def alphas_cumprod(self) -> jnp.array: - return jnp.cumprod(self.alphas, axis=0) + ets: jnp.ndarray = jnp.array([]) @classmethod - def create(cls, betas: jnp.array, num_train_timesteps: int): - return cls( - betas=betas, - _timesteps=jnp.arange(0, num_train_timesteps)[::-1], - ) + def create(cls, num_train_timesteps: int): + return cls(_timesteps=jnp.arange(0, num_train_timesteps)[::-1]) @dataclass @@ -112,7 +99,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. - trained_betas (`np.ndarray`, optional): + trained_betas (`jnp.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. skip_prk_steps (`bool`): allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required @@ -126,28 +113,31 @@ def __init__( beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", - trained_betas: Optional[jnp.array] = None, + trained_betas: Optional[jnp.ndarray] = None, skip_prk_steps: bool = False, ): if trained_betas is not None: - betas = jnp.asarray(trained_betas) + self.betas = jnp.asarray(trained_betas) if beta_schedule == "linear": - betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) + self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2 + self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2 elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule - betas = betas_for_alpha_bar(num_train_timesteps) + self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + self.alphas = 1.0 - self.betas + self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) + # For now we only support F-PNDM, i.e. the runge-kutta method # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf # mainly at formula (9), (12), (13) and the Algorithm 2. self.pndm_order = 4 - self.state = PNDMSchedulerState.create(betas=betas, num_train_timesteps=num_train_timesteps) + self.state = PNDMSchedulerState.create(num_train_timesteps=num_train_timesteps) def set_timesteps( self, state: PNDMSchedulerState, num_inference_steps: int, offset: int = 0 @@ -157,7 +147,7 @@ def set_timesteps( Args: state (`PNDMSchedulerState`): - the PNDMScheduler state data class instance. + the `FlaxPNDMScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. offset (`int`): @@ -165,7 +155,7 @@ def set_timesteps( """ step_ratio = self.config.num_train_timesteps // num_inference_steps # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 + # rounding to avoid issues when num_inference_step is power of 3 _timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] _timesteps = _timesteps + offset @@ -212,7 +202,7 @@ def step( This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`. Args: - state (`PNDMSchedulerState`): the PNDMScheduler state data class instance. + state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): @@ -246,7 +236,7 @@ def step_prk( solution to the differential equation. Args: - state (`PNDMSchedulerState`): the PNDMScheduler state data class instance. + state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): @@ -268,24 +258,24 @@ def step_prk( timestep = state.prk_timesteps[state.counter // 4 * 4] if state.counter % 4 == 0: - state.replace( + state = state.replace( cur_model_output=state.cur_model_output + 1 / 6 * model_output, ets=state.ets.append(model_output), cur_sample=sample, ) elif (self.counter - 1) % 4 == 0: - state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output) + state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output) elif (self.counter - 2) % 4 == 0: - state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output) + state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output) elif (self.counter - 3) % 4 == 0: model_output = state.cur_model_output + 1 / 6 * model_output - state.replace(cur_model_output=0) + state = state.replace(cur_model_output=0) # cur_sample should not be `None` cur_sample = state.cur_sample if state.cur_sample is not None else sample prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output, state=state) - state.replace(counter=state.counter + 1) + state = state.replace(counter=state.counter + 1) if not return_dict: return (prev_sample, state) @@ -305,7 +295,7 @@ def step_plms( times to approximate the solution. Args: - state (`PNDMSchedulerState`): the PNDMScheduler state data class instance. + state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): @@ -333,18 +323,18 @@ def step_plms( prev_timestep = max(timestep - self.config.num_train_timesteps // state.num_inference_steps, 0) if state.counter != 1: - state.replace(ets=state.ets.append(model_output)) + state = state.replace(ets=state.ets.append(model_output)) else: prev_timestep = timestep timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps if len(state.ets) == 1 and state.counter == 0: model_output = model_output - state.replace(cur_sample=sample) + state = state.replace(cur_sample=sample) elif len(state.ets) == 1 and state.counter == 1: model_output = (model_output + state.ets[-1]) / 2 sample = state.cur_sample - state.replace(cur_sample=None) + state = state.replace(cur_sample=None) elif len(state.ets) == 2: model_output = (3 * state.ets[-1] - state.ets[-2]) / 2 elif len(state.ets) == 3: @@ -355,7 +345,7 @@ def step_plms( ) prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output, state=state) - state.replace(counter=state.counter + 1) + state = state.replace(counter=state.counter + 1) if not return_dict: return (prev_sample, state) @@ -375,8 +365,8 @@ def _get_prev_sample(self, sample, timestep, timestep_prev, model_output, state) # sample -> x_t # model_output -> e_θ(x_t, t) # prev_sample -> x_(t−δ) - alpha_prod_t = state.alphas_cumprod[timestep + 1 - state._offset] - alpha_prod_t_prev = state.alphas_cumprod[timestep_prev + 1 - state._offset] + alpha_prod_t = self.alphas_cumprod[timestep + 1 - state._offset] + alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - state._offset] beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev @@ -400,14 +390,13 @@ def _get_prev_sample(self, sample, timestep, timestep_prev, model_output, state) def add_noise( self, - state: PNDMSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray, ) -> jnp.ndarray: - sqrt_alpha_prod = state.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) - sqrt_one_minus_alpha_prod = (1 - state.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index a880b5aa30ab..4af8f4fdad7d 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -55,6 +55,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin.from_config`] functions. Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. snr (`float`): coefficient weighting the step from the model_output sample (from the network) to the random noise. sigma_min (`float`): diff --git a/src/diffusers/schedulers/scheduling_sde_ve_flax.py b/src/diffusers/schedulers/scheduling_sde_ve_flax.py new file mode 100644 index 000000000000..e5860706aa2e --- /dev/null +++ b/src/diffusers/schedulers/scheduling_sde_ve_flax.py @@ -0,0 +1,260 @@ +# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import flax +import jax.numpy as jnp +from jax import random + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +@flax.struct.dataclass +class ScoreSdeVeSchedulerState: + # setable values + timesteps: Optional[jnp.ndarray] = None + discrete_sigmas: Optional[jnp.ndarray] = None + sigmas: Optional[jnp.ndarray] = None + + @classmethod + def create(cls): + return cls() + + +@dataclass +class FlaxSdeVeOutput(SchedulerOutput): + """ + Output class for the ScoreSdeVeScheduler's step function output. + + Args: + state (`ScoreSdeVeSchedulerState`): + prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + prev_sample_mean (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): + Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps. + """ + + state: ScoreSdeVeSchedulerState + prev_sample: jnp.ndarray + prev_sample_mean: Optional[jnp.ndarray] = None + + +class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): + """ + The variance exploding stochastic differential equation (SDE) scheduler. + + For more information, see the original paper: https://arxiv.org/abs/2011.13456 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + snr (`float`): + coefficient weighting the step from the model_output sample (from the network) to the random noise. + sigma_min (`float`): + initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the + distribution of the data. + sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model. + sampling_eps (`float`): the end value of sampling, where timesteps decrease progressively from 1 to + epsilon. + correct_steps (`int`): number of correction steps performed on a produced sample. + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 2000, + snr: float = 0.15, + sigma_min: float = 0.01, + sigma_max: float = 1348.0, + sampling_eps: float = 1e-5, + correct_steps: int = 1, + ): + state = ScoreSdeVeSchedulerState.create() + + self.state = self.set_sigmas(state, num_train_timesteps, sigma_min, sigma_max, sampling_eps) + + def set_timesteps( + self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, sampling_eps: float = None + ) -> ScoreSdeVeSchedulerState: + """ + Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). + + """ + sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps + + timesteps = jnp.linspace(1, sampling_eps, num_inference_steps) + return state.replace(timesteps=timesteps) + + def set_sigmas( + self, + state: ScoreSdeVeSchedulerState, + num_inference_steps: int, + sigma_min: float = None, + sigma_max: float = None, + sampling_eps: float = None, + ) -> ScoreSdeVeSchedulerState: + """ + Sets the noise scales used for the diffusion chain. Supporting function to be run before inference. + + The sigmas control the weight of the `drift` and `diffusion` components of sample update. + + Args: + state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + sigma_min (`float`, optional): + initial noise scale value (overrides value given at Scheduler instantiation). + sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation). + sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). + """ + sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min + sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max + sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps + if state.timesteps is None: + state = self.set_timesteps(state, num_inference_steps, sampling_eps) + + discrete_sigmas = jnp.exp(jnp.linspace(jnp.log(sigma_min), jnp.log(sigma_max), num_inference_steps)) + sigmas = jnp.array([sigma_min * (sigma_max / sigma_min) ** t for t in state.timesteps]) + + return state.replace(discrete_sigmas=discrete_sigmas, sigmas=sigmas) + + def get_adjacent_sigma(self, state, timesteps, t): + return jnp.where(timesteps == 0, jnp.zeros_like(t), state.discrete_sigmas[timesteps - 1]) + + def step_pred( + self, + state: ScoreSdeVeSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + key: random.KeyArray, + return_dict: bool = True, + ) -> Union[FlaxSdeVeOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + if state.timesteps is None: + raise ValueError( + "`state.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" + ) + + timestep = timestep * jnp.ones( + sample.shape[0], + ) + timesteps = (timestep * (len(state.timesteps) - 1)).long() + + sigma = state.discrete_sigmas[timesteps] + adjacent_sigma = self.get_adjacent_sigma(state, timesteps, timestep) + drift = jnp.zeros_like(sample) + diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5 + + # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x) + # also equation 47 shows the analog from SDE models to ancestral sampling methods + drift = drift - diffusion[:, None, None, None] ** 2 * model_output + + # equation 6: sample noise for the diffusion term of + key = random.split(key, num=1) + noise = random.normal(key=key, shape=sample.shape) + prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep + # TODO is the variable diffusion the correct scaling term for the noise? + prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g + + if not return_dict: + return (prev_sample, prev_sample_mean, state) + + return FlaxSdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean, state=state) + + def step_correct( + self, + state: ScoreSdeVeSchedulerState, + model_output: jnp.ndarray, + sample: jnp.ndarray, + key: random.KeyArray, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Correct the predicted sample based on the output model_output of the network. This is often run repeatedly + after making the prediction for the previous timestep. + + Args: + state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + if state.timesteps is None: + raise ValueError( + "`state.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" + ) + + # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z" + # sample noise for correction + key = random.split(key, num=1) + noise = random.normal(key=key, shape=sample.shape) + + # compute step size from the model_output, the noise, and the snr + grad_norm = jnp.linalg.norm(model_output) + noise_norm = jnp.linalg.norm(noise) + step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2 + step_size = step_size * jnp.ones(sample.shape[0]) + + # compute corrected sample: model_output term and noise term + prev_sample_mean = sample + step_size[:, None, None, None] * model_output + prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise + + if not return_dict: + return (prev_sample, state) + + return FlaxSdeVeOutput(prev_sample=prev_sample, state=state) + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/utils/dummy_flax_objects.py b/src/diffusers/utils/dummy_flax_objects.py index 981dc5586ad9..d246f41a72d2 100644 --- a/src/diffusers/utils/dummy_flax_objects.py +++ b/src/diffusers/utils/dummy_flax_objects.py @@ -11,8 +11,43 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) +class FlaxDDIMScheduler(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxDDPMScheduler(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxKarrasVeScheduler(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxLMSDiscreteScheduler(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxPNDMScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) + + +class FlaxScoreSdeVeScheduler(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index d4aef24a989a..6306be28024a 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -814,7 +814,7 @@ def test_full_loop_no_noise(self): for i, t in enumerate(scheduler.timesteps): sigma_t = scheduler.sigmas[i] - for _ in range(scheduler.correct_steps): + for _ in range(scheduler.config.correct_steps): with torch.no_grad(): model_output = model(sample, sigma_t) sample = scheduler.step_correct(model_output, sample, generator=generator, **kwargs).prev_sample diff --git a/tests/test_training.py b/tests/test_training.py index 27caf03365b6..519c5ab9e716 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -52,7 +52,7 @@ def test_training_step_equality(self): tensor_format="pt", ) - assert ddpm_scheduler.num_train_timesteps == ddim_scheduler.num_train_timesteps + assert ddpm_scheduler.config.num_train_timesteps == ddim_scheduler.config.num_train_timesteps # shared batches for DDPM and DDIM set_seed(0)