Skip to content

Commit 1826a1e

Browse files
[LoRA] Support HunyuanVideo (#10254)
* 1217 * 1217 * 1217 * update * reverse * add test * update test * make style * update * make style --------- Co-authored-by: Aryan <[email protected]>
1 parent 0ed09a1 commit 1826a1e

File tree

7 files changed

+600
-15
lines changed

7 files changed

+600
-15
lines changed

src/diffusers/loaders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def text_encoder_attn_modules(text_encoder):
7070
"FluxLoraLoaderMixin",
7171
"CogVideoXLoraLoaderMixin",
7272
"Mochi1LoraLoaderMixin",
73+
"HunyuanVideoLoraLoaderMixin",
7374
"SanaLoraLoaderMixin",
7475
]
7576
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
@@ -90,6 +91,7 @@ def text_encoder_attn_modules(text_encoder):
9091
AmusedLoraLoaderMixin,
9192
CogVideoXLoraLoaderMixin,
9293
FluxLoraLoaderMixin,
94+
HunyuanVideoLoraLoaderMixin,
9395
LoraLoaderMixin,
9496
LTXVideoLoraLoaderMixin,
9597
Mochi1LoraLoaderMixin,

src/diffusers/loaders/lora_pipeline.py

Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3870,6 +3870,314 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], *
38703870
super().unfuse_lora(components=components)
38713871

38723872

3873+
class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
3874+
r"""
3875+
Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`].
3876+
"""
3877+
3878+
_lora_loadable_modules = ["transformer"]
3879+
transformer_name = TRANSFORMER_NAME
3880+
3881+
@classmethod
3882+
@validate_hf_hub_args
3883+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
3884+
def lora_state_dict(
3885+
cls,
3886+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
3887+
**kwargs,
3888+
):
3889+
r"""
3890+
Return state dict for lora weights and the network alphas.
3891+
3892+
<Tip warning={true}>
3893+
3894+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
3895+
3896+
This function is experimental and might change in the future.
3897+
3898+
</Tip>
3899+
3900+
Parameters:
3901+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3902+
Can be either:
3903+
3904+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
3905+
the Hub.
3906+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
3907+
with [`ModelMixin.save_pretrained`].
3908+
- A [torch state
3909+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
3910+
3911+
cache_dir (`Union[str, os.PathLike]`, *optional*):
3912+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
3913+
is not used.
3914+
force_download (`bool`, *optional*, defaults to `False`):
3915+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
3916+
cached versions if they exist.
3917+
3918+
proxies (`Dict[str, str]`, *optional*):
3919+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
3920+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
3921+
local_files_only (`bool`, *optional*, defaults to `False`):
3922+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
3923+
won't be downloaded from the Hub.
3924+
token (`str` or *bool*, *optional*):
3925+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
3926+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
3927+
revision (`str`, *optional*, defaults to `"main"`):
3928+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
3929+
allowed by Git.
3930+
subfolder (`str`, *optional*, defaults to `""`):
3931+
The subfolder location of a model file within a larger model repository on the Hub or locally.
3932+
3933+
"""
3934+
# Load the main state dict first which has the LoRA layers for either of
3935+
# transformer and text encoder or both.
3936+
cache_dir = kwargs.pop("cache_dir", None)
3937+
force_download = kwargs.pop("force_download", False)
3938+
proxies = kwargs.pop("proxies", None)
3939+
local_files_only = kwargs.pop("local_files_only", None)
3940+
token = kwargs.pop("token", None)
3941+
revision = kwargs.pop("revision", None)
3942+
subfolder = kwargs.pop("subfolder", None)
3943+
weight_name = kwargs.pop("weight_name", None)
3944+
use_safetensors = kwargs.pop("use_safetensors", None)
3945+
3946+
allow_pickle = False
3947+
if use_safetensors is None:
3948+
use_safetensors = True
3949+
allow_pickle = True
3950+
3951+
user_agent = {
3952+
"file_type": "attn_procs_weights",
3953+
"framework": "pytorch",
3954+
}
3955+
3956+
state_dict = _fetch_state_dict(
3957+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
3958+
weight_name=weight_name,
3959+
use_safetensors=use_safetensors,
3960+
local_files_only=local_files_only,
3961+
cache_dir=cache_dir,
3962+
force_download=force_download,
3963+
proxies=proxies,
3964+
token=token,
3965+
revision=revision,
3966+
subfolder=subfolder,
3967+
user_agent=user_agent,
3968+
allow_pickle=allow_pickle,
3969+
)
3970+
3971+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
3972+
if is_dora_scale_present:
3973+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
3974+
logger.warning(warn_msg)
3975+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
3976+
3977+
return state_dict
3978+
3979+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
3980+
def load_lora_weights(
3981+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
3982+
):
3983+
"""
3984+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
3985+
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
3986+
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
3987+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
3988+
dict is loaded into `self.transformer`.
3989+
3990+
Parameters:
3991+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3992+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3993+
adapter_name (`str`, *optional*):
3994+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3995+
`default_{i}` where i is the total number of adapters being loaded.
3996+
low_cpu_mem_usage (`bool`, *optional*):
3997+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3998+
weights.
3999+
kwargs (`dict`, *optional*):
4000+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
4001+
"""
4002+
if not USE_PEFT_BACKEND:
4003+
raise ValueError("PEFT backend is required for this method.")
4004+
4005+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
4006+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
4007+
raise ValueError(
4008+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
4009+
)
4010+
4011+
# if a dict is passed, copy it instead of modifying it inplace
4012+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
4013+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
4014+
4015+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
4016+
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
4017+
4018+
is_correct_format = all("lora" in key for key in state_dict.keys())
4019+
if not is_correct_format:
4020+
raise ValueError("Invalid LoRA checkpoint.")
4021+
4022+
self.load_lora_into_transformer(
4023+
state_dict,
4024+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
4025+
adapter_name=adapter_name,
4026+
_pipeline=self,
4027+
low_cpu_mem_usage=low_cpu_mem_usage,
4028+
)
4029+
4030+
@classmethod
4031+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel
4032+
def load_lora_into_transformer(
4033+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
4034+
):
4035+
"""
4036+
This will load the LoRA layers specified in `state_dict` into `transformer`.
4037+
4038+
Parameters:
4039+
state_dict (`dict`):
4040+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
4041+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
4042+
encoder lora layers.
4043+
transformer (`HunyuanVideoTransformer3DModel`):
4044+
The Transformer model to load the LoRA layers into.
4045+
adapter_name (`str`, *optional*):
4046+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
4047+
`default_{i}` where i is the total number of adapters being loaded.
4048+
low_cpu_mem_usage (`bool`, *optional*):
4049+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
4050+
weights.
4051+
"""
4052+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
4053+
raise ValueError(
4054+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
4055+
)
4056+
4057+
# Load the layers corresponding to transformer.
4058+
logger.info(f"Loading {cls.transformer_name}.")
4059+
transformer.load_lora_adapter(
4060+
state_dict,
4061+
network_alphas=None,
4062+
adapter_name=adapter_name,
4063+
_pipeline=_pipeline,
4064+
low_cpu_mem_usage=low_cpu_mem_usage,
4065+
)
4066+
4067+
@classmethod
4068+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
4069+
def save_lora_weights(
4070+
cls,
4071+
save_directory: Union[str, os.PathLike],
4072+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
4073+
is_main_process: bool = True,
4074+
weight_name: str = None,
4075+
save_function: Callable = None,
4076+
safe_serialization: bool = True,
4077+
):
4078+
r"""
4079+
Save the LoRA parameters corresponding to the UNet and text encoder.
4080+
4081+
Arguments:
4082+
save_directory (`str` or `os.PathLike`):
4083+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
4084+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
4085+
State dict of the LoRA layers corresponding to the `transformer`.
4086+
is_main_process (`bool`, *optional*, defaults to `True`):
4087+
Whether the process calling this is the main process or not. Useful during distributed training and you
4088+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
4089+
process to avoid race conditions.
4090+
save_function (`Callable`):
4091+
The function to use to save the state dictionary. Useful during distributed training when you need to
4092+
replace `torch.save` with another method. Can be configured with the environment variable
4093+
`DIFFUSERS_SAVE_MODE`.
4094+
safe_serialization (`bool`, *optional*, defaults to `True`):
4095+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
4096+
"""
4097+
state_dict = {}
4098+
4099+
if not transformer_lora_layers:
4100+
raise ValueError("You must pass `transformer_lora_layers`.")
4101+
4102+
if transformer_lora_layers:
4103+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
4104+
4105+
# Save the model
4106+
cls.write_lora_layers(
4107+
state_dict=state_dict,
4108+
save_directory=save_directory,
4109+
is_main_process=is_main_process,
4110+
weight_name=weight_name,
4111+
save_function=save_function,
4112+
safe_serialization=safe_serialization,
4113+
)
4114+
4115+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
4116+
def fuse_lora(
4117+
self,
4118+
components: List[str] = ["transformer", "text_encoder"],
4119+
lora_scale: float = 1.0,
4120+
safe_fusing: bool = False,
4121+
adapter_names: Optional[List[str]] = None,
4122+
**kwargs,
4123+
):
4124+
r"""
4125+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
4126+
4127+
<Tip warning={true}>
4128+
4129+
This is an experimental API.
4130+
4131+
</Tip>
4132+
4133+
Args:
4134+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
4135+
lora_scale (`float`, defaults to 1.0):
4136+
Controls how much to influence the outputs with the LoRA parameters.
4137+
safe_fusing (`bool`, defaults to `False`):
4138+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
4139+
adapter_names (`List[str]`, *optional*):
4140+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
4141+
4142+
Example:
4143+
4144+
```py
4145+
from diffusers import DiffusionPipeline
4146+
import torch
4147+
4148+
pipeline = DiffusionPipeline.from_pretrained(
4149+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
4150+
).to("cuda")
4151+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
4152+
pipeline.fuse_lora(lora_scale=0.7)
4153+
```
4154+
"""
4155+
super().fuse_lora(
4156+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
4157+
)
4158+
4159+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
4160+
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
4161+
r"""
4162+
Reverses the effect of
4163+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
4164+
4165+
<Tip warning={true}>
4166+
4167+
This is an experimental API.
4168+
4169+
</Tip>
4170+
4171+
Args:
4172+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
4173+
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
4174+
unfuse_text_encoder (`bool`, defaults to `True`):
4175+
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
4176+
LoRA parameters then it won't have any effect.
4177+
"""
4178+
super().unfuse_lora(components=components)
4179+
4180+
38734181
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
38744182
def __init__(self, *args, **kwargs):
38754183
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."

src/diffusers/loaders/peft.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
"FluxTransformer2DModel": lambda model_cls, weights: weights,
5454
"CogVideoXTransformer3DModel": lambda model_cls, weights: weights,
5555
"MochiTransformer3DModel": lambda model_cls, weights: weights,
56+
"HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
5657
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
5758
"SanaTransformer2DModel": lambda model_cls, weights: weights,
5859
}

0 commit comments

Comments
 (0)