@@ -1703,7 +1703,8 @@ def lora_state_dict(
1703
1703
The subfolder location of a model file within a larger model repository on the Hub or locally.
1704
1704
1705
1705
"""
1706
- # Load the main state dict first which has the LoRA layers for transformer
1706
+ # Load the main state dict first which has the LoRA layers for either of
1707
+ # transformer and text encoder or both.
1707
1708
cache_dir = kwargs .pop ("cache_dir" , None )
1708
1709
force_download = kwargs .pop ("force_download" , False )
1709
1710
proxies = kwargs .pop ("proxies" , None )
@@ -1724,7 +1725,7 @@ def lora_state_dict(
1724
1725
"framework" : "pytorch" ,
1725
1726
}
1726
1727
1727
- state_dict = cls . _fetch_state_dict (
1728
+ state_dict = _fetch_state_dict (
1728
1729
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
1729
1730
weight_name = weight_name ,
1730
1731
use_safetensors = use_safetensors ,
@@ -1739,6 +1740,12 @@ def lora_state_dict(
1739
1740
allow_pickle = allow_pickle ,
1740
1741
)
1741
1742
1743
+ is_dora_scale_present = any ("dora_scale" in k for k in state_dict )
1744
+ if is_dora_scale_present :
1745
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
1746
+ logger .warning (warn_msg )
1747
+ state_dict = {k : v for k , v in state_dict .items () if "dora_scale" not in k }
1748
+
1742
1749
return state_dict
1743
1750
1744
1751
def load_lora_weights (
@@ -1787,7 +1794,9 @@ def load_lora_weights(
1787
1794
1788
1795
@classmethod
1789
1796
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
1790
- def load_lora_into_transformer (cls , state_dict , transformer , adapter_name = None , _pipeline = None ):
1797
+ def load_lora_into_transformer (
1798
+ cls , state_dict , transformer , adapter_name = None , _pipeline = None , low_cpu_mem_usage = False
1799
+ ):
1791
1800
"""
1792
1801
This will load the LoRA layers specified in `state_dict` into `transformer`.
1793
1802
@@ -1801,68 +1810,24 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None,
1801
1810
adapter_name (`str`, *optional*):
1802
1811
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1803
1812
`default_{i}` where i is the total number of adapters being loaded.
1813
+ low_cpu_mem_usage (`bool`, *optional*):
1814
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1815
+ weights.
1804
1816
"""
1805
- from peft import LoraConfig , inject_adapter_in_model , set_peft_model_state_dict
1806
-
1807
- keys = list (state_dict .keys ())
1808
-
1809
- transformer_keys = [k for k in keys if k .startswith (cls .transformer_name )]
1810
- state_dict = {
1811
- k .replace (f"{ cls .transformer_name } ." , "" ): v for k , v in state_dict .items () if k in transformer_keys
1812
- }
1813
-
1814
- if len (state_dict .keys ()) > 0 :
1815
- # check with first key if is not in peft format
1816
- first_key = next (iter (state_dict .keys ()))
1817
- if "lora_A" not in first_key :
1818
- state_dict = convert_unet_state_dict_to_peft (state_dict )
1819
-
1820
- if adapter_name in getattr (transformer , "peft_config" , {}):
1821
- raise ValueError (
1822
- f"Adapter name { adapter_name } already in use in the transformer - please select a new adapter name."
1823
- )
1824
-
1825
- rank = {}
1826
- for key , val in state_dict .items ():
1827
- if "lora_B" in key :
1828
- rank [key ] = val .shape [1 ]
1829
-
1830
- lora_config_kwargs = get_peft_kwargs (rank , network_alpha_dict = None , peft_state_dict = state_dict )
1831
- if "use_dora" in lora_config_kwargs :
1832
- if lora_config_kwargs ["use_dora" ] and is_peft_version ("<" , "0.9.0" ):
1833
- raise ValueError (
1834
- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
1835
- )
1836
- else :
1837
- lora_config_kwargs .pop ("use_dora" )
1838
- lora_config = LoraConfig (** lora_config_kwargs )
1839
-
1840
- # adapter_name
1841
- if adapter_name is None :
1842
- adapter_name = get_adapter_name (transformer )
1843
-
1844
- # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
1845
- # otherwise loading LoRA weights will lead to an error
1846
- is_model_cpu_offload , is_sequential_cpu_offload = cls ._optionally_disable_offloading (_pipeline )
1847
-
1848
- inject_adapter_in_model (lora_config , transformer , adapter_name = adapter_name )
1849
- incompatible_keys = set_peft_model_state_dict (transformer , state_dict , adapter_name )
1850
-
1851
- if incompatible_keys is not None :
1852
- # check only for unexpected keys
1853
- unexpected_keys = getattr (incompatible_keys , "unexpected_keys" , None )
1854
- if unexpected_keys :
1855
- logger .warning (
1856
- f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
1857
- f" { unexpected_keys } . "
1858
- )
1817
+ if low_cpu_mem_usage and is_peft_version ("<" , "0.13.0" ):
1818
+ raise ValueError (
1819
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
1820
+ )
1859
1821
1860
- # Offload back.
1861
- if is_model_cpu_offload :
1862
- _pipeline .enable_model_cpu_offload ()
1863
- elif is_sequential_cpu_offload :
1864
- _pipeline .enable_sequential_cpu_offload ()
1865
- # Unsafe code />
1822
+ # Load the layers corresponding to transformer.
1823
+ logger .info (f"Loading { cls .transformer_name } ." )
1824
+ transformer .load_lora_adapter (
1825
+ state_dict ,
1826
+ network_alphas = None ,
1827
+ adapter_name = adapter_name ,
1828
+ _pipeline = _pipeline ,
1829
+ low_cpu_mem_usage = low_cpu_mem_usage ,
1830
+ )
1866
1831
1867
1832
@classmethod
1868
1833
def save_lora_weights (
0 commit comments