From c90be7a4e29b4d9eb5dea51d0f04c29e4493009a Mon Sep 17 00:00:00 2001 From: chenjunsong Date: Wed, 29 Nov 2023 15:13:47 +0800 Subject: [PATCH 1/2] adapt PixArtAlphaPipeline for pixart-lcm model --- .../pixart_alpha/pipeline_pixart_alpha.py | 51 ++++++++++++++++++- 1 file changed, 49 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index 32e36aaddc53..84129d558807 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -134,6 +134,51 @@ } +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + class PixArtAlphaPipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using PixArt-Alpha. @@ -625,6 +670,7 @@ def __call__( prompt: Union[str, List[str]] = None, negative_prompt: str = "", num_inference_steps: int = 20, + original_inference_steps: int = None, timesteps: List[int] = None, guidance_scale: float = 4.5, num_images_per_prompt: Optional[int] = 1, @@ -783,8 +829,9 @@ def __call__( prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, original_inference_steps=original_inference_steps + ) # 5. Prepare latents. latent_channels = self.transformer.config.in_channels From a1641a3e0571fa55456b300e9f012f2d97a7cf65 Mon Sep 17 00:00:00 2001 From: chenjunsong Date: Thu, 30 Nov 2023 01:04:03 +0800 Subject: [PATCH 2/2] remove original_inference_steps from __call__ --- .../pipelines/pixart_alpha/pipeline_pixart_alpha.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index 84129d558807..090b66915dd0 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -670,7 +670,6 @@ def __call__( prompt: Union[str, List[str]] = None, negative_prompt: str = "", num_inference_steps: int = 20, - original_inference_steps: int = None, timesteps: List[int] = None, guidance_scale: float = 4.5, num_images_per_prompt: Optional[int] = 1, @@ -829,9 +828,7 @@ def __call__( prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, original_inference_steps=original_inference_steps - ) + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) # 5. Prepare latents. latent_channels = self.transformer.config.in_channels