From bb1232800bd96f5253d31da692f12da30886b91c Mon Sep 17 00:00:00 2001 From: William Berman Date: Thu, 6 Apr 2023 21:45:03 -0700 Subject: [PATCH] ddpm custom timesteps add custom timesteps test add custom timesteps descending order check docs timesteps -> custom_timesteps can only pass one of num_inference_steps and timesteps --- src/diffusers/schedulers/scheduling_ddpm.py | 79 ++++++++++++++++----- tests/schedulers/test_scheduler_ddpm.py | 56 +++++++++++++++ 2 files changed, 119 insertions(+), 16 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index eaaf497f9c1d..2bc34bb8b444 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -162,6 +162,7 @@ def __init__( self.init_noise_sigma = 1.0 # setable values + self.custom_timesteps = False self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) @@ -191,31 +192,62 @@ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = """ return sample - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps( + self, + num_inference_steps: Optional[int] = None, + device: Union[str, torch.device] = None, + timesteps: Optional[List[int]] = None, + ): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. + num_inference_steps (`Optional[int]`): + the number of diffusion steps used when generating samples with a pre-trained model. If passed, then + `timesteps` must be `None`. + device (`str` or `torch.device`, optional): + the device to which the timesteps are moved to. + custom_timesteps (`List[int]`, optional): + custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of equal spacing between timesteps is used. If passed, `num_inference_steps` + must be `None`. + """ + if num_inference_steps is not None and timesteps is not None: + raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") + + if timesteps is not None: + for i in range(1, len(timesteps)): + if timesteps[i] >= timesteps[i - 1]: + raise ValueError("`custom_timesteps` must be in descending order.") + + if timesteps[0] >= self.config.num_train_timesteps: + raise ValueError( + f"`timesteps` must start before `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps}." + ) + + timesteps = np.array(timesteps, dtype=np.int64) + self.custom_timesteps = True + else: + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) - if num_inference_steps > self.config.num_train_timesteps: - raise ValueError( - f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" - f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" - f" maximal {self.config.num_train_timesteps} timesteps." - ) + self.num_inference_steps = num_inference_steps - self.num_inference_steps = num_inference_steps + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + self.custom_timesteps = False - step_ratio = self.config.num_train_timesteps // self.num_inference_steps - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) self.timesteps = torch.from_numpy(timesteps).to(device) def _get_variance(self, t, predicted_variance=None, variance_type=None): - num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps - prev_t = t - self.config.num_train_timesteps // num_inference_steps + prev_t = self.previous_timestep(t) + alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev @@ -314,8 +346,8 @@ def step( """ t = timestep - num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps - prev_t = timestep - self.config.num_train_timesteps // num_inference_steps + + prev_t = self.previous_timestep(t) if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) @@ -428,3 +460,18 @@ def get_velocity( def __len__(self): return self.config.num_train_timesteps + + def previous_timestep(self, timestep): + if self.custom_timesteps: + index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] + if index == self.timesteps.shape[0] - 1: + prev_t = torch.tensor(-1) + else: + prev_t = self.timesteps[index + 1] + else: + num_inference_steps = ( + self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps + ) + prev_t = timestep - self.config.num_train_timesteps // num_inference_steps + + return prev_t diff --git a/tests/schedulers/test_scheduler_ddpm.py b/tests/schedulers/test_scheduler_ddpm.py index b55a39ee2e79..c44ded43e67e 100644 --- a/tests/schedulers/test_scheduler_ddpm.py +++ b/tests/schedulers/test_scheduler_ddpm.py @@ -129,3 +129,59 @@ def test_full_loop_with_v_prediction(self): assert abs(result_sum.item() - 202.0296) < 1e-2 assert abs(result_mean.item() - 0.2631) < 1e-3 + + def test_custom_timesteps(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + timesteps = [100, 87, 50, 1, 0] + + scheduler.set_timesteps(timesteps=timesteps) + + scheduler_timesteps = scheduler.timesteps + + for i, timestep in enumerate(scheduler_timesteps): + if i == len(timesteps) - 1: + expected_prev_t = -1 + else: + expected_prev_t = timesteps[i + 1] + + prev_t = scheduler.previous_timestep(timestep) + prev_t = prev_t.item() + + self.assertEqual(prev_t, expected_prev_t) + + def test_custom_timesteps_increasing_order(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + timesteps = [100, 87, 50, 51, 0] + + with self.assertRaises(ValueError, msg="`custom_timesteps` must be in descending order."): + scheduler.set_timesteps(timesteps=timesteps) + + def test_custom_timesteps_passing_both_num_inference_steps_and_timesteps(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + timesteps = [100, 87, 50, 1, 0] + num_inference_steps = len(timesteps) + + with self.assertRaises(ValueError, msg="Can only pass one of `num_inference_steps` or `custom_timesteps`."): + scheduler.set_timesteps(num_inference_steps=num_inference_steps, timesteps=timesteps) + + def test_custom_timesteps_too_large(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + timesteps = [scheduler.config.num_train_timesteps] + + with self.assertRaises( + ValueError, + msg="`timesteps` must start before `self.config.train_timesteps`: {scheduler.config.num_train_timesteps}}", + ): + scheduler.set_timesteps(timesteps=timesteps)