Skip to content
32 changes: 26 additions & 6 deletions src/diffusers/schedulers/scheduling_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,33 @@

import math
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin


@dataclass
class DDIMSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's step function output.

Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""

prev_sample: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
Expand Down Expand Up @@ -179,7 +199,7 @@ def step(
use_clipped_model_output: bool = False,
generator=None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
) -> Union[DDIMSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Expand All @@ -192,11 +212,11 @@ def step(
eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): TODO
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class

Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
[`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.

"""
Expand Down Expand Up @@ -261,7 +281,7 @@ def step(
if not return_dict:
return (prev_sample,)

return SchedulerOutput(prev_sample=prev_sample)
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)

def add_noise(
self,
Expand Down
32 changes: 26 additions & 6 deletions src/diffusers/schedulers/scheduling_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,33 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim

import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin


@dataclass
class DDPMSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's step function output.

Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""

prev_sample: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
Expand Down Expand Up @@ -177,7 +197,7 @@ def step(
predict_epsilon=True,
generator=None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
) -> Union[DDPMSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Expand All @@ -190,11 +210,11 @@ def step(
predict_epsilon (`bool`):
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class

Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
[`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.

"""
Expand Down Expand Up @@ -242,7 +262,7 @@ def step(
if not return_dict:
return (pred_prev_sample,)

return SchedulerOutput(prev_sample=pred_prev_sample)
return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)

def add_noise(
self,
Expand Down
16 changes: 12 additions & 4 deletions src/diffusers/schedulers/scheduling_karras_ve.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,14 @@ class KarrasVeOutput(BaseOutput):
denoising loop.
derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Derivative of predicted original image sample (x_0).
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""

prev_sample: torch.FloatTensor
derivative: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None


class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
Expand Down Expand Up @@ -153,7 +157,7 @@ def step(
sigma_hat (`float`): TODO
sigma_prev (`float`): TODO
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class

KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check).
Returns:
Expand All @@ -170,7 +174,9 @@ def step(
if not return_dict:
return (sample_prev, derivative)

return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative)
return KarrasVeOutput(
prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
)

def step_correct(
self,
Expand All @@ -192,7 +198,7 @@ def step_correct(
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
derivative (`torch.FloatTensor` or `np.ndarray`): TODO
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class

Returns:
prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
Expand All @@ -205,7 +211,9 @@ def step_correct(
if not return_dict:
return (sample_prev, derivative)

return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative)
return KarrasVeOutput(
prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
)

def add_noise(self, original_samples, noise, timesteps):
raise NotImplementedError()
34 changes: 27 additions & 7 deletions src/diffusers/schedulers/scheduling_lms_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Optional, Tuple, Union

import numpy as np
Expand All @@ -20,7 +21,26 @@
from scipy import integrate

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin


@dataclass
class LMSDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's step function output.

Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""

prev_sample: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None


class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
Expand Down Expand Up @@ -133,7 +153,7 @@ def step(
sample: Union[torch.FloatTensor, np.ndarray],
order: int = 4,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
) -> Union[LMSDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Expand All @@ -144,12 +164,12 @@ def step(
sample (`torch.FloatTensor` or `np.ndarray`):
current instance of sample being created by diffusion process.
order: coefficient for multi-step inference.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than LMSDiscreteSchedulerOutput class

Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
[`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.

"""
sigma = self.sigmas[timestep]
Expand All @@ -175,7 +195,7 @@ def step(
if not return_dict:
return (prev_sample,)

return SchedulerOutput(prev_sample=prev_sample)
return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)

def add_noise(
self,
Expand Down