From 66580c44de350f813c3049096a429cdb78103254 Mon Sep 17 00:00:00 2001 From: Jonathan Whitaker Date: Thu, 22 Sep 2022 10:09:40 +0000 Subject: [PATCH 1/6] Adding pred_original_sample to SchedulerOutput of DDPMScheduler, DDIMScheduler, LMSDiscreteScheduler, KarrasVeScheduler step methods so we can access the predicted denoised outputs --- src/diffusers/schedulers/scheduling_ddim.py | 2 +- src/diffusers/schedulers/scheduling_ddpm.py | 2 +- src/diffusers/schedulers/scheduling_karras_ve.py | 8 ++++++-- src/diffusers/schedulers/scheduling_lms_discrete.py | 2 +- src/diffusers/schedulers/scheduling_utils.py | 6 +++++- 5 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index a5369b1603c6..fe6df9c1a7d5 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -261,7 +261,7 @@ def step( if not return_dict: return (prev_sample,) - return SchedulerOutput(prev_sample=prev_sample) + return SchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) def add_noise( self, diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index d008b84da6e7..f8f114240b78 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -242,7 +242,7 @@ def step( if not return_dict: return (pred_prev_sample,) - return SchedulerOutput(prev_sample=pred_prev_sample) + return SchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) def add_noise( self, diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py index caf7625fb683..07bef58426f7 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve.py +++ b/src/diffusers/schedulers/scheduling_karras_ve.py @@ -35,10 +35,14 @@ class KarrasVeOutput(BaseOutput): denoising loop. derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Derivative of predicted original image sample (x_0). + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.FloatTensor derivative: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None class KarrasVeScheduler(SchedulerMixin, ConfigMixin): @@ -170,7 +174,7 @@ def step( if not return_dict: return (sample_prev, derivative) - return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative) + return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample) def step_correct( self, @@ -205,7 +209,7 @@ def step_correct( if not return_dict: return (sample_prev, derivative) - return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative) + return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample) def add_noise(self, original_samples, noise, timesteps): raise NotImplementedError() diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 5857ae70a856..caf97cf3ddca 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -175,7 +175,7 @@ def step( if not return_dict: return (prev_sample,) - return SchedulerOutput(prev_sample=prev_sample) + return SchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) def add_noise( self, diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index f2bcd73acf32..a17f06b8a04b 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Union +from typing import Union, Optional import numpy as np import torch @@ -32,9 +32,13 @@ class SchedulerOutput(BaseOutput): prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None class SchedulerMixin: From 79a6a9a22d9fe6a23fdd58c854cf7f15e8adb49a Mon Sep 17 00:00:00 2001 From: Jonathan Whitaker Date: Thu, 22 Sep 2022 12:01:49 +0000 Subject: [PATCH 2/6] Gave DDPMScheduler, DDIMScheduler and LMSDiscreteScheduler their own output dataclasses so the default SchedulerOutput in scheduling_utils does not need pred_original_sample as an optional extra --- src/diffusers/schedulers/scheduling_ddim.py | 32 +++++++++++++++---- src/diffusers/schedulers/scheduling_ddpm.py | 32 +++++++++++++++---- .../schedulers/scheduling_karras_ve.py | 12 ++++--- .../schedulers/scheduling_lms_discrete.py | 32 +++++++++++++++---- src/diffusers/schedulers/scheduling_utils.py | 6 +--- 5 files changed, 87 insertions(+), 27 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index fe6df9c1a7d5..2e320620c027 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -23,7 +23,27 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import SchedulerMixin +from dataclasses import dataclass +from ..utils import BaseOutput + + +@dataclass +class DDIMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): @@ -179,7 +199,7 @@ def step( use_clipped_model_output: bool = False, generator=None, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> Union[DDIMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). @@ -192,11 +212,11 @@ def step( eta (`float`): weight of noise for added noise in diffusion step. use_clipped_model_output (`bool`): TODO generator: random number generator. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ @@ -261,7 +281,7 @@ def step( if not return_dict: return (prev_sample,) - return SchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) def add_noise( self, diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index f8f114240b78..22543e27c0e7 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -21,7 +21,27 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import SchedulerMixin +from dataclasses import dataclass +from ..utils import BaseOutput + + +@dataclass +class DDPMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): @@ -177,7 +197,7 @@ def step( predict_epsilon=True, generator=None, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> Union[DDPMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). @@ -190,11 +210,11 @@ def step( predict_epsilon (`bool`): optional flag to use when model predicts the samples directly instead of the noise, epsilon. generator: random number generator. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ @@ -242,7 +262,7 @@ def step( if not return_dict: return (pred_prev_sample,) - return SchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) + return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) def add_noise( self, diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py index 07bef58426f7..6e66bed400f4 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve.py +++ b/src/diffusers/schedulers/scheduling_karras_ve.py @@ -157,7 +157,7 @@ def step( sigma_hat (`float`): TODO sigma_prev (`float`): TODO sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check). Returns: @@ -174,7 +174,9 @@ def step( if not return_dict: return (sample_prev, derivative) - return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample) + return KarrasVeOutput( + prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample + ) def step_correct( self, @@ -196,7 +198,7 @@ def step_correct( sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO derivative (`torch.FloatTensor` or `np.ndarray`): TODO - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class Returns: prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO @@ -209,7 +211,9 @@ def step_correct( if not return_dict: return (sample_prev, derivative) - return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample) + return KarrasVeOutput( + prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample + ) def add_noise(self, original_samples, noise, timesteps): raise NotImplementedError() diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index caf97cf3ddca..38fd6fa14c32 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -20,7 +20,27 @@ from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import SchedulerMixin +from dataclasses import dataclass +from ..utils import BaseOutput + + +@dataclass +class DDPMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): @@ -133,7 +153,7 @@ def step( sample: Union[torch.FloatTensor, np.ndarray], order: int = 4, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> Union[DDPMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). @@ -144,11 +164,11 @@ def step( sample (`torch.FloatTensor` or `np.ndarray`): current instance of sample being created by diffusion process. order: coefficient for multi-step inference. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ @@ -175,7 +195,7 @@ def step( if not return_dict: return (prev_sample,) - return SchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + return DDPMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) def add_noise( self, diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index a17f06b8a04b..f2bcd73acf32 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Union, Optional +from typing import Union import numpy as np import torch @@ -32,13 +32,9 @@ class SchedulerOutput(BaseOutput): prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. - pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - The predicted denoised sample (x_{0}) based on the model output from the current timestep. - `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.FloatTensor - pred_original_sample: Optional[torch.FloatTensor] = None class SchedulerMixin: From efd4be2c023533ceeeb1b7d4acc8998c85383151 Mon Sep 17 00:00:00 2001 From: Jonathan Whitaker Date: Thu, 22 Sep 2022 12:04:57 +0000 Subject: [PATCH 3/6] Reordered library imports to follow standard --- src/diffusers/schedulers/scheduling_ddim.py | 4 ++-- src/diffusers/schedulers/scheduling_ddpm.py | 4 ++-- src/diffusers/schedulers/scheduling_lms_discrete.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 2e320620c027..7a4741687c7b 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -18,14 +18,14 @@ import math import warnings from typing import Optional, Tuple, Union +from dataclasses import dataclass import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin -from dataclasses import dataclass from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin @dataclass diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 22543e27c0e7..73fd327e40c3 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -16,14 +16,14 @@ import math from typing import Optional, Tuple, Union +from dataclasses import dataclass import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin -from dataclasses import dataclass from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin @dataclass diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 38fd6fa14c32..7bd8d0183c86 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Optional, Tuple, Union +from dataclasses import dataclass import numpy as np import torch @@ -20,9 +21,8 @@ from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin -from dataclasses import dataclass from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin @dataclass From 0b42775fa99b740f9586247791bc460c1327f7a8 Mon Sep 17 00:00:00 2001 From: Jonathan Whitaker Date: Thu, 22 Sep 2022 12:07:40 +0000 Subject: [PATCH 4/6] didnt get import order quite right apparently --- src/diffusers/schedulers/scheduling_ddim.py | 2 +- src/diffusers/schedulers/scheduling_ddpm.py | 2 +- src/diffusers/schedulers/scheduling_lms_discrete.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 7a4741687c7b..0613ffd41d0e 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -17,8 +17,8 @@ import math import warnings -from typing import Optional, Tuple, Union from dataclasses import dataclass +from typing import Optional, Tuple, Union import numpy as np import torch diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 73fd327e40c3..440b880385d4 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -15,8 +15,8 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math -from typing import Optional, Tuple, Union from dataclasses import dataclass +from typing import Optional, Tuple, Union import numpy as np import torch diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 7bd8d0183c86..5934dcd885c7 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union from dataclasses import dataclass +from typing import Optional, Tuple, Union import numpy as np import torch From 7f28bc92c3ae0013507bbd6669188e2b3bf87049 Mon Sep 17 00:00:00 2001 From: Jonathan Whitaker Date: Thu, 22 Sep 2022 12:12:12 +0000 Subject: [PATCH 5/6] Forgot to change name of LMSDiscreteSchedulerOutput --- src/diffusers/schedulers/scheduling_lms_discrete.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 5934dcd885c7..697a3bd105eb 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -26,7 +26,7 @@ @dataclass -class DDPMSchedulerOutput(BaseOutput): +class LMSDiscreteSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. @@ -153,7 +153,7 @@ def step( sample: Union[torch.FloatTensor, np.ndarray], order: int = 4, return_dict: bool = True, - ) -> Union[DDPMSchedulerOutput, Tuple]: + ) -> Union[LMSDiscreteSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). @@ -164,11 +164,11 @@ def step( sample (`torch.FloatTensor` or `np.ndarray`): current instance of sample being created by diffusion process. order: coefficient for multi-step inference. - return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class + return_dict (`bool`): option for returning tuple rather than LMSDiscreteSchedulerOutput class Returns: - [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`: - [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ @@ -195,7 +195,7 @@ def step( if not return_dict: return (prev_sample,) - return DDPMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) def add_noise( self, From 2f88b10414ae3a5b31d00fbcad921c297c56cbe2 Mon Sep 17 00:00:00 2001 From: Jonathan Whitaker Date: Thu, 22 Sep 2022 12:22:56 +0000 Subject: [PATCH 6/6] Aha, needed some extra libs for make style to fully work --- src/diffusers/schedulers/scheduling_lms_discrete.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 697a3bd105eb..1dd6dbda1e19 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -168,8 +168,8 @@ def step( Returns: [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] or `tuple`: - [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. + [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. """ sigma = self.sigmas[timestep]