From e30f4f344c8acb6f54550c1f8ed804767ae370cd Mon Sep 17 00:00:00 2001 From: Tony Lian <1040424979@qq.com> Date: Fri, 1 Dec 2023 10:44:50 -0800 Subject: [PATCH 1/4] LLMGroundedDiffusionPipeline: inherit from DiffusionPipeline and fix peft --- examples/community/llm_grounded_diffusion.py | 896 +++++++++++++++++-- 1 file changed, 816 insertions(+), 80 deletions(-) diff --git a/examples/community/llm_grounded_diffusion.py b/examples/community/llm_grounded_diffusion.py index ea2aad1a5674..3dccc03dfb43 100644 --- a/examples/community/llm_grounded_diffusion.py +++ b/examples/community/llm_grounded_diffusion.py @@ -16,6 +16,7 @@ import ast import gc +import inspect import math import warnings from collections.abc import Iterable @@ -23,16 +24,29 @@ import torch import torch.nn.functional as F +from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from diffusers.configuration_utils import FrozenDict +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models.attention import Attention, GatedSelfAttentionDense from diffusers.models.attention_processor import AttnProcessor2_0 -from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.pipelines import DiffusionPipeline from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import KarrasDiffusionSchedulers -from diffusers.utils import logging, replace_example_docstring +from diffusers.utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.torch_utils import randn_tensor EXAMPLE_DOC_STRING = """ @@ -96,7 +110,12 @@ # All keys in Stable Diffusion models: [('down', 0, 0, 0), ('down', 0, 1, 0), ('down', 1, 0, 0), ('down', 1, 1, 0), ('down', 2, 0, 0), ('down', 2, 1, 0), ('mid', 0, 0, 0), ('up', 1, 0, 0), ('up', 1, 1, 0), ('up', 1, 2, 0), ('up', 2, 0, 0), ('up', 2, 1, 0), ('up', 2, 2, 0), ('up', 3, 0, 0), ('up', 3, 1, 0), ('up', 3, 2, 0)] # Note that the first up block is `UpBlock2D` rather than `CrossAttnUpBlock2D` and does not have attention. The last index is always 0 in our case since we have one `BasicTransformerBlock` in each `Transformer2DModel`. -DEFAULT_GUIDANCE_ATTN_KEYS = [("mid", 0, 0, 0), ("up", 1, 0, 0), ("up", 1, 1, 0), ("up", 1, 2, 0)] +DEFAULT_GUIDANCE_ATTN_KEYS = [ + ("mid", 0, 0, 0), + ("up", 1, 0, 0), + ("up", 1, 1, 0), + ("up", 1, 2, 0), +] def convert_attn_keys(key): @@ -109,13 +128,17 @@ def convert_attn_keys(key): return f"{key[0]}_blocks.{key[1]}.attentions.{key[2]}.transformer_blocks.{key[3]}.attn2.processor" -DEFAULT_GUIDANCE_ATTN_KEYS = [convert_attn_keys(key) for key in DEFAULT_GUIDANCE_ATTN_KEYS] +DEFAULT_GUIDANCE_ATTN_KEYS = [ + convert_attn_keys(key) for key in DEFAULT_GUIDANCE_ATTN_KEYS +] def scale_proportion(obj_box, H, W): # Separately rounding box_w and box_h to allow shift invariant box sizes. Otherwise box sizes may change when both coordinates being rounded end with ".5". x_min, y_min = round(obj_box[0] * W), round(obj_box[1] * H) - box_w, box_h = round((obj_box[2] - obj_box[0]) * W), round((obj_box[3] - obj_box[1]) * H) + box_w, box_h = round((obj_box[2] - obj_box[0]) * W), round( + (obj_box[3] - obj_box[1]) * H + ) x_max, y_max = x_min + box_w, y_min + box_h x_min, y_min = max(x_min, 0), max(y_min, 0) @@ -126,7 +149,15 @@ def scale_proportion(obj_box, H, W): # Adapted from the parent class `AttnProcessor2_0` class AttnProcessorWithHook(AttnProcessor2_0): - def __init__(self, attn_processor_key, hidden_size, cross_attention_dim, hook=None, fast_attn=True, enabled=True): + def __init__( + self, + attn_processor_key, + hidden_size, + cross_attention_dim, + hook=None, + fast_attn=True, + enabled=True, + ): super().__init__() self.attn_processor_key = attn_processor_key self.hidden_size = hidden_size @@ -153,27 +184,38 @@ def __call__( if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape ) if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) - query = attn.to_q(hidden_states, scale=scale) + args = () if USE_PEFT_BACKEND else (scale,) + query = attn.to_q(hidden_states, *args) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) - key = attn.to_k(encoder_hidden_states, scale=scale) - value = attn.to_v(encoder_hidden_states, scale=scale) + key = attn.to_k(encoder_hidden_states, *args) + value = attn.to_v(encoder_hidden_states, *args) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads @@ -182,41 +224,63 @@ def __call__( query_batch_dim = attn.head_to_batch_dim(query) key_batch_dim = attn.head_to_batch_dim(key) value_batch_dim = attn.head_to_batch_dim(value) - attention_probs = attn.get_attention_scores(query_batch_dim, key_batch_dim, attention_mask) + attention_probs = attn.get_attention_scores( + query_batch_dim, key_batch_dim, attention_mask + ) if self.hook is not None and self.enabled: # Call the hook with query, key, value, and attention maps - self.hook(self.attn_processor_key, query_batch_dim, key_batch_dim, value_batch_dim, attention_probs) + self.hook( + self.attn_processor_key, + query_batch_dim, + key_batch_dim, + value_batch_dim, + attention_probs, + ) if self.fast_attn: - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + query = query.view(batch_size, -1, attn.heads, + head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, + head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, + head_dim).transpose(1, 2) if attention_mask is not None: # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1] + ) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + ) + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim ) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) else: hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj - hidden_states = attn.to_out[0](hidden_states, scale=scale) + hidden_states = attn.to_out[0](hidden_states, *args) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) if attn.residual_connection: hidden_states = hidden_states + residual @@ -226,7 +290,7 @@ def __call__( return hidden_states -class LLMGroundedDiffusionPipeline(StableDiffusionPipeline): +class LLMGroundedDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin): r""" Pipeline for layout-grounded text-to-image generation using LLM-grounded Diffusion (LMD+): https://arxiv.org/pdf/2305.13655.pdf. @@ -257,6 +321,11 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline): Whether a safety checker is needed for this pipeline. """ + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + objects_text = "Objects: " bg_prompt_text = "Background prompt: " bg_prompt_text_no_trailing_space = bg_prompt_text.rstrip() @@ -275,18 +344,109 @@ def __init__( image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): - super().__init__( - vae, - text_encoder, - tokenizer, - unet, - scheduler, + # This is copied from StableDiffusionPipeline, with hook initizations for LMD+. + super().__init__() + + if ( + hasattr(scheduler.config, "steps_offset") + and scheduler.config.steps_offset != 1 + ): + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate( + "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False + ) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if ( + hasattr(scheduler.config, "clip_sample") + and scheduler.config.clip_sample is True + ): + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate( + "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False + ) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr( + unet.config, "_diffusers_version" + ) and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse( + "0.9.0.dev0" + ) + is_unet_sample_size_less_64 = ( + hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate( + "sample_size<64", "1.0.0", deprecation_message, standard_warn=False + ) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, image_encoder=image_encoder, - requires_safety_checker=requires_safety_checker, ) - + self.vae_scale_factor = 2 ** ( + len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor) + self.register_to_config( + requires_safety_checker=requires_safety_checker) + + # Initialize the attention hooks for LLM-grounded Diffusion self.register_attn_hooks(unet) self._saved_attn = None @@ -345,12 +505,17 @@ def _parse_response_with_negative(cls, text): @classmethod def parse_llm_response(cls, response, canvas_height=512, canvas_width=512): # Infer from spec - gen_boxes, bg_prompt, neg_prompt = cls._parse_response_with_negative(text=response) + gen_boxes, bg_prompt, neg_prompt = cls._parse_response_with_negative( + text=response + ) gen_boxes = sorted(gen_boxes, key=lambda gen_box: gen_box[0]) phrases = [name for name, _ in gen_boxes] - boxes = [cls.convert_box(box, height=canvas_height, width=canvas_width) for _, box in gen_boxes] + boxes = [ + cls.convert_box(box, height=canvas_height, width=canvas_width) + for _, box in gen_boxes + ] return phrases, boxes, bg_prompt, neg_prompt @@ -368,10 +533,13 @@ def check_inputs( phrase_indices=None, ): if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + raise ValueError( + f"`height` and `width` have to be divisible by 8 but are {height} and {width}." + ) if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + callback_steps is not None + and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" @@ -387,10 +555,15 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt is not None and ( + not isinstance(prompt, str) and not isinstance(prompt, list) + ): + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" + ) elif prompt is None and phrase_indices is None: - raise ValueError("If the prompt is None, the phrase_indices cannot be None") + raise ValueError( + "If the prompt is None, the phrase_indices cannot be None") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -419,18 +592,25 @@ def register_attn_hooks(self, unet): for name in unet.attn_processors.keys(): # Only obtain the queries and keys from cross-attention - if name.endswith("attn1.processor") or name.endswith("fuser.attn.processor"): + if name.endswith("attn1.processor") or name.endswith( + "fuser.attn.processor" + ): # Keep the same attn_processors for self-attention (no hooks for self-attention) attn_procs[name] = unet.attn_processors[name] continue - cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + cross_attention_dim = ( + None + if name.endswith("attn1.processor") + else unet.config.cross_attention_dim + ) if name.startswith("mid_block"): hidden_size = unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + hidden_size = list(reversed(unet.config.block_out_channels))[ + block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] @@ -459,7 +639,9 @@ def enable_attn_hook(self, enabled=True): def get_token_map(self, prompt, padding="do_not_pad", verbose=False): """Get a list of mapping: prompt index to str (prompt in a list of token str)""" - fg_prompt_tokens = self.tokenizer([prompt], padding=padding, max_length=77, return_tensors="np") + fg_prompt_tokens = self.tokenizer( + [prompt], padding=padding, max_length=77, return_tensors="np" + ) input_ids = fg_prompt_tokens["input_ids"][0] token_map = [] @@ -473,7 +655,14 @@ def get_token_map(self, prompt, padding="do_not_pad", verbose=False): return token_map - def get_phrase_indices(self, prompt, phrases, token_map=None, add_suffix_if_not_found=False, verbose=False): + def get_phrase_indices( + self, + prompt, + phrases, + token_map=None, + add_suffix_if_not_found=False, + verbose=False, + ): for obj in phrases: # Suffix the prompt with object name for attention guidance if object is not in the prompt, using "|" to separate the prompt and the suffix if obj not in prompt: @@ -481,26 +670,43 @@ def get_phrase_indices(self, prompt, phrases, token_map=None, add_suffix_if_not_ if token_map is None: # We allow using a pre-computed token map. - token_map = self.get_token_map(prompt=prompt, padding="do_not_pad", verbose=verbose) + token_map = self.get_token_map( + prompt=prompt, padding="do_not_pad", verbose=verbose + ) token_map_str = " ".join(token_map) phrase_indices = [] for obj in phrases: - phrase_token_map = self.get_token_map(prompt=obj, padding="do_not_pad", verbose=verbose) + phrase_token_map = self.get_token_map( + prompt=obj, padding="do_not_pad", verbose=verbose + ) # Remove and in substr phrase_token_map = phrase_token_map[1:-1] phrase_token_map_len = len(phrase_token_map) phrase_token_map_str = " ".join(phrase_token_map) if verbose: - logger.info("Full str:", token_map_str, "Substr:", phrase_token_map_str, "Phrase:", phrases) + logger.info( + "Full str:", + token_map_str, + "Substr:", + phrase_token_map_str, + "Phrase:", + phrases, + ) # Count the number of token before substr # The substring comes with a trailing space that needs to be removed by minus one in the index. - obj_first_index = len(token_map_str[: token_map_str.index(phrase_token_map_str) - 1].split(" ")) + obj_first_index = len( + token_map_str[: token_map_str.index(phrase_token_map_str) - 1].split( + " " + ) + ) - obj_position = list(range(obj_first_index, obj_first_index + phrase_token_map_len)) + obj_position = list( + range(obj_first_index, obj_first_index + phrase_token_map_len) + ) phrase_indices.append(obj_position) if add_suffix_if_not_found: @@ -534,7 +740,8 @@ def add_ca_loss_per_attn_map_to_loss( for obj_box in obj_boxes: # x_min, y_min, x_max, y_max = int(obj_box[0] * W), int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H) - x_min, y_min, x_max, y_max = scale_proportion(obj_box, H=H, W=W) + x_min, y_min, x_max, y_max = scale_proportion( + obj_box, H=H, W=W) mask[y_min:y_max, x_min:x_max] = 1 for obj_position in phrase_indices[obj_idx]: @@ -554,14 +761,26 @@ def add_ca_loss_per_attn_map_to_loss( # Take the topk over spatial dimension, and then take the sum over heads dim # The mean is over k_fg and k_bg dimension, so we don't need to sum and divide on our own. - obj_loss += (1 - (ca_map_obj * mask_1d).topk(k=k_fg).values.mean(dim=1)).sum(dim=0) * fg_weight - obj_loss += ((ca_map_obj * (1 - mask_1d)).topk(k=k_bg).values.mean(dim=1)).sum(dim=0) * bg_weight + obj_loss += ( + 1 - (ca_map_obj * mask_1d).topk(k=k_fg).values.mean(dim=1) + ).sum(dim=0) * fg_weight + obj_loss += ( + (ca_map_obj * (1 - mask_1d)).topk(k=k_bg).values.mean(dim=1) + ).sum(dim=0) * bg_weight loss += obj_loss / len(phrase_indices[obj_idx]) return loss - def compute_ca_loss(self, saved_attn, bboxes, phrase_indices, guidance_attn_keys, verbose=False, **kwargs): + def compute_ca_loss( + self, + saved_attn, + bboxes, + phrase_indices, + guidance_attn_keys, + verbose=False, + **kwargs, + ): """ The `saved_attn` is supposed to be passed to `save_attn_to_dict` in `cross_attention_kwargs` prior to computing ths loss. `AttnProcessor` will put attention maps into the `save_attn_to_dict`. @@ -610,13 +829,16 @@ def __call__( negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + generator: Optional[Union[torch.Generator, + List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback: Optional[Callable[[ + int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, @@ -671,6 +893,7 @@ def __call__( negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): @@ -726,14 +949,19 @@ def __call__( if prompt is not None and isinstance(prompt, str): batch_size = 1 if phrase_indices is None: - phrase_indices, prompt = self.get_phrase_indices(prompt, phrases, add_suffix_if_not_found=True) + phrase_indices, prompt = self.get_phrase_indices( + prompt, phrases, add_suffix_if_not_found=True + ) elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) if phrase_indices is None: phrase_indices = [] prompt_parsed = [] for prompt_item in prompt: - phrase_indices_parsed_item, prompt_parsed_item = self.get_phrase_indices( + ( + phrase_indices_parsed_item, + prompt_parsed_item, + ) = self.get_phrase_indices( prompt_item, add_suffix_if_not_found=True ) phrase_indices.append(phrase_indices_parsed_item) @@ -768,6 +996,13 @@ def __call__( if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if ip_adapter_image is not None: + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt + ) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps @@ -799,28 +1034,40 @@ def __call__( if n_objs: # prepare batched input to the PositionNet (boxes, phrases, mask) # Get tokens for phrases from pre-trained CLIPTokenizer - tokenizer_inputs = self.tokenizer(phrases, padding=True, return_tensors="pt").to(device) + tokenizer_inputs = self.tokenizer( + phrases, padding=True, return_tensors="pt" + ).to(device) # For the token, we use the same pre-trained text encoder # to obtain its text feature - _text_embeddings = self.text_encoder(**tokenizer_inputs).pooler_output + _text_embeddings = self.text_encoder( + **tokenizer_inputs).pooler_output # For each entity, described in phrases, is denoted with a bounding box, # we represent the location information as (xmin,ymin,xmax,ymax) - cond_boxes = torch.zeros(max_objs, 4, device=device, dtype=self.text_encoder.dtype) + cond_boxes = torch.zeros( + max_objs, 4, device=device, dtype=self.text_encoder.dtype + ) if n_objs: cond_boxes[:n_objs] = torch.tensor(boxes) text_embeddings = torch.zeros( - max_objs, self.unet.config.cross_attention_dim, device=device, dtype=self.text_encoder.dtype + max_objs, + self.unet.config.cross_attention_dim, + device=device, + dtype=self.text_encoder.dtype, ) if n_objs: text_embeddings[:n_objs] = _text_embeddings # Generate a mask for each object that is entity described by phrases - masks = torch.zeros(max_objs, device=device, dtype=self.text_encoder.dtype) + masks = torch.zeros(max_objs, device=device, + dtype=self.text_encoder.dtype) masks[:n_objs] = 1 repeat_batch = batch_size * num_images_per_prompt - cond_boxes = cond_boxes.unsqueeze(0).expand(repeat_batch, -1, -1).clone() - text_embeddings = text_embeddings.unsqueeze(0).expand(repeat_batch, -1, -1).clone() + cond_boxes = cond_boxes.unsqueeze( + 0).expand(repeat_batch, -1, -1).clone() + text_embeddings = ( + text_embeddings.unsqueeze(0).expand(repeat_batch, -1, -1).clone() + ) masks = masks.unsqueeze(0).expand(repeat_batch, -1).clone() if do_classifier_free_guidance: repeat_batch = repeat_batch * 2 @@ -836,16 +1083,23 @@ def __call__( "masks": masks, } - num_grounding_steps = int(gligen_scheduled_sampling_beta * len(timesteps)) + num_grounding_steps = int( + gligen_scheduled_sampling_beta * len(timesteps)) self.enable_fuser(True) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + # 6.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} if ip_adapter_image is not None else None + ) + loss_attn = torch.tensor(10000.0) # 7. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + num_warmup_steps = len(timesteps) - \ + num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Scheduled sampling @@ -869,8 +1123,13 @@ def __call__( ) # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = ( + torch.cat( + [latents] * 2) if do_classifier_free_guidance else latents + ) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) # predict the noise residual noise_pred = self.unet( @@ -878,26 +1137,38 @@ def __call__( t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs + ).prev_sample # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + + 1) % self.scheduler.order == 0 + ): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + image = self.vae.decode( + latents / self.vae.config.scaling_factor, return_dict=False + )[0] + image, has_nsfw_concept = self.run_safety_checker( + image, device, prompt_embeds.dtype + ) else: image = latents has_nsfw_concept = None @@ -907,7 +1178,9 @@ def __call__( else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + image = self.image_processor.postprocess( + image, output_type=output_type, do_denormalize=do_denormalize + ) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: @@ -916,7 +1189,9 @@ def __call__( if not return_dict: return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=has_nsfw_concept + ) @torch.set_grad_enabled(True) def latent_lmd_guidance( @@ -958,13 +1233,17 @@ def latent_lmd_guidance( self.enable_attn_hook(enabled=True) while ( - loss.item() / loss_scale > loss_threshold and iteration < max_iter and index < guidance_timesteps + loss.item() / loss_scale > loss_threshold + and iteration < max_iter + and index < guidance_timesteps ): self._saved_attn = {} latents.requires_grad_(True) latent_model_input = latents - latent_model_input = scheduler.scale_model_input(latent_model_input, t) + latent_model_input = scheduler.scale_model_input( + latent_model_input, t + ) unet( latent_model_input, @@ -992,11 +1271,14 @@ def latent_lmd_guidance( # This callback allows visualizations. if guidance_callback is not None: - guidance_callback(self, latents, loss, iteration, index) + guidance_callback( + self, latents, loss, iteration, index) self._saved_attn = None - grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents])[0] + grad_cond = torch.autograd.grad( + loss.requires_grad_(True), [latents] + )[0] latents.requires_grad_(False) @@ -1022,3 +1304,457 @@ def latent_lmd_guidance( self.enable_attn_hook(enabled=False) return latents, loss + + # Below are methods copied from StableDiffusionPipeline + # The design choice of not inheriting from StableDiffusionPipeline is discussed here: https://github.com/huggingface/diffusers/pull/5993#issuecomment-1834258517 + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", + deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat( + [prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer( + prompt, padding="longest", return_tensors="pt" + ).input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[ + -1 + ] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1: -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask + ) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + output_hidden_states=True, + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm( + prompt_embeds + ) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to( + dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view( + bs_embed * num_images_per_prompt, seq_len, -1 + ) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt( + uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=prompt_embeds_dtype, device=device + ) + + negative_prompt_embeds = negative_prompt_embeds.repeat( + 1, num_images_per_prompt, 1 + ) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor( + image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave( + num_images_per_prompt, dim=0) + + uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess( + image, output_type="pil" + ) + else: + feature_extractor_input = self.image_processor.numpy_to_pil( + image) + safety_checker_input = self.feature_extractor( + feature_extractor_input, return_tensors="pt" + ).to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to( + dtype) + ) + return image, has_nsfw_concept + + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", + deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size, + num_channels_latents, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor( + shape, generator=generator, device=device, dtype=dtype + ) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stages where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values + that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if not hasattr(self, "unet"): + raise ValueError("The pipeline must have `unet` for using FreeU.") + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + + def disable_freeu(self): + """Disables the FreeU mechanism if enabled.""" + self.unet.disable_freeu() + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + timesteps (`torch.Tensor`): + generate embedding vectors at these timesteps + embedding_dim (`int`, *optional*, defaults to 512): + dimension of the embeddings to generate + dtype: + data type of the generated embeddings + + Returns: + `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps From 284fc825f0fffb1e22516d1f42a51673a95618f4 Mon Sep 17 00:00:00 2001 From: Tony Lian <1040424979@qq.com> Date: Fri, 1 Dec 2023 10:45:24 -0800 Subject: [PATCH 2/4] Use main in the revision in the examples --- examples/community/README.md | 1 + examples/community/llm_grounded_diffusion.py | 1 + 2 files changed, 2 insertions(+) diff --git a/examples/community/README.md b/examples/community/README.md index f6044102641f..37b51c8c4139 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -78,6 +78,7 @@ from diffusers import DiffusionPipeline pipe = DiffusionPipeline.from_pretrained( "longlian/lmd_plus", custom_pipeline="llm_grounded_diffusion", + custom_revision="main", variant="fp16", torch_dtype=torch.float16 ) pipe.enable_model_cpu_offload() diff --git a/examples/community/llm_grounded_diffusion.py b/examples/community/llm_grounded_diffusion.py index 3dccc03dfb43..243a52db3fc8 100644 --- a/examples/community/llm_grounded_diffusion.py +++ b/examples/community/llm_grounded_diffusion.py @@ -58,6 +58,7 @@ >>> pipe = DiffusionPipeline.from_pretrained( ... "longlian/lmd_plus", ... custom_pipeline="llm_grounded_diffusion", + ... custom_revision="main", ... variant="fp16", torch_dtype=torch.float16 ... ) >>> pipe.enable_model_cpu_offload() From 4a6e40efee7c3c23fd3a44f925f186b66763ce1f Mon Sep 17 00:00:00 2001 From: Tony Lian <1040424979@qq.com> Date: Fri, 1 Dec 2023 10:59:25 -0800 Subject: [PATCH 3/4] Add "Copied from" statements in comments --- examples/community/llm_grounded_diffusion.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/examples/community/llm_grounded_diffusion.py b/examples/community/llm_grounded_diffusion.py index 243a52db3fc8..b7e943167e61 100644 --- a/examples/community/llm_grounded_diffusion.py +++ b/examples/community/llm_grounded_diffusion.py @@ -1308,6 +1308,8 @@ def latent_lmd_guidance( # Below are methods copied from StableDiffusionPipeline # The design choice of not inheriting from StableDiffusionPipeline is discussed here: https://github.com/huggingface/diffusers/pull/5993#issuecomment-1834258517 + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to @@ -1315,6 +1317,7 @@ def enable_vae_slicing(self): """ self.vae.enable_slicing() + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to @@ -1322,6 +1325,7 @@ def disable_vae_slicing(self): """ self.vae.disable_slicing() + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to @@ -1330,6 +1334,7 @@ def enable_vae_tiling(self): """ self.vae.enable_tiling() + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to @@ -1337,6 +1342,7 @@ def disable_vae_tiling(self): """ self.vae.disable_tiling() + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, @@ -1371,6 +1377,7 @@ def _encode_prompt( return prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, @@ -1576,6 +1583,7 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt): dtype = next(self.image_encoder.parameters()).dtype @@ -1591,6 +1599,7 @@ def encode_image(self, image, device, num_images_per_prompt): uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None @@ -1611,6 +1620,7 @@ def run_safety_checker(self, image, device, dtype): ) return image, has_nsfw_concept + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", @@ -1623,6 +1633,7 @@ def decode_latents(self, latents): image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. @@ -1644,6 +1655,7 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents( self, batch_size, @@ -1678,6 +1690,7 @@ def prepare_latents( latents = latents * self.scheduler.init_noise_sigma return latents + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. @@ -1700,6 +1713,7 @@ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): raise ValueError("The pipeline must have `unet` for using FreeU.") self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling def disable_freeu(self): """Disables the FreeU mechanism if enabled.""" self.unet.disable_freeu() @@ -1733,14 +1747,17 @@ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32 assert emb.shape == (w.shape[0], embedding_dim) return emb + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale @property def guidance_scale(self): return self._guidance_scale + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_rescale @property def guidance_rescale(self): return self._guidance_rescale + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip @property def clip_skip(self): return self._clip_skip @@ -1748,14 +1765,17 @@ def clip_skip(self): # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs @property def cross_attention_kwargs(self): return self._cross_attention_kwargs + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps @property def num_timesteps(self): return self._num_timesteps From c4d5632e719aad2bdc9d4f8e6f59ae7bbea8bbb9 Mon Sep 17 00:00:00 2001 From: Tony Lian <1040424979@qq.com> Date: Fri, 1 Dec 2023 11:18:32 -0800 Subject: [PATCH 4/4] Fix formatting with ruff --- examples/community/llm_grounded_diffusion.py | 368 +++++-------------- 1 file changed, 100 insertions(+), 268 deletions(-) diff --git a/examples/community/llm_grounded_diffusion.py b/examples/community/llm_grounded_diffusion.py index b7e943167e61..14f4deabcea7 100644 --- a/examples/community/llm_grounded_diffusion.py +++ b/examples/community/llm_grounded_diffusion.py @@ -129,17 +129,13 @@ def convert_attn_keys(key): return f"{key[0]}_blocks.{key[1]}.attentions.{key[2]}.transformer_blocks.{key[3]}.attn2.processor" -DEFAULT_GUIDANCE_ATTN_KEYS = [ - convert_attn_keys(key) for key in DEFAULT_GUIDANCE_ATTN_KEYS -] +DEFAULT_GUIDANCE_ATTN_KEYS = [convert_attn_keys(key) for key in DEFAULT_GUIDANCE_ATTN_KEYS] def scale_proportion(obj_box, H, W): # Separately rounding box_w and box_h to allow shift invariant box sizes. Otherwise box sizes may change when both coordinates being rounded end with ".5". x_min, y_min = round(obj_box[0] * W), round(obj_box[1] * H) - box_w, box_h = round((obj_box[2] - obj_box[0]) * W), round( - (obj_box[3] - obj_box[1]) * H - ) + box_w, box_h = round((obj_box[2] - obj_box[0]) * W), round((obj_box[3] - obj_box[1]) * H) x_max, y_max = x_min + box_w, y_min + box_h x_min, y_min = max(x_min, 0), max(y_min, 0) @@ -185,25 +181,17 @@ def __call__( if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view( - batch_size, channel, height * width - ).transpose(1, 2) + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( - hidden_states.shape - if encoder_hidden_states is None - else encoder_hidden_states.shape + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: - attention_mask = attn.prepare_attention_mask( - attention_mask, sequence_length, batch_size - ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( - 1, 2 - ) + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) args = () if USE_PEFT_BACKEND else (scale,) query = attn.to_q(hidden_states, *args) @@ -211,9 +199,7 @@ def __call__( if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states( - encoder_hidden_states - ) + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states, *args) value = attn.to_v(encoder_hidden_states, *args) @@ -225,9 +211,7 @@ def __call__( query_batch_dim = attn.head_to_batch_dim(query) key_batch_dim = attn.head_to_batch_dim(key) value_batch_dim = attn.head_to_batch_dim(value) - attention_probs = attn.get_attention_scores( - query_batch_dim, key_batch_dim, attention_mask - ) + attention_probs = attn.get_attention_scores(query_batch_dim, key_batch_dim, attention_mask) if self.hook is not None and self.enabled: # Call the hook with query, key, value, and attention maps @@ -240,20 +224,15 @@ def __call__( ) if self.fast_attn: - query = query.view(batch_size, -1, attn.heads, - head_dim).transpose(1, 2) + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, - head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, - head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) if attention_mask is not None: # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view( - batch_size, attn.heads, -1, attention_mask.shape[-1] - ) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 @@ -265,9 +244,7 @@ def __call__( dropout_p=0.0, is_causal=False, ) - hidden_states = hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) else: hidden_states = torch.bmm(attention_probs, value) @@ -279,9 +256,7 @@ def __call__( hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape( - batch_size, channel, height, width - ) + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual @@ -291,7 +266,9 @@ def __call__( return hidden_states -class LLMGroundedDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin): +class LLMGroundedDiffusionPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin +): r""" Pipeline for layout-grounded text-to-image generation using LLM-grounded Diffusion (LMD+): https://arxiv.org/pdf/2305.13655.pdf. @@ -348,10 +325,7 @@ def __init__( # This is copied from StableDiffusionPipeline, with hook initizations for LMD+. super().__init__() - if ( - hasattr(scheduler.config, "steps_offset") - and scheduler.config.steps_offset != 1 - ): + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " @@ -360,17 +334,12 @@ def __init__( " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) - deprecate( - "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False - ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) - if ( - hasattr(scheduler.config, "clip_sample") - and scheduler.config.clip_sample is True - ): + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" @@ -378,9 +347,7 @@ def __init__( " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) - deprecate( - "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False - ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) @@ -401,16 +368,10 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr( - unet.config, "_diffusers_version" - ) and version.parse( + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version - ) < version.parse( - "0.9.0.dev0" - ) - is_unet_sample_size_less_64 = ( - hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 - ) + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -423,9 +384,7 @@ def __init__( " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) - deprecate( - "sample_size<64", "1.0.0", deprecation_message, standard_warn=False - ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) @@ -440,12 +399,9 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** ( - len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor) - self.register_to_config( - requires_safety_checker=requires_safety_checker) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) # Initialize the attention hooks for LLM-grounded Diffusion self.register_attn_hooks(unet) @@ -506,17 +462,12 @@ def _parse_response_with_negative(cls, text): @classmethod def parse_llm_response(cls, response, canvas_height=512, canvas_width=512): # Infer from spec - gen_boxes, bg_prompt, neg_prompt = cls._parse_response_with_negative( - text=response - ) + gen_boxes, bg_prompt, neg_prompt = cls._parse_response_with_negative(text=response) gen_boxes = sorted(gen_boxes, key=lambda gen_box: gen_box[0]) phrases = [name for name, _ in gen_boxes] - boxes = [ - cls.convert_box(box, height=canvas_height, width=canvas_width) - for _, box in gen_boxes - ] + boxes = [cls.convert_box(box, height=canvas_height, width=canvas_width) for _, box in gen_boxes] return phrases, boxes, bg_prompt, neg_prompt @@ -534,13 +485,10 @@ def check_inputs( phrase_indices=None, ): if height % 8 != 0 or width % 8 != 0: - raise ValueError( - f"`height` and `width` have to be divisible by 8 but are {height} and {width}." - ) + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( - callback_steps is not None - and (not isinstance(callback_steps, int) or callback_steps <= 0) + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" @@ -556,15 +504,10 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and ( - not isinstance(prompt, str) and not isinstance(prompt, list) - ): - raise ValueError( - f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" - ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif prompt is None and phrase_indices is None: - raise ValueError( - "If the prompt is None, the phrase_indices cannot be None") + raise ValueError("If the prompt is None, the phrase_indices cannot be None") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -593,25 +536,18 @@ def register_attn_hooks(self, unet): for name in unet.attn_processors.keys(): # Only obtain the queries and keys from cross-attention - if name.endswith("attn1.processor") or name.endswith( - "fuser.attn.processor" - ): + if name.endswith("attn1.processor") or name.endswith("fuser.attn.processor"): # Keep the same attn_processors for self-attention (no hooks for self-attention) attn_procs[name] = unet.attn_processors[name] continue - cross_attention_dim = ( - None - if name.endswith("attn1.processor") - else unet.config.cross_attention_dim - ) + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): hidden_size = unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(unet.config.block_out_channels))[ - block_id] + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] @@ -640,9 +576,7 @@ def enable_attn_hook(self, enabled=True): def get_token_map(self, prompt, padding="do_not_pad", verbose=False): """Get a list of mapping: prompt index to str (prompt in a list of token str)""" - fg_prompt_tokens = self.tokenizer( - [prompt], padding=padding, max_length=77, return_tensors="np" - ) + fg_prompt_tokens = self.tokenizer([prompt], padding=padding, max_length=77, return_tensors="np") input_ids = fg_prompt_tokens["input_ids"][0] token_map = [] @@ -671,17 +605,13 @@ def get_phrase_indices( if token_map is None: # We allow using a pre-computed token map. - token_map = self.get_token_map( - prompt=prompt, padding="do_not_pad", verbose=verbose - ) + token_map = self.get_token_map(prompt=prompt, padding="do_not_pad", verbose=verbose) token_map_str = " ".join(token_map) phrase_indices = [] for obj in phrases: - phrase_token_map = self.get_token_map( - prompt=obj, padding="do_not_pad", verbose=verbose - ) + phrase_token_map = self.get_token_map(prompt=obj, padding="do_not_pad", verbose=verbose) # Remove and in substr phrase_token_map = phrase_token_map[1:-1] phrase_token_map_len = len(phrase_token_map) @@ -699,15 +629,9 @@ def get_phrase_indices( # Count the number of token before substr # The substring comes with a trailing space that needs to be removed by minus one in the index. - obj_first_index = len( - token_map_str[: token_map_str.index(phrase_token_map_str) - 1].split( - " " - ) - ) + obj_first_index = len(token_map_str[: token_map_str.index(phrase_token_map_str) - 1].split(" ")) - obj_position = list( - range(obj_first_index, obj_first_index + phrase_token_map_len) - ) + obj_position = list(range(obj_first_index, obj_first_index + phrase_token_map_len)) phrase_indices.append(obj_position) if add_suffix_if_not_found: @@ -741,8 +665,7 @@ def add_ca_loss_per_attn_map_to_loss( for obj_box in obj_boxes: # x_min, y_min, x_max, y_max = int(obj_box[0] * W), int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H) - x_min, y_min, x_max, y_max = scale_proportion( - obj_box, H=H, W=W) + x_min, y_min, x_max, y_max = scale_proportion(obj_box, H=H, W=W) mask[y_min:y_max, x_min:x_max] = 1 for obj_position in phrase_indices[obj_idx]: @@ -762,12 +685,8 @@ def add_ca_loss_per_attn_map_to_loss( # Take the topk over spatial dimension, and then take the sum over heads dim # The mean is over k_fg and k_bg dimension, so we don't need to sum and divide on our own. - obj_loss += ( - 1 - (ca_map_obj * mask_1d).topk(k=k_fg).values.mean(dim=1) - ).sum(dim=0) * fg_weight - obj_loss += ( - (ca_map_obj * (1 - mask_1d)).topk(k=k_bg).values.mean(dim=1) - ).sum(dim=0) * bg_weight + obj_loss += (1 - (ca_map_obj * mask_1d).topk(k=k_fg).values.mean(dim=1)).sum(dim=0) * fg_weight + obj_loss += ((ca_map_obj * (1 - mask_1d)).topk(k=k_bg).values.mean(dim=1)).sum(dim=0) * bg_weight loss += obj_loss / len(phrase_indices[obj_idx]) @@ -830,16 +749,14 @@ def __call__( negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, - generator: Optional[Union[torch.Generator, - List[torch.Generator]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[ - int, int, torch.FloatTensor], None]] = None, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, @@ -950,9 +867,7 @@ def __call__( if prompt is not None and isinstance(prompt, str): batch_size = 1 if phrase_indices is None: - phrase_indices, prompt = self.get_phrase_indices( - prompt, phrases, add_suffix_if_not_found=True - ) + phrase_indices, prompt = self.get_phrase_indices(prompt, phrases, add_suffix_if_not_found=True) elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) if phrase_indices is None: @@ -962,9 +877,7 @@ def __call__( ( phrase_indices_parsed_item, prompt_parsed_item, - ) = self.get_phrase_indices( - prompt_item, add_suffix_if_not_found=True - ) + ) = self.get_phrase_indices(prompt_item, add_suffix_if_not_found=True) phrase_indices.append(phrase_indices_parsed_item) prompt_parsed.append(prompt_parsed_item) prompt = prompt_parsed @@ -998,9 +911,7 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_images_per_prompt - ) + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) @@ -1035,19 +946,14 @@ def __call__( if n_objs: # prepare batched input to the PositionNet (boxes, phrases, mask) # Get tokens for phrases from pre-trained CLIPTokenizer - tokenizer_inputs = self.tokenizer( - phrases, padding=True, return_tensors="pt" - ).to(device) + tokenizer_inputs = self.tokenizer(phrases, padding=True, return_tensors="pt").to(device) # For the token, we use the same pre-trained text encoder # to obtain its text feature - _text_embeddings = self.text_encoder( - **tokenizer_inputs).pooler_output + _text_embeddings = self.text_encoder(**tokenizer_inputs).pooler_output # For each entity, described in phrases, is denoted with a bounding box, # we represent the location information as (xmin,ymin,xmax,ymax) - cond_boxes = torch.zeros( - max_objs, 4, device=device, dtype=self.text_encoder.dtype - ) + cond_boxes = torch.zeros(max_objs, 4, device=device, dtype=self.text_encoder.dtype) if n_objs: cond_boxes[:n_objs] = torch.tensor(boxes) text_embeddings = torch.zeros( @@ -1059,16 +965,12 @@ def __call__( if n_objs: text_embeddings[:n_objs] = _text_embeddings # Generate a mask for each object that is entity described by phrases - masks = torch.zeros(max_objs, device=device, - dtype=self.text_encoder.dtype) + masks = torch.zeros(max_objs, device=device, dtype=self.text_encoder.dtype) masks[:n_objs] = 1 repeat_batch = batch_size * num_images_per_prompt - cond_boxes = cond_boxes.unsqueeze( - 0).expand(repeat_batch, -1, -1).clone() - text_embeddings = ( - text_embeddings.unsqueeze(0).expand(repeat_batch, -1, -1).clone() - ) + cond_boxes = cond_boxes.unsqueeze(0).expand(repeat_batch, -1, -1).clone() + text_embeddings = text_embeddings.unsqueeze(0).expand(repeat_batch, -1, -1).clone() masks = masks.unsqueeze(0).expand(repeat_batch, -1).clone() if do_classifier_free_guidance: repeat_batch = repeat_batch * 2 @@ -1084,23 +986,19 @@ def __call__( "masks": masks, } - num_grounding_steps = int( - gligen_scheduled_sampling_beta * len(timesteps)) + num_grounding_steps = int(gligen_scheduled_sampling_beta * len(timesteps)) self.enable_fuser(True) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 6.1 Add image embeds for IP-Adapter - added_cond_kwargs = ( - {"image_embeds": image_embeds} if ip_adapter_image is not None else None - ) + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None loss_attn = torch.tensor(10000.0) # 7. Denoising loop - num_warmup_steps = len(timesteps) - \ - num_inference_steps * self.scheduler.order + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Scheduled sampling @@ -1124,13 +1022,8 @@ def __call__( ) # expand the latents if we are doing classifier free guidance - latent_model_input = ( - torch.cat( - [latents] * 2) if do_classifier_free_guidance else latents - ) - latent_model_input = self.scheduler.scale_model_input( - latent_model_input, t - ) + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( @@ -1144,32 +1037,21 @@ def __call__( # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step( - noise_pred, t, latents, **extra_step_kwargs - ).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided - if i == len(timesteps) - 1 or ( - (i + 1) > num_warmup_steps and (i + - 1) % self.scheduler.order == 0 - ): + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": - image = self.vae.decode( - latents / self.vae.config.scaling_factor, return_dict=False - )[0] - image, has_nsfw_concept = self.run_safety_checker( - image, device, prompt_embeds.dtype - ) + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None @@ -1179,9 +1061,7 @@ def __call__( else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - image = self.image_processor.postprocess( - image, output_type=output_type, do_denormalize=do_denormalize - ) + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: @@ -1190,9 +1070,7 @@ def __call__( if not return_dict: return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput( - images=image, nsfw_content_detected=has_nsfw_concept - ) + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) @torch.set_grad_enabled(True) def latent_lmd_guidance( @@ -1234,17 +1112,13 @@ def latent_lmd_guidance( self.enable_attn_hook(enabled=True) while ( - loss.item() / loss_scale > loss_threshold - and iteration < max_iter - and index < guidance_timesteps + loss.item() / loss_scale > loss_threshold and iteration < max_iter and index < guidance_timesteps ): self._saved_attn = {} latents.requires_grad_(True) latent_model_input = latents - latent_model_input = scheduler.scale_model_input( - latent_model_input, t - ) + latent_model_input = scheduler.scale_model_input(latent_model_input, t) unet( latent_model_input, @@ -1272,14 +1146,11 @@ def latent_lmd_guidance( # This callback allows visualizations. if guidance_callback is not None: - guidance_callback( - self, latents, loss, iteration, index) + guidance_callback(self, latents, loss, iteration, index) self._saved_attn = None - grad_cond = torch.autograd.grad( - loss.requires_grad_(True), [latents] - )[0] + grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents])[0] latents.requires_grad_(False) @@ -1356,8 +1227,7 @@ def _encode_prompt( **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." - deprecate("_encode_prompt()", "1.0.0", - deprecation_message, standard_warn=False) + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) prompt_embeds_tuple = self.encode_prompt( prompt=prompt, @@ -1372,8 +1242,7 @@ def _encode_prompt( ) # concatenate for backwards comp - prompt_embeds = torch.cat( - [prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds @@ -1450,33 +1319,26 @@ def encode_prompt( return_tensors="pt", ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer( - prompt, padding="longest", return_tensors="pt" - ).input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[ - -1 - ] and not torch.equal(text_input_ids, untruncated_ids): + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1: -1] + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) - if ( - hasattr(self.text_encoder.config, "use_attention_mask") - and self.text_encoder.config.use_attention_mask - ): + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: - prompt_embeds = self.text_encoder( - text_input_ids.to(device), attention_mask=attention_mask - ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( @@ -1492,9 +1354,7 @@ def encode_prompt( # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. - prompt_embeds = self.text_encoder.text_model.final_layer_norm( - prompt_embeds - ) + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype @@ -1503,15 +1363,12 @@ def encode_prompt( else: prompt_embeds_dtype = prompt_embeds.dtype - prompt_embeds = prompt_embeds.to( - dtype=prompt_embeds_dtype, device=device) + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view( - bs_embed * num_images_per_prompt, seq_len, -1 - ) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: @@ -1536,8 +1393,7 @@ def encode_prompt( # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): - uncond_tokens = self.maybe_convert_prompt( - uncond_tokens, self.tokenizer) + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( @@ -1548,10 +1404,7 @@ def encode_prompt( return_tensors="pt", ) - if ( - hasattr(self.text_encoder.config, "use_attention_mask") - and self.text_encoder.config.use_attention_mask - ): + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None @@ -1566,16 +1419,10 @@ def encode_prompt( # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to( - dtype=prompt_embeds_dtype, device=device - ) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - negative_prompt_embeds = negative_prompt_embeds.repeat( - 1, num_images_per_prompt, 1 - ) - negative_prompt_embeds = negative_prompt_embeds.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers @@ -1588,13 +1435,11 @@ def encode_image(self, image, device, num_images_per_prompt): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): - image = self.feature_extractor( - image, return_tensors="pt").pixel_values + image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave( - num_images_per_prompt, dim=0) + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds @@ -1605,26 +1450,19 @@ def run_safety_checker(self, image, device, dtype): has_nsfw_concept = None else: if torch.is_tensor(image): - feature_extractor_input = self.image_processor.postprocess( - image, output_type="pil" - ) + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: - feature_extractor_input = self.image_processor.numpy_to_pil( - image) - safety_checker_input = self.feature_extractor( - feature_extractor_input, return_tensors="pt" - ).to(device) + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to( - dtype) + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" - deprecate("decode_latents", "1.0.0", - deprecation_message, standard_warn=False) + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] @@ -1640,17 +1478,13 @@ def prepare_extra_step_kwargs(self, generator, eta): # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] - accepts_eta = "eta" in set( - inspect.signature(self.scheduler.step).parameters.keys() - ) + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator - accepts_generator = "generator" in set( - inspect.signature(self.scheduler.step).parameters.keys() - ) + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs @@ -1680,9 +1514,7 @@ def prepare_latents( ) if latents is None: - latents = randn_tensor( - shape, generator=generator, device=device, dtype=dtype - ) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device)