1
+ # integration test (hunyuan dit)
1
2
import torch
2
3
from diffusers import HunyuanDiTPipeline
3
4
4
- pipe = HunyuanDiTPipeline .from_pretrained ("XCLiu/HunyuanDiT-0523" , torch_dtype = torch .float32 )
5
+ import torch
6
+ from huggingface_hub import hf_hub_download
7
+ from diffusers import HunyuanDiT2DModel
8
+ import safetensors .torch
9
+
10
+ device = "cuda"
11
+ model_config = HunyuanDiT2DModel .load_config ("XCLiu/HunyuanDiT-0523" , subfolder = "transformer" )
12
+ model = HunyuanDiT2DModel .from_config (model_config ).to (device )
13
+
14
+ ckpt_path = hf_hub_download (
15
+ "XCLiu/HunyuanDiT-0523" ,
16
+ filename = "diffusion_pytorch_model.safetensors" ,
17
+ subfolder = "transformer" ,
18
+ )
19
+ state_dict = safetensors .torch .load_file (ckpt_path )
20
+
21
+ num_layers = 40
22
+ for i in range (num_layers ):
23
+
24
+ # attn1
25
+ # Wkqv -> to_q, to_k, to_v
26
+ q , k , v = torch .chunk (state_dict [f"blocks.{ i } .attn1.Wqkv.weight" ], 3 , dim = 0 )
27
+ q_bias , k_bias , v_bias = torch .chunk (state_dict [f"blocks.{ i } .attn1.Wqkv.bias" ], 3 , dim = 0 )
28
+ state_dict [f"blocks.{ i } .attn1.to_q.weight" ] = q
29
+ state_dict [f"blocks.{ i } .attn1.to_q.bias" ] = q_bias
30
+ state_dict [f"blocks.{ i } .attn1.to_k.weight" ] = k
31
+ state_dict [f"blocks.{ i } .attn1.to_k.bias" ] = k_bias
32
+ state_dict [f"blocks.{ i } .attn1.to_v.weight" ] = v
33
+ state_dict [f"blocks.{ i } .attn1.to_v.bias" ] = v_bias
34
+ state_dict .pop (f"blocks.{ i } .attn1.Wqkv.weight" )
35
+ state_dict .pop (f"blocks.{ i } .attn1.Wqkv.bias" )
36
+
37
+ # q_norm, k_norm -> norm_q, norm_k
38
+ state_dict [f"blocks.{ i } .attn1.norm_q.weight" ] = state_dict [f"blocks.{ i } .attn1.q_norm.weight" ]
39
+ state_dict [f"blocks.{ i } .attn1.norm_q.bias" ] = state_dict [f"blocks.{ i } .attn1.q_norm.bias" ]
40
+ state_dict [f"blocks.{ i } .attn1.norm_k.weight" ] = state_dict [f"blocks.{ i } .attn1.k_norm.weight" ]
41
+ state_dict [f"blocks.{ i } .attn1.norm_k.bias" ] = state_dict [f"blocks.{ i } .attn1.k_norm.bias" ]
42
+
43
+ state_dict .pop (f"blocks.{ i } .attn1.q_norm.weight" )
44
+ state_dict .pop (f"blocks.{ i } .attn1.q_norm.bias" )
45
+ state_dict .pop (f"blocks.{ i } .attn1.k_norm.weight" )
46
+ state_dict .pop (f"blocks.{ i } .attn1.k_norm.bias" )
47
+
48
+ # out_proj -> to_out
49
+ state_dict [f"blocks.{ i } .attn1.to_out.0.weight" ] = state_dict [f"blocks.{ i } .attn1.out_proj.weight" ]
50
+ state_dict [f"blocks.{ i } .attn1.to_out.0.bias" ] = state_dict [f"blocks.{ i } .attn1.out_proj.bias" ]
51
+ state_dict .pop (f"blocks.{ i } .attn1.out_proj.weight" )
52
+ state_dict .pop (f"blocks.{ i } .attn1.out_proj.bias" )
53
+
54
+ # attn2
55
+ # kq_proj -> to_k, to_v
56
+ k , v = torch .chunk (state_dict [f"blocks.{ i } .attn2.kv_proj.weight" ], 2 , dim = 0 )
57
+ k_bias , v_bias = torch .chunk (state_dict [f"blocks.{ i } .attn2.kv_proj.bias" ], 2 , dim = 0 )
58
+ state_dict [f"blocks.{ i } .attn2.to_k.weight" ] = k
59
+ state_dict [f"blocks.{ i } .attn2.to_k.bias" ] = k_bias
60
+ state_dict [f"blocks.{ i } .attn2.to_v.weight" ] = v
61
+ state_dict [f"blocks.{ i } .attn2.to_v.bias" ] = v_bias
62
+ state_dict .pop (f"blocks.{ i } .attn2.kv_proj.weight" )
63
+ state_dict .pop (f"blocks.{ i } .attn2.kv_proj.bias" )
64
+
65
+ # q_proj -> to_q
66
+ state_dict [f"blocks.{ i } .attn2.to_q.weight" ] = state_dict [f"blocks.{ i } .attn2.q_proj.weight" ]
67
+ state_dict [f"blocks.{ i } .attn2.to_q.bias" ] = state_dict [f"blocks.{ i } .attn2.q_proj.bias" ]
68
+ state_dict .pop (f"blocks.{ i } .attn2.q_proj.weight" )
69
+ state_dict .pop (f"blocks.{ i } .attn2.q_proj.bias" )
70
+
71
+ # q_norm, k_norm -> norm_q, norm_k
72
+ state_dict [f"blocks.{ i } .attn2.norm_q.weight" ] = state_dict [f"blocks.{ i } .attn2.q_norm.weight" ]
73
+ state_dict [f"blocks.{ i } .attn2.norm_q.bias" ] = state_dict [f"blocks.{ i } .attn2.q_norm.bias" ]
74
+ state_dict [f"blocks.{ i } .attn2.norm_k.weight" ] = state_dict [f"blocks.{ i } .attn2.k_norm.weight" ]
75
+ state_dict [f"blocks.{ i } .attn2.norm_k.bias" ] = state_dict [f"blocks.{ i } .attn2.k_norm.bias" ]
76
+
77
+ state_dict .pop (f"blocks.{ i } .attn2.q_norm.weight" )
78
+ state_dict .pop (f"blocks.{ i } .attn2.q_norm.bias" )
79
+ state_dict .pop (f"blocks.{ i } .attn2.k_norm.weight" )
80
+ state_dict .pop (f"blocks.{ i } .attn2.k_norm.bias" )
81
+
82
+ # out_proj -> to_out
83
+ state_dict [f"blocks.{ i } .attn2.to_out.0.weight" ] = state_dict [f"blocks.{ i } .attn2.out_proj.weight" ]
84
+ state_dict [f"blocks.{ i } .attn2.to_out.0.bias" ] = state_dict [f"blocks.{ i } .attn2.out_proj.bias" ]
85
+ state_dict .pop (f"blocks.{ i } .attn2.out_proj.weight" )
86
+ state_dict .pop (f"blocks.{ i } .attn2.out_proj.bias" )
87
+
88
+ # switch norm 2 and norm 3
89
+ norm2_weight = state_dict [f"blocks.{ i } .norm2.weight" ]
90
+ norm2_bias = state_dict [f"blocks.{ i } .norm2.bias" ]
91
+ state_dict [f"blocks.{ i } .norm2.weight" ] = state_dict [f"blocks.{ i } .norm3.weight" ]
92
+ state_dict [f"blocks.{ i } .norm2.bias" ] = state_dict [f"blocks.{ i } .norm3.bias" ]
93
+ state_dict [f"blocks.{ i } .norm3.weight" ] = norm2_weight
94
+ state_dict [f"blocks.{ i } .norm3.bias" ] = norm2_bias
95
+
96
+ model .load_state_dict (state_dict )
97
+
98
+ pipe = HunyuanDiTPipeline .from_pretrained ("XCLiu/HunyuanDiT-0523" , transformer = model , torch_dtype = torch .float32 )
5
99
pipe .to ('cuda' )
6
100
7
101
### NOTE: HunyuanDiT supports both Chinese and English inputs
8
102
prompt = "一个宇航员在骑马"
9
103
#prompt = "An astronaut riding a horse"
10
- generator = torch .Generator (device = "cuda" ).manual_seed (3456 )
104
+ generator = torch .Generator (device = "cuda" ).manual_seed (0 )
11
105
image = pipe (height = 1024 , width = 1024 , prompt = prompt , generator = generator ).images [0 ]
12
106
13
- image .save ("./ img.png" )
107
+ image .save ("img.png" )
0 commit comments