Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,7 @@ def get_timesteps(self, num_inference_steps, strength, device):

t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
self.scheduler._step_index_init = t_start * self.scheduler.order
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be a very general function call (e.g. one that every scheduler has).

We could maybe call it self.scheduler.set_begin_index(...) and then make sure that every scheduler has such a function.


return timesteps, num_inference_steps - t_start

Expand Down
51 changes: 33 additions & 18 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,13 +215,21 @@ def __init__(
self.model_outputs = [None] * solver_order
self.lower_order_nums = 0
self._step_index = None
self._step_index_init = None

@property
def step_index(self):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return self._step_index

@property
def step_index_init(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a setter method here as well (https://stackoverflow.com/questions/2627002/whats-the-pythonic-way-to-use-getters-and-setters) that can be set from all pipelines if necessary

Copy link
Collaborator Author

@yiyixuxu yiyixuxu Jan 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok but can i ask why do we need both a setter method and this set_begin_index method? they are doing exactly the same thing, no?
https://github.com/huggingface/diffusers/pull/5746/files#r1463161680

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok, I think we had a bit of a misunderstanding here 😅 We don't need / we shouldn't add it if it doesn't have to be used. I was under the impression that we have to use a setter method. But it seems like we don't need it after all!

"""
the first step_index for denoising loop.
"""
return self._step_index_init

def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
"""
Expand Down Expand Up @@ -760,23 +768,28 @@ def multistep_dpm_solver_third_order_update(
return x_t

def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)

index_candidates = (self.timesteps == timestep).nonzero()

if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()

if self.step_index_init is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)

index_candidates = (self.timesteps == timestep).nonzero()

if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()

self._step_index = step_index
self._step_index_init = step_index
self._step_index = step_index
else:
self._step_index = self.step_index_init

def step(
self,
Expand Down Expand Up @@ -884,8 +897,10 @@ def add_noise(
else:
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)

step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
if self.step_index_init is None:
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
else:
step_indices = [self.step_index_init] * timesteps.shape[0]

sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
Expand Down