From 35fd62a1611b5e4457af7b77466aee253de11ed9 Mon Sep 17 00:00:00 2001 From: njindal Date: Sat, 8 Apr 2023 20:55:51 +0530 Subject: [PATCH 01/15] [2064]: Add stochastic sampler --- src/diffusers/__init__.py | 1 + src/diffusers/schedulers/__init__.py | 1 + .../schedulers/scheduling_dpmsolver_sde.py | 450 ++++++++++++++++++ src/diffusers/schedulers/scheduling_utils.py | 1 + 4 files changed, 453 insertions(+) create mode 100644 src/diffusers/schedulers/scheduling_dpmsolver_sde.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index f8ac91c0eb95..519bd0975750 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -76,6 +76,7 @@ DDPMScheduler, DEISMultistepScheduler, DPMSolverMultistepScheduler, + DPMSolverSDEScheduler, DPMSolverSinglestepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index e5d5bb40633f..d14c0715db08 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -27,6 +27,7 @@ from .scheduling_ddpm import DDPMScheduler from .scheduling_deis_multistep import DEISMultistepScheduler from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler + from .scheduling_dpmsolver_sde import DPMSolverSDEScheduler from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from .scheduling_euler_discrete import EulerDiscreteScheduler diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py new file mode 100644 index 000000000000..ee9e1b82652b --- /dev/null +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -0,0 +1,450 @@ +# Copyright 2023 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torchsde + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +class BatchedBrownianTree: + """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" + + def __init__(self, x, t0, t1, seed=33, **kwargs): + seed = 33 + # print('seed is', seed) + t0, t1, self.sign = self.sort(t0, t1) + w0 = kwargs.get("w0", torch.zeros_like(x)) + if seed is None: + seed = torch.randint(0, 2 ** 63 - 1, []).item() + self.batched = True + try: + assert len(seed) == x.shape[0] + w0 = w0[0] + except TypeError: + seed = [seed] + self.batched = False + self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] + + @staticmethod + def sort(a, b): + return (a, b, 1) if a < b else (b, a, -1) + + def __call__(self, t0, t1): + t0, t1, sign = self.sort(t0, t1) + w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) + return w if self.batched else w[0] + + +class BrownianTreeNoiseSampler: + """A noise sampler backed by a torchsde.BrownianTree. + + Args: + x (Tensor): The tensor whose shape, device and dtype to use to generate + random samples. + sigma_min (float): The low end of the valid interval. + sigma_max (float): The high end of the valid interval. + seed (int or List[int]): The random seed. If a list of seeds is + supplied instead of a single integer, then the noise sampler will use one BrownianTree per batch item, each + with its own seed. + transform (callable): A function that maps sigma to the sampler's + internal timestep. + """ + + def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x): + self.transform = transform + t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max)) + self.tree = BatchedBrownianTree(x, t0, t1, seed) + + def __call__(self, sigma, sigma_next): + t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next)) + return self.tree(t0, t1) / (t1 - t0).abs().sqrt() + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): + """ + Implements Stochastic Sampler (Algorithm 2) from Karras et al. (2022). Based on the original k-diffusion + implementation by Katherine Crowson: + https://github.com/crowsonkb/k-diffusion/blob/41b4cb6df0506694a7776af31349acf082bf6091/k_diffusion/sampling.py#L543 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the + starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the + noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence + of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 2 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, # sensible defaults + beta_end: float = 0.012, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + prediction_type: str = "epsilon", + use_karras_sigmas: Optional[bool] = False, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # set all values + self.set_timesteps(num_train_timesteps, None, num_train_timesteps) + self.use_karras_sigmas = use_karras_sigmas + self.noise_sampler = None + + def index_for_timestep(self, timestep): + indices = (self.timesteps == timestep).nonzero() + if self.state_in_first_order: + pos = -1 + else: + pos = 0 + return indices[pos].item() + + def scale_model_input( + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + ) -> torch.FloatTensor: + """ + Args: + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep + Returns: + `torch.FloatTensor`: scaled input sample + """ + step_index = self.index_for_timestep(timestep) + + sigma = self.sigmas[step_index] + sigma_input = sigma if self.state_in_first_order else self.mid_point_sigma + sample = sample / ((sigma_input ** 2 + 1) ** 0.5) + return sample + + def set_timesteps( + self, + num_inference_steps: int, + device: Union[str, torch.device] = None, + num_train_timesteps: Optional[int] = None, + ): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps + + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + + if self.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + + second_order_timesteps = self._second_order_timesteps(sigmas, log_sigmas) + + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + sigmas = torch.from_numpy(sigmas).to(device=device) + self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = self.sigmas.max() + + timesteps = torch.from_numpy(timesteps) + second_order_timesteps = torch.from_numpy(second_order_timesteps) + timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) + timesteps[1::2] = second_order_timesteps + + if str(device).startswith("mps"): + # mps does not support float64 + self.timesteps = timesteps.to(device, dtype=torch.float32) + else: + self.timesteps = timesteps.to(device=device) + + # empty first order variables + self.sample = None + self.pred_original_sample = None + self.mid_point_sigma = None + + def _second_order_timesteps(self, sigmas, log_sigmas): + def sigma_fn(_t): + return np.exp(-_t) + + def t_fn(_sigma): + return -np.log(_sigma) + + r = 0.5 + timesteps = np.zeros_like(sigmas[:-1]) + for i in range(len(sigmas) - 1): + t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) + h = t_next - t + s = t + h * r + timesteps[i] = self._sigma_to_t(sigma_fn(s), log_sigmas) + return timesteps + + # copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(sigma) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # copied from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, self.num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + @property + def state_in_first_order(self): + return self.sample is None + + def step( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: Union[float, torch.FloatTensor], + sample: Union[torch.FloatTensor, np.ndarray], + return_dict: bool = True, + s_noise: float = 1.0, + ) -> Union[SchedulerOutput, Tuple]: + """ + Args: + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep + (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + step_index = self.index_for_timestep(timestep) + + if self.noise_sampler is None: + min_sigma, max_sigma = self.sigmas[self.sigmas > 0].min(), self.sigmas.max() + self.noise_sampler = BrownianTreeNoiseSampler(sample, min_sigma, max_sigma) + + def sigma_fn(_t): + return _t.neg().exp() + + def t_fn(_sigma): + return _sigma.log().neg() + + if self.state_in_first_order: + sigma = self.sigmas[step_index] + sigma_next = self.sigmas[step_index + 1] + else: + # 2nd order / Heun's method + sigma = self.sigmas[step_index - 1] + sigma_next = self.sigmas[step_index] + + r = 0.5 + t, t_next = t_fn(sigma), t_fn(sigma_next) + h = t_next - t + s = t + h * r + fac = 1 / (2 * r) + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.config.prediction_type == "epsilon": + sigma_input = sigma if self.state_in_first_order else sigma_fn(s) + pred_original_sample = sample - sigma_input * model_output + elif self.config.prediction_type == "v_prediction": + sigma_input = sigma if self.state_in_first_order else sigma_fn(s) + pred_original_sample = model_output * (-sigma_input / (sigma_input ** 2 + 1) ** 0.5) + ( + sample / (sigma_input ** 2 + 1) + ) + elif self.config.prediction_type == "sample": + raise NotImplementedError("prediction_type not implemented yet: sample") + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + if sigma_next == 0: + derivative = (sample - pred_original_sample) / sigma + dt = sigma_next - sigma + prev_sample = sample + derivative * dt + else: + if self.state_in_first_order: + sigma_from = sigma_fn(t) + sigma_to = sigma_fn(s) + sigma_up = min(sigma_to, (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5) + sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 + ancestral_t = t_fn(sigma_down) + prev_sample = (sigma_fn(ancestral_t) / sigma_fn(t)) * sample - ( + t - ancestral_t + ).expm1() * pred_original_sample + prev_sample = prev_sample + self.noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * sigma_up + + # store for 2nd order step + self.sample = sample + self.pred_original_sample = pred_original_sample + self.mid_point_sigma = sigma_fn(s) + else: + # 2nd order + sigma_from = sigma_fn(t) + sigma_to = sigma_fn(t_next) + sigma_up = min(sigma_to, (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5) + sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 + + ancestral_t_next = t_fn(sigma_down) + denoised_d = (1 - fac) * self.pred_original_sample + fac * pred_original_sample + prev_sample = (sigma_fn(ancestral_t_next) / sigma_fn(t)) * self.sample - ( + t - ancestral_t_next + ).expm1() * denoised_d + prev_sample = prev_sample + self.noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * sigma_up + + # free for "first order mode" + self.sample = None + self.pred_original_sample = None + self.mid_point_sigma = None + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.FloatTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + self.timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + step_indices = [self.index_for_timestep(t) for t in timesteps] + + sigma = self.sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index a4121f75d850..54812e6014a0 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -43,6 +43,7 @@ class KarrasDiffusionSchedulers(Enum): KDPM2AncestralDiscreteScheduler = 11 DEISMultistepScheduler = 12 UniPCMultistepScheduler = 13 + DPMSolverSDEScheduler = 6 @dataclass From 680ee43789594f09b983a5e444a09f842627d32c Mon Sep 17 00:00:00 2001 From: njindal Date: Sat, 8 Apr 2023 20:56:45 +0530 Subject: [PATCH 02/15] [2064]: Add stochastic sampler --- src/diffusers/schedulers/scheduling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 54812e6014a0..0f95beb022ac 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -43,7 +43,7 @@ class KarrasDiffusionSchedulers(Enum): KDPM2AncestralDiscreteScheduler = 11 DEISMultistepScheduler = 12 UniPCMultistepScheduler = 13 - DPMSolverSDEScheduler = 6 + DPMSolverSDEScheduler = 14 @dataclass From 5f35f307686a20d87817473be7e67f9b4dc9ba8f Mon Sep 17 00:00:00 2001 From: njindal Date: Sat, 8 Apr 2023 21:04:38 +0530 Subject: [PATCH 03/15] [2064]: Add stochastic sampler --- src/diffusers/schedulers/scheduling_dpmsolver_sde.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index ee9e1b82652b..eec5b3e9320c 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -27,8 +27,6 @@ class BatchedBrownianTree: """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" def __init__(self, x, t0, t1, seed=33, **kwargs): - seed = 33 - # print('seed is', seed) t0, t1, self.sign = self.sort(t0, t1) w0 = kwargs.get("w0", torch.zeros_like(x)) if seed is None: From 259cc0bddd72d1d3648cf5a24cb6fed093ee5ee7 Mon Sep 17 00:00:00 2001 From: njindal Date: Sat, 8 Apr 2023 21:07:06 +0530 Subject: [PATCH 04/15] [2064]: Add stochastic sampler --- .../schedulers/scheduling_dpmsolver_sde.py | 72 +++++++++---------- src/diffusers/utils/dummy_pt_objects.py | 15 ++++ 2 files changed, 51 insertions(+), 36 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index eec5b3e9320c..af9a8184ec21 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -30,7 +30,7 @@ def __init__(self, x, t0, t1, seed=33, **kwargs): t0, t1, self.sign = self.sort(t0, t1) w0 = kwargs.get("w0", torch.zeros_like(x)) if seed is None: - seed = torch.randint(0, 2 ** 63 - 1, []).item() + seed = torch.randint(0, 2**63 - 1, []).item() self.batched = True try: assert len(seed) == x.shape[0] @@ -140,14 +140,14 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): @register_to_config def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.00085, # sensible defaults - beta_end: float = 0.012, - beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, - prediction_type: str = "epsilon", - use_karras_sigmas: Optional[bool] = False, + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, # sensible defaults + beta_end: float = 0.012, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + prediction_type: str = "epsilon", + use_karras_sigmas: Optional[bool] = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -156,7 +156,7 @@ def __init__( elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( - torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_train_timesteps, dtype=torch.float32) ** 2 + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule @@ -181,9 +181,9 @@ def index_for_timestep(self, timestep): return indices[pos].item() def scale_model_input( - self, - sample: torch.FloatTensor, - timestep: Union[float, torch.FloatTensor], + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], ) -> torch.FloatTensor: """ Args: @@ -197,14 +197,14 @@ def scale_model_input( sigma = self.sigmas[step_index] sigma_input = sigma if self.state_in_first_order else self.mid_point_sigma - sample = sample / ((sigma_input ** 2 + 1) ** 0.5) + sample = sample / ((sigma_input**2 + 1) ** 0.5) return sample def set_timesteps( - self, - num_inference_steps: int, - device: Union[str, torch.device] = None, - num_train_timesteps: Optional[int] = None, + self, + num_inference_steps: int, + device: Union[str, torch.device] = None, + num_train_timesteps: Optional[int] = None, ): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -313,12 +313,12 @@ def state_in_first_order(self): return self.sample is None def step( - self, - model_output: Union[torch.FloatTensor, np.ndarray], - timestep: Union[float, torch.FloatTensor], - sample: Union[torch.FloatTensor, np.ndarray], - return_dict: bool = True, - s_noise: float = 1.0, + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: Union[float, torch.FloatTensor], + sample: Union[torch.FloatTensor, np.ndarray], + return_dict: bool = True, + s_noise: float = 1.0, ) -> Union[SchedulerOutput, Tuple]: """ Args: @@ -365,8 +365,8 @@ def t_fn(_sigma): pred_original_sample = sample - sigma_input * model_output elif self.config.prediction_type == "v_prediction": sigma_input = sigma if self.state_in_first_order else sigma_fn(s) - pred_original_sample = model_output * (-sigma_input / (sigma_input ** 2 + 1) ** 0.5) + ( - sample / (sigma_input ** 2 + 1) + pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( + sample / (sigma_input**2 + 1) ) elif self.config.prediction_type == "sample": raise NotImplementedError("prediction_type not implemented yet: sample") @@ -383,11 +383,11 @@ def t_fn(_sigma): if self.state_in_first_order: sigma_from = sigma_fn(t) sigma_to = sigma_fn(s) - sigma_up = min(sigma_to, (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5) - sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 + sigma_up = min(sigma_to, (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5) + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 ancestral_t = t_fn(sigma_down) prev_sample = (sigma_fn(ancestral_t) / sigma_fn(t)) * sample - ( - t - ancestral_t + t - ancestral_t ).expm1() * pred_original_sample prev_sample = prev_sample + self.noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * sigma_up @@ -399,13 +399,13 @@ def t_fn(_sigma): # 2nd order sigma_from = sigma_fn(t) sigma_to = sigma_fn(t_next) - sigma_up = min(sigma_to, (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5) - sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 + sigma_up = min(sigma_to, (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5) + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 ancestral_t_next = t_fn(sigma_down) denoised_d = (1 - fac) * self.pred_original_sample + fac * pred_original_sample prev_sample = (sigma_fn(ancestral_t_next) / sigma_fn(t)) * self.sample - ( - t - ancestral_t_next + t - ancestral_t_next ).expm1() * denoised_d prev_sample = prev_sample + self.noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * sigma_up @@ -420,10 +420,10 @@ def t_fn(_sigma): return SchedulerOutput(prev_sample=prev_sample) def add_noise( - self, - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: torch.FloatTensor, + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 014e193aa32a..5718f857db8c 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -465,6 +465,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class DPMSolverSDEScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class DPMSolverSinglestepScheduler(metaclass=DummyObject): _backends = ["torch"] From 9565cf4590dfb2c25691eb92d2bd9984d560180b Mon Sep 17 00:00:00 2001 From: njindal Date: Sun, 9 Apr 2023 00:18:04 +0530 Subject: [PATCH 05/15] [2064]: Add stochastic sampler --- .../schedulers/scheduling_dpmsolver_sde.py | 81 ++++++++++--------- 1 file changed, 42 insertions(+), 39 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index af9a8184ec21..a612db9742fc 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -261,13 +261,12 @@ def sigma_fn(_t): def t_fn(_sigma): return -np.log(_sigma) - r = 0.5 - timesteps = np.zeros_like(sigmas[:-1]) - for i in range(len(sigmas) - 1): - t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) - h = t_next - t - s = t + h * r - timesteps[i] = self._sigma_to_t(sigma_fn(s), log_sigmas) + midpoint_ratio = 0.5 + t = t_fn(sigmas) + delta_time = np.diff(t) + t_proposed = t[:-1] + delta_time * midpoint_ratio + sig_proposed = sigma_fn(t_proposed) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sig_proposed]) return timesteps # copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t @@ -324,10 +323,11 @@ def step( Args: Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep - (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`): - current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + model_output (Union[torch.FloatTensor, np.ndarray]): Direct output from learned diffusion model. + timestep (Union[float, torch.FloatTensor]): Current discrete timestep in the diffusion chain. + sample (Union[torch.FloatTensor, np.ndarray]): Current instance of sample being created by diffusion process. + return_dict (bool, optional): Option for returning tuple rather than SchedulerOutput class. Defaults to True. + s_noise (float, optional): Scaling factor for the noise added to the sample. Defaults to 1.0. Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When @@ -335,14 +335,16 @@ def step( """ step_index = self.index_for_timestep(timestep) + # Create a noise sampler if it hasn't been created yet if self.noise_sampler is None: min_sigma, max_sigma = self.sigmas[self.sigmas > 0].min(), self.sigmas.max() self.noise_sampler = BrownianTreeNoiseSampler(sample, min_sigma, max_sigma) - def sigma_fn(_t): + def sigma_fn(_t: torch.FloatTensor) -> torch.FloatTensor: return _t.neg().exp() - def t_fn(_sigma): + # Define functions to compute sigma and t from each other + def t_fn(_sigma: torch.FloatTensor) -> torch.FloatTensor: return _sigma.log().neg() if self.state_in_first_order: @@ -353,18 +355,19 @@ def t_fn(_sigma): sigma = self.sigmas[step_index - 1] sigma_next = self.sigmas[step_index] - r = 0.5 + # Set the midpoint and step size for the current step + midpoint_ratio = 0.5 t, t_next = t_fn(sigma), t_fn(sigma_next) - h = t_next - t - s = t + h * r - fac = 1 / (2 * r) + delta_time = t_next - t + t_proposed = t + delta_time * midpoint_ratio + fac = 1 / (2 * midpoint_ratio) # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": - sigma_input = sigma if self.state_in_first_order else sigma_fn(s) + sigma_input = sigma if self.state_in_first_order else sigma_fn(t_proposed) pred_original_sample = sample - sigma_input * model_output elif self.config.prediction_type == "v_prediction": - sigma_input = sigma if self.state_in_first_order else sigma_fn(s) + sigma_input = sigma if self.state_in_first_order else sigma_fn(t_proposed) pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( sample / (sigma_input**2 + 1) ) @@ -380,34 +383,34 @@ def t_fn(_sigma): dt = sigma_next - sigma prev_sample = sample + derivative * dt else: + + def get_ancestral_step(sigma_from, sigma_to): + up = min(sigma_to, (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5) + down = (sigma_to**2 - up**2) ** 0.5 + return down, up, t_fn(down) + + def get_prev_sample(input_latent, predicted_original_sample, anc_t, _t, _t_next, _sigma_up): + previous_sample = (sigma_fn(anc_t) / sigma_fn(_t)) * input_latent - ( + _t - anc_t + ).expm1() * predicted_original_sample + previous_sample = ( + previous_sample + self.noise_sampler(sigma_fn(_t), sigma_fn(_t_next)) * s_noise * _sigma_up + ) + return previous_sample + if self.state_in_first_order: - sigma_from = sigma_fn(t) - sigma_to = sigma_fn(s) - sigma_up = min(sigma_to, (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5) - sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 - ancestral_t = t_fn(sigma_down) - prev_sample = (sigma_fn(ancestral_t) / sigma_fn(t)) * sample - ( - t - ancestral_t - ).expm1() * pred_original_sample - prev_sample = prev_sample + self.noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * sigma_up + sigma_down, sigma_up, ancestral_t = get_ancestral_step(sigma_fn(t), sigma_fn(t_proposed)) + prev_sample = get_prev_sample(sample, pred_original_sample, ancestral_t, t, t_proposed, sigma_up) # store for 2nd order step self.sample = sample self.pred_original_sample = pred_original_sample - self.mid_point_sigma = sigma_fn(s) + self.mid_point_sigma = sigma_fn(t_proposed) else: # 2nd order - sigma_from = sigma_fn(t) - sigma_to = sigma_fn(t_next) - sigma_up = min(sigma_to, (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5) - sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 - - ancestral_t_next = t_fn(sigma_down) + sigma_down, sigma_up, ancestral_t = get_ancestral_step(sigma_fn(t), sigma_fn(t_next)) denoised_d = (1 - fac) * self.pred_original_sample + fac * pred_original_sample - prev_sample = (sigma_fn(ancestral_t_next) / sigma_fn(t)) * self.sample - ( - t - ancestral_t_next - ).expm1() * denoised_d - prev_sample = prev_sample + self.noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * sigma_up + prev_sample = get_prev_sample(self.sample, denoised_d, ancestral_t, t, t_next, sigma_up) # free for "first order mode" self.sample = None From dfe3928c4a77c3f9aa2c6700ffd0eaaa41163926 Mon Sep 17 00:00:00 2001 From: njindal Date: Sun, 9 Apr 2023 00:19:13 +0530 Subject: [PATCH 06/15] [2064]: Add stochastic sampler --- src/diffusers/schedulers/scheduling_dpmsolver_sde.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index a612db9742fc..70d238893508 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -340,10 +340,10 @@ def step( min_sigma, max_sigma = self.sigmas[self.sigmas > 0].min(), self.sigmas.max() self.noise_sampler = BrownianTreeNoiseSampler(sample, min_sigma, max_sigma) + # Define functions to compute sigma and t from each other def sigma_fn(_t: torch.FloatTensor) -> torch.FloatTensor: return _t.neg().exp() - # Define functions to compute sigma and t from each other def t_fn(_sigma: torch.FloatTensor) -> torch.FloatTensor: return _sigma.log().neg() @@ -351,7 +351,7 @@ def t_fn(_sigma: torch.FloatTensor) -> torch.FloatTensor: sigma = self.sigmas[step_index] sigma_next = self.sigmas[step_index + 1] else: - # 2nd order / Heun's method + # 2nd order sigma = self.sigmas[step_index - 1] sigma_next = self.sigmas[step_index] From 15afd1ee31e5779c25f359636c381989901d6848 Mon Sep 17 00:00:00 2001 From: njindal Date: Sun, 9 Apr 2023 08:28:01 +0530 Subject: [PATCH 07/15] [2064]: Add stochastic sampler --- .../schedulers/scheduling_dpmsolver_sde.py | 44 +++++++------------ 1 file changed, 16 insertions(+), 28 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index 70d238893508..0e62adf5a4ae 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -26,7 +26,7 @@ class BatchedBrownianTree: """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" - def __init__(self, x, t0, t1, seed=33, **kwargs): + def __init__(self, x, t0, t1, seed=None, **kwargs): t0, t1, self.sign = self.sort(t0, t1) w0 = kwargs.get("w0", torch.zeros_like(x)) if seed is None: @@ -251,7 +251,6 @@ def set_timesteps( # empty first order variables self.sample = None - self.pred_original_sample = None self.mid_point_sigma = None def _second_order_timesteps(self, sigmas, log_sigmas): @@ -360,7 +359,6 @@ def t_fn(_sigma: torch.FloatTensor) -> torch.FloatTensor: t, t_next = t_fn(sigma), t_fn(sigma_next) delta_time = t_next - t t_proposed = t + delta_time * midpoint_ratio - fac = 1 / (2 * midpoint_ratio) # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": @@ -383,38 +381,28 @@ def t_fn(_sigma: torch.FloatTensor) -> torch.FloatTensor: dt = sigma_next - sigma prev_sample = sample + derivative * dt else: - - def get_ancestral_step(sigma_from, sigma_to): - up = min(sigma_to, (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5) - down = (sigma_to**2 - up**2) ** 0.5 - return down, up, t_fn(down) - - def get_prev_sample(input_latent, predicted_original_sample, anc_t, _t, _t_next, _sigma_up): - previous_sample = (sigma_fn(anc_t) / sigma_fn(_t)) * input_latent - ( - _t - anc_t - ).expm1() * predicted_original_sample - previous_sample = ( - previous_sample + self.noise_sampler(sigma_fn(_t), sigma_fn(_t_next)) * s_noise * _sigma_up - ) - return previous_sample - if self.state_in_first_order: - sigma_down, sigma_up, ancestral_t = get_ancestral_step(sigma_fn(t), sigma_fn(t_proposed)) - prev_sample = get_prev_sample(sample, pred_original_sample, ancestral_t, t, t_proposed, sigma_up) + t_next = t_proposed + else: + sample = self.sample + + sigma_from = sigma_fn(t) + sigma_to = sigma_fn(t_next) + sigma_up = min(sigma_to, (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5) + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + ancestral_t = t_fn(sigma_down) + prev_sample = (sigma_fn(ancestral_t) / sigma_fn(t)) * sample - ( + t - ancestral_t + ).expm1() * pred_original_sample + prev_sample = prev_sample + self.noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * sigma_up + if self.state_in_first_order: # store for 2nd order step self.sample = sample - self.pred_original_sample = pred_original_sample - self.mid_point_sigma = sigma_fn(t_proposed) + self.mid_point_sigma = sigma_fn(t_next) else: - # 2nd order - sigma_down, sigma_up, ancestral_t = get_ancestral_step(sigma_fn(t), sigma_fn(t_next)) - denoised_d = (1 - fac) * self.pred_original_sample + fac * pred_original_sample - prev_sample = get_prev_sample(self.sample, denoised_d, ancestral_t, t, t_next, sigma_up) - # free for "first order mode" self.sample = None - self.pred_original_sample = None self.mid_point_sigma = None if not return_dict: From f3d8c7c3189e3d03243f1beea0887bcc7feda371 Mon Sep 17 00:00:00 2001 From: njindal Date: Wed, 12 Apr 2023 14:54:51 +0530 Subject: [PATCH 08/15] Review comments --- src/diffusers/schedulers/scheduling_dpmsolver_sde.py | 2 -- src/diffusers/schedulers/scheduling_heun_discrete.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index 0e62adf5a4ae..1c115a3ef2d4 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -123,8 +123,6 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): `linear` or `scaled_linear`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, - `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index f7f1467fc53a..baed49dc20a2 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -70,8 +70,6 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): `linear` or `scaled_linear`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, - `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 From e55832d6fbea6b06cdde8324066b0427875a712f Mon Sep 17 00:00:00 2001 From: njindal Date: Wed, 26 Apr 2023 18:03:06 +0530 Subject: [PATCH 09/15] [Review comment]: Add is_torchsde_available() --- src/diffusers/__init__.py | 9 ++++++++- src/diffusers/schedulers/__init__.py | 17 +++++++++++++++-- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/dummy_pt_objects.py | 15 --------------- .../utils/dummy_torch_and_torchsde_objects.py | 17 +++++++++++++++++ src/diffusers/utils/import_utils.py | 17 +++++++++++++++++ src/diffusers/utils/testing_utils.py | 8 ++++++++ 7 files changed, 66 insertions(+), 18 deletions(-) create mode 100644 src/diffusers/utils/dummy_torch_and_torchsde_objects.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 603b304ef6c1..18326d5d6a79 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -12,6 +12,7 @@ is_onnx_available, is_scipy_available, is_torch_available, + is_torchsde_available, is_transformers_available, is_transformers_version, is_unidecode_available, @@ -76,7 +77,6 @@ DDPMScheduler, DEISMultistepScheduler, DPMSolverMultistepScheduler, - DPMSolverSDEScheduler, DPMSolverSinglestepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, @@ -103,6 +103,13 @@ else: from .schedulers import LMSDiscreteScheduler +try: + if not (is_torch_available() and is_torchsde_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_torchsde_objects import * # noqa F403 +else: + from .schedulers import DPMSolverSDEScheduler try: if not (is_torch_available() and is_transformers_available()): diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index d14c0715db08..c4b62c722257 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -13,7 +13,13 @@ # limitations under the License. -from ..utils import OptionalDependencyNotAvailable, is_flax_available, is_scipy_available, is_torch_available +from ..utils import ( + OptionalDependencyNotAvailable, + is_flax_available, + is_scipy_available, + is_torch_available, + is_torchsde_available, +) try: @@ -27,7 +33,6 @@ from .scheduling_ddpm import DDPMScheduler from .scheduling_deis_multistep import DEISMultistepScheduler from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler - from .scheduling_dpmsolver_sde import DPMSolverSDEScheduler from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from .scheduling_euler_discrete import EulerDiscreteScheduler @@ -73,3 +78,11 @@ from ..utils.dummy_torch_and_scipy_objects import * # noqa F403 else: from .scheduling_lms_discrete import LMSDiscreteScheduler + +try: + if not (is_torch_available() and is_torchsde_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_torchsde_objects import * # noqa F403 +else: + from .scheduling_dpmsolver_sde import DPMSolverSDEScheduler diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 1b8eca050c9e..f3e4c9d1d0ec 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -70,6 +70,7 @@ is_tf_available, is_torch_available, is_torch_version, + is_torchsde_available, is_transformers_available, is_transformers_version, is_unidecode_available, diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 5718f857db8c..014e193aa32a 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -465,21 +465,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class DPMSolverSDEScheduler(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - class DPMSolverSinglestepScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_torchsde_objects.py b/src/diffusers/utils/dummy_torch_and_torchsde_objects.py new file mode 100644 index 000000000000..a81bbb316f32 --- /dev/null +++ b/src/diffusers/utils/dummy_torch_and_torchsde_objects.py @@ -0,0 +1,17 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class DPMSolverSDEScheduler(metaclass=DummyObject): + _backends = ["torch", "torchsde"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "torchsde"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "torchsde"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "torchsde"]) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 2d90cb9747a7..88b2d98c5bf8 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -287,6 +287,13 @@ except importlib_metadata.PackageNotFoundError: _bs4_available = False +_torchsde_available = importlib.util.find_spec("torchsde") +try: + _torchsde_version = importlib_metadata.version("torchsde") + logger.debug(f"Successfully imported torchsde version {_torchsde_version}") +except importlib_metadata.PackageNotFoundError: + _torchsde_available = False + def is_torch_available(): return _torch_available @@ -372,6 +379,10 @@ def is_bs4_available(): return _bs4_available +def is_torchsde_available(): + return _torchsde_available + + # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the @@ -475,6 +486,11 @@ def is_bs4_available(): that match your environment. Please note that you may need to restart your runtime after installation. """ +# docstyle-ignore +TORCHSDE_IMPORT_ERROR = """ +{0} requires the torchsde library but it was not found in your environment. You can install it with pip: `pip install torchsde` +""" + BACKENDS_MAPPING = OrderedDict( [ @@ -495,6 +511,7 @@ def is_bs4_available(): ("tensorboard", (_tensorboard_available, TENSORBOARD_IMPORT_ERROR)), ("compel", (_compel_available, COMPEL_IMPORT_ERROR)), ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), + ("torchsde", (_torchsde_available, TORCHSDE_IMPORT_ERROR)), ] ) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index d8fed5dec1c8..4ad7d97b4462 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -26,6 +26,7 @@ is_opencv_available, is_torch_available, is_torch_version, + is_torchsde_available, ) from .logging import get_logger @@ -216,6 +217,13 @@ def require_note_seq(test_case): return unittest.skipUnless(is_note_seq_available(), "test requires note_seq")(test_case) +def require_torchsde(test_case): + """ + Decorator marking a test that requires torchsde. These tests are skipped when torchsde isn't installed. + """ + return unittest.skipUnless(is_torchsde_available(), "test requires torchsde")(test_case) + + def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -> np.ndarray: if isinstance(arry, str): # local_path = "/home/patrick_huggingface_co/" From 01c64b809d5475a4b343b9f62e38aeb81f5ea973 Mon Sep 17 00:00:00 2001 From: njindal Date: Wed, 26 Apr 2023 19:51:29 +0530 Subject: [PATCH 10/15] [Review comment]: Test and docs --- docs/source/en/api/schedulers/dpm_sde.mdx | 23 +++ .../schedulers/scheduling_dpmsolver_sde.py | 26 +++- tests/schedulers/test_scheduler_dpm_sde.py | 141 ++++++++++++++++++ 3 files changed, 182 insertions(+), 8 deletions(-) create mode 100644 docs/source/en/api/schedulers/dpm_sde.mdx create mode 100644 tests/schedulers/test_scheduler_dpm_sde.py diff --git a/docs/source/en/api/schedulers/dpm_sde.mdx b/docs/source/en/api/schedulers/dpm_sde.mdx new file mode 100644 index 000000000000..4e94e6c4050e --- /dev/null +++ b/docs/source/en/api/schedulers/dpm_sde.mdx @@ -0,0 +1,23 @@ + + +# Heun scheduler inspired by Karras et. al paper + +## Overview + +Implements Stochastic Sampler from [Karras et. al](https://arxiv.org/abs/2206.00364). +Scheduler ported from @crowsonkb's https://github.com/crowsonkb/k-diffusion library: + +All credit for making this scheduler work goes to [Katherine Crowson](https://github.com/crowsonkb/) + +## DPMSolverSDEScheduler +[[autodoc]] DPMSolverSDEScheduler \ No newline at end of file diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index 1c115a3ef2d4..ae9229981152 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -131,6 +131,8 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + noise_sampler_seed (`int`, *optional*, defaults to `None`): + The random seed to use for the noise sampler. If `None`, a random seed will be generated. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -146,6 +148,7 @@ def __init__( trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", use_karras_sigmas: Optional[bool] = False, + noise_sampler_seed: Optional[int] = None, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -169,9 +172,15 @@ def __init__( self.set_timesteps(num_train_timesteps, None, num_train_timesteps) self.use_karras_sigmas = use_karras_sigmas self.noise_sampler = None + self.noise_sampler_seed = noise_sampler_seed + + # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() - def index_for_timestep(self, timestep): - indices = (self.timesteps == timestep).nonzero() if self.state_in_first_order: pos = -1 else: @@ -335,7 +344,7 @@ def step( # Create a noise sampler if it hasn't been created yet if self.noise_sampler is None: min_sigma, max_sigma = self.sigmas[self.sigmas > 0].min(), self.sigmas.max() - self.noise_sampler = BrownianTreeNoiseSampler(sample, min_sigma, max_sigma) + self.noise_sampler = BrownianTreeNoiseSampler(sample, min_sigma, max_sigma, self.noise_sampler_seed) # Define functions to compute sigma and t from each other def sigma_fn(_t: torch.FloatTensor) -> torch.FloatTensor: @@ -408,6 +417,7 @@ def t_fn(_sigma: torch.FloatTensor) -> torch.FloatTensor: return SchedulerOutput(prev_sample=prev_sample) + # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, @@ -415,18 +425,18 @@ def add_noise( timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples - self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 - self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: - self.timesteps = self.timesteps.to(original_samples.device) + schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - step_indices = [self.index_for_timestep(t) for t in timesteps] + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] - sigma = self.sigmas[step_indices].flatten() + sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) diff --git a/tests/schedulers/test_scheduler_dpm_sde.py b/tests/schedulers/test_scheduler_dpm_sde.py new file mode 100644 index 000000000000..01b909337fc3 --- /dev/null +++ b/tests/schedulers/test_scheduler_dpm_sde.py @@ -0,0 +1,141 @@ +import torch + +from diffusers import DPMSolverSDEScheduler +from diffusers.utils import torch_device +from diffusers.utils.testing_utils import require_torchsde + +from .test_schedulers import SchedulerCommonTest + + +@require_torchsde +class DPMSolverSDESchedulerTest(SchedulerCommonTest): + scheduler_classes = (DPMSolverSDEScheduler,) + num_inference_steps = 10 + + def get_scheduler_config(self, **kwargs): + # ?? + config = { + "num_train_timesteps": 1100, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + "noise_sampler_seed": 0, + } + + config.update(**kwargs) + return config + + def test_timesteps(self): + for timesteps in [10, 50, 100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_betas(self): + for beta_start, beta_end in zip([0.00001, 0.0001, 0.001], [0.0002, 0.002, 0.02]): + self.check_over_configs(beta_start=beta_start, beta_end=beta_end) + + def test_schedules(self): + for schedule in ["linear", "scaled_linear"]: + self.check_over_configs(beta_schedule=schedule) + + def test_prediction_type(self): + for prediction_type in ["epsilon", "v_prediction"]: + self.check_over_configs(prediction_type=prediction_type) + + def test_full_loop_no_noise(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = sample.to(torch_device) + + for i, t in enumerate(scheduler.timesteps): + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 167.47821044921875) < 1e-2 + assert abs(result_mean.item() - 0.2178705964565277) < 1e-3 + + def test_full_loop_with_v_prediction(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(prediction_type="v_prediction") + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = sample.to(torch_device) + + for i, t in enumerate(scheduler.timesteps): + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 124.77149200439453) < 1e-2 + assert abs(result_mean.item() - 0.16226289014816284) < 1e-3 + + def test_full_loop_device(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps, device=torch_device) + + model = self.dummy_model() + sample = self.dummy_sample_deter.to(torch_device) * scheduler.init_noise_sigma + + for t in scheduler.timesteps: + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 167.46957397460938) < 1e-2 + assert abs(result_mean.item() - 0.21805934607982635) < 1e-3 + + def test_full_loop_device_karras_sigmas(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config, use_karras_sigmas=True) + + scheduler.set_timesteps(self.num_inference_steps, device=torch_device) + + model = self.dummy_model() + sample = self.dummy_sample_deter.to(torch_device) * scheduler.init_noise_sigma + sample = sample.to(torch_device) + + for t in scheduler.timesteps: + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 176.66974135742188) < 1e-2 + assert abs(result_mean.item() - 0.23003872730981811) < 1e-2 From a45722c16d69b96b27b459df83220cb6915ef29d Mon Sep 17 00:00:00 2001 From: njindal Date: Wed, 26 Apr 2023 19:53:43 +0530 Subject: [PATCH 11/15] [Review comment] --- src/diffusers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 18326d5d6a79..17047eece7f7 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -107,7 +107,7 @@ if not (is_torch_available() and is_torchsde_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from ..utils.dummy_torch_and_torchsde_objects import * # noqa F403 + from .utils.dummy_torch_and_torchsde_objects import * # noqa F403 else: from .schedulers import DPMSolverSDEScheduler From 8cbe43e0917a0a101de07f10b85f1fc72ad13ba4 Mon Sep 17 00:00:00 2001 From: njindal Date: Wed, 26 Apr 2023 20:01:32 +0530 Subject: [PATCH 12/15] [Review comment] --- tests/schedulers/test_scheduler_dpm_sde.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/schedulers/test_scheduler_dpm_sde.py b/tests/schedulers/test_scheduler_dpm_sde.py index 01b909337fc3..4354b93e9cf2 100644 --- a/tests/schedulers/test_scheduler_dpm_sde.py +++ b/tests/schedulers/test_scheduler_dpm_sde.py @@ -63,7 +63,7 @@ def test_full_loop_no_noise(self): result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_sum.item() - 167.47821044921875) < 1e-2 + assert abs(result_sum.item() - 162.52383422851562) < 1e-2 assert abs(result_mean.item() - 0.2178705964565277) < 1e-3 def test_full_loop_with_v_prediction(self): @@ -88,7 +88,7 @@ def test_full_loop_with_v_prediction(self): result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_sum.item() - 124.77149200439453) < 1e-2 + assert abs(result_sum.item() - 119.8487548828125) < 1e-2 assert abs(result_mean.item() - 0.16226289014816284) < 1e-3 def test_full_loop_device(self): @@ -112,7 +112,7 @@ def test_full_loop_device(self): result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_sum.item() - 167.46957397460938) < 1e-2 + assert abs(result_sum.item() - 162.52383422851562) < 1e-2 assert abs(result_mean.item() - 0.21805934607982635) < 1e-3 def test_full_loop_device_karras_sigmas(self): @@ -137,5 +137,5 @@ def test_full_loop_device_karras_sigmas(self): result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_sum.item() - 176.66974135742188) < 1e-2 + assert abs(result_sum.item() - 170.3135223388672) < 1e-2 assert abs(result_mean.item() - 0.23003872730981811) < 1e-2 From f8e364693e4db8cb7df5b0996a114c61b666d19c Mon Sep 17 00:00:00 2001 From: njindal Date: Wed, 26 Apr 2023 20:10:34 +0530 Subject: [PATCH 13/15] [Review comment] --- tests/schedulers/test_scheduler_dpm_sde.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/schedulers/test_scheduler_dpm_sde.py b/tests/schedulers/test_scheduler_dpm_sde.py index 4354b93e9cf2..c7be5613770f 100644 --- a/tests/schedulers/test_scheduler_dpm_sde.py +++ b/tests/schedulers/test_scheduler_dpm_sde.py @@ -64,7 +64,7 @@ def test_full_loop_no_noise(self): result_mean = torch.mean(torch.abs(sample)) assert abs(result_sum.item() - 162.52383422851562) < 1e-2 - assert abs(result_mean.item() - 0.2178705964565277) < 1e-3 + assert abs(result_mean.item() - 0.211619570851326) < 1e-3 def test_full_loop_with_v_prediction(self): scheduler_class = self.scheduler_classes[0] @@ -89,7 +89,7 @@ def test_full_loop_with_v_prediction(self): result_mean = torch.mean(torch.abs(sample)) assert abs(result_sum.item() - 119.8487548828125) < 1e-2 - assert abs(result_mean.item() - 0.16226289014816284) < 1e-3 + assert abs(result_mean.item() - 0.1560530662536621) < 1e-3 def test_full_loop_device(self): scheduler_class = self.scheduler_classes[0] @@ -113,7 +113,7 @@ def test_full_loop_device(self): result_mean = torch.mean(torch.abs(sample)) assert abs(result_sum.item() - 162.52383422851562) < 1e-2 - assert abs(result_mean.item() - 0.21805934607982635) < 1e-3 + assert abs(result_mean.item() - 0.211619570851326) < 1e-3 def test_full_loop_device_karras_sigmas(self): scheduler_class = self.scheduler_classes[0] From 5ab1c6e5e97686bf32cc8f76f87cec1a03215c5d Mon Sep 17 00:00:00 2001 From: njindal Date: Wed, 26 Apr 2023 20:38:00 +0530 Subject: [PATCH 14/15] [Review comment] --- docs/source/en/_toctree.yml | 2 ++ tests/schedulers/test_scheduler_dpm_sde.py | 32 ++++++++++++++++------ 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index ccaaff7ca680..35c5fd78a1f6 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -266,6 +266,8 @@ title: VP-SDE - local: api/schedulers/vq_diffusion title: VQDiffusionScheduler + - local: api/schedulers/dpm_sde + title: DPMSolverSDEScheduler title: Schedulers - sections: - local: api/experimental/rl diff --git a/tests/schedulers/test_scheduler_dpm_sde.py b/tests/schedulers/test_scheduler_dpm_sde.py index c7be5613770f..eec07c29ab02 100644 --- a/tests/schedulers/test_scheduler_dpm_sde.py +++ b/tests/schedulers/test_scheduler_dpm_sde.py @@ -63,8 +63,12 @@ def test_full_loop_no_noise(self): result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_sum.item() - 162.52383422851562) < 1e-2 - assert abs(result_mean.item() - 0.211619570851326) < 1e-3 + if torch_device in ["mps"]: + assert abs(result_sum.item() - 167.47821044921875) < 1e-2 + assert abs(result_mean.item() - 0.2178705964565277) < 1e-3 + else: + assert abs(result_sum.item() - 162.52383422851562) < 1e-2 + assert abs(result_mean.item() - 0.211619570851326) < 1e-3 def test_full_loop_with_v_prediction(self): scheduler_class = self.scheduler_classes[0] @@ -88,8 +92,12 @@ def test_full_loop_with_v_prediction(self): result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_sum.item() - 119.8487548828125) < 1e-2 - assert abs(result_mean.item() - 0.1560530662536621) < 1e-3 + if torch_device in ["mps"]: + assert abs(result_sum.item() - 124.77149200439453) < 1e-2 + assert abs(result_mean.item() - 0.16226289014816284) < 1e-3 + else: + assert abs(result_sum.item() - 119.8487548828125) < 1e-2 + assert abs(result_mean.item() - 0.1560530662536621) < 1e-3 def test_full_loop_device(self): scheduler_class = self.scheduler_classes[0] @@ -112,8 +120,12 @@ def test_full_loop_device(self): result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_sum.item() - 162.52383422851562) < 1e-2 - assert abs(result_mean.item() - 0.211619570851326) < 1e-3 + if torch_device in ["mps"]: + assert abs(result_sum.item() - 167.46957397460938) < 1e-2 + assert abs(result_mean.item() - 0.21805934607982635) < 1e-3 + else: + assert abs(result_sum.item() - 162.52383422851562) < 1e-2 + assert abs(result_mean.item() - 0.211619570851326) < 1e-3 def test_full_loop_device_karras_sigmas(self): scheduler_class = self.scheduler_classes[0] @@ -137,5 +149,9 @@ def test_full_loop_device_karras_sigmas(self): result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_sum.item() - 170.3135223388672) < 1e-2 - assert abs(result_mean.item() - 0.23003872730981811) < 1e-2 + if torch_device in ["mps"]: + assert abs(result_sum.item() - 176.66974135742188) < 1e-2 + assert abs(result_mean.item() - 0.23003872730981811) < 1e-2 + else: + assert abs(result_sum.item() - 170.3135223388672) < 1e-2 + assert abs(result_mean.item() - 0.23003872730981811) < 1e-2 From 2e6d5e4cc809b4ccbf02a0d82ab344ed324316ea Mon Sep 17 00:00:00 2001 From: njindal Date: Thu, 27 Apr 2023 08:08:38 +0530 Subject: [PATCH 15/15] [Review comment] --- docs/source/en/api/schedulers/dpm_sde.mdx | 4 ++-- src/diffusers/utils/import_utils.py | 2 +- tests/schedulers/test_scheduler_dpm_sde.py | 1 - 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/docs/source/en/api/schedulers/dpm_sde.mdx b/docs/source/en/api/schedulers/dpm_sde.mdx index 4e94e6c4050e..33ec514cef64 100644 --- a/docs/source/en/api/schedulers/dpm_sde.mdx +++ b/docs/source/en/api/schedulers/dpm_sde.mdx @@ -10,11 +10,11 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# Heun scheduler inspired by Karras et. al paper +# DPM Stochastic Scheduler inspired by Karras et. al paper ## Overview -Implements Stochastic Sampler from [Karras et. al](https://arxiv.org/abs/2206.00364). +Inspired by Stochastic Sampler from [Karras et. al](https://arxiv.org/abs/2206.00364). Scheduler ported from @crowsonkb's https://github.com/crowsonkb/k-diffusion library: All credit for making this scheduler work goes to [Katherine Crowson](https://github.com/crowsonkb/) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 88b2d98c5bf8..4ded0f272462 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -287,7 +287,7 @@ except importlib_metadata.PackageNotFoundError: _bs4_available = False -_torchsde_available = importlib.util.find_spec("torchsde") +_torchsde_available = importlib.util.find_spec("torchsde") is not None try: _torchsde_version = importlib_metadata.version("torchsde") logger.debug(f"Successfully imported torchsde version {_torchsde_version}") diff --git a/tests/schedulers/test_scheduler_dpm_sde.py b/tests/schedulers/test_scheduler_dpm_sde.py index eec07c29ab02..010c4bdb1196 100644 --- a/tests/schedulers/test_scheduler_dpm_sde.py +++ b/tests/schedulers/test_scheduler_dpm_sde.py @@ -13,7 +13,6 @@ class DPMSolverSDESchedulerTest(SchedulerCommonTest): num_inference_steps = 10 def get_scheduler_config(self, **kwargs): - # ?? config = { "num_train_timesteps": 1100, "beta_start": 0.0001,