Skip to content

Commit 978281d

Browse files
author
XCLiu
committed
clean up test
1 parent 2e7bc12 commit 978281d

16 files changed

+102
-108
lines changed

examples/community/README.md

100644100755
File mode changed.

examples/community/dps_pipeline.py

100644100755
File mode changed.

examples/community/latent_consistency_txt2img.py

100644100755
File mode changed.

examples/community/one_step_unet.py

100644100755
File mode changed.

examples/community/sd_text2img_k_diffusion.py

100644100755
File mode changed.

examples/community/stable_diffusion_tensorrt_img2img.py

100644100755
File mode changed.

examples/community/stable_diffusion_tensorrt_inpaint.py

100644100755
File mode changed.

examples/community/stable_diffusion_tensorrt_txt2img.py

100644100755
File mode changed.

scripts/convert_dance_diffusion_to_diffusers.py

100644100755
File mode changed.

src/diffusers/models/attention.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def __init__(
192192
super().__init__()
193193

194194
# Define 3 blocks. Each block has its own normalization layer.
195+
# NOTE: when new version comes, chech norm2 and norm 3
195196
# 1. Self-Attn
196197
self.norm1 = FP32_Layernorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
197198

@@ -208,7 +209,7 @@ def __init__(
208209
)
209210

210211
# 2. Cross-Attn
211-
self.norm3 = FP32_Layernorm(dim, norm_eps, norm_elementwise_affine)
212+
self.norm2 = FP32_Layernorm(dim, norm_eps, norm_elementwise_affine)
212213

213214
self.attn2 = Attention(
214215
query_dim=dim,
@@ -221,9 +222,7 @@ def __init__(
221222
processor= HunyuanAttnProcessor2_0(),
222223
)
223224
# 3. Feed-forward
224-
self.norm2 = FP32_Layernorm(dim, norm_eps, norm_elementwise_affine)
225-
226-
### TODO: switch norm2 and norm3 in the state dict
225+
self.norm3 = FP32_Layernorm(dim, norm_eps, norm_elementwise_affine)
227226

228227
self.ff = FeedForward(
229228
dim,
@@ -283,13 +282,13 @@ def forward(
283282

284283
# 2. Cross-Attention
285284
hidden_states = hidden_states + self.attn2(
286-
self.norm3(hidden_states),
285+
self.norm2(hidden_states),
287286
encoder_hidden_states = encoder_hidden_states,
288287
temb = freq_cis_img,
289288
)
290289

291290
# FFN Layer ### TODO: switch norm2 and norm3 in the state dict
292-
mlp_inputs = self.norm2(hidden_states)
291+
mlp_inputs = self.norm3(hidden_states)
293292
hidden_states = hidden_states + self.ff(mlp_inputs)
294293

295294
return hidden_states

src/diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py

100644100755
File mode changed.

src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py

100644100755
File mode changed.

test_hunyuan_dit.py

Lines changed: 97 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,107 @@
1+
# integration test (hunyuan dit)
12
import torch
23
from diffusers import HunyuanDiTPipeline
34

4-
pipe = HunyuanDiTPipeline.from_pretrained("XCLiu/HunyuanDiT-0523", torch_dtype=torch.float32)
5+
import torch
6+
from huggingface_hub import hf_hub_download
7+
from diffusers import HunyuanDiT2DModel
8+
import safetensors.torch
9+
10+
device = "cuda"
11+
model_config = HunyuanDiT2DModel.load_config("XCLiu/HunyuanDiT-0523", subfolder="transformer")
12+
model = HunyuanDiT2DModel.from_config(model_config).to(device)
13+
14+
ckpt_path = hf_hub_download(
15+
"XCLiu/HunyuanDiT-0523",
16+
filename ="diffusion_pytorch_model.safetensors",
17+
subfolder="transformer",
18+
)
19+
state_dict = safetensors.torch.load_file(ckpt_path)
20+
21+
num_layers = 40
22+
for i in range(num_layers):
23+
24+
# attn1
25+
# Wkqv -> to_q, to_k, to_v
26+
q, k, v = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.weight"], 3, dim=0)
27+
q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.bias"], 3, dim=0)
28+
state_dict[f"blocks.{i}.attn1.to_q.weight"] = q
29+
state_dict[f"blocks.{i}.attn1.to_q.bias"] = q_bias
30+
state_dict[f"blocks.{i}.attn1.to_k.weight"] = k
31+
state_dict[f"blocks.{i}.attn1.to_k.bias"] = k_bias
32+
state_dict[f"blocks.{i}.attn1.to_v.weight"] = v
33+
state_dict[f"blocks.{i}.attn1.to_v.bias"] = v_bias
34+
state_dict.pop(f"blocks.{i}.attn1.Wqkv.weight")
35+
state_dict.pop(f"blocks.{i}.attn1.Wqkv.bias")
36+
37+
# q_norm, k_norm -> norm_q, norm_k
38+
state_dict[f"blocks.{i}.attn1.norm_q.weight"] = state_dict[f"blocks.{i}.attn1.q_norm.weight"]
39+
state_dict[f"blocks.{i}.attn1.norm_q.bias"] = state_dict[f"blocks.{i}.attn1.q_norm.bias"]
40+
state_dict[f"blocks.{i}.attn1.norm_k.weight"] = state_dict[f"blocks.{i}.attn1.k_norm.weight"]
41+
state_dict[f"blocks.{i}.attn1.norm_k.bias"] = state_dict[f"blocks.{i}.attn1.k_norm.bias"]
42+
43+
state_dict.pop(f"blocks.{i}.attn1.q_norm.weight")
44+
state_dict.pop(f"blocks.{i}.attn1.q_norm.bias")
45+
state_dict.pop(f"blocks.{i}.attn1.k_norm.weight")
46+
state_dict.pop(f"blocks.{i}.attn1.k_norm.bias")
47+
48+
# out_proj -> to_out
49+
state_dict[f"blocks.{i}.attn1.to_out.0.weight"] = state_dict[f"blocks.{i}.attn1.out_proj.weight"]
50+
state_dict[f"blocks.{i}.attn1.to_out.0.bias"] = state_dict[f"blocks.{i}.attn1.out_proj.bias"]
51+
state_dict.pop(f"blocks.{i}.attn1.out_proj.weight")
52+
state_dict.pop(f"blocks.{i}.attn1.out_proj.bias")
53+
54+
# attn2
55+
# kq_proj -> to_k, to_v
56+
k, v = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.weight"], 2, dim=0)
57+
k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.bias"], 2, dim=0)
58+
state_dict[f"blocks.{i}.attn2.to_k.weight"] = k
59+
state_dict[f"blocks.{i}.attn2.to_k.bias"] = k_bias
60+
state_dict[f"blocks.{i}.attn2.to_v.weight"] = v
61+
state_dict[f"blocks.{i}.attn2.to_v.bias"] = v_bias
62+
state_dict.pop(f"blocks.{i}.attn2.kv_proj.weight")
63+
state_dict.pop(f"blocks.{i}.attn2.kv_proj.bias")
64+
65+
# q_proj -> to_q
66+
state_dict[f"blocks.{i}.attn2.to_q.weight"] = state_dict[f"blocks.{i}.attn2.q_proj.weight"]
67+
state_dict[f"blocks.{i}.attn2.to_q.bias"] = state_dict[f"blocks.{i}.attn2.q_proj.bias"]
68+
state_dict.pop(f"blocks.{i}.attn2.q_proj.weight")
69+
state_dict.pop(f"blocks.{i}.attn2.q_proj.bias")
70+
71+
# q_norm, k_norm -> norm_q, norm_k
72+
state_dict[f"blocks.{i}.attn2.norm_q.weight"] = state_dict[f"blocks.{i}.attn2.q_norm.weight"]
73+
state_dict[f"blocks.{i}.attn2.norm_q.bias"] = state_dict[f"blocks.{i}.attn2.q_norm.bias"]
74+
state_dict[f"blocks.{i}.attn2.norm_k.weight"] = state_dict[f"blocks.{i}.attn2.k_norm.weight"]
75+
state_dict[f"blocks.{i}.attn2.norm_k.bias"] = state_dict[f"blocks.{i}.attn2.k_norm.bias"]
76+
77+
state_dict.pop(f"blocks.{i}.attn2.q_norm.weight")
78+
state_dict.pop(f"blocks.{i}.attn2.q_norm.bias")
79+
state_dict.pop(f"blocks.{i}.attn2.k_norm.weight")
80+
state_dict.pop(f"blocks.{i}.attn2.k_norm.bias")
81+
82+
# out_proj -> to_out
83+
state_dict[f"blocks.{i}.attn2.to_out.0.weight"] = state_dict[f"blocks.{i}.attn2.out_proj.weight"]
84+
state_dict[f"blocks.{i}.attn2.to_out.0.bias"] = state_dict[f"blocks.{i}.attn2.out_proj.bias"]
85+
state_dict.pop(f"blocks.{i}.attn2.out_proj.weight")
86+
state_dict.pop(f"blocks.{i}.attn2.out_proj.bias")
87+
88+
# switch norm 2 and norm 3
89+
norm2_weight = state_dict[f"blocks.{i}.norm2.weight"]
90+
norm2_bias = state_dict[f"blocks.{i}.norm2.bias"]
91+
state_dict[f"blocks.{i}.norm2.weight"] = state_dict[f"blocks.{i}.norm3.weight"]
92+
state_dict[f"blocks.{i}.norm2.bias"] = state_dict[f"blocks.{i}.norm3.bias"]
93+
state_dict[f"blocks.{i}.norm3.weight"] = norm2_weight
94+
state_dict[f"blocks.{i}.norm3.bias"] = norm2_bias
95+
96+
model.load_state_dict(state_dict)
97+
98+
pipe = HunyuanDiTPipeline.from_pretrained("XCLiu/HunyuanDiT-0523", transformer=model, torch_dtype=torch.float32)
599
pipe.to('cuda')
6100

7101
### NOTE: HunyuanDiT supports both Chinese and English inputs
8102
prompt = "一个宇航员在骑马"
9103
#prompt = "An astronaut riding a horse"
10-
generator=torch.Generator(device="cuda").manual_seed(3456)
104+
generator=torch.Generator(device="cuda").manual_seed(0)
11105
image = pipe(height=1024, width=1024, prompt=prompt, generator=generator).images[0]
12106

13-
image.save("./img.png")
107+
image.save("img.png")

test_hunyuan_dit_yiyi_attention.py

Lines changed: 0 additions & 99 deletions
This file was deleted.

tests/others/test_utils.py

100644100755
File mode changed.

tests/schedulers/test_schedulers.py

100644100755
File mode changed.

0 commit comments

Comments
 (0)