Skip to content

Commit 834bfc6

Browse files
committed
add conversion script
1 parent 32ab1c9 commit 834bfc6

File tree

1 file changed

+130
-13
lines changed

1 file changed

+130
-13
lines changed

scripts/convert_wan_to_diffusers.py

Lines changed: 130 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import argparse
22
import pathlib
3-
from typing import Any, Dict
3+
from typing import Any, Dict, Tuple
44

55
import torch
66
from accelerate import init_empty_weights
@@ -14,6 +14,8 @@
1414
WanImageToVideoPipeline,
1515
WanPipeline,
1616
WanTransformer3DModel,
17+
WanVACEPipeline,
18+
WanVACETransformer3DModel,
1719
)
1820

1921

@@ -59,7 +61,52 @@
5961
"attn2.norm_k_img": "attn2.norm_added_k",
6062
}
6163

64+
VACE_TRANSFORMER_KEYS_RENAME_DICT = {
65+
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
66+
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
67+
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
68+
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
69+
"time_projection.1": "condition_embedder.time_proj",
70+
"head.modulation": "scale_shift_table",
71+
"head.head": "proj_out",
72+
"modulation": "scale_shift_table",
73+
"ffn.0": "ffn.net.0.proj",
74+
"ffn.2": "ffn.net.2",
75+
# Hack to swap the layer names
76+
# The original model calls the norms in following order: norm1, norm3, norm2
77+
# We convert it to: norm1, norm2, norm3
78+
"norm2": "norm__placeholder",
79+
"norm3": "norm2",
80+
"norm__placeholder": "norm3",
81+
# # For the I2V model
82+
# "img_emb.proj.0": "condition_embedder.image_embedder.norm1",
83+
# "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
84+
# "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
85+
# "img_emb.proj.4": "condition_embedder.image_embedder.norm2",
86+
# # for the FLF2V model
87+
# "img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed",
88+
# Add attention component mappings
89+
"self_attn.q": "attn1.to_q",
90+
"self_attn.k": "attn1.to_k",
91+
"self_attn.v": "attn1.to_v",
92+
"self_attn.o": "attn1.to_out.0",
93+
"self_attn.norm_q": "attn1.norm_q",
94+
"self_attn.norm_k": "attn1.norm_k",
95+
"cross_attn.q": "attn2.to_q",
96+
"cross_attn.k": "attn2.to_k",
97+
"cross_attn.v": "attn2.to_v",
98+
"cross_attn.o": "attn2.to_out.0",
99+
"cross_attn.norm_q": "attn2.norm_q",
100+
"cross_attn.norm_k": "attn2.norm_k",
101+
"attn2.to_k_img": "attn2.add_k_proj",
102+
"attn2.to_v_img": "attn2.add_v_proj",
103+
"attn2.norm_k_img": "attn2.norm_added_k",
104+
"before_proj": "proj_in",
105+
"after_proj": "proj_out",
106+
}
107+
62108
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
109+
VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {}
63110

64111

65112
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
@@ -74,7 +121,7 @@ def load_sharded_safetensors(dir: pathlib.Path):
74121
return state_dict
75122

76123

77-
def get_transformer_config(model_type: str) -> Dict[str, Any]:
124+
def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
78125
if model_type == "Wan-T2V-1.3B":
79126
config = {
80127
"model_id": "StevenZhang/Wan2.1-T2V-1.3B-Diff",
@@ -94,6 +141,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
94141
"text_dim": 4096,
95142
},
96143
}
144+
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
145+
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
97146
elif model_type == "Wan-T2V-14B":
98147
config = {
99148
"model_id": "StevenZhang/Wan2.1-T2V-14B-Diff",
@@ -113,6 +162,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
113162
"text_dim": 4096,
114163
},
115164
}
165+
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
166+
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
116167
elif model_type == "Wan-I2V-14B-480p":
117168
config = {
118169
"model_id": "StevenZhang/Wan2.1-I2V-14B-480P-Diff",
@@ -133,6 +184,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
133184
"text_dim": 4096,
134185
},
135186
}
187+
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
188+
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
136189
elif model_type == "Wan-I2V-14B-720p":
137190
config = {
138191
"model_id": "StevenZhang/Wan2.1-I2V-14B-720P-Diff",
@@ -153,6 +206,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
153206
"text_dim": 4096,
154207
},
155208
}
209+
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
210+
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
156211
elif model_type == "Wan-FLF2V-14B-720P":
157212
config = {
158213
"model_id": "ypyp/Wan2.1-FLF2V-14B-720P", # This is just a placeholder
@@ -175,28 +230,80 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
175230
"pos_embed_seq_len": 257 * 2,
176231
},
177232
}
178-
return config
233+
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
234+
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
235+
elif model_type == "Wan-VACE-1.3B":
236+
config = {
237+
"model_id": "Wan-AI/Wan2.1-VACE-1.3B",
238+
"diffusers_config": {
239+
"added_kv_proj_dim": None,
240+
"attention_head_dim": 128,
241+
"cross_attn_norm": True,
242+
"eps": 1e-06,
243+
"ffn_dim": 8960,
244+
"freq_dim": 256,
245+
"in_channels": 16,
246+
"num_attention_heads": 12,
247+
"num_layers": 30,
248+
"out_channels": 16,
249+
"patch_size": [1, 2, 2],
250+
"qk_norm": "rms_norm_across_heads",
251+
"text_dim": 4096,
252+
"vace_layers": [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28],
253+
"vace_in_channels": 96,
254+
},
255+
}
256+
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
257+
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
258+
elif model_type == "Wan-VACE-14B":
259+
config = {
260+
"model_id": "Wan-AI/Wan2.1-VACE-14B",
261+
"diffusers_config": {
262+
"added_kv_proj_dim": None,
263+
"attention_head_dim": 128,
264+
"cross_attn_norm": True,
265+
"eps": 1e-06,
266+
"ffn_dim": 13824,
267+
"freq_dim": 256,
268+
"in_channels": 16,
269+
"num_attention_heads": 40,
270+
"num_layers": 40,
271+
"out_channels": 16,
272+
"patch_size": [1, 2, 2],
273+
"qk_norm": "rms_norm_across_heads",
274+
"text_dim": 4096,
275+
"vace_layers": [0, 5, 10, 15, 20, 25, 30, 35],
276+
"vace_in_channels": 96,
277+
},
278+
}
279+
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
280+
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
281+
return config, RENAME_DICT, SPECIAL_KEYS_REMAP
179282

180283

181284
def convert_transformer(model_type: str):
182-
config = get_transformer_config(model_type)
285+
config, RENAME_DICT, SPECIAL_KEYS_REMAP = get_transformer_config(model_type)
286+
183287
diffusers_config = config["diffusers_config"]
184288
model_id = config["model_id"]
185289
model_dir = pathlib.Path(snapshot_download(model_id, repo_type="model"))
186290

187291
original_state_dict = load_sharded_safetensors(model_dir)
188292

189293
with init_empty_weights():
190-
transformer = WanTransformer3DModel.from_config(diffusers_config)
294+
if "VACE" not in model_type:
295+
transformer = WanTransformer3DModel.from_config(diffusers_config)
296+
else:
297+
transformer = WanVACETransformer3DModel.from_config(diffusers_config)
191298

192299
for key in list(original_state_dict.keys()):
193300
new_key = key[:]
194-
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
301+
for replace_key, rename_key in RENAME_DICT.items():
195302
new_key = new_key.replace(replace_key, rename_key)
196303
update_state_dict_(original_state_dict, key, new_key)
197304

198305
for key in list(original_state_dict.keys()):
199-
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
306+
for special_key, handler_fn_inplace in SPECIAL_KEYS_REMAP.items():
200307
if special_key not in key:
201308
continue
202309
handler_fn_inplace(key, original_state_dict)
@@ -412,7 +519,7 @@ def get_args():
412519
parser = argparse.ArgumentParser()
413520
parser.add_argument("--model_type", type=str, default=None)
414521
parser.add_argument("--output_path", type=str, required=True)
415-
parser.add_argument("--dtype", default="fp32")
522+
parser.add_argument("--dtype", default="fp32", choices=["fp32", "fp16", "bf16", "none"])
416523
return parser.parse_args()
417524

418525

@@ -426,18 +533,20 @@ def get_args():
426533
if __name__ == "__main__":
427534
args = get_args()
428535

429-
transformer = None
430-
dtype = DTYPE_MAPPING[args.dtype]
431-
432-
transformer = convert_transformer(args.model_type).to(dtype=dtype)
536+
transformer = convert_transformer(args.model_type)
433537
vae = convert_vae()
434-
text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl")
538+
text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl", torch_dtype=torch.bfloat16)
435539
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
436540
flow_shift = 16.0 if "FLF2V" in args.model_type else 3.0
437541
scheduler = UniPCMultistepScheduler(
438542
prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift
439543
)
440544

545+
# If user has specified "none", we keep the original dtypes of the state dict without any conversion
546+
if args.dtype != "none":
547+
dtype = DTYPE_MAPPING[args.dtype]
548+
transformer.to(dtype)
549+
441550
if "I2V" in args.model_type or "FLF2V" in args.model_type:
442551
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
443552
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
@@ -452,6 +561,14 @@ def get_args():
452561
image_encoder=image_encoder,
453562
image_processor=image_processor,
454563
)
564+
elif "VACE" in args.model_type:
565+
pipe = WanVACEPipeline(
566+
transformer=transformer,
567+
text_encoder=text_encoder,
568+
tokenizer=tokenizer,
569+
vae=vae,
570+
scheduler=scheduler,
571+
)
455572
else:
456573
pipe = WanPipeline(
457574
transformer=transformer,

0 commit comments

Comments
 (0)