diff --git a/examples/community/lpw_stable_diffusion.py b/examples/community/lpw_stable_diffusion.py index fdfe5ec200e3..7eb9c2a0a0a9 100644 --- a/examples/community/lpw_stable_diffusion.py +++ b/examples/community/lpw_stable_diffusion.py @@ -5,14 +5,37 @@ import numpy as np import torch +import diffusers import PIL from diffusers import SchedulerMixin, StableDiffusionPipeline from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker -from diffusers.utils import PIL_INTERPOLATION, deprecate, logging +from diffusers.utils import deprecate, logging +from packaging import version from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +try: + from diffusers.utils import PIL_INTERPOLATION +except ImportError: + if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PIL_INTERPOLATION = { + "linear": PIL.Image.Resampling.BILINEAR, + "bilinear": PIL.Image.Resampling.BILINEAR, + "bicubic": PIL.Image.Resampling.BICUBIC, + "lanczos": PIL.Image.Resampling.LANCZOS, + "nearest": PIL.Image.Resampling.NEAREST, + } + else: + PIL_INTERPOLATION = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + "nearest": PIL.Image.NEAREST, + } +# ------------------------------------------------------------------------------ + logger = logging.get_logger(__name__) # pylint: disable=invalid-name re_attention = re.compile( @@ -404,27 +427,75 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: SchedulerMixin, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - requires_safety_checker: bool = True, - ): - super().__init__( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - requires_safety_checker=requires_safety_checker, - ) + if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"): + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: SchedulerMixin, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + ): + super().__init__( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + requires_safety_checker=requires_safety_checker, + ) + self.__init__additional__() + + else: + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: SchedulerMixin, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.__init__additional__() + + def __init__additional__(self): + if not hasattr(self, "vae_scale_factor"): + setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1)) + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device def _encode_prompt( self, @@ -752,37 +823,33 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - if mask is not None: - # masking - init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) - latents = (init_latents_proper * mask) + (latents * (1 - mask)) - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if i % callback_steps == 0: - if callback is not None: - callback(i, t, latents) - if is_cancelled_callback is not None and is_cancelled_callback(): - return None + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if mask is not None: + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) + if is_cancelled_callback is not None and is_cancelled_callback(): + return None # 9. Post-processing image = self.decode_latents(latents) diff --git a/examples/community/lpw_stable_diffusion_onnx.py b/examples/community/lpw_stable_diffusion_onnx.py index 4abdf782072d..22403132d44d 100644 --- a/examples/community/lpw_stable_diffusion_onnx.py +++ b/examples/community/lpw_stable_diffusion_onnx.py @@ -5,14 +5,55 @@ import numpy as np import torch +import diffusers import PIL from diffusers import OnnxStableDiffusionPipeline, SchedulerMixin -from diffusers.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel +from diffusers.onnx_utils import OnnxRuntimeModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput -from diffusers.utils import PIL_INTERPOLATION, deprecate, logging +from diffusers.utils import deprecate, logging +from packaging import version from transformers import CLIPFeatureExtractor, CLIPTokenizer +try: + from diffusers.onnx_utils import ORT_TO_NP_TYPE +except ImportError: + ORT_TO_NP_TYPE = { + "tensor(bool)": np.bool_, + "tensor(int8)": np.int8, + "tensor(uint8)": np.uint8, + "tensor(int16)": np.int16, + "tensor(uint16)": np.uint16, + "tensor(int32)": np.int32, + "tensor(uint32)": np.uint32, + "tensor(int64)": np.int64, + "tensor(uint64)": np.uint64, + "tensor(float16)": np.float16, + "tensor(float)": np.float32, + "tensor(double)": np.float64, + } + +try: + from diffusers.utils import PIL_INTERPOLATION +except ImportError: + if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PIL_INTERPOLATION = { + "linear": PIL.Image.Resampling.BILINEAR, + "bilinear": PIL.Image.Resampling.BILINEAR, + "bicubic": PIL.Image.Resampling.BICUBIC, + "lanczos": PIL.Image.Resampling.LANCZOS, + "nearest": PIL.Image.Resampling.NEAREST, + } + else: + PIL_INTERPOLATION = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + "nearest": PIL.Image.NEAREST, + } +# ------------------------------------------------------------------------------ + logger = logging.get_logger(__name__) # pylint: disable=invalid-name re_attention = re.compile( @@ -390,30 +431,59 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) """ + if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"): + + def __init__( + self, + vae_encoder: OnnxRuntimeModel, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: SchedulerMixin, + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + ): + super().__init__( + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + requires_safety_checker=requires_safety_checker, + ) + self.__init__additional__() - def __init__( - self, - vae_encoder: OnnxRuntimeModel, - vae_decoder: OnnxRuntimeModel, - text_encoder: OnnxRuntimeModel, - tokenizer: CLIPTokenizer, - unet: OnnxRuntimeModel, - scheduler: SchedulerMixin, - safety_checker: OnnxRuntimeModel, - feature_extractor: CLIPFeatureExtractor, - requires_safety_checker: bool = True, - ): - super().__init__( - vae_encoder=vae_encoder, - vae_decoder=vae_decoder, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - requires_safety_checker=requires_safety_checker, - ) + else: + + def __init__( + self, + vae_encoder: OnnxRuntimeModel, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: SchedulerMixin, + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__( + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.__init__additional__() + + def __init__additional__(self): self.unet_in_channels = 4 self.vae_scale_factor = 8 @@ -741,49 +811,47 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) - latent_model_input = latent_model_input.numpy() - - # predict the noise residual - noise_pred = self.unet( - sample=latent_model_input, - timestep=np.array([t], dtype=timestep_dtype), - encoder_hidden_states=text_embeddings, - ) - noise_pred = noise_pred[0] + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) + latent_model_input = latent_model_input.numpy() + + # predict the noise residual + noise_pred = self.unet( + sample=latent_model_input, + timestep=np.array([t], dtype=timestep_dtype), + encoder_hidden_states=text_embeddings, + ) + noise_pred = noise_pred[0] - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ) + latents = scheduler_output.prev_sample.numpy() + + if mask is not None: + # masking + init_latents_proper = self.scheduler.add_noise( + torch.from_numpy(init_latents_orig), + torch.from_numpy(noise), + t, + ).numpy() + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) + if is_cancelled_callback is not None and is_cancelled_callback(): + return None - # compute the previous noisy sample x_t -> x_t-1 - scheduler_output = self.scheduler.step( - torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs - ) - latents = scheduler_output.prev_sample.numpy() - - if mask is not None: - # masking - init_latents_proper = self.scheduler.add_noise( - torch.from_numpy(init_latents_orig), - torch.from_numpy(noise), - t, - ).numpy() - latents = (init_latents_proper * mask) + (latents * (1 - mask)) - - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if i % callback_steps == 0: - if callback is not None: - callback(i, t, latents) - if is_cancelled_callback is not None and is_cancelled_callback(): - return None # 9. Post-processing image = self.decode_latents(latents)