@@ -118,54 +118,39 @@ struct FlashPrefillMma<
118118 using ElementAccumulator = typename TiledMmaQK::ValTypeC;
119119 static constexpr bool CausalMask = CausalMask_;
120120 static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize;
121-
122121 using MmaAtomShape = typename MmaAtom::Shape_MNK;
123-
124122 static constexpr auto PV_ATOM_M =
125123 decltype (get<0 >(SubgroupLayout{}.shape()))::value;
126124 static constexpr auto PV_ATOM_N =
127125 decltype (get<1 >(SubgroupLayout{}.shape()))::value;
128126 static constexpr auto PV_ATOM_K =
129127 decltype (get<2 >(SubgroupLayout{}.shape()))::value;
130-
131128 using SubgroupTileShapePV =
132129 decltype (cute::shape_div(TileShapePV{}, (SubgroupLayout{}.shape())));
133-
134130 static constexpr auto QK_BLK_M = get<0 >(TileShapeQK{});
135131 static constexpr auto QK_BLK_N = get<1 >(TileShapeQK{});
136132 static constexpr auto QK_BLK_K = get<2 >(TileShapeQK{});
137-
138- // This TiledMma is only required to serve the specific tiling requirements
139- // for matrix K. This is due to the consumption of matrix K by all subgroups
140- // within a workgroup.
141- static constexpr auto QK_ATOM_M = PV_ATOM_M; // 8
142- static constexpr auto QK_ATOM_N = PV_ATOM_N; // 1
143- static constexpr auto QK_ATOM_K = PV_ATOM_K; // 1
144-
145- using SubgroupTileShapeQK = decltype (cute::shape_div(
146- TileShapeQK{},
147- SubgroupLayout{}.shape())); // 128, 64, 32 / 16, 1, 1 = (8, 64, 32 )
148-
133+ static constexpr auto QK_ATOM_M = PV_ATOM_M;
134+ static constexpr auto QK_ATOM_N = PV_ATOM_N;
135+ static constexpr auto QK_ATOM_K = PV_ATOM_K;
136+ using SubgroupTileShapeQK =
137+ decltype (cute::shape_div(TileShapeQK{}, SubgroupLayout{}.shape()));
149138 static constexpr auto QK_SG_M = get<0 >(SubgroupTileShapeQK{});
150139 static constexpr auto QK_SG_N = get<1 >(SubgroupTileShapeQK{});
151140 static constexpr auto QK_SG_K = get<2 >(SubgroupTileShapeQK{});
152-
153141 static constexpr bool is_var_len =
154142 cutlass::fmha::collective::is_variable_length_v<
155143 tuple_element_t <3 , ProblemShapeType>>;
156-
157144 using FragsShapeS = decltype (cute::shape_div(
158145 take<0 , 2 >(SubgroupTileShapeQK{}),
159- take<0 , 2 >(MmaAtomShape()))); // 8, 64, 32 / 8, 16, 16 (1, 4)
146+ take<0 , 2 >(MmaAtomShape())));
160147 static constexpr int Vec =
161- (get<0 >(MmaAtomShape()) * get<1 >(MmaAtomShape())) / SubgroupSize; // 8
148+ (get<0 >(MmaAtomShape()) * get<1 >(MmaAtomShape())) / SubgroupSize;
162149 static constexpr int FragsM = get<0 >(FragsShapeS{});
163- static constexpr int FragsNS = get<1 >(FragsShapeS{}); // 4
164-
150+ static constexpr int FragsNS = get<1 >(FragsShapeS{});
165151 static constexpr uint32_t MaxThreadsPerBlock =
166152 size (SubgroupLayout{}) * SubgroupSize;
167153 using CopyThreadShape = Shape<_1, Int<SubgroupSize>>;
168-
169154 using traits_load_Q = Copy_Traits<GmemTiledCopyQ, StrideQ>;
170155 using atom_load_Q = Copy_Atom<traits_load_Q, ElementQ>;
171156 using val_layout_load_Q = decltype (make_layout(
@@ -174,7 +159,6 @@ struct FlashPrefillMma<
174159 atom_load_Q{},
175160 Layout<CopyThreadShape>{},
176161 val_layout_load_Q{}));
177-
178162 using traits_load_K = Copy_Traits<GmemTiledCopyK, StrideK>;
179163 using atom_load_K = Copy_Atom<traits_load_K, ElementK>;
180164 using val_layout_load_K = decltype (make_layout(
@@ -183,7 +167,6 @@ struct FlashPrefillMma<
183167 atom_load_K{},
184168 Layout<CopyThreadShape>{},
185169 val_layout_load_K{}));
186-
187170 using traits_load_V = Copy_Traits<GmemTiledCopyV, StrideV>;
188171 using atom_load_V = Copy_Atom<traits_load_V, ElementV>;
189172 using val_layout_load_V = decltype (make_layout(
@@ -195,6 +178,7 @@ struct FlashPrefillMma<
195178 template <typename T>
196179 static constexpr bool is_fp8_v =
197180 cute::is_same_v<T, float_e4m3_t > || cute::is_same_v<T, float_e5m2_t >;
181+
198182 // Host side kernel arguments
199183 struct Arguments {
200184 ElementQ const * ptr_Q;
@@ -222,7 +206,6 @@ struct FlashPrefillMma<
222206 Arguments const & args,
223207 void * workspace) {
224208 (void )workspace;
225-
226209 auto
227210 [batch,
228211 num_heads_q,
@@ -231,7 +214,6 @@ struct FlashPrefillMma<
231214 seq_len_kv,
232215 head_size_qk,
233216 head_size_vo] = problem_shape;
234-
235217 auto tensorQ = make_tensor (
236218 make_gmem_ptr (args.ptr_Q ),
237219 make_layout (
@@ -250,7 +232,6 @@ struct FlashPrefillMma<
250232 XE_Copy_Q copyQ{XE_Copy_Q{}.with (tensorQ)};
251233 XE_Copy_K copyK{XE_Copy_K{}.with (tensorK)};
252234 XE_Copy_V copyV{XE_Copy_V{}.with (tensorV)};
253-
254235 return Params{copyQ, copyK, copyV};
255236 }
256237
@@ -265,22 +246,16 @@ struct FlashPrefillMma<
265246 int thread_idx = static_cast <int >(ThreadIdxX ());
266247 auto thr_copy_Q = params.gmem_tiled_copy_q .get_slice (thread_idx);
267248 auto thr_copy_K = params.gmem_tiled_copy_k .get_slice (thread_idx);
268- // Instantiate the MMA object
269249 TiledMmaQK tiled_mma;
270- // To make all threads in a warp have the same global tensors pass in the
271- // index of thread 0 in each warp
272250 auto sg = compat::get_nd_item<1 >().get_sub_group ();
273251 auto first_thread_in_sg_idx =
274252 sg.get_group_id ()[0 ] * DispatchPolicy::SubgroupSize;
275253 auto thread_mma_q = tiled_mma.get_slice (first_thread_in_sg_idx);
276254 auto thread_mma_k = tiled_mma.get_slice (0 );
277-
278255 // Partition
279256 Tensor tCgQ = thread_mma_q.partition_A (gQ );
280257 Tensor tCgK = thread_mma_k.partition_B (gK );
281-
282258 // Create fragments
283- // TODO(Codeplay): fix this, this is probably not general
284259 using TCrQ_Type =
285260 cute::conditional_t <is_fp8_v<ElementQ>, uint8_t , ElementQ>;
286261 using TCrK_Type =
@@ -289,68 +264,18 @@ struct FlashPrefillMma<
289264 params.gmem_tiled_copy_q , take<0 , 3 >(tCgQ.shape ())));
290265 Tensor tCrK = make_tensor<TCrK_Type>(make_fragment_layout (
291266 params.gmem_tiled_copy_k , take<0 , 3 >(tCgK.shape ())));
292-
293267 // Retile registers for copies
294268 Tensor tQrQ = thr_copy_Q.retile_D (tCrQ);
295269 Tensor tKrK = thr_copy_K.retile_D (tCrK);
296-
297270 // Retile global tile for copies
298271 Tensor tQgQ = thr_copy_Q.retile_S (tCgQ);
299272 Tensor tKgK = thr_copy_K.retile_S (tCgK);
300273
301- #if CUTLASS_ENABLE_DEBUG_PRINTS
302- #define PRINT (x ) \
303- print (#x " : " ); \
304- print (x); \
305- print (" \n " );
306- if (cute::thread (LOG_THREAD, LOG_GROUP)) {
307- print (" ======================= Q: \n " );
308- PRINT (gQ );
309- PRINT (tCrQ);
310- PRINT (tCgQ);
311- PRINT (tQrQ);
312- PRINT (tQgQ);
313-
314- print (" ===================== K :\n " );
315- PRINT (gK );
316- PRINT (tCrK);
317- PRINT (tCgK);
318- PRINT (tKrK);
319- PRINT (tKgK);
320-
321- print (" ===================== Config: \n " );
322- PRINT (MaxThreadsPerBlock);
323- PRINT (SubgroupTileShapeQK{});
324- }
325- #undef PRINT
326- #endif
327-
328- //
329274 // Mainloop
330- //
331-
332275 for (int k_tile = 0 ; k_tile < k_tile_count; ++k_tile) {
333276 copy (params.gmem_tiled_copy_q , tQgQ (_, _, _, k_tile), tQrQ);
334277 copy (params.gmem_tiled_copy_k , tKgK (_, _, _, k_tile), tKrK);
335- if constexpr (is_fp8_v<ElementQ> && is_fp8_v<ElementK>) {
336- auto tCrQ_ = make_fragment_like<half_t >(tCrQ);
337- convert_FP8_to_FP16<ElementQ>(tCrQ, tCrQ_);
338- auto tCrK_ = make_fragment_like<half_t >(tCrK);
339- convert_FP8_to_FP16<ElementK>(tCrK, tCrK_);
340- cute::gemm (tiled_mma, accum, tCrQ_, tCrK_, frag_src);
341-
342- } else if constexpr (is_fp8_v<ElementQ> && !is_fp8_v<ElementK>) {
343- auto tCrQ_ = make_fragment_like<half_t >(tCrQ);
344- convert_FP8_to_FP16<ElementQ>(tCrQ, tCrQ_);
345- cute::gemm (tiled_mma, accum, tCrQ_, tCrK, frag_src);
346-
347- } else if constexpr (!is_fp8_v<ElementQ> && is_fp8_v<ElementK>) {
348- auto tCrK_ = make_fragment_like<half_t >(tCrK);
349- convert_FP8_to_FP16<ElementK>(tCrK, tCrK_);
350- cute::gemm (tiled_mma, accum, tCrQ, tCrK_, frag_src);
351- } else {
352- cute::gemm (tiled_mma, accum, tCrQ, tCrK, frag_src);
353- }
278+ cute::gemm (tiled_mma, accum, tCrQ, tCrK, frag_src);
354279 }
355280 }
356281 template <
@@ -366,10 +291,7 @@ struct FlashPrefillMma<
366291 FragSrc const & frag_src,
367292 Params const & params) {
368293 int thread_idx = static_cast <int >(ThreadIdxX ());
369- // Instantiate the MMA object
370294 TiledMmaPV tiled_mma;
371- // Tile GV to the shape of <64,64> and loop over the HeadSize/64 to avoid
372- // Register spill
373295 Tensor gV_ = take<0 , 3 >(
374296 local_tile (gV , select<1 , 2 >(TileShapePV{}), make_coord (_, _)));
375297 auto sg = compat::get_nd_item<1 >().get_sub_group ();
@@ -381,49 +303,20 @@ struct FlashPrefillMma<
381303 cute::conditional_t <is_fp8_v<ElementV>, uint8_t , ElementV>;
382304 Tensor tCrV = make_tensor<TCrV_Type>(make_fragment_layout (
383305 params.gmem_tiled_copy_v , take<0 , 3 >(tCgV.shape ())));
384-
385306 // Partition the copying of A and B tiles across the threads
386307 auto gmem_thr_copy_V = params.gmem_tiled_copy_v .get_slice (thread_idx);
387308 Tensor tVrV = gmem_thr_copy_V.retile_D (tCrV);
388309 Tensor tVgV = gmem_thr_copy_V.retile_S (tCgV);
389310
390- #if CUTLASS_ENABLE_DEBUG_PRINTS
391- #define PRINT (x ) \
392- print (#x " : " ); \
393- print (x); \
394- print (" \n " );
395- if (cute::thread (LOG_THREAD, LOG_GROUP)) {
396- print (" ===================== V :\n " );
397- PRINT (gV );
398- PRINT (tCrV);
399- PRINT (tCgV);
400- PRINT (tVrV);
401- PRINT (tVgV);
402-
403- print (" ===================== Config: \n " );
404- PRINT (MaxThreadsPerBlock);
405- PRINT (SubgroupTileShapePV{});
406- }
407- #undef PRINT
408- #endif
409-
410- // 7) Convert S to P (FP32 -> BF16)
311+ // Convert S to P (FP32 -> BF16)
411312 Tensor tPr = convert_type<typename TiledMmaPV::ValTypeA>(tSr);
412313 //
413314 // Mainloop
414315 //
415316 CUTLASS_PRAGMA_UNROLL
416317 for (int i = 0 ; i < tile_count; i++) {
417318 copy (params.gmem_tiled_copy_v , tVgV (_, _, _, i), tVrV);
418- if constexpr (is_fp8_v<ElementV>) {
419- auto tCrV_ = make_fragment_like<half_t >(tCrV);
420- convert_FP8_to_FP16<ElementV>(tCrV, tCrV_);
421- cute::gemm (
422- tiled_mma, accum (_, _, _, i), tPr, tCrV_, frag_src (_, _, _, i));
423- } else {
424- cute::gemm (
425- tiled_mma, accum (_, _, _, i), tPr, tCrV, frag_src (_, _, _, i));
426- }
319+ cute::gemm (tiled_mma, accum (_, _, _, i), tPr, tCrV, frag_src (_, _, _, i));
427320 }
428321 }
429322
@@ -496,3 +389,5 @@ struct FlashPrefillMma<
496389};
497390
498391} // namespace cutlass::flash_attention::collective
392+
393+ // ///////////////////////////////////////////////////////////////////////////////////////////////
0 commit comments