Skip to content

Commit 9bcbd65

Browse files
committed
rebase fwd kernel to 32f7463e5fcbe8a958204c90fcf16379fb6dad6e
1 parent ef82015 commit 9bcbd65

File tree

5 files changed

+195
-479
lines changed

5 files changed

+195
-479
lines changed

src/ATen/native/transformers/xpu/flash_attn/sycltla/collective/xe_flash_attn_prefill_mma_bshd.h

Lines changed: 14 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -118,54 +118,39 @@ struct FlashPrefillMma<
118118
using ElementAccumulator = typename TiledMmaQK::ValTypeC;
119119
static constexpr bool CausalMask = CausalMask_;
120120
static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize;
121-
122121
using MmaAtomShape = typename MmaAtom::Shape_MNK;
123-
124122
static constexpr auto PV_ATOM_M =
125123
decltype(get<0>(SubgroupLayout{}.shape()))::value;
126124
static constexpr auto PV_ATOM_N =
127125
decltype(get<1>(SubgroupLayout{}.shape()))::value;
128126
static constexpr auto PV_ATOM_K =
129127
decltype(get<2>(SubgroupLayout{}.shape()))::value;
130-
131128
using SubgroupTileShapePV =
132129
decltype(cute::shape_div(TileShapePV{}, (SubgroupLayout{}.shape())));
133-
134130
static constexpr auto QK_BLK_M = get<0>(TileShapeQK{});
135131
static constexpr auto QK_BLK_N = get<1>(TileShapeQK{});
136132
static constexpr auto QK_BLK_K = get<2>(TileShapeQK{});
137-
138-
// This TiledMma is only required to serve the specific tiling requirements
139-
// for matrix K. This is due to the consumption of matrix K by all subgroups
140-
// within a workgroup.
141-
static constexpr auto QK_ATOM_M = PV_ATOM_M; // 8
142-
static constexpr auto QK_ATOM_N = PV_ATOM_N; // 1
143-
static constexpr auto QK_ATOM_K = PV_ATOM_K; // 1
144-
145-
using SubgroupTileShapeQK = decltype(cute::shape_div(
146-
TileShapeQK{},
147-
SubgroupLayout{}.shape())); // 128, 64, 32 / 16, 1, 1 = (8, 64, 32 )
148-
133+
static constexpr auto QK_ATOM_M = PV_ATOM_M;
134+
static constexpr auto QK_ATOM_N = PV_ATOM_N;
135+
static constexpr auto QK_ATOM_K = PV_ATOM_K;
136+
using SubgroupTileShapeQK =
137+
decltype(cute::shape_div(TileShapeQK{}, SubgroupLayout{}.shape()));
149138
static constexpr auto QK_SG_M = get<0>(SubgroupTileShapeQK{});
150139
static constexpr auto QK_SG_N = get<1>(SubgroupTileShapeQK{});
151140
static constexpr auto QK_SG_K = get<2>(SubgroupTileShapeQK{});
152-
153141
static constexpr bool is_var_len =
154142
cutlass::fmha::collective::is_variable_length_v<
155143
tuple_element_t<3, ProblemShapeType>>;
156-
157144
using FragsShapeS = decltype(cute::shape_div(
158145
take<0, 2>(SubgroupTileShapeQK{}),
159-
take<0, 2>(MmaAtomShape()))); // 8, 64, 32 / 8, 16, 16 (1, 4)
146+
take<0, 2>(MmaAtomShape())));
160147
static constexpr int Vec =
161-
(get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize; // 8
148+
(get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize;
162149
static constexpr int FragsM = get<0>(FragsShapeS{});
163-
static constexpr int FragsNS = get<1>(FragsShapeS{}); // 4
164-
150+
static constexpr int FragsNS = get<1>(FragsShapeS{});
165151
static constexpr uint32_t MaxThreadsPerBlock =
166152
size(SubgroupLayout{}) * SubgroupSize;
167153
using CopyThreadShape = Shape<_1, Int<SubgroupSize>>;
168-
169154
using traits_load_Q = Copy_Traits<GmemTiledCopyQ, StrideQ>;
170155
using atom_load_Q = Copy_Atom<traits_load_Q, ElementQ>;
171156
using val_layout_load_Q = decltype(make_layout(
@@ -174,7 +159,6 @@ struct FlashPrefillMma<
174159
atom_load_Q{},
175160
Layout<CopyThreadShape>{},
176161
val_layout_load_Q{}));
177-
178162
using traits_load_K = Copy_Traits<GmemTiledCopyK, StrideK>;
179163
using atom_load_K = Copy_Atom<traits_load_K, ElementK>;
180164
using val_layout_load_K = decltype(make_layout(
@@ -183,7 +167,6 @@ struct FlashPrefillMma<
183167
atom_load_K{},
184168
Layout<CopyThreadShape>{},
185169
val_layout_load_K{}));
186-
187170
using traits_load_V = Copy_Traits<GmemTiledCopyV, StrideV>;
188171
using atom_load_V = Copy_Atom<traits_load_V, ElementV>;
189172
using val_layout_load_V = decltype(make_layout(
@@ -195,6 +178,7 @@ struct FlashPrefillMma<
195178
template <typename T>
196179
static constexpr bool is_fp8_v =
197180
cute::is_same_v<T, float_e4m3_t> || cute::is_same_v<T, float_e5m2_t>;
181+
198182
// Host side kernel arguments
199183
struct Arguments {
200184
ElementQ const* ptr_Q;
@@ -222,7 +206,6 @@ struct FlashPrefillMma<
222206
Arguments const& args,
223207
void* workspace) {
224208
(void)workspace;
225-
226209
auto
227210
[batch,
228211
num_heads_q,
@@ -231,7 +214,6 @@ struct FlashPrefillMma<
231214
seq_len_kv,
232215
head_size_qk,
233216
head_size_vo] = problem_shape;
234-
235217
auto tensorQ = make_tensor(
236218
make_gmem_ptr(args.ptr_Q),
237219
make_layout(
@@ -250,7 +232,6 @@ struct FlashPrefillMma<
250232
XE_Copy_Q copyQ{XE_Copy_Q{}.with(tensorQ)};
251233
XE_Copy_K copyK{XE_Copy_K{}.with(tensorK)};
252234
XE_Copy_V copyV{XE_Copy_V{}.with(tensorV)};
253-
254235
return Params{copyQ, copyK, copyV};
255236
}
256237

@@ -265,22 +246,16 @@ struct FlashPrefillMma<
265246
int thread_idx = static_cast<int>(ThreadIdxX());
266247
auto thr_copy_Q = params.gmem_tiled_copy_q.get_slice(thread_idx);
267248
auto thr_copy_K = params.gmem_tiled_copy_k.get_slice(thread_idx);
268-
// Instantiate the MMA object
269249
TiledMmaQK tiled_mma;
270-
// To make all threads in a warp have the same global tensors pass in the
271-
// index of thread 0 in each warp
272250
auto sg = compat::get_nd_item<1>().get_sub_group();
273251
auto first_thread_in_sg_idx =
274252
sg.get_group_id()[0] * DispatchPolicy::SubgroupSize;
275253
auto thread_mma_q = tiled_mma.get_slice(first_thread_in_sg_idx);
276254
auto thread_mma_k = tiled_mma.get_slice(0);
277-
278255
// Partition
279256
Tensor tCgQ = thread_mma_q.partition_A(gQ);
280257
Tensor tCgK = thread_mma_k.partition_B(gK);
281-
282258
// Create fragments
283-
// TODO(Codeplay): fix this, this is probably not general
284259
using TCrQ_Type =
285260
cute::conditional_t<is_fp8_v<ElementQ>, uint8_t, ElementQ>;
286261
using TCrK_Type =
@@ -289,68 +264,18 @@ struct FlashPrefillMma<
289264
params.gmem_tiled_copy_q, take<0, 3>(tCgQ.shape())));
290265
Tensor tCrK = make_tensor<TCrK_Type>(make_fragment_layout(
291266
params.gmem_tiled_copy_k, take<0, 3>(tCgK.shape())));
292-
293267
// Retile registers for copies
294268
Tensor tQrQ = thr_copy_Q.retile_D(tCrQ);
295269
Tensor tKrK = thr_copy_K.retile_D(tCrK);
296-
297270
// Retile global tile for copies
298271
Tensor tQgQ = thr_copy_Q.retile_S(tCgQ);
299272
Tensor tKgK = thr_copy_K.retile_S(tCgK);
300273

301-
#if CUTLASS_ENABLE_DEBUG_PRINTS
302-
#define PRINT(x) \
303-
print(#x ": "); \
304-
print(x); \
305-
print("\n");
306-
if (cute::thread(LOG_THREAD, LOG_GROUP)) {
307-
print("======================= Q: \n");
308-
PRINT(gQ);
309-
PRINT(tCrQ);
310-
PRINT(tCgQ);
311-
PRINT(tQrQ);
312-
PRINT(tQgQ);
313-
314-
print("===================== K :\n");
315-
PRINT(gK);
316-
PRINT(tCrK);
317-
PRINT(tCgK);
318-
PRINT(tKrK);
319-
PRINT(tKgK);
320-
321-
print("===================== Config: \n");
322-
PRINT(MaxThreadsPerBlock);
323-
PRINT(SubgroupTileShapeQK{});
324-
}
325-
#undef PRINT
326-
#endif
327-
328-
//
329274
// Mainloop
330-
//
331-
332275
for (int k_tile = 0; k_tile < k_tile_count; ++k_tile) {
333276
copy(params.gmem_tiled_copy_q, tQgQ(_, _, _, k_tile), tQrQ);
334277
copy(params.gmem_tiled_copy_k, tKgK(_, _, _, k_tile), tKrK);
335-
if constexpr (is_fp8_v<ElementQ> && is_fp8_v<ElementK>) {
336-
auto tCrQ_ = make_fragment_like<half_t>(tCrQ);
337-
convert_FP8_to_FP16<ElementQ>(tCrQ, tCrQ_);
338-
auto tCrK_ = make_fragment_like<half_t>(tCrK);
339-
convert_FP8_to_FP16<ElementK>(tCrK, tCrK_);
340-
cute::gemm(tiled_mma, accum, tCrQ_, tCrK_, frag_src);
341-
342-
} else if constexpr (is_fp8_v<ElementQ> && !is_fp8_v<ElementK>) {
343-
auto tCrQ_ = make_fragment_like<half_t>(tCrQ);
344-
convert_FP8_to_FP16<ElementQ>(tCrQ, tCrQ_);
345-
cute::gemm(tiled_mma, accum, tCrQ_, tCrK, frag_src);
346-
347-
} else if constexpr (!is_fp8_v<ElementQ> && is_fp8_v<ElementK>) {
348-
auto tCrK_ = make_fragment_like<half_t>(tCrK);
349-
convert_FP8_to_FP16<ElementK>(tCrK, tCrK_);
350-
cute::gemm(tiled_mma, accum, tCrQ, tCrK_, frag_src);
351-
} else {
352-
cute::gemm(tiled_mma, accum, tCrQ, tCrK, frag_src);
353-
}
278+
cute::gemm(tiled_mma, accum, tCrQ, tCrK, frag_src);
354279
}
355280
}
356281
template <
@@ -366,10 +291,7 @@ struct FlashPrefillMma<
366291
FragSrc const& frag_src,
367292
Params const& params) {
368293
int thread_idx = static_cast<int>(ThreadIdxX());
369-
// Instantiate the MMA object
370294
TiledMmaPV tiled_mma;
371-
// Tile GV to the shape of <64,64> and loop over the HeadSize/64 to avoid
372-
// Register spill
373295
Tensor gV_ = take<0, 3>(
374296
local_tile(gV, select<1, 2>(TileShapePV{}), make_coord(_, _)));
375297
auto sg = compat::get_nd_item<1>().get_sub_group();
@@ -381,49 +303,20 @@ struct FlashPrefillMma<
381303
cute::conditional_t<is_fp8_v<ElementV>, uint8_t, ElementV>;
382304
Tensor tCrV = make_tensor<TCrV_Type>(make_fragment_layout(
383305
params.gmem_tiled_copy_v, take<0, 3>(tCgV.shape())));
384-
385306
// Partition the copying of A and B tiles across the threads
386307
auto gmem_thr_copy_V = params.gmem_tiled_copy_v.get_slice(thread_idx);
387308
Tensor tVrV = gmem_thr_copy_V.retile_D(tCrV);
388309
Tensor tVgV = gmem_thr_copy_V.retile_S(tCgV);
389310

390-
#if CUTLASS_ENABLE_DEBUG_PRINTS
391-
#define PRINT(x) \
392-
print(#x ": "); \
393-
print(x); \
394-
print("\n");
395-
if (cute::thread(LOG_THREAD, LOG_GROUP)) {
396-
print("===================== V :\n");
397-
PRINT(gV);
398-
PRINT(tCrV);
399-
PRINT(tCgV);
400-
PRINT(tVrV);
401-
PRINT(tVgV);
402-
403-
print("===================== Config: \n");
404-
PRINT(MaxThreadsPerBlock);
405-
PRINT(SubgroupTileShapePV{});
406-
}
407-
#undef PRINT
408-
#endif
409-
410-
// 7) Convert S to P (FP32 -> BF16)
311+
// Convert S to P (FP32 -> BF16)
411312
Tensor tPr = convert_type<typename TiledMmaPV::ValTypeA>(tSr);
412313
//
413314
// Mainloop
414315
//
415316
CUTLASS_PRAGMA_UNROLL
416317
for (int i = 0; i < tile_count; i++) {
417318
copy(params.gmem_tiled_copy_v, tVgV(_, _, _, i), tVrV);
418-
if constexpr (is_fp8_v<ElementV>) {
419-
auto tCrV_ = make_fragment_like<half_t>(tCrV);
420-
convert_FP8_to_FP16<ElementV>(tCrV, tCrV_);
421-
cute::gemm(
422-
tiled_mma, accum(_, _, _, i), tPr, tCrV_, frag_src(_, _, _, i));
423-
} else {
424-
cute::gemm(
425-
tiled_mma, accum(_, _, _, i), tPr, tCrV, frag_src(_, _, _, i));
426-
}
319+
cute::gemm(tiled_mma, accum(_, _, _, i), tPr, tCrV, frag_src(_, _, _, i));
427320
}
428321
}
429322

@@ -496,3 +389,5 @@ struct FlashPrefillMma<
496389
};
497390

498391
} // namespace cutlass::flash_attention::collective
392+
393+
/////////////////////////////////////////////////////////////////////////////////////////////////

src/ATen/native/transformers/xpu/flash_attn/sycltla/collective/xe_flash_attn_sdpa_fwd_bshd_epilogue.h

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ class FlashPrefillEpilogue<
166166
return true;
167167
}
168168

169+
// The main operator
169170
CUTLASS_HOST_DEVICE
170171
FlashPrefillEpilogue(Params const& params_, TensorStorage const&)
171172
: params(params_) {}
@@ -187,16 +188,13 @@ class FlashPrefillEpilogue<
187188
int const& q_head_coord,
188189
float softmax_scale) {
189190
using namespace cute;
190-
191191
static constexpr bool is_var_len =
192192
cutlass::fmha::collective::is_variable_length_v<
193193
tuple_element_t<2, ProblemShape>>;
194-
195194
using FragOutLayout = typename FragOut::layout_type;
196195
constexpr int Vec = shape<0>(FragOutLayout{});
197196
constexpr int FragsM = shape<1>(FragOutLayout{});
198197
constexpr int FragsN = size(select<2, 3>(shape(FragOutLayout{})));
199-
200198
auto g = compat::get_nd_item<1>().get_sub_group();
201199
auto out_reg = make_tensor(
202200
static_cast<decltype(out)&&>(out).data(),
@@ -231,14 +229,9 @@ class FlashPrefillEpilogue<
231229
// Indexing variables
232230
auto [batch, num_heads_q, head_size_vo] = select<0, 1, 6>(problem_shape);
233231
auto [seq_len_qo] = select<0>(sequence_length_shape);
234-
// Represent the full output tensor
235-
// Tensor mO_mnl = cute::get_xe_tensor(make_shape(seq_len_qo, head_size_vo,
236-
// (is_var_len ? batch : 1) * num_heads_q));
237232
Tensor mO_mnl =
238233
cute::get_xe_tensor(make_shape(seq_len_qo, head_size_vo, 1));
239-
240234
auto [m_coord, n_coord, k_coord, l_coord] = tile_coord;
241-
// Tile the output tensor per WG
242235
Tensor g_wg_O = local_tile(
243236
mO_mnl,
244237
select<0, 1>(TileShapeOutput{}),
@@ -247,21 +240,14 @@ class FlashPrefillEpilogue<
247240
get<2>(typename TiledMmaOutput::ThrLayoutVMNK{}.shape());
248241
auto m_sg = get_sub_group_id() / ATOM_N;
249242
auto n_sg = get_sub_group_id() % ATOM_N;
250-
// Tile the output tensor per SG
251243
Tensor gO = local_tile(
252244
g_wg_O,
253245
SubgroupTileShape{},
254246
make_coord(m_sg, n_sg, _),
255247
Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
256248
auto thread_xe_store_o = params.xe_store_o.get_thread_slice(ThreadIdxX());
257249
Tensor tOgO = thread_xe_store_o.partition_D(gO);
258-
259250
Tensor final_out_reg = make_fragment_like<ElementOutput>(out_reg);
260-
// iff ElementOutput == ElementAccumulator, then convert_type doesn't do the
261-
// right conversion iff ElementOutput == fp8, there is no NumericConverter
262-
// specialization available for both the above cases, we call copy() which
263-
// internally performs a static_cast op on the data. for ElementOutput ==
264-
// bf16 | fp16, convert_type calls relevant NumericConverter specialization.
265251
if constexpr (
266252
cute::is_any_of_v<
267253
ElementOutput,
@@ -280,30 +266,17 @@ class FlashPrefillEpilogue<
280266
int lane_id = static_cast<int>(sg.get_local_linear_id());
281267
int sub_group_id = get_sub_group_id();
282268
const int BLK_M = size(select<0>(TileShapeOutput{}));
283-
284-
// write along the sequence.
285-
// use the entire sub_group to write lse since all
286-
// work items within subgroup have the same sum() data stored
287-
// in registers
288269
auto blk_m_coord = get<0>(tile_coord); // seq_len_blk_idx
289-
290270
size_t lse_offset =
291271
k_coord * num_heads_q * seq_len_qo + // shift the batch -- batch_idx *
292272
// num_heads_q * seq_len_qo -- OK
293273
q_head_coord *
294274
seq_len_qo + // shift the head -- head_q * seq_len_qo -- ok
295275
m_coord * BLK_M; // shift to the particular tile
296-
297276
int localtile_seq_coord = 0;
298-
299-
// Calculate the sequence coordinate
300-
// The coordinate value should be within [0.. seq_len_qo - 1]
301277
localtile_seq_coord = sub_group_id * SubgroupSize +
302-
lane_id; // one subgroup will handle 16 (usually) sequence
303-
304-
// checked
278+
lane_id; // one subgroup will handle 16 sequence
305279
int seq_coord = m_coord * BLK_M + localtile_seq_coord;
306-
307280
// Check that if this is within the seq_len_qo
308281
if (seq_coord < seq_len_qo) {
309282
auto cur_sum = rowsum[lane_id];
@@ -356,3 +329,5 @@ class FlashPrefillEpilogue<
356329
} // namespace collective
357330
} // namespace flash_attention
358331
} // namespace cutlass
332+
333+
/////////////////////////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)