From e2fe0e2d55a03b844e2500373521800880f8d227 Mon Sep 17 00:00:00 2001 From: Warlord-K Date: Tue, 30 Jul 2024 17:22:20 +0530 Subject: [PATCH 1/6] Add AuraFlowLoraLoaderMixin --- src/diffusers/loaders/__init__.py | 2 + src/diffusers/loaders/lora_pipeline.py | 331 ++++++++++++++++++ .../transformers/auraflow_transformer_2d.py | 67 +++- .../pipelines/aura_flow/pipeline_aura_flow.py | 3 +- 4 files changed, 398 insertions(+), 5 deletions(-) diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 5db13825c9eb..20411b65ade2 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -64,6 +64,7 @@ def text_encoder_attn_modules(text_encoder): "AmusedLoraLoaderMixin", "StableDiffusionLoraLoaderMixin", "SD3LoraLoaderMixin", + "AuraFlowLoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin", "LoraLoaderMixin", ] @@ -85,6 +86,7 @@ def text_encoder_attn_modules(text_encoder): AmusedLoraLoaderMixin, LoraLoaderMixin, SD3LoraLoaderMixin, + AuraFlowLoraLoaderMixin, StableDiffusionLoraLoaderMixin, StableDiffusionXLLoraLoaderMixin, ) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 73273618956a..2dedcb004623 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1475,6 +1475,337 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "t super().unfuse_lora(components=components) +class AuraFlowLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`AuraFlowTransformer2DModel`] + Specific to [`AuraFlowPipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + text_encoder_name = TEXT_ENCODER_NAME + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = cls._fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + return state_dict + + def load_lora_weights( + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` + + All kwargs are forwarded to `self.lora_state_dict`. + + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is + loaded. + + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + ) + + + @classmethod + def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`SD3Transformer2DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + + keys = list(state_dict.keys()) + + transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] + state_dict = { + k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys + } + + if len(state_dict.keys()) > 0: + # check with first key if is not in peft format + first_key = next(iter(state_dict.keys())) + if "lora_A" not in first_key: + state_dict = convert_unet_state_dict_to_peft(state_dict) + + if adapter_name in getattr(transformer, "peft_config", {}): + raise ValueError( + f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." + ) + + rank = {} + for key, val in state_dict.items(): + if "lora_B" in key: + rank[key] = val.shape[1] + + lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + lora_config_kwargs.pop("use_dora") + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(transformer) + + # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks + # otherwise loading LoRA weights will lead to an error + is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + + inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) + incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) + + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> + + @classmethod + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, torch.nn.Module] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not (transformer_lora_layers): + raise ValueError( + "You must pass `transformer_lora_layers`." + ) + + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + def fuse_lora( + self, + components: List[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + ) + + def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + """ + super().unfuse_lora(components=components) + # The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially # relied on `StableDiffusionLoraLoaderMixin` for its LoRA support. class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin): diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index 89d51969aeaa..6fa890b854ce 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import Any, Dict, Union +from typing import Any, Dict, Optional, Union import torch import torch.nn as nn @@ -22,11 +22,13 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import is_torch_version, logging from ...utils.torch_utils import maybe_allow_in_graph -from ..attention_processor import Attention, AuraFlowAttnProcessor2_0 +from ..attention_processor import Attention, AuraFlowAttnProcessor2_0, FusedJointAttnProcessor2_0 from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormZero, FP32LayerNorm +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -232,7 +234,7 @@ def forward( return encoder_hidden_states, hidden_states -class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin): +class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): r""" A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/). @@ -320,6 +322,46 @@ def __init__( self.gradient_checkpointing = False + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0 + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedJointAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value @@ -329,6 +371,7 @@ def forward( hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, timestep: torch.LongTensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: height, width = hidden_states.shape[-2:] @@ -341,7 +384,19 @@ def forward( encoder_hidden_states = torch.cat( [self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1 ) - + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) # MMDiT blocks. for index_block, block in enumerate(self.joint_transformer_blocks): if self.training and self.gradient_checkpointing: @@ -416,6 +471,10 @@ def custom_forward(*inputs): shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size) ) + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + if not return_dict: return (output,) diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index 6a86b5ceded9..f50960dcd01f 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -24,6 +24,7 @@ from ...utils import logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...loaders import AuraFlowLoraLoaderMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -104,7 +105,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class AuraFlowPipeline(DiffusionPipeline): +class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin): r""" Args: tokenizer (`T5TokenizerFast`): From a21bd03e63d1d7caa1a0cd7c00b351ada07e7b34 Mon Sep 17 00:00:00 2001 From: Warlord-K Date: Tue, 30 Jul 2024 18:46:20 +0530 Subject: [PATCH 2/6] Add comments, remove qkv fusion --- src/diffusers/loaders/__init__.py | 2 +- src/diffusers/loaders/lora_pipeline.py | 7 ++- .../transformers/auraflow_transformer_2d.py | 57 +++---------------- .../pipelines/aura_flow/pipeline_aura_flow.py | 2 +- 4 files changed, 16 insertions(+), 52 deletions(-) diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 20411b65ade2..9bb5f8fd0abc 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -84,9 +84,9 @@ def text_encoder_attn_modules(text_encoder): from .ip_adapter import IPAdapterMixin from .lora_pipeline import ( AmusedLoraLoaderMixin, + AuraFlowLoraLoaderMixin, LoraLoaderMixin, SD3LoraLoaderMixin, - AuraFlowLoraLoaderMixin, StableDiffusionLoraLoaderMixin, StableDiffusionXLLoraLoaderMixin, ) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 2dedcb004623..07b524cb5c2f 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1483,9 +1483,9 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): _lora_loadable_modules = ["transformer"] transformer_name = TRANSFORMER_NAME - text_encoder_name = TEXT_ENCODER_NAME @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict @validate_hf_hub_args def lora_state_dict( cls, @@ -1576,6 +1576,7 @@ def lora_state_dict( return state_dict + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_weights def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): @@ -1622,6 +1623,7 @@ def load_lora_weights( @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -1700,6 +1702,7 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, # Unsafe code /> @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.save_lora_weights def save_lora_weights( cls, save_directory: Union[str, os.PathLike], @@ -1747,6 +1750,7 @@ def save_lora_weights( safe_serialization=safe_serialization, ) + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.fuse_lora def fuse_lora( self, components: List[str] = ["transformer"], @@ -1790,6 +1794,7 @@ def fuse_lora( components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.unfuse_lora def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): r""" Reverses the effect of diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index 6fa890b854ce..5838248b33fa 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -20,15 +20,14 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import is_torch_version, logging +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention_processor import Attention, AuraFlowAttnProcessor2_0, FusedJointAttnProcessor2_0 from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormZero, FP32LayerNorm -from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers -from ...loaders import FromOriginalModelMixin, PeftAdapterMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -322,46 +321,6 @@ def __init__( self.gradient_checkpointing = False - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0 - def fuse_qkv_projections(self): - """ - Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) - are fused. For cross-attention modules, key and value projection matrices are fused. - - - - This API is 🧪 experimental. - - - """ - self.original_attn_processors = None - - for _, attn_processor in self.attn_processors.items(): - if "Added" in str(attn_processor.__class__.__name__): - raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") - - self.original_attn_processors = self.attn_processors - - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) - - self.set_attn_processor(FusedJointAttnProcessor2_0()) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections - def unfuse_qkv_projections(self): - """Disables the fused QKV projection if enabled. - - - - This API is 🧪 experimental. - - - - """ - if self.original_attn_processors is not None: - self.set_attn_processor(self.original_attn_processors) - def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value @@ -371,7 +330,7 @@ def forward( hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, timestep: torch.LongTensor = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: height, width = hidden_states.shape[-2:] @@ -384,18 +343,18 @@ def forward( encoder_hidden_states = torch.cat( [self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1 ) - if joint_attention_kwargs is not None: - joint_attention_kwargs = joint_attention_kwargs.copy() - lora_scale = joint_attention_kwargs.pop("scale", 1.0) + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: - if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: logger.warning( - "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." ) # MMDiT blocks. for index_block, block in enumerate(self.joint_transformer_blocks): diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index f50960dcd01f..26399f00d815 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -18,13 +18,13 @@ from transformers import T5Tokenizer, UMT5EncoderModel from ...image_processor import VaeImageProcessor +from ...loaders import AuraFlowLoraLoaderMixin from ...models import AuraFlowTransformer2DModel, AutoencoderKL from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput -from ...loaders import AuraFlowLoraLoaderMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name From e8e9571993577a326ea0e412096f869c0f1ef131 Mon Sep 17 00:00:00 2001 From: Warlord-K Date: Tue, 30 Jul 2024 18:50:50 +0530 Subject: [PATCH 3/6] Add Tests --- tests/lora/test_lora_layers_af.py | 90 +++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 tests/lora/test_lora_layers_af.py diff --git a/tests/lora/test_lora_layers_af.py b/tests/lora/test_lora_layers_af.py new file mode 100644 index 000000000000..2b050aa74da9 --- /dev/null +++ b/tests/lora/test_lora_layers_af.py @@ -0,0 +1,90 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import unittest + +from diffusers import ( + FlowMatchEulerDiscreteScheduler, + AuraFlowPipeline, +) +from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, require_torch_gpu, torch_device + + +if is_peft_available(): + pass + +sys.path.append(".") + +from utils import PeftLoraLoaderMixinTests # noqa: E402 + + +@require_peft_backend +class AFLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): + pipeline_class = AuraFlowPipeline + scheduler_cls = FlowMatchEulerDiscreteScheduler() + scheduler_kwargs = {} + transformer_kwargs = { + "sample_size": 64, + "patch_size": 2, + "in_channels": 4, + "num_mmdit_layers": 4, + "num_single_dit_layers": 32, + "attention_head_dim": 256, + "num_attention_heads": 12, + "joint_attention_dim": 2048, + "caption_projection_dim": 3072, + "out_channels": 4, + "pos_embed_max_size": 1024, + } + vae_kwargs = { + "sample_size": 1024, + "in_channels": 3, + "out_channels": 3, + "block_out_channels": [ + 128, + 256, + 512, + 512 + ], + "layers_per_block": 2, + "latent_channels": 4, + "norm_num_groups": 32, + "use_quant_conv": True, + "use_post_quant_conv": True, + "shift_factor": None, + "scaling_factor": 0.13025, + } + has_three_text_encoders = False + + @require_torch_gpu + def test_af_lora(self): + """ + Test loading the loras that are saved with the diffusers and peft formats. + Related PR: https://github.com/huggingface/diffusers/pull/8584 + """ + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + lora_model_id = "Warlord-K/gorkem-auraflow-lora" + + lora_filename = "pytorch_lora_weights.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + pipe.unload_lora_weights() + + lora_filename = "lora_peft_format.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) From 39efb8044a60c9f58e5675dc8a3b51f6661ea75e Mon Sep 17 00:00:00 2001 From: Warlord-K Date: Tue, 30 Jul 2024 18:53:49 +0530 Subject: [PATCH 4/6] Add AuraFlowLoraLoaderMixin to documentation --- docs/source/en/api/loaders/lora.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/en/api/loaders/lora.md b/docs/source/en/api/loaders/lora.md index 2060a1eefd52..ed0d0f1819ad 100644 --- a/docs/source/en/api/loaders/lora.md +++ b/docs/source/en/api/loaders/lora.md @@ -17,6 +17,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi - [`StableDiffusionLoraLoaderMixin`] provides functions for loading and unloading, fusing and unfusing, enabling and disabling, and more functions for managing LoRA weights. This class can be used with any model. - [`StableDiffusionXLLoraLoaderMixin`] is a [Stable Diffusion (SDXL)](../../api/pipelines/stable_diffusion/stable_diffusion_xl) version of the [`StableDiffusionLoraLoaderMixin`] class for loading and saving LoRA weights. It can only be used with the SDXL model. - [`SD3LoraLoaderMixin`] provides similar functions for [Stable Diffusion 3](https://huggingface.co/blog/sd3). +- [`AuraFlowLoraLoaderMixin`] provides similar functions for [AuraFlow](https://huggingface.co/fal/AuraFlow). - [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`]. - [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more. @@ -38,6 +39,10 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse [[autodoc]] loaders.lora_pipeline.SD3LoraLoaderMixin +## AuraFlowLoraLoaderMixin + +[[autodoc]] loaders.lora_pipeline.AuraFlowLoraLoaderMixin + ## AmusedLoraLoaderMixin [[autodoc]] loaders.lora_pipeline.AmusedLoraLoaderMixin From fa32398691cc129b40254700b4b364b1ff8c3eb4 Mon Sep 17 00:00:00 2001 From: Warlord-K Date: Mon, 12 Aug 2024 01:44:56 +0530 Subject: [PATCH 5/6] Add Suggested changes --- src/diffusers/loaders/lora_pipeline.py | 3 +- src/diffusers/loaders/peft.py | 1 + .../transformers/auraflow_transformer_2d.py | 2 +- tests/lora/test_lora_layers_af.py | 65 +++++++------------ 4 files changed, 27 insertions(+), 44 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 1216ab92fc8e..9f2ecc80eadc 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1958,7 +1958,6 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): transformer_name = TRANSFORMER_NAME @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict @validate_hf_hub_args def lora_state_dict( cls, @@ -2092,7 +2091,6 @@ def load_lora_weights( _pipeline=self, ) - @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): @@ -2263,6 +2261,7 @@ def fuse_lora( components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict with text_encoder removed from components def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): r""" Reverses the effect of diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 89d6a28b14dd..f31011151bc4 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -33,6 +33,7 @@ "UNetMotionModel": _maybe_expand_lora_scales, "SD3Transformer2DModel": lambda model_cls, weights: weights, "FluxTransformer2DModel": lambda model_cls, weights: weights, + "AuraFlowTransformer2DModel": lambda model_cls, weights: weights, } diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index dbf87f7ed6ee..02362fd6013a 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -238,7 +238,7 @@ def forward( return encoder_hidden_states, hidden_states -class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): +class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): r""" A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/). diff --git a/tests/lora/test_lora_layers_af.py b/tests/lora/test_lora_layers_af.py index 2b050aa74da9..9615249633cd 100644 --- a/tests/lora/test_lora_layers_af.py +++ b/tests/lora/test_lora_layers_af.py @@ -15,6 +15,8 @@ import sys import unittest +from transformers import AutoTokenizer, T5EncoderModel + from diffusers import ( FlowMatchEulerDiscreteScheduler, AuraFlowPipeline, @@ -31,60 +33,41 @@ @require_peft_backend -class AFLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): +class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): pipeline_class = AuraFlowPipeline scheduler_cls = FlowMatchEulerDiscreteScheduler() scheduler_kwargs = {} + uses_flow_matching = True transformer_kwargs = { "sample_size": 64, - "patch_size": 2, + "patch_size": 1, "in_channels": 4, - "num_mmdit_layers": 4, - "num_single_dit_layers": 32, - "attention_head_dim": 256, - "num_attention_heads": 12, - "joint_attention_dim": 2048, - "caption_projection_dim": 3072, + "num_mmdit_layers": 1, + "num_single_dit_layers": 1, + "attention_head_dim": 16, + "num_attention_heads": 2, + "joint_attention_dim": 32, + "caption_projection_dim": 32, "out_channels": 4, - "pos_embed_max_size": 1024, + "pos_embed_max_size": 32, } + tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" + text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5" + vae_kwargs = { - "sample_size": 1024, + "sample_size": 32, "in_channels": 3, "out_channels": 3, - "block_out_channels": [ - 128, - 256, - 512, - 512 - ], - "layers_per_block": 2, + "block_out_channels": (4,), + "layers_per_block": 1, "latent_channels": 4, - "norm_num_groups": 32, - "use_quant_conv": True, - "use_post_quant_conv": True, + "norm_num_groups": 1, + "use_quant_conv": False, + "use_post_quant_conv": False, "shift_factor": None, "scaling_factor": 0.13025, } - has_three_text_encoders = False - - @require_torch_gpu - def test_af_lora(self): - """ - Test loading the loras that are saved with the diffusers and peft formats. - Related PR: https://github.com/huggingface/diffusers/pull/8584 - """ - components = self.get_dummy_components() - - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - lora_model_id = "Warlord-K/gorkem-auraflow-lora" - - lora_filename = "pytorch_lora_weights.safetensors" - pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) - pipe.unload_lora_weights() - lora_filename = "lora_peft_format.safetensors" - pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + @property + def output_shape(self): + return (1, 64, 64, 3) \ No newline at end of file From 5f56714e4156edb2dc7b57ade61ba92615b7671e Mon Sep 17 00:00:00 2001 From: Warlord-K Date: Mon, 12 Aug 2024 11:06:37 +0530 Subject: [PATCH 6/6] Change attention_kwargs->joint_attention_kwargs --- src/diffusers/loaders/lora_pipeline.py | 2 +- .../models/transformers/auraflow_transformer_2d.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 9f2ecc80eadc..d0c71cba557b 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1959,6 +1959,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], @@ -2261,7 +2262,6 @@ def fuse_lora( components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict with text_encoder removed from components def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): r""" Reverses the effect of diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index 02362fd6013a..e6fa5cc1dc25 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -435,7 +435,7 @@ def forward( hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, timestep: torch.LongTensor = None, - attention_kwargs: Optional[Dict[str, Any]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: height, width = hidden_states.shape[-2:] @@ -448,18 +448,18 @@ def forward( encoder_hidden_states = torch.cat( [self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1 ) - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) # MMDiT blocks. for index_block, block in enumerate(self.joint_transformer_blocks):