@@ -3870,6 +3870,314 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], *
3870
3870
super ().unfuse_lora (components = components )
3871
3871
3872
3872
3873
+ class HunyuanVideoLoraLoaderMixin (LoraBaseMixin ):
3874
+ r"""
3875
+ Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`].
3876
+ """
3877
+
3878
+ _lora_loadable_modules = ["transformer" ]
3879
+ transformer_name = TRANSFORMER_NAME
3880
+
3881
+ @classmethod
3882
+ @validate_hf_hub_args
3883
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
3884
+ def lora_state_dict (
3885
+ cls ,
3886
+ pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]],
3887
+ ** kwargs ,
3888
+ ):
3889
+ r"""
3890
+ Return state dict for lora weights and the network alphas.
3891
+
3892
+ <Tip warning={true}>
3893
+
3894
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
3895
+
3896
+ This function is experimental and might change in the future.
3897
+
3898
+ </Tip>
3899
+
3900
+ Parameters:
3901
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3902
+ Can be either:
3903
+
3904
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
3905
+ the Hub.
3906
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
3907
+ with [`ModelMixin.save_pretrained`].
3908
+ - A [torch state
3909
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
3910
+
3911
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
3912
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
3913
+ is not used.
3914
+ force_download (`bool`, *optional*, defaults to `False`):
3915
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
3916
+ cached versions if they exist.
3917
+
3918
+ proxies (`Dict[str, str]`, *optional*):
3919
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
3920
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
3921
+ local_files_only (`bool`, *optional*, defaults to `False`):
3922
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
3923
+ won't be downloaded from the Hub.
3924
+ token (`str` or *bool*, *optional*):
3925
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
3926
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
3927
+ revision (`str`, *optional*, defaults to `"main"`):
3928
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
3929
+ allowed by Git.
3930
+ subfolder (`str`, *optional*, defaults to `""`):
3931
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
3932
+
3933
+ """
3934
+ # Load the main state dict first which has the LoRA layers for either of
3935
+ # transformer and text encoder or both.
3936
+ cache_dir = kwargs .pop ("cache_dir" , None )
3937
+ force_download = kwargs .pop ("force_download" , False )
3938
+ proxies = kwargs .pop ("proxies" , None )
3939
+ local_files_only = kwargs .pop ("local_files_only" , None )
3940
+ token = kwargs .pop ("token" , None )
3941
+ revision = kwargs .pop ("revision" , None )
3942
+ subfolder = kwargs .pop ("subfolder" , None )
3943
+ weight_name = kwargs .pop ("weight_name" , None )
3944
+ use_safetensors = kwargs .pop ("use_safetensors" , None )
3945
+
3946
+ allow_pickle = False
3947
+ if use_safetensors is None :
3948
+ use_safetensors = True
3949
+ allow_pickle = True
3950
+
3951
+ user_agent = {
3952
+ "file_type" : "attn_procs_weights" ,
3953
+ "framework" : "pytorch" ,
3954
+ }
3955
+
3956
+ state_dict = _fetch_state_dict (
3957
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
3958
+ weight_name = weight_name ,
3959
+ use_safetensors = use_safetensors ,
3960
+ local_files_only = local_files_only ,
3961
+ cache_dir = cache_dir ,
3962
+ force_download = force_download ,
3963
+ proxies = proxies ,
3964
+ token = token ,
3965
+ revision = revision ,
3966
+ subfolder = subfolder ,
3967
+ user_agent = user_agent ,
3968
+ allow_pickle = allow_pickle ,
3969
+ )
3970
+
3971
+ is_dora_scale_present = any ("dora_scale" in k for k in state_dict )
3972
+ if is_dora_scale_present :
3973
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
3974
+ logger .warning (warn_msg )
3975
+ state_dict = {k : v for k , v in state_dict .items () if "dora_scale" not in k }
3976
+
3977
+ return state_dict
3978
+
3979
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
3980
+ def load_lora_weights (
3981
+ self , pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]], adapter_name = None , ** kwargs
3982
+ ):
3983
+ """
3984
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
3985
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
3986
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
3987
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
3988
+ dict is loaded into `self.transformer`.
3989
+
3990
+ Parameters:
3991
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3992
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3993
+ adapter_name (`str`, *optional*):
3994
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3995
+ `default_{i}` where i is the total number of adapters being loaded.
3996
+ low_cpu_mem_usage (`bool`, *optional*):
3997
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3998
+ weights.
3999
+ kwargs (`dict`, *optional*):
4000
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
4001
+ """
4002
+ if not USE_PEFT_BACKEND :
4003
+ raise ValueError ("PEFT backend is required for this method." )
4004
+
4005
+ low_cpu_mem_usage = kwargs .pop ("low_cpu_mem_usage" , _LOW_CPU_MEM_USAGE_DEFAULT_LORA )
4006
+ if low_cpu_mem_usage and is_peft_version ("<" , "0.13.0" ):
4007
+ raise ValueError (
4008
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
4009
+ )
4010
+
4011
+ # if a dict is passed, copy it instead of modifying it inplace
4012
+ if isinstance (pretrained_model_name_or_path_or_dict , dict ):
4013
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict .copy ()
4014
+
4015
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
4016
+ state_dict = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
4017
+
4018
+ is_correct_format = all ("lora" in key for key in state_dict .keys ())
4019
+ if not is_correct_format :
4020
+ raise ValueError ("Invalid LoRA checkpoint." )
4021
+
4022
+ self .load_lora_into_transformer (
4023
+ state_dict ,
4024
+ transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
4025
+ adapter_name = adapter_name ,
4026
+ _pipeline = self ,
4027
+ low_cpu_mem_usage = low_cpu_mem_usage ,
4028
+ )
4029
+
4030
+ @classmethod
4031
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel
4032
+ def load_lora_into_transformer (
4033
+ cls , state_dict , transformer , adapter_name = None , _pipeline = None , low_cpu_mem_usage = False
4034
+ ):
4035
+ """
4036
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
4037
+
4038
+ Parameters:
4039
+ state_dict (`dict`):
4040
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
4041
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
4042
+ encoder lora layers.
4043
+ transformer (`HunyuanVideoTransformer3DModel`):
4044
+ The Transformer model to load the LoRA layers into.
4045
+ adapter_name (`str`, *optional*):
4046
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
4047
+ `default_{i}` where i is the total number of adapters being loaded.
4048
+ low_cpu_mem_usage (`bool`, *optional*):
4049
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
4050
+ weights.
4051
+ """
4052
+ if low_cpu_mem_usage and is_peft_version ("<" , "0.13.0" ):
4053
+ raise ValueError (
4054
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
4055
+ )
4056
+
4057
+ # Load the layers corresponding to transformer.
4058
+ logger .info (f"Loading { cls .transformer_name } ." )
4059
+ transformer .load_lora_adapter (
4060
+ state_dict ,
4061
+ network_alphas = None ,
4062
+ adapter_name = adapter_name ,
4063
+ _pipeline = _pipeline ,
4064
+ low_cpu_mem_usage = low_cpu_mem_usage ,
4065
+ )
4066
+
4067
+ @classmethod
4068
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
4069
+ def save_lora_weights (
4070
+ cls ,
4071
+ save_directory : Union [str , os .PathLike ],
4072
+ transformer_lora_layers : Dict [str , Union [torch .nn .Module , torch .Tensor ]] = None ,
4073
+ is_main_process : bool = True ,
4074
+ weight_name : str = None ,
4075
+ save_function : Callable = None ,
4076
+ safe_serialization : bool = True ,
4077
+ ):
4078
+ r"""
4079
+ Save the LoRA parameters corresponding to the UNet and text encoder.
4080
+
4081
+ Arguments:
4082
+ save_directory (`str` or `os.PathLike`):
4083
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
4084
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
4085
+ State dict of the LoRA layers corresponding to the `transformer`.
4086
+ is_main_process (`bool`, *optional*, defaults to `True`):
4087
+ Whether the process calling this is the main process or not. Useful during distributed training and you
4088
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
4089
+ process to avoid race conditions.
4090
+ save_function (`Callable`):
4091
+ The function to use to save the state dictionary. Useful during distributed training when you need to
4092
+ replace `torch.save` with another method. Can be configured with the environment variable
4093
+ `DIFFUSERS_SAVE_MODE`.
4094
+ safe_serialization (`bool`, *optional*, defaults to `True`):
4095
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
4096
+ """
4097
+ state_dict = {}
4098
+
4099
+ if not transformer_lora_layers :
4100
+ raise ValueError ("You must pass `transformer_lora_layers`." )
4101
+
4102
+ if transformer_lora_layers :
4103
+ state_dict .update (cls .pack_weights (transformer_lora_layers , cls .transformer_name ))
4104
+
4105
+ # Save the model
4106
+ cls .write_lora_layers (
4107
+ state_dict = state_dict ,
4108
+ save_directory = save_directory ,
4109
+ is_main_process = is_main_process ,
4110
+ weight_name = weight_name ,
4111
+ save_function = save_function ,
4112
+ safe_serialization = safe_serialization ,
4113
+ )
4114
+
4115
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
4116
+ def fuse_lora (
4117
+ self ,
4118
+ components : List [str ] = ["transformer" , "text_encoder" ],
4119
+ lora_scale : float = 1.0 ,
4120
+ safe_fusing : bool = False ,
4121
+ adapter_names : Optional [List [str ]] = None ,
4122
+ ** kwargs ,
4123
+ ):
4124
+ r"""
4125
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
4126
+
4127
+ <Tip warning={true}>
4128
+
4129
+ This is an experimental API.
4130
+
4131
+ </Tip>
4132
+
4133
+ Args:
4134
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
4135
+ lora_scale (`float`, defaults to 1.0):
4136
+ Controls how much to influence the outputs with the LoRA parameters.
4137
+ safe_fusing (`bool`, defaults to `False`):
4138
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
4139
+ adapter_names (`List[str]`, *optional*):
4140
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
4141
+
4142
+ Example:
4143
+
4144
+ ```py
4145
+ from diffusers import DiffusionPipeline
4146
+ import torch
4147
+
4148
+ pipeline = DiffusionPipeline.from_pretrained(
4149
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
4150
+ ).to("cuda")
4151
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
4152
+ pipeline.fuse_lora(lora_scale=0.7)
4153
+ ```
4154
+ """
4155
+ super ().fuse_lora (
4156
+ components = components , lora_scale = lora_scale , safe_fusing = safe_fusing , adapter_names = adapter_names
4157
+ )
4158
+
4159
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
4160
+ def unfuse_lora (self , components : List [str ] = ["transformer" , "text_encoder" ], ** kwargs ):
4161
+ r"""
4162
+ Reverses the effect of
4163
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
4164
+
4165
+ <Tip warning={true}>
4166
+
4167
+ This is an experimental API.
4168
+
4169
+ </Tip>
4170
+
4171
+ Args:
4172
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
4173
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
4174
+ unfuse_text_encoder (`bool`, defaults to `True`):
4175
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
4176
+ LoRA parameters then it won't have any effect.
4177
+ """
4178
+ super ().unfuse_lora (components = components )
4179
+
4180
+
3873
4181
class LoraLoaderMixin (StableDiffusionLoraLoaderMixin ):
3874
4182
def __init__ (self , * args , ** kwargs ):
3875
4183
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
0 commit comments