@@ -121,65 +121,81 @@ def manual_attn(q, k, v):
121121@gpu_func (emit = True )
122122@canonicalize (using = [scf .canonicalizer , arith .canonicalizer ])
123123def flash_attention (
124- Q : T .memref (B * nh * N * d , T .f32 ()),
124+ Q : T .memref (B , nh , N , d , T .f32 ()),
125125 K : T .memref (B , nh , N , d , T .f32 ()),
126- V : T .memref (B * nh * N * d , T .f32 ()),
127- l : T .memref (B * nh * N , T .f32 ()),
128- m : T .memref (B * nh * N , T .f32 ()),
129- O : T .memref (B * nh * N * d , T .f32 ()),
126+ V : T .memref (B , nh , N , d , T .f32 ()),
127+ l : T .memref (B , nh , N , T .f32 ()),
128+ m : T .memref (B , nh , N , T .f32 ()),
129+ O : T .memref (B , nh , N , d , T .f32 ()),
130130):
131131 tx = thread_idx .x
132132 # batch idx, head_idx
133133 bx , by = block_idx .x , block_idx .y
134134 # gpu.printf("bx %ld, by %ld\n", bx, by)
135135
136136 # Offset into Q,K,V,O,l,m - different for each batch and head
137- qkv_offset = bx * nh * N * d + by * N * d
138- lm_offset = bx * nh * N + by * N # offset for l and m
137+ K_ = K [bx , by , :, :]
138+ V_ = V [bx , by , :, :]
139+ Q_ = Q [bx , by , :, :]
140+ O_ = O [bx , by , :, :]
141+ l_ = l [bx , by , :]
142+ m_ = m [bx , by , :]
139143
140144 # Define SRAM for Q,K,V,S
141145 sram = gpu .dynamic_shared_memory ()
142- Qi = memref .view (sram , (Br * d ,), dtype = T .f32 ())
143- Kj = memref .view (sram , (Bc * d ,), dtype = T .f32 (), shift = Qi .n_elements )
146+ Qi = memref .view (
147+ sram ,
148+ (Br , d ),
149+ dtype = T .f32 (),
150+ )
151+ Kj = memref .view (
152+ sram ,
153+ (Bc , d ),
154+ dtype = T .f32 (),
155+ shift = Qi .n_elements ,
156+ )
144157 Vj = memref .view (
145- sram , (Bc * d ,), dtype = T .f32 (), shift = Qi .n_elements + Kj .n_elements
158+ sram ,
159+ (Bc , d ),
160+ dtype = T .f32 (),
161+ shift = Qi .n_elements + Kj .n_elements ,
146162 )
147163 S = memref .view (
148164 sram ,
149- (Br * Bc , ),
165+ (Br , Bc ),
150166 dtype = T .f32 (),
151167 shift = Qi .n_elements + Kj .n_elements + Vj .n_elements ,
152168 )
153169
154- # K_ = memref.reinterpret_cast(K, [0], [B, nh, N, d])
155- # K_ = K_[bx : bx + 1, by : by + 1, :, :]
156- for j in scf .range_ (0 , Tc ):
157- K_ = K [bx , by , :, :]
170+ for bc in scf .range_ (0 , N , Bc ):
158171 # Load Kj, Vj to SRAM
172+ K_ = K_ [:, :, bc : bc + 1 , :]
173+ V_ = V_ [:, :, bc : bc + 1 , :]
159174 for x in scf .range_ (0 , d ):
160- # Kj[tx * d + x] = K[qkv_offset + Bc * d * j + tx * d + x]
161- K_ = K_ [:, :, j * Bc : (j + 1 ) * Bc , :]
162- Vj [tx * d + x ] = V [qkv_offset + Bc * d * j + tx * d + x ]
175+ Kj [tx , x ] = K_ [0 , 0 , tx , x ]
176+ Vj [tx , x ] = V_ [0 , 0 , tx , x ]
163177
164- for i in scf .range_ (0 , Tr ):
178+ for br in scf .range_ (0 , N , Br ):
165179 # Load Qi to SRAM, l and m to registers
180+ Q_ = Q_ [:, :, br : br + 1 , :]
166181 for x in scf .range_ (0 , d ):
167- ii = qkv_offset + Bc * d * i + tx * d + x
168- Qi [tx * d + x ] = Q [ii ]
182+ Qi [tx , x ] = Q_ [0 , 0 , tx , x ]
169183
170- row_m_prev = m [lm_offset + Br * i + tx ]
171- row_l_prev = l [lm_offset + Br * i + tx ]
184+ l_ = l_ [:, :, br : br + 1 ]
185+ m_ = m_ [:, :, br : br + 1 ]
186+ row_l_prev = l_ [0 , 0 , tx ]
187+ row_m_prev = m_ [0 , 0 , tx ]
172188
173189 # S = QK^T, row_m = rowmax(S)
174190 row_m : T .f32 () = float (np .finfo (np .float32 ).min )
175191 for y , row_m , _ in scf .range_ (0 , Bc , iter_args = [row_m ]):
176192 sum : T .f32 () = 0.0
177193 for x , sum , _ in scf .range_ (0 , d , iter_args = [sum ]):
178- sum += Qi [tx * d + x ] * Kj [y * d + x ]
194+ sum += Qi [tx , x ] * Kj [y , x ]
179195 sum = yield sum
180196
181197 sum *= softmax_scale
182- S [Bc * tx + y ] = sum
198+ S [tx , y ] = sum
183199
184200 if sum > row_m :
185201 row_m_ = yield sum
@@ -191,8 +207,8 @@ def flash_attention(
191207 # P = exp(S - row_m), row_l = rowsum(P)
192208 row_l : T .f32 () = 0.0
193209 for y , row_l , _ in scf .range_ (0 , Bc , iter_args = [row_l ]):
194- S [Bc * tx + y ] = math .exp (S [Bc * tx + y ] - row_m )
195- row_l += S [Bc * tx + y ]
210+ S [tx , y ] = math .exp (S [tx , y ] - row_m )
211+ row_l += S [tx , y ]
196212 row_l = yield row_l
197213
198214 # Compute new m and l
@@ -205,24 +221,26 @@ def flash_attention(
205221 c = row_l_prev * math .exp (row_m_prev - row_m_new )
206222
207223 # Write O, l, m to HBM
224+ O_ = O_ [:, :, br : br + 1 , :]
208225 for x in scf .range_ (0 , d ):
209226 pv : T .f32 () = 0.0 # Pij * Vj
210227 for y , pv , _ in scf .range_ (0 , Bc , iter_args = [pv ]):
211- pv += S [Bc * tx + y ] * Vj [y * d + x ]
228+ pv += S [tx , y ] * Vj [y , x ]
212229 pv = yield pv
213230
214- ii = qkv_offset + Bc * d * i + tx * d + x
215- O [ii ] = div * (c * O [ii ] + math .exp (row_m - row_m_new ) * pv )
231+ O_ [0 , 0 , tx , x ] = div * (
232+ c * O_ [0 , 0 , tx , x ] + math .exp (row_m - row_m_new ) * pv
233+ )
216234
217- m [ lm_offset + Br * i + tx ] = row_m_new
218- l [ lm_offset + Br * i + tx ] = row_l_new
235+ l_ [ 0 , 0 , tx ] = row_l_new
236+ m_ [ 0 , 0 , tx ] = row_m_new
219237
220238 gpu .barrier ()
221239
222240
223241ip .__exit__ (None , None , None )
224242
225- print (gpu_module )
243+ # print(gpu_module)
226244
227245sram_size = 4 * Bc * d * np .float32 ().itemsize
228246
0 commit comments