Skip to content

Commit f196e28

Browse files
[Sigmas] Keep sigmas on CPU
1 parent 236eaa2 commit f196e28

15 files changed

+30
-30
lines changed

src/diffusers/schedulers/scheduling_consistency_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def __init__(
9898
self.custom_timesteps = False
9999
self.is_scale_input_called = False
100100
self._step_index = None
101-
self.sigmas.to('cpu') # to avoid too much CPU/GPU communication
101+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
102102

103103
def index_for_timestep(self, timestep, schedule_timesteps=None):
104104
if schedule_timesteps is None:
@@ -231,7 +231,7 @@ def set_timesteps(
231231
self.timesteps = torch.from_numpy(timesteps).to(device=device)
232232

233233
self._step_index = None
234-
self.sigmas.to('cpu') # to avoid too much CPU/GPU communication
234+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
235235

236236
# Modified _convert_to_karras implementation that takes in ramp as argument
237237
def _convert_to_karras(self, ramp):

src/diffusers/schedulers/scheduling_deis_multistep.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def __init__(
188188
self.model_outputs = [None] * solver_order
189189
self.lower_order_nums = 0
190190
self._step_index = None
191-
self.sigmas.to('cpu') # to avoid too much CPU/GPU communication
191+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
192192

193193
@property
194194
def step_index(self):
@@ -256,7 +256,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
256256

257257
# add an index counter for schedulers that allow duplicated timesteps
258258
self._step_index = None
259-
self.sigmas.to('cpu') # to avoid too much CPU/GPU communication
259+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
260260

261261
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
262262
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def __init__(
218218
self.model_outputs = [None] * solver_order
219219
self.lower_order_nums = 0
220220
self._step_index = None
221-
self.sigmas.to('cpu') # to avoid too much CPU/GPU communication
221+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
222222

223223
@property
224224
def step_index(self):
@@ -295,7 +295,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
295295

296296
# add an index counter for schedulers that allow duplicated timesteps
297297
self._step_index = None
298-
self.sigmas.to('cpu') # to avoid too much CPU/GPU communication
298+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
299299

300300
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
301301
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:

src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def __init__(
213213
self.model_outputs = [None] * solver_order
214214
self.lower_order_nums = 0
215215
self._step_index = None
216-
self.sigmas.to('cpu') # to avoid too much CPU/GPU communication
216+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
217217
self.use_karras_sigmas = use_karras_sigmas
218218

219219
@property
@@ -294,7 +294,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
294294

295295
# add an index counter for schedulers that allow duplicated timesteps
296296
self._step_index = None
297-
self.sigmas.to('cpu') # to avoid too much CPU/GPU communication
297+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
298298

299299
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
300300
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:

src/diffusers/schedulers/scheduling_dpmsolver_sde.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def __init__(
200200
self.noise_sampler = None
201201
self.noise_sampler_seed = noise_sampler_seed
202202
self._step_index = None
203-
self.sigmas.to('cpu') # to avoid too much CPU/GPU communication
203+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
204204

205205
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
206206
def index_for_timestep(self, timestep, schedule_timesteps=None):
@@ -350,7 +350,7 @@ def set_timesteps(
350350
self.mid_point_sigma = None
351351

352352
self._step_index = None
353-
self.sigmas.to('cpu') # to avoid too much CPU/GPU communication
353+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
354354
self.noise_sampler = None
355355

356356
# for exp beta schedules, such as the one for `pipeline_shap_e.py`

src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def __init__(
201201
self.sample = None
202202
self.order_list = self.get_order_list(num_train_timesteps)
203203
self._step_index = None
204-
self.sigmas.to('cpu') # to avoid too much CPU/GPU communication
204+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
205205

206206
def get_order_list(self, num_inference_steps: int) -> List[int]:
207207
"""
@@ -293,7 +293,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
293293

294294
# add an index counter for schedulers that allow duplicated timesteps
295295
self._step_index = None
296-
self.sigmas.to('cpu') # to avoid too much CPU/GPU communication
296+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
297297

298298
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
299299
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:

src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def __init__(
168168
self.is_scale_input_called = False
169169

170170
self._step_index = None
171-
self.sigmas.to('cpu') # to avoid too much CPU/GPU communication
171+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
172172

173173
@property
174174
def init_noise_sigma(self):
@@ -252,7 +252,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
252252

253253
self.timesteps = torch.from_numpy(timesteps).to(device=device)
254254
self._step_index = None
255-
self.sigmas.to('cpu') # to avoid too much CPU/GPU communication
255+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
256256

257257
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
258258
def _init_step_index(self, timestep):

src/diffusers/schedulers/scheduling_euler_discrete.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def __init__(
239239
self.use_karras_sigmas = use_karras_sigmas
240240

241241
self._step_index = None
242-
self.sigmas.to('cpu') # to avoid too much CPU/GPU communication
242+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
243243

244244
@property
245245
def init_noise_sigma(self):
@@ -344,7 +344,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
344344

345345
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
346346
self._step_index = None
347-
self.sigmas.to('cpu') # to avoid too much CPU/GPU communication
347+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
348348

349349
def _sigma_to_t(self, sigma, log_sigmas):
350350
# get log sigma

src/diffusers/schedulers/scheduling_heun_discrete.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def __init__(
150150
self.use_karras_sigmas = use_karras_sigmas
151151

152152
self._step_index = None
153-
self.sigmas.to('cpu') # to avoid too much CPU/GPU communication
153+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
154154

155155
def index_for_timestep(self, timestep, schedule_timesteps=None):
156156
if schedule_timesteps is None:
@@ -272,7 +272,7 @@ def set_timesteps(
272272
self.dt = None
273273

274274
self._step_index = None
275-
self.sigmas.to('cpu') # to avoid too much CPU/GPU communication
275+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
276276

277277
# (YiYi Notes: keep this for now since we are keeping add_noise function which use index_for_timestep)
278278
# for exp beta schedules, such as the one for `pipeline_shap_e.py`

src/diffusers/schedulers/scheduling_ipndm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(
5656
# running values
5757
self.ets = []
5858
self._step_index = None
59-
self.sigmas.to('cpu') # to avoid too much CPU/GPU communication
59+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
6060

6161
@property
6262
def step_index(self):
@@ -91,7 +91,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
9191

9292
self.ets = []
9393
self._step_index = None
94-
self.sigmas.to('cpu') # to avoid too much CPU/GPU communication
94+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
9595

9696
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
9797
def _init_step_index(self, timestep):

0 commit comments

Comments
 (0)