11import argparse
22import pathlib
3- from typing import Any , Dict
3+ from typing import Any , Dict , Tuple
44
55import torch
66from accelerate import init_empty_weights
1414 WanImageToVideoPipeline ,
1515 WanPipeline ,
1616 WanTransformer3DModel ,
17+ WanVACEPipeline ,
18+ WanVACETransformer3DModel ,
1719)
1820
1921
5961 "attn2.norm_k_img" : "attn2.norm_added_k" ,
6062}
6163
64+ VACE_TRANSFORMER_KEYS_RENAME_DICT = {
65+ "time_embedding.0" : "condition_embedder.time_embedder.linear_1" ,
66+ "time_embedding.2" : "condition_embedder.time_embedder.linear_2" ,
67+ "text_embedding.0" : "condition_embedder.text_embedder.linear_1" ,
68+ "text_embedding.2" : "condition_embedder.text_embedder.linear_2" ,
69+ "time_projection.1" : "condition_embedder.time_proj" ,
70+ "head.modulation" : "scale_shift_table" ,
71+ "head.head" : "proj_out" ,
72+ "modulation" : "scale_shift_table" ,
73+ "ffn.0" : "ffn.net.0.proj" ,
74+ "ffn.2" : "ffn.net.2" ,
75+ # Hack to swap the layer names
76+ # The original model calls the norms in following order: norm1, norm3, norm2
77+ # We convert it to: norm1, norm2, norm3
78+ "norm2" : "norm__placeholder" ,
79+ "norm3" : "norm2" ,
80+ "norm__placeholder" : "norm3" ,
81+ # # For the I2V model
82+ # "img_emb.proj.0": "condition_embedder.image_embedder.norm1",
83+ # "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
84+ # "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
85+ # "img_emb.proj.4": "condition_embedder.image_embedder.norm2",
86+ # # for the FLF2V model
87+ # "img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed",
88+ # Add attention component mappings
89+ "self_attn.q" : "attn1.to_q" ,
90+ "self_attn.k" : "attn1.to_k" ,
91+ "self_attn.v" : "attn1.to_v" ,
92+ "self_attn.o" : "attn1.to_out.0" ,
93+ "self_attn.norm_q" : "attn1.norm_q" ,
94+ "self_attn.norm_k" : "attn1.norm_k" ,
95+ "cross_attn.q" : "attn2.to_q" ,
96+ "cross_attn.k" : "attn2.to_k" ,
97+ "cross_attn.v" : "attn2.to_v" ,
98+ "cross_attn.o" : "attn2.to_out.0" ,
99+ "cross_attn.norm_q" : "attn2.norm_q" ,
100+ "cross_attn.norm_k" : "attn2.norm_k" ,
101+ "attn2.to_k_img" : "attn2.add_k_proj" ,
102+ "attn2.to_v_img" : "attn2.add_v_proj" ,
103+ "attn2.norm_k_img" : "attn2.norm_added_k" ,
104+ "before_proj" : "proj_in" ,
105+ "after_proj" : "proj_out" ,
106+ }
107+
62108TRANSFORMER_SPECIAL_KEYS_REMAP = {}
109+ VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {}
63110
64111
65112def update_state_dict_ (state_dict : Dict [str , Any ], old_key : str , new_key : str ) -> Dict [str , Any ]:
@@ -74,7 +121,7 @@ def load_sharded_safetensors(dir: pathlib.Path):
74121 return state_dict
75122
76123
77- def get_transformer_config (model_type : str ) -> Dict [str , Any ]:
124+ def get_transformer_config (model_type : str ) -> Tuple [ Dict [str , Any ], ... ]:
78125 if model_type == "Wan-T2V-1.3B" :
79126 config = {
80127 "model_id" : "StevenZhang/Wan2.1-T2V-1.3B-Diff" ,
@@ -94,6 +141,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
94141 "text_dim" : 4096 ,
95142 },
96143 }
144+ RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
145+ SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
97146 elif model_type == "Wan-T2V-14B" :
98147 config = {
99148 "model_id" : "StevenZhang/Wan2.1-T2V-14B-Diff" ,
@@ -113,6 +162,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
113162 "text_dim" : 4096 ,
114163 },
115164 }
165+ RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
166+ SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
116167 elif model_type == "Wan-I2V-14B-480p" :
117168 config = {
118169 "model_id" : "StevenZhang/Wan2.1-I2V-14B-480P-Diff" ,
@@ -133,6 +184,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
133184 "text_dim" : 4096 ,
134185 },
135186 }
187+ RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
188+ SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
136189 elif model_type == "Wan-I2V-14B-720p" :
137190 config = {
138191 "model_id" : "StevenZhang/Wan2.1-I2V-14B-720P-Diff" ,
@@ -153,6 +206,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
153206 "text_dim" : 4096 ,
154207 },
155208 }
209+ RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
210+ SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
156211 elif model_type == "Wan-FLF2V-14B-720P" :
157212 config = {
158213 "model_id" : "ypyp/Wan2.1-FLF2V-14B-720P" , # This is just a placeholder
@@ -175,28 +230,80 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
175230 "pos_embed_seq_len" : 257 * 2 ,
176231 },
177232 }
178- return config
233+ RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
234+ SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
235+ elif model_type == "Wan-VACE-1.3B" :
236+ config = {
237+ "model_id" : "Wan-AI/Wan2.1-VACE-1.3B" ,
238+ "diffusers_config" : {
239+ "added_kv_proj_dim" : None ,
240+ "attention_head_dim" : 128 ,
241+ "cross_attn_norm" : True ,
242+ "eps" : 1e-06 ,
243+ "ffn_dim" : 8960 ,
244+ "freq_dim" : 256 ,
245+ "in_channels" : 16 ,
246+ "num_attention_heads" : 12 ,
247+ "num_layers" : 30 ,
248+ "out_channels" : 16 ,
249+ "patch_size" : [1 , 2 , 2 ],
250+ "qk_norm" : "rms_norm_across_heads" ,
251+ "text_dim" : 4096 ,
252+ "vace_layers" : [0 , 2 , 4 , 6 , 8 , 10 , 12 , 14 , 16 , 18 , 20 , 22 , 24 , 26 , 28 ],
253+ "vace_in_channels" : 96 ,
254+ },
255+ }
256+ RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
257+ SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
258+ elif model_type == "Wan-VACE-14B" :
259+ config = {
260+ "model_id" : "Wan-AI/Wan2.1-VACE-14B" ,
261+ "diffusers_config" : {
262+ "added_kv_proj_dim" : None ,
263+ "attention_head_dim" : 128 ,
264+ "cross_attn_norm" : True ,
265+ "eps" : 1e-06 ,
266+ "ffn_dim" : 13824 ,
267+ "freq_dim" : 256 ,
268+ "in_channels" : 16 ,
269+ "num_attention_heads" : 40 ,
270+ "num_layers" : 40 ,
271+ "out_channels" : 16 ,
272+ "patch_size" : [1 , 2 , 2 ],
273+ "qk_norm" : "rms_norm_across_heads" ,
274+ "text_dim" : 4096 ,
275+ "vace_layers" : [0 , 5 , 10 , 15 , 20 , 25 , 30 , 35 ],
276+ "vace_in_channels" : 96 ,
277+ },
278+ }
279+ RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
280+ SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
281+ return config , RENAME_DICT , SPECIAL_KEYS_REMAP
179282
180283
181284def convert_transformer (model_type : str ):
182- config = get_transformer_config (model_type )
285+ config , RENAME_DICT , SPECIAL_KEYS_REMAP = get_transformer_config (model_type )
286+
183287 diffusers_config = config ["diffusers_config" ]
184288 model_id = config ["model_id" ]
185289 model_dir = pathlib .Path (snapshot_download (model_id , repo_type = "model" ))
186290
187291 original_state_dict = load_sharded_safetensors (model_dir )
188292
189293 with init_empty_weights ():
190- transformer = WanTransformer3DModel .from_config (diffusers_config )
294+ if "VACE" not in model_type :
295+ transformer = WanTransformer3DModel .from_config (diffusers_config )
296+ else :
297+ transformer = WanVACETransformer3DModel .from_config (diffusers_config )
191298
192299 for key in list (original_state_dict .keys ()):
193300 new_key = key [:]
194- for replace_key , rename_key in TRANSFORMER_KEYS_RENAME_DICT .items ():
301+ for replace_key , rename_key in RENAME_DICT .items ():
195302 new_key = new_key .replace (replace_key , rename_key )
196303 update_state_dict_ (original_state_dict , key , new_key )
197304
198305 for key in list (original_state_dict .keys ()):
199- for special_key , handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP .items ():
306+ for special_key , handler_fn_inplace in SPECIAL_KEYS_REMAP .items ():
200307 if special_key not in key :
201308 continue
202309 handler_fn_inplace (key , original_state_dict )
@@ -412,7 +519,7 @@ def get_args():
412519 parser = argparse .ArgumentParser ()
413520 parser .add_argument ("--model_type" , type = str , default = None )
414521 parser .add_argument ("--output_path" , type = str , required = True )
415- parser .add_argument ("--dtype" , default = "fp32" )
522+ parser .add_argument ("--dtype" , default = "fp32" , choices = [ "fp32" , "fp16" , "bf16" , "none" ] )
416523 return parser .parse_args ()
417524
418525
@@ -426,18 +533,20 @@ def get_args():
426533if __name__ == "__main__" :
427534 args = get_args ()
428535
429- transformer = None
430- dtype = DTYPE_MAPPING [args .dtype ]
431-
432- transformer = convert_transformer (args .model_type ).to (dtype = dtype )
536+ transformer = convert_transformer (args .model_type )
433537 vae = convert_vae ()
434- text_encoder = UMT5EncoderModel .from_pretrained ("google/umt5-xxl" )
538+ text_encoder = UMT5EncoderModel .from_pretrained ("google/umt5-xxl" , torch_dtype = torch . bfloat16 )
435539 tokenizer = AutoTokenizer .from_pretrained ("google/umt5-xxl" )
436540 flow_shift = 16.0 if "FLF2V" in args .model_type else 3.0
437541 scheduler = UniPCMultistepScheduler (
438542 prediction_type = "flow_prediction" , use_flow_sigmas = True , num_train_timesteps = 1000 , flow_shift = flow_shift
439543 )
440544
545+ # If user has specified "none", we keep the original dtypes of the state dict without any conversion
546+ if args .dtype != "none" :
547+ dtype = DTYPE_MAPPING [args .dtype ]
548+ transformer .to (dtype )
549+
441550 if "I2V" in args .model_type or "FLF2V" in args .model_type :
442551 image_encoder = CLIPVisionModelWithProjection .from_pretrained (
443552 "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" , torch_dtype = torch .bfloat16
@@ -452,6 +561,14 @@ def get_args():
452561 image_encoder = image_encoder ,
453562 image_processor = image_processor ,
454563 )
564+ elif "VACE" in args .model_type :
565+ pipe = WanVACEPipeline (
566+ transformer = transformer ,
567+ text_encoder = text_encoder ,
568+ tokenizer = tokenizer ,
569+ vae = vae ,
570+ scheduler = scheduler ,
571+ )
455572 else :
456573 pipe = WanPipeline (
457574 transformer = transformer ,
0 commit comments