Skip to content

Commit 2e1278f

Browse files
committed
Unify offset configuration in DDIM and PNDM schedulers
1 parent 25a51b6 commit 2e1278f

File tree

6 files changed

+54
-45
lines changed

6 files changed

+54
-45
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -217,12 +217,7 @@ def __call__(
217217
latents = latents.to(self.device)
218218

219219
# set timesteps
220-
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
221-
extra_set_kwargs = {}
222-
if accepts_offset:
223-
extra_set_kwargs["offset"] = 1
224-
225-
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
220+
self.scheduler.set_timesteps(num_inference_steps)
226221

227222
# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
228223
if isinstance(self.scheduler, LMSDiscreteScheduler):

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -169,14 +169,7 @@ def __call__(
169169
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
170170

171171
# set timesteps
172-
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
173-
extra_set_kwargs = {}
174-
offset = 0
175-
if accepts_offset:
176-
offset = 1
177-
extra_set_kwargs["offset"] = 1
178-
179-
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
172+
self.scheduler.set_timesteps(num_inference_steps)
180173

181174
if not isinstance(init_image, torch.FloatTensor):
182175
init_image = preprocess(init_image)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -192,14 +192,7 @@ def __call__(
192192
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
193193

194194
# set timesteps
195-
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
196-
extra_set_kwargs = {}
197-
offset = 0
198-
if accepts_offset:
199-
offset = 1
200-
extra_set_kwargs["offset"] = 1
201-
202-
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
195+
self.scheduler.set_timesteps(num_inference_steps)
203196

204197
# preprocess image
205198
init_image = preprocess_image(init_image).to(self.device)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,7 @@ def __call__(
100100
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
101101

102102
# set timesteps
103-
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
104-
extra_set_kwargs = {}
105-
if accepts_offset:
106-
extra_set_kwargs["offset"] = 1
107-
108-
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
103+
self.scheduler.set_timesteps(num_inference_steps)
109104

110105
# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
111106
if isinstance(self.scheduler, LMSDiscreteScheduler):

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import math
1919
from typing import Optional, Tuple, Union
20+
import warnings
2021

2122
import numpy as np
2223
import torch
@@ -78,7 +79,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
7879
clip_sample (`bool`, default `True`):
7980
option to clip predicted sample between -1 and 1 for numerical stability.
8081
set_alpha_to_one (`bool`, default `True`):
81-
if alpha for final step is 1 or the final alpha of the "non-previous" one.
82+
each diffusion step uses the value of alphas product at that step and at the previous one.
83+
For the final step there is no previous alpha. When this option is `True` the previous alpha
84+
product is fixed to `1`, otherwise it uses the value of alpha at step 0.
85+
steps_offset (`int`, default `0`):
86+
an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`,
87+
to make the last step use step 0 for the previous alpha product.
8288
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
8389
8490
"""
@@ -94,6 +100,7 @@ def __init__(
94100
timestep_values: Optional[np.ndarray] = None,
95101
clip_sample: bool = True,
96102
set_alpha_to_one: bool = True,
103+
steps_offset: int = 0,
97104
tensor_format: str = "pt",
98105
):
99106
if trained_betas is not None:
@@ -112,10 +119,6 @@ def __init__(
112119
self.alphas = 1.0 - self.betas
113120
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
114121

115-
# At every step in ddim, we are looking into the previous alphas_cumprod
116-
# For the final step, there is no previous alphas_cumprod because we are already at 0
117-
# `set_alpha_to_one` decides whether we set this paratemer simply to one or
118-
# whether we use the final alpha of the "non-previous" one.
119122
self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
120123

121124
# setable values
@@ -135,15 +138,25 @@ def _get_variance(self, timestep, prev_timestep):
135138

136139
return variance
137140

138-
def set_timesteps(self, num_inference_steps: int, offset: int = 0):
141+
def set_timesteps(self, num_inference_steps: int, **kwargs):
139142
"""
140143
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
141144
142145
Args:
143146
num_inference_steps (`int`):
144147
the number of diffusion steps used when generating samples with a pre-trained model.
145-
offset (`int`): TODO
146148
"""
149+
150+
offset = self.config.steps_offset
151+
152+
if "offset" in kwargs:
153+
warnings.warn(
154+
"`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0."
155+
" Please pass `steps_offset` to `__init__` instead."
156+
)
157+
158+
offset = kwargs["offset"]
159+
147160
self.num_inference_steps = num_inference_steps
148161
self.timesteps = np.arange(
149162
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps

src/diffusers/schedulers/scheduling_pndm.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import math
1818
from typing import Optional, Tuple, Union
19+
import warnings
1920

2021
import numpy as np
2122
import torch
@@ -73,10 +74,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
7374
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
7475
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
7576
trained_betas (`np.ndarray`, optional): TODO
76-
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays
7777
skip_prk_steps (`bool`):
7878
allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
7979
before plms steps; defaults to `False`.
80+
set_alpha_to_one (`bool`, default `True`):
81+
each diffusion step uses the value of alphas product at that step and at the previous one.
82+
For the final step there is no previous alpha. When this option is `True` the previous alpha
83+
product is fixed to `1`, otherwise it uses the value of alpha at step 0.
84+
steps_offset (`int`, default `0`):
85+
an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`,
86+
to make the last step use step 0 for the previous alpha product.
87+
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays
8088
8189
"""
8290

@@ -88,8 +96,10 @@ def __init__(
8896
beta_end: float = 0.02,
8997
beta_schedule: str = "linear",
9098
trained_betas: Optional[np.ndarray] = None,
91-
tensor_format: str = "pt",
9299
skip_prk_steps: bool = False,
100+
set_alpha_to_one: bool = True,
101+
steps_offset: int = 0,
102+
tensor_format: str = "pt",
93103
):
94104
if trained_betas is not None:
95105
self.betas = np.asarray(trained_betas)
@@ -107,6 +117,8 @@ def __init__(
107117
self.alphas = 1.0 - self.betas
108118
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
109119

120+
self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
121+
110122
self.one = np.array(1.0)
111123

112124
# For now we only support F-PNDM, i.e. the runge-kutta method
@@ -123,29 +135,37 @@ def __init__(
123135
# setable values
124136
self.num_inference_steps = None
125137
self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
126-
self._offset = 0
127138
self.prk_timesteps = None
128139
self.plms_timesteps = None
129140
self.timesteps = None
130141

131142
self.tensor_format = tensor_format
132143
self.set_format(tensor_format=tensor_format)
133144

134-
def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.FloatTensor:
145+
def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor:
135146
"""
136147
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
137148
138149
Args:
139150
num_inference_steps (`int`):
140151
the number of diffusion steps used when generating samples with a pre-trained model.
141-
offset (`int`): TODO
142152
"""
153+
154+
offset = self.config.steps_offset
155+
156+
if "offset" in kwargs:
157+
warnings.warn(
158+
"`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0."
159+
" Please pass `steps_offset` to `__init__` instead."
160+
)
161+
162+
offset = kwargs["offset"]
163+
143164
self.num_inference_steps = num_inference_steps
144165
self._timesteps = list(
145166
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
146167
)
147-
self._offset = offset
148-
self._timesteps = np.array([t + self._offset for t in self._timesteps])
168+
self._timesteps = np.array(self._timesteps) + offset
149169

150170
if self.config.skip_prk_steps:
151171
# for some models like stable diffusion the prk steps can/should be skipped to
@@ -322,7 +342,7 @@ def step_plms(
322342

323343
return SchedulerOutput(prev_sample=prev_sample)
324344

325-
def _get_prev_sample(self, sample, timestep, timestep_prev, model_output):
345+
def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
326346
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
327347
# this function computes x_(t−δ) using the formula of (9)
328348
# Note that x_t needs to be added to both sides of the equation
@@ -335,8 +355,8 @@ def _get_prev_sample(self, sample, timestep, timestep_prev, model_output):
335355
# sample -> x_t
336356
# model_output -> e_θ(x_t, t)
337357
# prev_sample -> x_(t−δ)
338-
alpha_prod_t = self.alphas_cumprod[timestep + 1 - self._offset]
339-
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - self._offset]
358+
alpha_prod_t = self.alphas_cumprod[timestep]
359+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
340360
beta_prod_t = 1 - alpha_prod_t
341361
beta_prod_t_prev = 1 - alpha_prod_t_prev
342362

0 commit comments

Comments
 (0)