Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 69 additions & 3 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import math
from typing import List, Optional, Tuple, Union

import torchsde
import numpy as np
import torch

Expand All @@ -26,6 +26,57 @@
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput


class BatchedBrownianTree:
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""

def __init__(self, x, t0, t1, seed=None, **kwargs):
t0, t1, self.sign = self.sort(t0, t1)
w0 = kwargs.get('w0', torch.zeros_like(x))
if seed is None:
seed = torch.randint(0, 2 ** 63 - 1, []).item()
self.batched = True
try:
assert len(seed) == x.shape[0]
w0 = w0[0]
except TypeError:
seed = [seed]
self.batched = False
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]

@staticmethod
def sort(a, b):
return (a, b, 1) if a < b else (b, a, -1)

def __call__(self, t0, t1):
t0, t1, sign = self.sort(t0, t1)
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
return w if self.batched else w[0]


class BrownianTreeNoiseSampler:
"""A noise sampler backed by a torchsde.BrownianTree.

Args:
x (Tensor): The tensor whose shape, device and dtype to use to generate
random samples.
sigma_min (float): The low end of the valid interval.
sigma_max (float): The high end of the valid interval.
seed (int or List[int]): The random seed. If a list of seeds is
supplied instead of a single integer, then the noise sampler will
use one BrownianTree per batch item, each with its own seed.
transform (callable): A function that maps sigma to the sampler's
internal timestep.
"""

def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
self.transform = transform
t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
self.tree = BatchedBrownianTree(x, t0, t1, seed)

def __call__(self, sigma, sigma_next):
t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()

if is_scipy_available():
import scipy.stats

Expand Down Expand Up @@ -207,6 +258,7 @@ def __init__(
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
solver_order: int = 2,
prediction_type: str = "epsilon",
do_brownian_noise: bool = False,
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
Expand Down Expand Up @@ -252,7 +304,7 @@ def __init__(
self.betas = rescale_zero_terminal_snr(self.betas)

self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0, dtype=torch.float32)

if rescale_betas_zero_snr:
# Close to 0 without being 0 so first sigma is not inf
Expand Down Expand Up @@ -294,6 +346,7 @@ def __init__(
self.lower_order_nums = 0
self._step_index = None
self._begin_index = None
self.do_brownian_noise = do_brownian_noise
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

@property
Expand Down Expand Up @@ -1042,6 +1095,7 @@ def step(
if self.step_index is None:
self._init_step_index(timestep)


# Improve numerical stability for small number of steps
lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
self.config.euler_at_final
Expand All @@ -1059,10 +1113,22 @@ def step(

# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:

if self.do_brownian_noise:
sigma_min, sigma_max = self.sigmas[self.sigmas > 0].min().item(), self.sigmas.max().item()
sampler = BrownianTreeNoiseSampler(x=sample,
sigma_max=sigma_max,
sigma_min=sigma_min,
seed=generator.initial_seed() if generator else None)

if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None and not self.do_brownian_noise:
noise = randn_tensor(
model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32
)
elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and self.do_brownian_noise:
if self.sigmas[self.step_index + 1] == 0:
return SchedulerOutput(prev_sample=model_output)
noise = sampler(self.sigmas[self.step_index], self.sigmas[self.step_index + 1])
elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
noise = variance_noise.to(device=model_output.device, dtype=torch.float32)
else:
Expand Down