-
Notifications
You must be signed in to change notification settings - Fork 6.6k
[don't merge]Img2Img: timestep mismatch in scheduler with duplicated first timesteps #5746
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -215,13 +215,21 @@ def __init__( | |
| self.model_outputs = [None] * solver_order | ||
| self.lower_order_nums = 0 | ||
| self._step_index = None | ||
| self._step_index_init = None | ||
|
|
||
| @property | ||
| def step_index(self): | ||
| """ | ||
| The index counter for current timestep. It will increae 1 after each scheduler step. | ||
| """ | ||
| return self._step_index | ||
|
|
||
| @property | ||
| def step_index_init(self): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's add a setter method here as well (https://stackoverflow.com/questions/2627002/whats-the-pythonic-way-to-use-getters-and-setters) that can be set from all pipelines if necessary
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok but can i ask why do we need both a setter method and this
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah ok, I think we had a bit of a misunderstanding here 😅 We don't need / we shouldn't add it if it doesn't have to be used. I was under the impression that we have to use a setter method. But it seems like we don't need it after all! |
||
| """ | ||
| the first step_index for denoising loop. | ||
| """ | ||
| return self._step_index_init | ||
|
|
||
| def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): | ||
| """ | ||
|
|
@@ -760,23 +768,28 @@ def multistep_dpm_solver_third_order_update( | |
| return x_t | ||
|
|
||
| def _init_step_index(self, timestep): | ||
| if isinstance(timestep, torch.Tensor): | ||
| timestep = timestep.to(self.timesteps.device) | ||
|
|
||
| index_candidates = (self.timesteps == timestep).nonzero() | ||
|
|
||
| if len(index_candidates) == 0: | ||
| step_index = len(self.timesteps) - 1 | ||
| # The sigma index that is taken for the **very** first `step` | ||
| # is always the second index (or the last index if there is only 1) | ||
| # This way we can ensure we don't accidentally skip a sigma in | ||
| # case we start in the middle of the denoising schedule (e.g. for image-to-image) | ||
| elif len(index_candidates) > 1: | ||
| step_index = index_candidates[1].item() | ||
| else: | ||
| step_index = index_candidates[0].item() | ||
|
|
||
| if self.step_index_init is None: | ||
| if isinstance(timestep, torch.Tensor): | ||
| timestep = timestep.to(self.timesteps.device) | ||
|
|
||
| index_candidates = (self.timesteps == timestep).nonzero() | ||
|
|
||
| if len(index_candidates) == 0: | ||
| step_index = len(self.timesteps) - 1 | ||
| # The sigma index that is taken for the **very** first `step` | ||
| # is always the second index (or the last index if there is only 1) | ||
| # This way we can ensure we don't accidentally skip a sigma in | ||
| # case we start in the middle of the denoising schedule (e.g. for image-to-image) | ||
| elif len(index_candidates) > 1: | ||
| step_index = index_candidates[1].item() | ||
| else: | ||
| step_index = index_candidates[0].item() | ||
|
|
||
| self._step_index = step_index | ||
| self._step_index_init = step_index | ||
| self._step_index = step_index | ||
| else: | ||
| self._step_index = self.step_index_init | ||
|
|
||
| def step( | ||
| self, | ||
|
|
@@ -884,8 +897,10 @@ def add_noise( | |
| else: | ||
| schedule_timesteps = self.timesteps.to(original_samples.device) | ||
| timesteps = timesteps.to(original_samples.device) | ||
|
|
||
| step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] | ||
| if self.step_index_init is None: | ||
| step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] | ||
| else: | ||
| step_indices = [self.step_index_init] * timesteps.shape[0] | ||
|
|
||
| sigma = sigmas[step_indices].flatten() | ||
| while len(sigma.shape) < len(original_samples.shape): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to be a very general function call (e.g. one that every scheduler has).
We could maybe call it
self.scheduler.set_begin_index(...)and then make sure that every scheduler has such a function.