1+ import math
2+ import torch
3+ import torch .distributed as dist
4+ import torch .multiprocessing as mp
5+ from diffusers .models .attention_dispatch import TemplatedUnifiedAttention
6+ import os
7+
8+ def run (rank , world_size ):
9+ dist .init_process_group (
10+ backend = "gloo" ,
11+ rank = rank ,
12+ world_size = world_size
13+ )
14+
15+ torch .manual_seed (0 )
16+
17+ B , S , H , D = 2 , 8 , 4 , 16 # small toy
18+ q = torch .randn (B , S , H , D )
19+ k = torch .randn (B , S , H , D )
20+ v = torch .randn (B , S , H , D )
21+
22+ q .requires_grad_ (True )
23+
24+ from diffusers .models ._modeling_parallel import (
25+ ParallelConfig ,
26+ ContextParallelConfig
27+ )
28+
29+ pc = ParallelConfig (
30+ context_parallel_config = ContextParallelConfig (
31+ ring_degree = 2 ,
32+ ulysses_degree = 2 ,
33+ )
34+ )
35+
36+ pc .context_parallel_config .setup (
37+ rank = rank ,
38+ world_size = world_size ,
39+ device = torch .device ("cpu" ),
40+ mesh = dist .device_mesh .init_device_mesh ("cpu" ,
41+ (2 ,2 ),
42+ mesh_dim_names = ["ring" , "ulysses" ],
43+ )
44+ )
45+
46+ def dummy_forward_op (
47+ ctx ,
48+ q ,
49+ k ,
50+ v ,
51+ attn_mask ,
52+ dropout_p ,
53+ is_causal ,
54+ scale ,
55+ enable_gqa ,
56+ return_lse ,
57+ * ,
58+ _save_ctx = True ,
59+ _parallel_config = None ,
60+ ):
61+ head_scale = math .sqrt (D )
62+ attn = (q @ k .transpose (- 1 , - 2 )) / head_scale
63+ out = attn @ v
64+ lse = torch .logsumexp (attn , dim = - 1 )
65+
66+ if _save_ctx :
67+ ctx .save_for_backward (q , k , v )
68+ ctx ._cached_qkv = []
69+ ctx ._cached_iter = 0
70+
71+ if not hasattr (ctx , "_cached_qkv" ):
72+ ctx ._cached_qkv = []
73+
74+ ctx ._cached_qkv .append ((q .detach (), k .detach (), v .detach ()))
75+
76+ return (out , lse ) if return_lse else out
77+
78+ def dummy_backward_op (ctx , grad_out , * args , ** kwargs ):
79+ if not hasattr (ctx , "_cached_qkv" ):
80+ raise RuntimeError ("No cached tensors for backward." )
81+
82+ if not hasattr (ctx , "_cached_iter" ):
83+ ctx ._cached_iter = 0
84+
85+ if ctx ._cached_iter >= len (ctx ._cached_qkv ):
86+ raise RuntimeError ("Backward called more times than cached forwards." )
87+
88+ q , k , v = ctx ._cached_qkv [ctx ._cached_iter ]
89+ ctx ._cached_iter += 1
90+
91+ head_scale = math .sqrt (D )
92+ attn = (q @ k .transpose (- 1 , - 2 )) / head_scale
93+
94+ grad_v = attn .transpose (- 1 , - 2 ) @ grad_out
95+ grad_attn = grad_out @ v .transpose (- 1 , - 2 )
96+ grad_q = (grad_attn @ k ) / head_scale
97+ grad_k = (grad_attn .transpose (- 1 , - 2 ) @ q ) / head_scale
98+
99+ return (
100+ grad_q ,
101+ grad_k ,
102+ grad_v ,
103+ )
104+
105+ attn = TemplatedUnifiedAttention ()
106+
107+ out = attn (
108+ None ,
109+ q , k , v , None ,
110+ dropout_p = 0.0 ,
111+ is_causal = False ,
112+ scale = None ,
113+ enable_gqa = False ,
114+ return_lse = False ,
115+ forward_op = dummy_forward_op ,
116+ backward_op = dummy_backward_op ,
117+ _parallel_config = pc ,
118+ )
119+
120+ print (f"[RANK { rank } ] output:" , out .shape )
121+
122+ out .sum ().backward ()
123+ print (f"[RANK { rank } ] grad:" , q .grad .shape )
124+
125+ dist .destroy_process_group ()
126+
127+ if __name__ == "__main__" :
128+ world_size = 4
129+ os .environ ["MASTER_ADDR" ] = "localhost"
130+ os .environ ["MASTER_PORT" ] = "12355"
131+ mp .spawn (run , args = (world_size ,), nprocs = world_size )
0 commit comments