Skip to content

Commit 3a407d8

Browse files
KarthikSundar2002Bissmella
authored andcommitted
feat: Adds Test for Unified SP Attention and Fixes a bug in Template Ring Attention
1 parent e0ed41e commit 3a407d8

File tree

2 files changed

+132
-1
lines changed

2 files changed

+132
-1
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1205,7 +1205,7 @@ def backward(
12051205

12061206
grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value))
12071207

1208-
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
1208+
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None
12091209

12101210

12111211
class TemplatedUlyssesAttention(torch.autograd.Function):
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import math
2+
import torch
3+
import torch.distributed as dist
4+
import torch.multiprocessing as mp
5+
from diffusers.models.attention_dispatch import TemplatedUnifiedAttention
6+
import os
7+
8+
def run(rank, world_size):
9+
dist.init_process_group(
10+
backend="gloo",
11+
rank=rank,
12+
world_size=world_size
13+
)
14+
15+
torch.manual_seed(0)
16+
17+
B, S, H, D = 2, 8, 4, 16 # small toy
18+
q = torch.randn(B, S, H, D)
19+
k = torch.randn(B, S, H, D)
20+
v = torch.randn(B, S, H, D)
21+
22+
q.requires_grad_(True)
23+
24+
from diffusers.models._modeling_parallel import (
25+
ParallelConfig,
26+
ContextParallelConfig
27+
)
28+
29+
pc = ParallelConfig(
30+
context_parallel_config=ContextParallelConfig(
31+
ring_degree=2,
32+
ulysses_degree=2,
33+
)
34+
)
35+
36+
pc.context_parallel_config.setup(
37+
rank=rank,
38+
world_size=world_size,
39+
device=torch.device("cpu"),
40+
mesh=dist.device_mesh.init_device_mesh("cpu",
41+
(2,2),
42+
mesh_dim_names=["ring", "ulysses"],
43+
)
44+
)
45+
46+
def dummy_forward_op(
47+
ctx,
48+
q,
49+
k,
50+
v,
51+
attn_mask,
52+
dropout_p,
53+
is_causal,
54+
scale,
55+
enable_gqa,
56+
return_lse,
57+
*,
58+
_save_ctx=True,
59+
_parallel_config=None,
60+
):
61+
head_scale = math.sqrt(D)
62+
attn = (q @ k.transpose(-1, -2)) / head_scale
63+
out = attn @ v
64+
lse = torch.logsumexp(attn, dim=-1)
65+
66+
if _save_ctx:
67+
ctx.save_for_backward(q, k, v)
68+
ctx._cached_qkv = []
69+
ctx._cached_iter = 0
70+
71+
if not hasattr(ctx, "_cached_qkv"):
72+
ctx._cached_qkv = []
73+
74+
ctx._cached_qkv.append((q.detach(), k.detach(), v.detach()))
75+
76+
return (out, lse) if return_lse else out
77+
78+
def dummy_backward_op(ctx, grad_out, *args, **kwargs):
79+
if not hasattr(ctx, "_cached_qkv"):
80+
raise RuntimeError("No cached tensors for backward.")
81+
82+
if not hasattr(ctx, "_cached_iter"):
83+
ctx._cached_iter = 0
84+
85+
if ctx._cached_iter >= len(ctx._cached_qkv):
86+
raise RuntimeError("Backward called more times than cached forwards.")
87+
88+
q, k, v = ctx._cached_qkv[ctx._cached_iter]
89+
ctx._cached_iter += 1
90+
91+
head_scale = math.sqrt(D)
92+
attn = (q @ k.transpose(-1, -2)) / head_scale
93+
94+
grad_v = attn.transpose(-1, -2) @ grad_out
95+
grad_attn = grad_out @ v.transpose(-1, -2)
96+
grad_q = (grad_attn @ k) / head_scale
97+
grad_k = (grad_attn.transpose(-1, -2) @ q) / head_scale
98+
99+
return (
100+
grad_q,
101+
grad_k,
102+
grad_v,
103+
)
104+
105+
attn = TemplatedUnifiedAttention()
106+
107+
out = attn(
108+
None,
109+
q, k, v, None,
110+
dropout_p=0.0,
111+
is_causal=False,
112+
scale=None,
113+
enable_gqa=False,
114+
return_lse=False,
115+
forward_op=dummy_forward_op,
116+
backward_op=dummy_backward_op,
117+
_parallel_config=pc,
118+
)
119+
120+
print(f"[RANK {rank}] output:", out.shape)
121+
122+
out.sum().backward()
123+
print(f"[RANK {rank}] grad:", q.grad.shape)
124+
125+
dist.destroy_process_group()
126+
127+
if __name__ == "__main__":
128+
world_size = 4
129+
os.environ["MASTER_ADDR"] = "localhost"
130+
os.environ["MASTER_PORT"] = "12355"
131+
mp.spawn(run, args=(world_size,), nprocs=world_size)

0 commit comments

Comments
 (0)