diff --git a/examples/community/lpw_stable_diffusion.py b/examples/community/lpw_stable_diffusion.py index 56fb903c7106..e912ad5244be 100644 --- a/examples/community/lpw_stable_diffusion.py +++ b/examples/community/lpw_stable_diffusion.py @@ -1,6 +1,6 @@ import inspect import re -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Callable, List, Optional, Union import numpy as np import PIL @@ -8,23 +8,32 @@ from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer -from diffusers import DiffusionPipeline -from diffusers.configuration_utils import FrozenDict -from diffusers.image_processor import VaeImageProcessor -from diffusers.loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin +import diffusers +from diffusers import SchedulerMixin, StableDiffusionPipeline from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker -from diffusers.schedulers import KarrasDiffusionSchedulers -from diffusers.utils import ( - PIL_INTERPOLATION, - deprecate, - is_accelerate_available, - is_accelerate_version, - logging, - randn_tensor, -) - - +from diffusers.utils import logging + + +try: + from diffusers.utils import PIL_INTERPOLATION +except ImportError: + if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PIL_INTERPOLATION = { + "linear": PIL.Image.Resampling.BILINEAR, + "bilinear": PIL.Image.Resampling.BILINEAR, + "bicubic": PIL.Image.Resampling.BICUBIC, + "lanczos": PIL.Image.Resampling.LANCZOS, + "nearest": PIL.Image.Resampling.NEAREST, + } + else: + PIL_INTERPOLATION = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + "nearest": PIL.Image.NEAREST, + } # ------------------------------------------------------------------------------ logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -135,7 +144,7 @@ def multiply_range(start_position, multiplier): return res -def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_length: int): +def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int): r""" Tokenize a list of prompts and return its tokens with weights of each token. @@ -196,7 +205,7 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos def get_unweighted_text_embeddings( - pipe: DiffusionPipeline, + pipe: StableDiffusionPipeline, text_input: torch.Tensor, chunk_length: int, no_boseos_middle: Optional[bool] = True, @@ -236,7 +245,7 @@ def get_unweighted_text_embeddings( def get_weighted_text_embeddings( - pipe: DiffusionPipeline, + pipe: StableDiffusionPipeline, prompt: Union[str, List[str]], uncond_prompt: Optional[Union[str, List[str]]] = None, max_embeddings_multiples: Optional[int] = 3, @@ -252,7 +261,7 @@ def get_weighted_text_embeddings( Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. Args: - pipe (`DiffusionPipeline`): + pipe (`StableDiffusionPipeline`): Pipe to provide access to the tokenizer and the text encoder. prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. @@ -340,7 +349,7 @@ def get_weighted_text_embeddings( pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle, ) - prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=text_embeddings.device) + prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device) if uncond_prompt is not None: uncond_embeddings = get_unweighted_text_embeddings( pipe, @@ -348,7 +357,7 @@ def get_weighted_text_embeddings( pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle, ) - uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=uncond_embeddings.device) + uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device) # assign weights to the prompts and normalize in the sense of mean # TODO: should we normalize by chunk or in a whole (current implementation)? @@ -368,50 +377,30 @@ def get_weighted_text_embeddings( return text_embeddings, None -def preprocess_image(image, batch_size): +def preprocess_image(image): w, h = image.size - w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = np.array(image).astype(np.float32) / 255.0 - image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size) + image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) return 2.0 * image - 1.0 -def preprocess_mask(mask, batch_size, scale_factor=8): - if not isinstance(mask, torch.FloatTensor): - mask = mask.convert("L") - w, h = mask.size - w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 - mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) - mask = np.array(mask).astype(np.float32) / 255.0 - mask = np.tile(mask, (4, 1, 1)) - mask = np.vstack([mask[None]] * batch_size) - mask = 1 - mask # repaint white, keep black - mask = torch.from_numpy(mask) - return mask - - else: - valid_mask_channel_sizes = [1, 3] - # if mask channel is fourth tensor dimension, permute dimensions to pytorch standard (B, C, H, W) - if mask.shape[3] in valid_mask_channel_sizes: - mask = mask.permute(0, 3, 1, 2) - elif mask.shape[1] not in valid_mask_channel_sizes: - raise ValueError( - f"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension," - f" but received mask of shape {tuple(mask.shape)}" - ) - # (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape - mask = mask.mean(dim=1, keepdim=True) - h, w = mask.shape[-2:] - h, w = (x - x % 8 for x in (h, w)) # resize to integer multiple of 8 - mask = torch.nn.functional.interpolate(mask, (h // scale_factor, w // scale_factor)) - return mask +def preprocess_mask(mask, scale_factor=8): + mask = mask.convert("L") + w, h = mask.size + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = 1 - mask # repaint white, keep black + mask = torch.from_numpy(mask) + return mask -class StableDiffusionLongPromptWeightingPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin -): +class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): r""" Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing weighting in prompt. @@ -440,196 +429,66 @@ class StableDiffusionLongPromptWeightingPipeline( Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ - _optional_components = ["safety_checker", "feature_extractor"] - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: KarrasDiffusionSchedulers, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPImageProcessor, - requires_safety_checker: bool = True, - ): - super().__init__() - - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " - "to update the config accordingly as leaving `steps_offset` might led to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file" - ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["steps_offset"] = 1 - scheduler._internal_dict = FrozenDict(new_config) - - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." - " `clip_sample` should be set to False in the configuration file. Please make sure to update the" - " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" - " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" - " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" - ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["clip_sample"] = False - scheduler._internal_dict = FrozenDict(new_config) - - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"): + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: SchedulerMixin, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + requires_safety_checker=requires_safety_checker, ) + self.__init__additional__() - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) + else: - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 - if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: - deprecation_message = ( - "The configuration file of the unet has set the default `sample_size` to smaller than" - " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" - " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" - " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" - " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" - " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" - " in the config might lead to incorrect results in future versions. If you have downloaded this" - " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" - " the `unet/config.json` file" + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: SchedulerMixin, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + ): + super().__init__( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, ) - deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(unet.config) - new_config["sample_size"] = 64 - unet._internal_dict = FrozenDict(new_config) - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.register_to_config( - requires_safety_checker=requires_safety_checker, - ) - - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. - - When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several - steps. This is useful to save some memory and allow larger batch sizes. - """ - self.vae.enable_slicing() - - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to - computing decoding in one step. - """ - self.vae.disable_slicing() - - def enable_vae_tiling(self): - r""" - Enable tiled VAE decoding. - - When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in - several steps. This is useful to save a large amount of memory and to allow the processing of larger images. - """ - self.vae.enable_tiling() - - def disable_vae_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to - computing decoding in one step. - """ - self.vae.disable_tiling() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - Note that offloading happens on a submodule basis. Memory savings are higher than with - `enable_model_cpu_offload`, but performance is lower. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): - from accelerate import cpu_offload - else: - raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") - - device = torch.device(f"cuda:{gpu_id}") - - if self.device.type != "cpu": - self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: - cpu_offload(cpu_offloaded_model, device) - - if self.safety_checker is not None: - cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload - def enable_model_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared - to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` - method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with - `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): - from accelerate import cpu_offload_with_hook - else: - raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") - - device = torch.device(f"cuda:{gpu_id}") - - if self.device.type != "cpu": - self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - hook = None - for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: - _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) - - if self.safety_checker is not None: - _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + self.__init__additional__() - # We'll offload the last model manually. - self.final_offload_hook = hook + def __init__additional__(self): + if not hasattr(self, "vae_scale_factor"): + setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1)) @property - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ - if not hasattr(self.unet, "_hf_hook"): + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( @@ -646,10 +505,8 @@ def _encode_prompt( device, num_images_per_prompt, do_classifier_free_guidance, - negative_prompt=None, - max_embeddings_multiples=3, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt, + max_embeddings_multiples, ): r""" Encodes the prompt into text encoder hidden states. @@ -669,71 +526,47 @@ def _encode_prompt( max_embeddings_multiples (`int`, *optional*, defaults to `3`): The max multiple length of prompt embeddings compared to the max output length of text encoder. """ - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if negative_prompt_embeds is None: - if negative_prompt is None: - negative_prompt = [""] * batch_size - elif isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] * batch_size - if batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - if prompt_embeds is None or negative_prompt_embeds is None: - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - if do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = self.maybe_convert_prompt(negative_prompt, self.tokenizer) - - prompt_embeds1, negative_prompt_embeds1 = get_weighted_text_embeddings( - pipe=self, - prompt=prompt, - uncond_prompt=negative_prompt if do_classifier_free_guidance else None, - max_embeddings_multiples=max_embeddings_multiples, + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + if negative_prompt is None: + negative_prompt = [""] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + if batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." ) - if prompt_embeds is None: - prompt_embeds = prompt_embeds1 - if negative_prompt_embeds is None: - negative_prompt_embeds = negative_prompt_embeds1 - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + text_embeddings, uncond_embeddings = get_weighted_text_embeddings( + pipe=self, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + max_embeddings_multiples=max_embeddings_multiples, + ) + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: - bs_embed, seq_len, _ = negative_prompt_embeds.shape - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + bs_embed, seq_len, _ = uncond_embeddings.shape + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - return prompt_embeds + return text_embeddings - def check_inputs( - self, - prompt, - height, - width, - strength, - callback_steps, - negative_prompt=None, - prompt_embeds=None, - negative_prompt_embeds=None, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + def check_inputs(self, prompt, height, width, strength, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): @@ -742,42 +575,17 @@ def check_inputs( f" {type(callback_steps)}." ) - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - def get_timesteps(self, num_inference_steps, strength, device, is_text2img): if is_text2img: return self.scheduler.timesteps.to(device), num_inference_steps else: # get the original timestep using init_timestep - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) - - t_start = max(num_inference_steps - init_timestep, 0) - timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].to(device) return timesteps, num_inference_steps - t_start def run_safety_checker(self, image, device, dtype): @@ -791,7 +599,7 @@ def run_safety_checker(self, image, device, dtype): return image, has_nsfw_concept def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents + latents = 1 / 0.18215 * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 @@ -815,51 +623,43 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs - def prepare_latents( - self, - image, - timestep, - num_images_per_prompt, - batch_size, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents=None, - ): + def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None): if image is None: - batch_size = batch_size * num_images_per_prompt - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) + shape = ( + batch_size, + self.unet.config.in_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents, None, None else: - image = image.to(device=self.device, dtype=dtype) init_latent_dist = self.vae.encode(image).latent_dist init_latents = init_latent_dist.sample(generator=generator) - init_latents = self.vae.config.scaling_factor * init_latents - - # Expand init_latents for batch_size and num_images_per_prompt - init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) + init_latents = 0.18215 * init_latents + init_latents = torch.cat([init_latents] * batch_size, dim=0) init_latents_orig = init_latents + shape = init_latents.shape # add noise to latents using the timesteps - noise = randn_tensor(init_latents.shape, generator=generator, device=self.device, dtype=dtype) - init_latents = self.scheduler.add_noise(init_latents, noise, timestep) - latents = init_latents + if device.type == "mps": + noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + noise = torch.randn(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.add_noise(init_latents, noise, timestep) return latents, init_latents_orig, noise @torch.no_grad() @@ -875,19 +675,15 @@ def __call__( guidance_scale: float = 7.5, strength: float = 0.8, num_images_per_prompt: Optional[int] = 1, - add_predicted_noise: Optional[bool] = False, eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + generator: Optional[torch.Generator] = None, latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, max_embeddings_multiples: Optional[int] = 3, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, is_cancelled_callback: Optional[Callable[[], bool]] = None, callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -927,26 +723,16 @@ def __call__( `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. - add_predicted_noise (`bool`, *optional*, defaults to True): - Use predicted noise instead of random noise when constructing noisy versions of the original image in - the reverse diffusion process eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. max_embeddings_multiples (`int`, *optional*, defaults to `3`): The max multiple length of prompt embeddings compared to the max output length of text encoder. output_type (`str`, *optional*, defaults to `"pil"`): @@ -964,10 +750,6 @@ def __call__( callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Returns: `None` if cancelled by `is_cancelled_callback`, @@ -982,18 +764,10 @@ def __call__( width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, height, width, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds - ) + self.check_inputs(prompt, height, width, strength, callback_steps) # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - + batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -1001,28 +775,26 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt - prompt_embeds = self._encode_prompt( + text_embeddings = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, max_embeddings_multiples, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, ) - dtype = prompt_embeds.dtype + dtype = text_embeddings.dtype # 4. Preprocess image and mask if isinstance(image, PIL.Image.Image): - image = preprocess_image(image, batch_size) + image = preprocess_image(image) if image is not None: image = image.to(device=self.device, dtype=dtype) if isinstance(mask_image, PIL.Image.Image): - mask_image = preprocess_mask(mask_image, batch_size, self.vae_scale_factor) + mask_image = preprocess_mask(mask_image, self.vae_scale_factor) if mask_image is not None: mask = mask_image.to(device=self.device, dtype=dtype) - mask = torch.cat([mask] * num_images_per_prompt) + mask = torch.cat([mask] * batch_size * num_images_per_prompt) else: mask = None @@ -1035,9 +807,7 @@ def __call__( latents, init_latents_orig, noise = self.prepare_latents( image, latent_timestep, - num_images_per_prompt, - batch_size, - self.unet.config.in_channels, + batch_size * num_images_per_prompt, height, width, dtype, @@ -1050,70 +820,43 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - if mask is not None: - # masking - if add_predicted_noise: - init_latents_proper = self.scheduler.add_noise( - init_latents_orig, noise_pred_uncond, torch.tensor([t]) - ) - else: - init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) - latents = (init_latents_proper * mask) + (latents * (1 - mask)) - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if i % callback_steps == 0: - if callback is not None: - callback(i, t, latents) - if is_cancelled_callback is not None and is_cancelled_callback(): - return None - - if output_type == "latent": - image = latents - has_nsfw_concept = None - elif output_type == "pil": - # 9. Post-processing - image = self.decode_latents(latents) - - # 10. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - - # 11. Convert to PIL + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if mask is not None: + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) + if is_cancelled_callback is not None and is_cancelled_callback(): + return None + + # 9. Post-processing + image = self.decode_latents(latents) + + # 10. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) + + # 11. Convert to PIL + if output_type == "pil": image = self.numpy_to_pil(image) - else: - # 9. Post-processing - image = self.decode_latents(latents) - - # 10. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - - # Offload last model to CPU - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.final_offload_hook.offload() if not return_dict: return image, has_nsfw_concept @@ -1130,17 +873,14 @@ def text2img( guidance_scale: float = 7.5, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + generator: Optional[torch.Generator] = None, latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, max_embeddings_multiples: Optional[int] = 3, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, is_cancelled_callback: Optional[Callable[[], bool]] = None, callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function for text-to-image generation. @@ -1168,20 +908,13 @@ def text2img( eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. max_embeddings_multiples (`int`, *optional*, defaults to `3`): The max multiple length of prompt embeddings compared to the max output length of text encoder. output_type (`str`, *optional*, defaults to `"pil"`): @@ -1199,13 +932,7 @@ def text2img( callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). - Returns: - `None` if cancelled by `is_cancelled_callback`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a @@ -1223,15 +950,12 @@ def text2img( eta=eta, generator=generator, latents=latents, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, max_embeddings_multiples=max_embeddings_multiples, output_type=output_type, return_dict=return_dict, callback=callback, is_cancelled_callback=is_cancelled_callback, callback_steps=callback_steps, - cross_attention_kwargs=cross_attention_kwargs, ) def img2img( @@ -1244,16 +968,13 @@ def img2img( guidance_scale: Optional[float] = 7.5, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, + generator: Optional[torch.Generator] = None, max_embeddings_multiples: Optional[int] = 3, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, is_cancelled_callback: Optional[Callable[[], bool]] = None, callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function for image-to-image generation. @@ -1286,16 +1007,9 @@ def img2img( eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. max_embeddings_multiples (`int`, *optional*, defaults to `3`): The max multiple length of prompt embeddings compared to the max output length of text encoder. output_type (`str`, *optional*, defaults to `"pil"`): @@ -1313,13 +1027,8 @@ def img2img( callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). - Returns: - `None` if cancelled by `is_cancelled_callback`, + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" @@ -1335,15 +1044,12 @@ def img2img( num_images_per_prompt=num_images_per_prompt, eta=eta, generator=generator, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, max_embeddings_multiples=max_embeddings_multiples, output_type=output_type, return_dict=return_dict, callback=callback, is_cancelled_callback=is_cancelled_callback, callback_steps=callback_steps, - cross_attention_kwargs=cross_attention_kwargs, ) def inpaint( @@ -1356,18 +1062,14 @@ def inpaint( num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, num_images_per_prompt: Optional[int] = 1, - add_predicted_noise: Optional[bool] = False, eta: Optional[float] = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, + generator: Optional[torch.Generator] = None, max_embeddings_multiples: Optional[int] = 3, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, is_cancelled_callback: Optional[Callable[[], bool]] = None, callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function for inpaint. @@ -1401,22 +1103,12 @@ def inpaint( usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. - add_predicted_noise (`bool`, *optional*, defaults to True): - Use predicted noise instead of random noise when constructing noisy versions of the original image in - the reverse diffusion process eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. max_embeddings_multiples (`int`, *optional*, defaults to `3`): The max multiple length of prompt embeddings compared to the max output length of text encoder. output_type (`str`, *optional*, defaults to `"pil"`): @@ -1434,13 +1126,8 @@ def inpaint( callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). - Returns: - `None` if cancelled by `is_cancelled_callback`, + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" @@ -1455,16 +1142,12 @@ def inpaint( guidance_scale=guidance_scale, strength=strength, num_images_per_prompt=num_images_per_prompt, - add_predicted_noise=add_predicted_noise, eta=eta, generator=generator, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, max_embeddings_multiples=max_embeddings_multiples, output_type=output_type, return_dict=return_dict, callback=callback, is_cancelled_callback=is_cancelled_callback, callback_steps=callback_steps, - cross_attention_kwargs=cross_attention_kwargs, )