From 01fa779f4e18044df0d405250fa5a6d2799c5c41 Mon Sep 17 00:00:00 2001 From: Crized-bit Date: Mon, 18 Nov 2024 21:07:58 +0000 Subject: [PATCH] Added noise --- .../scheduling_dpmsolver_multistep.py | 75 ++++++++++++++++++- 1 file changed, 71 insertions(+), 4 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 7677e37e9426..ab20887eafe1 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -16,7 +16,7 @@ import math from typing import List, Optional, Tuple, Union - +import torchsde import numpy as np import torch @@ -26,6 +26,57 @@ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput +class BatchedBrownianTree: + """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" + + def __init__(self, x, t0, t1, seed=None, **kwargs): + t0, t1, self.sign = self.sort(t0, t1) + w0 = kwargs.get('w0', torch.zeros_like(x)) + if seed is None: + seed = torch.randint(0, 2 ** 63 - 1, []).item() + self.batched = True + try: + assert len(seed) == x.shape[0] + w0 = w0[0] + except TypeError: + seed = [seed] + self.batched = False + self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] + + @staticmethod + def sort(a, b): + return (a, b, 1) if a < b else (b, a, -1) + + def __call__(self, t0, t1): + t0, t1, sign = self.sort(t0, t1) + w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) + return w if self.batched else w[0] + + +class BrownianTreeNoiseSampler: + """A noise sampler backed by a torchsde.BrownianTree. + + Args: + x (Tensor): The tensor whose shape, device and dtype to use to generate + random samples. + sigma_min (float): The low end of the valid interval. + sigma_max (float): The high end of the valid interval. + seed (int or List[int]): The random seed. If a list of seeds is + supplied instead of a single integer, then the noise sampler will + use one BrownianTree per batch item, each with its own seed. + transform (callable): A function that maps sigma to the sampler's + internal timestep. + """ + + def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x): + self.transform = transform + t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max)) + self.tree = BatchedBrownianTree(x, t0, t1, seed) + + def __call__(self, sigma, sigma_next): + t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next)) + return self.tree(t0, t1) / (t1 - t0).abs().sqrt() + if is_scipy_available(): import scipy.stats @@ -207,6 +258,7 @@ def __init__( trained_betas: Optional[Union[np.ndarray, List[float]]] = None, solver_order: int = 2, prediction_type: str = "epsilon", + do_brownian_noise: bool = False, thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, @@ -252,7 +304,7 @@ def __init__( self.betas = rescale_zero_terminal_snr(self.betas) self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0, dtype=torch.float32) if rescale_betas_zero_snr: # Close to 0 without being 0 so first sigma is not inf @@ -294,6 +346,7 @@ def __init__( self.lower_order_nums = 0 self._step_index = None self._begin_index = None + self.do_brownian_noise = do_brownian_noise self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property @@ -400,7 +453,8 @@ def set_timesteps( sigmas = np.exp(lambdas) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() elif self.config.use_exponential_sigmas: - sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + sigmas = np.flip(sigmas).copy() + sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) elif self.config.use_beta_sigmas: sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) @@ -1040,6 +1094,7 @@ def step( if self.step_index is None: self._init_step_index(timestep) + # Improve numerical stability for small number of steps lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( self.config.euler_at_final @@ -1057,10 +1112,22 @@ def step( # Upcast to avoid precision issues when computing prev_sample sample = sample.to(torch.float32) - if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None: + + if self.do_brownian_noise: + sigma_min, sigma_max = self.sigmas[self.sigmas > 0].min().item(), self.sigmas.max().item() + sampler = BrownianTreeNoiseSampler(x=sample, + sigma_max=sigma_max, + sigma_min=sigma_min, + seed=generator.initial_seed() if generator else None) + + if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None and not self.do_brownian_noise: noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32 ) + elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and self.do_brownian_noise: + if self.sigmas[self.step_index + 1] == 0: + return SchedulerOutput(prev_sample=model_output) + noise = sampler(self.sigmas[self.step_index], self.sigmas[self.step_index + 1]) elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: noise = variance_noise.to(device=model_output.device, dtype=torch.float32) else: