From a7deb96bd6ddb3ab53eaebb1ef9adafb26207f70 Mon Sep 17 00:00:00 2001 From: "fengqing.lu" Date: Tue, 11 Nov 2025 18:16:15 -0800 Subject: [PATCH 1/7] mha fwd/bwd kernel integration --- src/ATen/CMakeLists.txt | 4 +- .../transformers/xpu/flash_attn/flash_api.cpp | 109 ++ .../transformers/xpu/flash_attn/flash_api.h | 43 + .../xe_flash_attn_prefill_mma_bshd.h | 498 ++++++ .../xe_flash_attn_sdpa_fwd_bshd_epilogue.h | 358 ++++ ...lash_attn_sdpa_fwd_bshd_softmax_epilogue.h | 175 ++ .../xpu/flash_attn/sycltla/flash_api.h | 41 + .../kernel/tile_scheduler_sdpa_fwd_bshd.h | 263 +++ .../sycltla/kernel/xe_sdpa_fwd_bshd.h | 761 ++++++++ .../xpu/flash_attn/sycltla/mha_bwd.cpp | 1541 +++++++++++++++++ .../xpu/flash_attn/sycltla/mha_bwd.h | 457 +++++ .../xpu/flash_attn/sycltla/mha_common.h | 50 + .../xpu/flash_attn/sycltla/mha_fwd.cpp | 547 ++++++ .../xpu/flash_attn/sycltla/mha_fwd.h | 13 + .../transformers/xpu/flash_attn/utils.h | 135 ++ 15 files changed, 4993 insertions(+), 2 deletions(-) create mode 100644 src/ATen/native/transformers/xpu/flash_attn/flash_api.cpp create mode 100644 src/ATen/native/transformers/xpu/flash_attn/flash_api.h create mode 100644 src/ATen/native/transformers/xpu/flash_attn/sycltla/collective/xe_flash_attn_prefill_mma_bshd.h create mode 100644 src/ATen/native/transformers/xpu/flash_attn/sycltla/collective/xe_flash_attn_sdpa_fwd_bshd_epilogue.h create mode 100644 src/ATen/native/transformers/xpu/flash_attn/sycltla/collective/xe_flash_attn_sdpa_fwd_bshd_softmax_epilogue.h create mode 100644 src/ATen/native/transformers/xpu/flash_attn/sycltla/flash_api.h create mode 100644 src/ATen/native/transformers/xpu/flash_attn/sycltla/kernel/tile_scheduler_sdpa_fwd_bshd.h create mode 100644 src/ATen/native/transformers/xpu/flash_attn/sycltla/kernel/xe_sdpa_fwd_bshd.h create mode 100644 src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp create mode 100644 src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.h create mode 100644 src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_common.h create mode 100644 src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.cpp create mode 100644 src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.h create mode 100644 src/ATen/native/transformers/xpu/flash_attn/utils.h diff --git a/src/ATen/CMakeLists.txt b/src/ATen/CMakeLists.txt index 9863027cd6..961a5065b1 100644 --- a/src/ATen/CMakeLists.txt +++ b/src/ATen/CMakeLists.txt @@ -1,9 +1,9 @@ # ATen XPU sources file(GLOB xpu_cpp "xpu/*.cpp") -file(GLOB xpu_native_cpp "native/xpu/*.cpp" "native/sparse/*.cpp" "native/sparse/xpu/*.cpp" "native/nested/*.cpp" "native/nested/xpu/*.cpp" "native/transformers/*.cpp" "native/quantized/*.cpp" ${TORCH_ROOT}/aten/src/ATen/native/transformers/xpu/flash_attn/*.cpp) +file(GLOB xpu_native_cpp "native/xpu/*.cpp" "native/sparse/*.cpp" "native/sparse/xpu/*.cpp" "native/nested/*.cpp" "native/nested/xpu/*.cpp" "native/transformers/*.cpp" "native/quantized/*.cpp" "native/transformers/xpu/flash_attn/*.cpp") file(GLOB xpu_sycl "native/xpu/sycl/*.cpp" "native/sparse/xpu/sycl/*.cpp" "native/nested/xpu/sycl/*.cpp" "native/transformers/sycl/*.cpp" "native/quantized/sycl/*.cpp") -file(GLOB xpu_sycltla "${TORCH_ROOT}/aten/src/ATen/native/transformers/xpu/flash_attn/sycltla/*.cpp") +file(GLOB xpu_sycltla "native/transformers/xpu/flash_attn/sycltla/*.cpp") list(APPEND ATen_XPU_CPP_SRCS ${xpu_cpp}) if(USE_ONEMKL_XPU) diff --git a/src/ATen/native/transformers/xpu/flash_attn/flash_api.cpp b/src/ATen/native/transformers/xpu/flash_attn/flash_api.cpp new file mode 100644 index 0000000000..3facb81e50 --- /dev/null +++ b/src/ATen/native/transformers/xpu/flash_attn/flash_api.cpp @@ -0,0 +1,109 @@ +#include +#include + +namespace sycltla { + +bool is_flash_attention_available() { +#ifndef USE_SYCLTLA + return false; +#else + return true; +#endif +} + +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + c10::SymInt, + c10::SymInt, + at::Tensor, + at::Tensor> +flash_attention_forward( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const double dropout, + const bool is_causal, + const float scale) { +#ifndef USE_SYCLTLA + TORCH_CHECK( + false, + "flash_attention_forward: Torch XPU was not compiled with SYCLTLA support."); + return std::make_tuple( + at::Tensor(), + at::Tensor(), + at::Tensor(), + at::Tensor(), + c10::SymInt(0), + c10::SymInt(0), + at::Tensor(), + at::Tensor()); +#else + auto + [attention, + logsumexp, + cumulative_sequence_length_q, + cumulative_sequence_length_k, + max_seqlen_batch_q, + max_seqlen_batch_k, + philox_seed, + philox_offset] = + flash_attention_forward_sycltla( + query, key, value, dropout, is_causal, scale); + return std::make_tuple( + std::move(attention), + std::move(logsumexp), + std::move(cumulative_sequence_length_q), + std::move(cumulative_sequence_length_k), + std::move(max_seqlen_batch_q), + std::move(max_seqlen_batch_k), + std::move(philox_seed), + std::move(philox_offset)); +#endif +} + +std::tuple flash_attention_backward( + const at::Tensor& grad_out, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& out, + const at::Tensor& logsumexp, + const at::Tensor& cumulative_sequence_length_q, + const at::Tensor& cumulative_sequence_length_k, + const int64_t max_seqlen_batch_q, + const int64_t max_seqlen_batch_k, + const double dropout, + const bool is_causal, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset, + const float scale) { +#ifndef USE_SYCLTLA + TORCH_CHECK( + false, + "flash_attention_backward: Torch XPU was not compiled with SYCLTLA support."); + return std::make_tuple(at::Tensor(), at::Tensor(), at::Tensor()); +#else + auto [grad_query, grad_key, grad_value] = flash_attention_backward_sycltla( + grad_out, + query, + key, + value, + out, + logsumexp, + cumulative_sequence_length_q, + cumulative_sequence_length_k, + max_seqlen_batch_q, + max_seqlen_batch_k, + dropout, + is_causal, + philox_seed, + philox_offset, + scale); + return std::make_tuple( + std::move(grad_query), std::move(grad_key), std::move(grad_value)); +#endif +} +} // namespace sycltla diff --git a/src/ATen/native/transformers/xpu/flash_attn/flash_api.h b/src/ATen/native/transformers/xpu/flash_attn/flash_api.h new file mode 100644 index 0000000000..d97a975796 --- /dev/null +++ b/src/ATen/native/transformers/xpu/flash_attn/flash_api.h @@ -0,0 +1,43 @@ +#pragma once + +#include + +namespace sycltla { + +bool is_flash_attention_available(); + +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + c10::SymInt, + c10::SymInt, + at::Tensor, + at::Tensor> +flash_attention_forward( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const double dropout, + const bool is_causal, + const float scale); + +std::tuple flash_attention_backward( + const at::Tensor& grad_out, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& out, + const at::Tensor& logsumexp, + const at::Tensor& cumulative_sequence_length_q, + const at::Tensor& cumulative_sequence_length_k, + const int64_t max_seqlen_batch_q, + const int64_t max_seqlen_batch_k, + const double dropout, + const bool is_causal, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset, + const float scale); + +} // namespace sycltla diff --git a/src/ATen/native/transformers/xpu/flash_attn/sycltla/collective/xe_flash_attn_prefill_mma_bshd.h b/src/ATen/native/transformers/xpu/flash_attn/sycltla/collective/xe_flash_attn_prefill_mma_bshd.h new file mode 100644 index 0000000000..2099793293 --- /dev/null +++ b/src/ATen/native/transformers/xpu/flash_attn/sycltla/collective/xe_flash_attn_prefill_mma_bshd.h @@ -0,0 +1,498 @@ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fp8_to_fp16.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/atom/mma_atom.hpp" +#include "flash_attention_v2/collective/fmha_fusion.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::flash_attention::collective { +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE auto convert_type(Tensor const& tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + auto frag = + convert_op(*reinterpret_cast*>( + tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class DispatchPolicy, + class ProblemShapeType_, + class ElementQ_, + class StrideQ_, + class ElementK_, + class StrideK_, + class ElementV_, + class StrideV_, + class MMAOperation_, + class TileShapeQK_, + class TileShapePV_, + class SubgroupLayout_, + class GmemTiledCopyQ_, + class GmemTiledCopyK_, + class GmemTiledCopyV_, + bool CausalMask_> +struct FlashPrefillMma { + static_assert( + cutlass::detail::dependent_false, + "Could not find a mainloop specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Stages, + class ProblemShapeType_, + class ElementQ_, + class StrideQ_, + class ElementK_, + class StrideK_, + class ElementV_, + class StrideV_, + class MMAOperation_, + class TileShapeQK_, + class TileShapePV_, + class SubgroupLayout_, + class GmemTiledCopyQ_, + class GmemTiledCopyK_, + class GmemTiledCopyV_, + bool CausalMask_> +struct FlashPrefillMma< + gemm::MainloopIntelXeXMX16, + ProblemShapeType_, + ElementQ_, + StrideQ_, + ElementK_, + StrideK_, + ElementV_, + StrideV_, + MMAOperation_, + TileShapeQK_, + TileShapePV_, + SubgroupLayout_, + GmemTiledCopyQ_, + GmemTiledCopyK_, + GmemTiledCopyV_, + CausalMask_> { + // + // Type Aliases + // + using DispatchPolicy = gemm::MainloopIntelXeXMX16; + using TileShapeQK = TileShapeQK_; + using TileShapePV = TileShapePV_; + using SubgroupLayout = SubgroupLayout_; + using ProblemShapeType = ProblemShapeType_; + using ElementQ = ElementQ_; + using StrideQ = StrideQ_; + using ElementK = ElementK_; + using StrideK = StrideK_; + using ElementV = ElementV_; + using StrideV = StrideV_; + using GmemTiledCopyQ = GmemTiledCopyQ_; + using GmemTiledCopyK = GmemTiledCopyK_; + using GmemTiledCopyV = GmemTiledCopyV_; + using ArchTag = typename DispatchPolicy::ArchTag; + using MmaAtom = MMA_Atom; + using TiledMmaQK = + typename TiledMMAHelper, SubgroupLayout>:: + TiledMMA; + using TiledMmaPV = + typename TiledMMAHelper, SubgroupLayout>:: + TiledMMA; + using ElementAccumulator = typename TiledMmaQK::ValTypeC; + static constexpr bool CausalMask = CausalMask_; + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + + using MmaAtomShape = typename MmaAtom::Shape_MNK; + + static constexpr auto PV_ATOM_M = + decltype(get<0>(SubgroupLayout{}.shape()))::value; + static constexpr auto PV_ATOM_N = + decltype(get<1>(SubgroupLayout{}.shape()))::value; + static constexpr auto PV_ATOM_K = + decltype(get<2>(SubgroupLayout{}.shape()))::value; + + using SubgroupTileShapePV = + decltype(cute::shape_div(TileShapePV{}, (SubgroupLayout{}.shape()))); + + static constexpr auto QK_BLK_M = get<0>(TileShapeQK{}); + static constexpr auto QK_BLK_N = get<1>(TileShapeQK{}); + static constexpr auto QK_BLK_K = get<2>(TileShapeQK{}); + + // This TiledMma is only required to serve the specific tiling requirements + // for matrix K. This is due to the consumption of matrix K by all subgroups + // within a workgroup. + static constexpr auto QK_ATOM_M = PV_ATOM_M; // 8 + static constexpr auto QK_ATOM_N = PV_ATOM_N; // 1 + static constexpr auto QK_ATOM_K = PV_ATOM_K; // 1 + + using SubgroupTileShapeQK = decltype(cute::shape_div( + TileShapeQK{}, + SubgroupLayout{}.shape())); // 128, 64, 32 / 16, 1, 1 = (8, 64, 32 ) + + static constexpr auto QK_SG_M = get<0>(SubgroupTileShapeQK{}); + static constexpr auto QK_SG_N = get<1>(SubgroupTileShapeQK{}); + static constexpr auto QK_SG_K = get<2>(SubgroupTileShapeQK{}); + + static constexpr bool is_var_len = + cutlass::fmha::collective::is_variable_length_v< + tuple_element_t<3, ProblemShapeType>>; + + using FragsShapeS = decltype(cute::shape_div( + take<0, 2>(SubgroupTileShapeQK{}), + take<0, 2>(MmaAtomShape()))); // 8, 64, 32 / 8, 16, 16 (1, 4) + static constexpr int Vec = + (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize; // 8 + static constexpr int FragsM = get<0>(FragsShapeS{}); + static constexpr int FragsNS = get<1>(FragsShapeS{}); // 4 + + static constexpr uint32_t MaxThreadsPerBlock = + size(SubgroupLayout{}) * SubgroupSize; + using CopyThreadShape = Shape<_1, Int>; + + using traits_load_Q = Copy_Traits; + using atom_load_Q = Copy_Atom; + using val_layout_load_Q = decltype(make_layout( + shape_div(typename traits_load_Q::BlockShape{}, CopyThreadShape{}))); + using XE_Copy_Q = decltype(make_tiled_copy( + atom_load_Q{}, + Layout{}, + val_layout_load_Q{})); + + using traits_load_K = Copy_Traits; + using atom_load_K = Copy_Atom; + using val_layout_load_K = decltype(make_layout( + shape_div(typename traits_load_K::BlockShape{}, CopyThreadShape{}))); + using XE_Copy_K = decltype(make_tiled_copy( + atom_load_K{}, + Layout{}, + val_layout_load_K{})); + + using traits_load_V = Copy_Traits; + using atom_load_V = Copy_Atom; + using val_layout_load_V = decltype(make_layout( + shape_div(typename traits_load_V::BlockShape{}, CopyThreadShape{}))); + using XE_Copy_V = decltype(make_tiled_copy( + atom_load_V{}, + Layout{}, + val_layout_load_V{})); + template + static constexpr bool is_fp8_v = + cute::is_same_v || cute::is_same_v; + // Host side kernel arguments + struct Arguments { + ElementQ const* ptr_Q; + StrideQ dQ; + ElementK const* ptr_K; + StrideK dK; + ElementV const* ptr_V; + StrideV dV; + }; + + struct Params { + XE_Copy_Q gmem_tiled_copy_q; + XE_Copy_K gmem_tiled_copy_k; + XE_Copy_V gmem_tiled_copy_v; + }; + + // + // Methods + // + + FlashPrefillMma() = default; + + static constexpr Params to_underlying_arguments( + ProblemShapeType const& problem_shape, + Arguments const& args, + void* workspace) { + (void)workspace; + + auto + [batch, + num_heads_q, + num_heads_kv, + seq_len_qo, + seq_len_kv, + head_size_qk, + head_size_vo] = problem_shape; + + auto tensorQ = make_tensor( + make_gmem_ptr(args.ptr_Q), + make_layout( + make_shape(seq_len_qo, num_heads_q * head_size_qk, batch), + args.dQ)); + auto tensorK = make_tensor( + make_gmem_ptr(args.ptr_K), + make_layout( + make_shape(seq_len_kv, num_heads_kv * head_size_qk, batch), + args.dK)); + auto tensorV = make_tensor( + make_gmem_ptr(args.ptr_V), + make_layout( + make_shape(num_heads_kv * head_size_vo, seq_len_kv, batch), + args.dV)); + XE_Copy_Q copyQ{XE_Copy_Q{}.with(tensorQ)}; + XE_Copy_K copyK{XE_Copy_K{}.with(tensorK)}; + XE_Copy_V copyV{XE_Copy_V{}.with(tensorV)}; + + return Params{copyQ, copyK, copyV}; + } + + template + CUTLASS_DEVICE void mmaQK( + FragQccum& accum, + TensorQ gQ, + TensorK gK, + FragSrc const& frag_src, + int const& k_tile_count, + Params const& params) { + int thread_idx = static_cast(ThreadIdxX()); + auto thr_copy_Q = params.gmem_tiled_copy_q.get_slice(thread_idx); + auto thr_copy_K = params.gmem_tiled_copy_k.get_slice(thread_idx); + // Instantiate the MMA object + TiledMmaQK tiled_mma; + // To make all threads in a warp have the same global tensors pass in the + // index of thread 0 in each warp + auto sg = compat::get_nd_item<1>().get_sub_group(); + auto first_thread_in_sg_idx = + sg.get_group_id()[0] * DispatchPolicy::SubgroupSize; + auto thread_mma_q = tiled_mma.get_slice(first_thread_in_sg_idx); + auto thread_mma_k = tiled_mma.get_slice(0); + + // Partition + Tensor tCgQ = thread_mma_q.partition_A(gQ); + Tensor tCgK = thread_mma_k.partition_B(gK); + + // Create fragments + // TODO(Codeplay): fix this, this is probably not general + using TCrQ_Type = + cute::conditional_t, uint8_t, ElementQ>; + using TCrK_Type = + cute::conditional_t, uint8_t, ElementK>; + Tensor tCrQ = make_tensor(make_fragment_layout( + params.gmem_tiled_copy_q, take<0, 3>(tCgQ.shape()))); + Tensor tCrK = make_tensor(make_fragment_layout( + params.gmem_tiled_copy_k, take<0, 3>(tCgK.shape()))); + + // Retile registers for copies + Tensor tQrQ = thr_copy_Q.retile_D(tCrQ); + Tensor tKrK = thr_copy_K.retile_D(tCrK); + + // Retile global tile for copies + Tensor tQgQ = thr_copy_Q.retile_S(tCgQ); + Tensor tKgK = thr_copy_K.retile_S(tCgK); + +#if CUTLASS_ENABLE_DEBUG_PRINTS +#define PRINT(x) \ + print(#x ": "); \ + print(x); \ + print("\n"); + if (cute::thread(LOG_THREAD, LOG_GROUP)) { + print("======================= Q: \n"); + PRINT(gQ); + PRINT(tCrQ); + PRINT(tCgQ); + PRINT(tQrQ); + PRINT(tQgQ); + + print("===================== K :\n"); + PRINT(gK); + PRINT(tCrK); + PRINT(tCgK); + PRINT(tKrK); + PRINT(tKgK); + + print("===================== Config: \n"); + PRINT(MaxThreadsPerBlock); + PRINT(SubgroupTileShapeQK{}); + } +#undef PRINT +#endif + + // + // Mainloop + // + + for (int k_tile = 0; k_tile < k_tile_count; ++k_tile) { + copy(params.gmem_tiled_copy_q, tQgQ(_, _, _, k_tile), tQrQ); + copy(params.gmem_tiled_copy_k, tKgK(_, _, _, k_tile), tKrK); + if constexpr (is_fp8_v && is_fp8_v) { + auto tCrQ_ = make_fragment_like(tCrQ); + convert_FP8_to_FP16(tCrQ, tCrQ_); + auto tCrK_ = make_fragment_like(tCrK); + convert_FP8_to_FP16(tCrK, tCrK_); + cute::gemm(tiled_mma, accum, tCrQ_, tCrK_, frag_src); + + } else if constexpr (is_fp8_v && !is_fp8_v) { + auto tCrQ_ = make_fragment_like(tCrQ); + convert_FP8_to_FP16(tCrQ, tCrQ_); + cute::gemm(tiled_mma, accum, tCrQ_, tCrK, frag_src); + + } else if constexpr (!is_fp8_v && is_fp8_v) { + auto tCrK_ = make_fragment_like(tCrK); + convert_FP8_to_FP16(tCrK, tCrK_); + cute::gemm(tiled_mma, accum, tCrQ, tCrK_, frag_src); + } else { + cute::gemm(tiled_mma, accum, tCrQ, tCrK, frag_src); + } + } + } + template < + int tile_count, + class FragQccum, + class FragS, + class TensorV, + class FragSrc> + CUTLASS_DEVICE void mmaPV( + FragQccum& accum, + FragS const& tSr, + TensorV gV, + FragSrc const& frag_src, + Params const& params) { + int thread_idx = static_cast(ThreadIdxX()); + // Instantiate the MMA object + TiledMmaPV tiled_mma; + // Tile GV to the shape of <64,64> and loop over the HeadSize/64 to avoid + // Register spill + Tensor gV_ = take<0, 3>( + local_tile(gV, select<1, 2>(TileShapePV{}), make_coord(_, _))); + auto sg = compat::get_nd_item<1>().get_sub_group(); + auto first_thread_in_sg_idx = + sg.get_group_id()[0] * DispatchPolicy::SubgroupSize; + auto thread_mma = tiled_mma.get_slice(first_thread_in_sg_idx); + Tensor tCgV = thread_mma.partition_B(gV_); + using TCrV_Type = + cute::conditional_t, uint8_t, ElementV>; + Tensor tCrV = make_tensor(make_fragment_layout( + params.gmem_tiled_copy_v, take<0, 3>(tCgV.shape()))); + + // Partition the copying of A and B tiles across the threads + auto gmem_thr_copy_V = params.gmem_tiled_copy_v.get_slice(thread_idx); + Tensor tVrV = gmem_thr_copy_V.retile_D(tCrV); + Tensor tVgV = gmem_thr_copy_V.retile_S(tCgV); + +#if CUTLASS_ENABLE_DEBUG_PRINTS +#define PRINT(x) \ + print(#x ": "); \ + print(x); \ + print("\n"); + if (cute::thread(LOG_THREAD, LOG_GROUP)) { + print("===================== V :\n"); + PRINT(gV); + PRINT(tCrV); + PRINT(tCgV); + PRINT(tVrV); + PRINT(tVgV); + + print("===================== Config: \n"); + PRINT(MaxThreadsPerBlock); + PRINT(SubgroupTileShapePV{}); + } +#undef PRINT +#endif + + // 7) Convert S to P (FP32 -> BF16) + Tensor tPr = convert_type(tSr); + // + // Mainloop + // + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < tile_count; i++) { + copy(params.gmem_tiled_copy_v, tVgV(_, _, _, i), tVrV); + if constexpr (is_fp8_v) { + auto tCrV_ = make_fragment_like(tCrV); + convert_FP8_to_FP16(tCrV, tCrV_); + cute::gemm( + tiled_mma, accum(_, _, _, i), tPr, tCrV_, frag_src(_, _, _, i)); + } else { + cute::gemm( + tiled_mma, accum(_, _, _, i), tPr, tCrV, frag_src(_, _, _, i)); + } + } + } + + // SequenceLengthShape = Shape + // For Fixed Sequence Length, ProblemShape = Shape For Variable Sequence Length, ProblemShape = Shape + template + CUTLASS_DEVICE static constexpr Params get_updated_copies( + Params const& params, + ProblemShape const& problem_shape, + SequenceLengthShape const& sequence_length_shape, + int const& l_coord, + int const& q_head_coord = 0) { + auto [num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = + select<1, 2, 5, 6>(problem_shape); + auto [seq_len_qo, seq_len_kv] = sequence_length_shape; + auto q_group_size = num_heads_q / num_heads_kv; + auto kv_head_coord = q_head_coord / q_group_size; + int offset_q = 0, offset_k = 0, offset_v = 0; + + if constexpr (is_var_len) { + auto qo_cumulative_length = get<3>(problem_shape).cumulative_length; + auto kv_cumulative_length = get<4>(problem_shape).cumulative_length; + // auto kv_cached_cumulative_length = + // get<5>(problem_shape).cumulative_length; + + offset_q = num_heads_q * head_size_qk * qo_cumulative_length[l_coord] + + q_head_coord * head_size_qk; + offset_k = num_heads_kv * head_size_qk * kv_cumulative_length[l_coord] + + kv_head_coord * head_size_qk; + offset_v = num_heads_kv * head_size_vo * kv_cumulative_length[l_coord] + + kv_head_coord * head_size_vo; + } else { + offset_q = num_heads_q * head_size_qk * seq_len_qo * l_coord + + q_head_coord * head_size_qk; + offset_k = num_heads_kv * head_size_qk * seq_len_kv * l_coord + + kv_head_coord * head_size_qk; + offset_v = num_heads_kv * head_size_vo * seq_len_kv * l_coord + + kv_head_coord * head_size_vo; + } + + auto q_traits = static_cast(params.gmem_tiled_copy_q); + const ElementQ* q_ptr = (const ElementQ*)q_traits.base_ptr; + auto k_traits = static_cast(params.gmem_tiled_copy_k); + const ElementK* k_ptr = (const ElementK*)k_traits.base_ptr; + auto v_traits = static_cast(params.gmem_tiled_copy_v); + const ElementV* v_ptr = (const ElementV*)v_traits.base_ptr; + auto shape_q = + make_shape(static_cast(seq_len_qo), head_size_qk * num_heads_q, 1); + StrideQ stride_q = cutlass::make_cute_packed_stride(StrideQ{}, shape_q); + auto shape_k = make_shape( + static_cast(seq_len_kv), num_heads_kv * head_size_qk, 1); + StrideK stride_k = cutlass::make_cute_packed_stride(StrideK{}, shape_k); + auto shape_v = make_shape( + head_size_vo * num_heads_kv, static_cast(seq_len_kv), 1); + StrideV stride_v = cutlass::make_cute_packed_stride(StrideV{}, shape_v); + + auto tensorQ = make_tensor( + make_gmem_ptr(q_ptr + offset_q), make_layout(shape_q, stride_q)); + auto tensorK = make_tensor( + make_gmem_ptr(k_ptr + offset_k), make_layout(shape_k, stride_k)); + auto tensorV = make_tensor( + make_gmem_ptr(v_ptr + offset_v), make_layout(shape_v, stride_v)); + XE_Copy_Q copyQ{XE_Copy_Q{}.with(tensorQ)}; + XE_Copy_K copyK{XE_Copy_K{}.with(tensorK)}; + XE_Copy_V copyV{XE_Copy_V{}.with(tensorV)}; + return Params{copyQ, copyK, copyV}; + } +}; + +} // namespace cutlass::flash_attention::collective diff --git a/src/ATen/native/transformers/xpu/flash_attn/sycltla/collective/xe_flash_attn_sdpa_fwd_bshd_epilogue.h b/src/ATen/native/transformers/xpu/flash_attn/sycltla/collective/xe_flash_attn_sdpa_fwd_bshd_epilogue.h new file mode 100644 index 0000000000..8f163142d7 --- /dev/null +++ b/src/ATen/native/transformers/xpu/flash_attn/sycltla/collective/xe_flash_attn_sdpa_fwd_bshd_epilogue.h @@ -0,0 +1,358 @@ +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/detail/layout.hpp" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// +namespace cutlass { +namespace flash_attention { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class DispatchPolicy, + class MMAOperation_, + class TileShapeOutput_, + class SubgroupLayout_, + class... Args> +class FlashPrefillEpilogue { + static_assert( + cutlass::detail::dependent_false, + "Could not find an epilogue specialization."); +}; + +template < + class MMAOperation_, + class TileShapeOutput_, + class SubgroupLayout_, + class ElementCompute_, + class ElementO_, + class StrideO_, + class ElementLSE_, + class CopyOpO_> +class FlashPrefillEpilogue< + epilogue::IntelXeXMX16, + MMAOperation_, + TileShapeOutput_, + SubgroupLayout_, + ElementCompute_, + ElementO_, + StrideO_, + ElementLSE_, + CopyOpO_> { + public: + // + // Type Aliases + // + using DispatchPolicy = epilogue::IntelXeXMX16; + using ElementO = ElementO_; + using StrideO = StrideO_; + using ElementLSE = ElementLSE_; + using CopyOpO = CopyOpO_; + using SubgroupLayout = SubgroupLayout_; + using TileShapeOutput = TileShapeOutput_; + using TiledMmaOutput = typename TiledMMAHelper< + MMA_Atom, + Layout, + SubgroupLayout>::TiledMMA; + using GmemTiledCopyO = CopyOpO; + using ElementOutput = ElementO_; + using ElementCompute = ElementCompute_; + using ElementAccumulator = ElementCompute_; + using SubgroupTileShape = + decltype(cute::shape_div(TileShapeOutput{}, (SubgroupLayout{}.shape()))); + + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + + static_assert( + cute::rank(TileShapeOutput{}) == 3, + "TileShapeOutput must be rank-3: [CTA_M_QO, CTA_N_VO, CTA_K_PV]"); + static_assert( + cute::rank(StrideO{}) == 3, + "StrideO must be rank-3: [seq_len_qo, head_size_vo, batch * num_heads]"); + + using CopyThreadShape = Shape<_1, Int>; + + using traits_store_O = Copy_Traits; + using atom_load_O = Copy_Atom; + using val_layout_load_O = decltype(make_layout( + shape_div(typename traits_store_O::BlockShape{}, CopyThreadShape{}))); + using XE_Copy_O = decltype(make_tiled_copy( + atom_load_O{}, + Layout{}, + val_layout_load_O{})); + + private: + constexpr static bool is_destination_supported = + not cute::is_void_v; + + public: + using EmptyType = cute::tuple<>; + + struct TensorStorageImpl : cute::tuple {}; + + struct SharedStorage { + using TensorStorage = TensorStorageImpl; + + TensorStorage tensors; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + + // Host side epilogue arguments + struct Arguments { + ElementO const* ptr_O; + StrideO dO; + float* ptr_LSE; + }; + + // Device side epilogue params + struct Params { + XE_Copy_O xe_store_o; + float* ptr_LSE; + }; + + // + // Methods + // + + template + static constexpr Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + auto + [batch, + num_heads_q, + num_heads_kv, + seq_len_qo, + seq_len_kv, + head_size_qk, + head_size_vo] = problem_shape; + auto tensorO = make_tensor( + make_gmem_ptr(static_cast(args.ptr_O)), + make_layout( + make_shape(seq_len_qo, num_heads_q * head_size_vo, batch), + args.dO)); + XE_Copy_O xe_store_o{XE_Copy_O{}.with(tensorO)}; + return {xe_store_o, args.ptr_LSE}; + } + + template + static size_t get_workspace_size( + ProblemShape const& problem_shape, + Arguments const& args) { + return 0; + } + + template + static cutlass::Status initialize_workspace( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace, + cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + template + CUTLASS_HOST_DEVICE static bool can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + CUTLASS_HOST_DEVICE + FlashPrefillEpilogue(Params const& params_, TensorStorage const&) + : params(params_) {} + + template < + class ProblemShape, + class SequenceLengthShape, + class TileCoord, + class FragOut, + class FragMax, + class FragSum> + CUTLASS_DEVICE void operator()( + ProblemShape problem_shape, + SequenceLengthShape sequence_length_shape, + TileCoord tile_coord, + FragOut& out, + FragMax const& max, + FragSum& sum, + int const& q_head_coord, + float softmax_scale) { + using namespace cute; + + static constexpr bool is_var_len = + cutlass::fmha::collective::is_variable_length_v< + tuple_element_t<2, ProblemShape>>; + + using FragOutLayout = typename FragOut::layout_type; + constexpr int Vec = shape<0>(FragOutLayout{}); + constexpr int FragsM = shape<1>(FragOutLayout{}); + constexpr int FragsN = size(select<2, 3>(shape(FragOutLayout{}))); + + auto g = compat::get_nd_item<1>().get_sub_group(); + auto out_reg = make_tensor( + static_cast(out).data(), + Shape, Int, Int>{}); + float tLSE_reg = {-INFINITY}; + auto rowsum = make_fragment_like(sum); + + CUTLASS_PRAGMA_UNROLL + for (int y = 0; y < FragsM; y++) { + CUTLASS_PRAGMA_UNROLL + for (int x = 0; x < Vec; x++) { + int indx = y * Vec + x; + auto cur_sum = reduce_over_group(g, sum(indx), sycl::plus<>()); + auto cur_scale = (cur_sum == 0.f || cur_sum != cur_sum) + ? 1.0f + : sycl::native::recip(cur_sum); + rowsum(indx) = cur_sum; + CUTLASS_PRAGMA_UNROLL + for (int z = 0; z < FragsN; z++) { + // Handle -nan for bottom right masking. + // It will generate -nan when the whole sequence is masked. We need to + // manually assign 0 to the output + if (std::isnan(out_reg(x, y, z))) { + out_reg(x, y, z) = 0; + } else { + out_reg(x, y, z) *= cur_scale; + } + } + } + } + + // Indexing variables + auto [batch, num_heads_q, head_size_vo] = select<0, 1, 6>(problem_shape); + auto [seq_len_qo] = select<0>(sequence_length_shape); + // Represent the full output tensor + // Tensor mO_mnl = cute::get_xe_tensor(make_shape(seq_len_qo, head_size_vo, + // (is_var_len ? batch : 1) * num_heads_q)); + Tensor mO_mnl = + cute::get_xe_tensor(make_shape(seq_len_qo, head_size_vo, 1)); + + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord; + // Tile the output tensor per WG + Tensor g_wg_O = local_tile( + mO_mnl, + select<0, 1>(TileShapeOutput{}), + make_coord(m_coord, n_coord, 0)); // (BLK_M,BLK_N,m,n,l) + static constexpr auto ATOM_N = + get<2>(typename TiledMmaOutput::ThrLayoutVMNK{}.shape()); + auto m_sg = get_sub_group_id() / ATOM_N; + auto n_sg = get_sub_group_id() % ATOM_N; + // Tile the output tensor per SG + Tensor gO = local_tile( + g_wg_O, + SubgroupTileShape{}, + make_coord(m_sg, n_sg, _), + Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) + auto thread_xe_store_o = params.xe_store_o.get_thread_slice(ThreadIdxX()); + Tensor tOgO = thread_xe_store_o.partition_D(gO); + + Tensor final_out_reg = make_fragment_like(out_reg); + // iff ElementOutput == ElementAccumulator, then convert_type doesn't do the + // right conversion iff ElementOutput == fp8, there is no NumericConverter + // specialization available for both the above cases, we call copy() which + // internally performs a static_cast op on the data. for ElementOutput == + // bf16 | fp16, convert_type calls relevant NumericConverter specialization. + if constexpr ( + cute::is_any_of_v< + ElementOutput, + cute::float_e5m2_t, + cute::float_e4m3_t> || + cute::is_same_v) { + copy(out_reg, final_out_reg); + } else { + Tensor temp = convert_type(out_reg); + copy(temp, final_out_reg); + } + copy(params.xe_store_o, final_out_reg, tOgO); + + // Generating the LSE for backward training + auto sg = compat::get_nd_item<1>().get_sub_group(); + int lane_id = static_cast(sg.get_local_linear_id()); + int sub_group_id = get_sub_group_id(); + const int BLK_M = size(select<0>(TileShapeOutput{})); + + // write along the sequence. + // use the entire sub_group to write lse since all + // work items within subgroup have the same sum() data stored + // in registers + auto blk_m_coord = get<0>(tile_coord); // seq_len_blk_idx + + size_t lse_offset = + k_coord * num_heads_q * seq_len_qo + // shift the batch -- batch_idx * + // num_heads_q * seq_len_qo -- OK + q_head_coord * + seq_len_qo + // shift the head -- head_q * seq_len_qo -- ok + m_coord * BLK_M; // shift to the particular tile + + int localtile_seq_coord = 0; + + // Calculate the sequence coordinate + // The coordinate value should be within [0.. seq_len_qo - 1] + localtile_seq_coord = sub_group_id * SubgroupSize + + lane_id; // one subgroup will handle 16 (usually) sequence + + // checked + int seq_coord = m_coord * BLK_M + localtile_seq_coord; + + // Check that if this is within the seq_len_qo + if (seq_coord < seq_len_qo) { + auto cur_sum = rowsum[lane_id]; + tLSE_reg = + cur_sum == 0.f ? -INFINITY : max * softmax_scale + logf(cur_sum); + *(params.ptr_LSE + lse_offset + localtile_seq_coord) = + std::isnan(tLSE_reg) ? 0 : tLSE_reg; + } + } + + // SequenceLengthShapeType = Shape + // For Fixed Sequence Length, ProblemShapeType = Shape For Variable Sequence Length, ProblemShapeType = Shape + template + CUTLASS_DEVICE static constexpr Params get_updated_copies( + Params const& params, + ProblemShapeType const& problem_shape, + SequenceLengthShapeType const& sequence_length_shape, + int const& l_coord, + int const& q_head_coord) { + auto [num_heads_q, head_size_vo] = select<1, 6>(problem_shape); + auto [seq_len_qo] = select<0>(sequence_length_shape); + int offset_o = 0; + if constexpr (VarLen) { + auto qo_cumulative_length = get<3>(problem_shape).cumulative_length; + offset_o = num_heads_q * head_size_vo * qo_cumulative_length[l_coord] + + q_head_coord * head_size_vo; + } else { + offset_o = num_heads_q * head_size_vo * seq_len_qo * l_coord + + q_head_coord * head_size_vo; + } + auto store_traits = static_cast(params.xe_store_o); + ElementO* base_ptr = (ElementO*)store_traits.base_ptr; + auto shape_o = + make_shape(static_cast(seq_len_qo), num_heads_q * head_size_vo, 1); + StrideO stride_o = cutlass::make_cute_packed_stride(StrideO{}, shape_o); + auto tensorO = make_tensor( + make_gmem_ptr(base_ptr + offset_o), make_layout(shape_o, stride_o)); + XE_Copy_O xe_store_o{XE_Copy_O{}.with(tensorO)}; + return Params{xe_store_o, params.ptr_LSE}; + } + + private: + Params const& params; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace flash_attention +} // namespace cutlass diff --git a/src/ATen/native/transformers/xpu/flash_attn/sycltla/collective/xe_flash_attn_sdpa_fwd_bshd_softmax_epilogue.h b/src/ATen/native/transformers/xpu/flash_attn/sycltla/collective/xe_flash_attn_sdpa_fwd_bshd_softmax_epilogue.h new file mode 100644 index 0000000000..62b6844b3e --- /dev/null +++ b/src/ATen/native/transformers/xpu/flash_attn/sycltla/collective/xe_flash_attn_sdpa_fwd_bshd_softmax_epilogue.h @@ -0,0 +1,175 @@ +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/detail/layout.hpp" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace flash_attention { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class FlashPrefillSoftmaxEpilogue { + static_assert( + cutlass::detail::dependent_false, + "Could not find an epilogue specialization."); +}; + +template +class FlashPrefillSoftmaxEpilogue< + CausalMask_, + epilogue::IntelXeXMX16, + Element_> { + public: + // + // Type Aliases + // + using DispatchPolicy = epilogue::IntelXeXMX16; + using Element = Element_; + + static constexpr bool CausalMask = CausalMask_; + + using GmemTiledCopyOut = void; + + // Host side epilogue arguments + struct Arguments { + Element const scale; + }; + + // Device side epilogue params + using Params = Arguments; + + // + // Methods + // + + static constexpr Params to_underlying_arguments(Arguments const& args) { + constexpr double kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E + Element val = args.scale * static_cast(kLog2e); + return Params{val}; + } + + template + static size_t get_workspace_size() { + return 0; + } + + template + static cutlass::Status initialize_workspace() { + return Status::kSuccess; + } + + template + CUTLASS_HOST_DEVICE static bool can_implement() { + return true; + } + + CUTLASS_HOST_DEVICE + FlashPrefillSoftmaxEpilogue(Params const& params_) : params(params_) {} + + template < + int Vec, + int FragsM, + int FragsN, + class FragAcc, + class FragMax, + class FragSum> + CUTLASS_DEVICE void scale_exp_log2( + FragAcc& frag_s, + FragMax const& max, + FragSum& sum) { + auto g = compat::get_nd_item<1>().get_sub_group(); + const auto max_scale = max * params.scale; + CUTLASS_PRAGMA_UNROLL + for (int indx = 0; indx < Vec * FragsM; indx++) { + const auto max_scale_bcast = group_broadcast(g, max_scale, indx); + CUTLASS_PRAGMA_UNROLL + for (int z = 0; z < FragsN; z++) { + auto base_indx = indx + (z * Vec * FragsM); + Element eq = frag_s(base_indx) - max_scale_bcast; + frag_s(base_indx) = sycl::native::exp2(eq); + sum(indx) += frag_s(base_indx); + } + } + } + + template + CUTLASS_DEVICE void reduce_max(FragSrc& src, FragMax& max) { + auto g = compat::get_nd_item<1>().get_sub_group(); + CUTLASS_PRAGMA_UNROLL + for (int indx = 0; indx < Vec * FragsM; indx++) { + auto maxptr = group_broadcast(g, max, indx); + CUTLASS_PRAGMA_UNROLL + for (int z = 0; z < FragsN; z++) { + auto base_indx = indx + (z * Vec * FragsM); + maxptr = sycl::max(maxptr, src(base_indx)); + src(base_indx) *= params.scale; + } + maxptr = reduce_over_group(g, maxptr, sycl::maximum<>()); + if (indx == g.get_local_id()[0]) { + max = maxptr; + } + } + } + + template + CUTLASS_DEVICE void operator()( + bool is_first, + FragAcc& frag_s, + FragMax& max, + FragSum& sum, + FragOut& out) { + auto max_prev = max; + using FragAccLayout = typename FragAcc::layout_type; + using FragOutLayout = typename FragOut::layout_type; + constexpr int Vec = get<0>(FragAccLayout{}.shape()); + constexpr int FragsM = get<1>(FragAccLayout{}.shape()); + constexpr int FragsNAcc = get<2>(FragAccLayout{}.shape()); + constexpr int FragsNOut = size(select<2, 3>(FragOutLayout{}.shape())); + reduce_max(frag_s, max); + static_assert( + Vec * FragsM % 8 == 0, + " No. of attention rows per subgroup should be >= 1 MMA Atom " + "worth of rows."); + if (!is_first) { + auto g = compat::get_nd_item<1>().get_sub_group(); + Element max_scale{max * params.scale}; + Element exp_scale{ + sycl::native::exp2(max_prev * params.scale - max_scale)}; + CUTLASS_PRAGMA_UNROLL + for (int indx = 0; indx < Vec * FragsM; indx++) { // 16 rows in total + auto max_scale_bcast = group_broadcast(g, max_scale, indx); + auto exp_scale_bcast = group_broadcast(g, exp_scale, indx); + sum(indx) *= exp_scale_bcast; + CUTLASS_PRAGMA_UNROLL + for (int z = 0; z < FragsNAcc; z++) { + auto base_indx = indx + (z * Vec * FragsM); + frag_s(base_indx) = + sycl::native::exp2((frag_s(base_indx) - max_scale_bcast)); + sum(indx) += frag_s(base_indx); + } + CUTLASS_PRAGMA_UNROLL + for (int z = 0; z < FragsNOut; z++) { + auto base_indx = indx + (z * Vec * FragsM); // z * 16 rows + out(base_indx) *= exp_scale_bcast; + } + } + } else { + scale_exp_log2(frag_s, max, sum); + } + } + Params params; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace flash_attention +} // namespace cutlass diff --git a/src/ATen/native/transformers/xpu/flash_attn/sycltla/flash_api.h b/src/ATen/native/transformers/xpu/flash_attn/sycltla/flash_api.h new file mode 100644 index 0000000000..5756a8ce67 --- /dev/null +++ b/src/ATen/native/transformers/xpu/flash_attn/sycltla/flash_api.h @@ -0,0 +1,41 @@ +#pragma once + +#include + +namespace sycltla { + +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + c10::SymInt, + c10::SymInt, + at::Tensor, + at::Tensor> +flash_attention_forward_sycltla( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const double dropout, + const bool is_causal, + const float scale); + +std::tuple flash_attention_backward_sycltla( + const at::Tensor& grad_out, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& out, + const at::Tensor& logsumexp, + const at::Tensor& cumulative_sequence_length_q, + const at::Tensor& cumulative_sequence_length_k, + const int64_t max_seqlen_batch_q, + const int64_t max_seqlen_batch_k, + const double dropout, + const bool is_causal, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset, + const float scale); + +} // namespace sycltla diff --git a/src/ATen/native/transformers/xpu/flash_attn/sycltla/kernel/tile_scheduler_sdpa_fwd_bshd.h b/src/ATen/native/transformers/xpu/flash_attn/sycltla/kernel/tile_scheduler_sdpa_fwd_bshd.h new file mode 100644 index 0000000000..9708ee70e8 --- /dev/null +++ b/src/ATen/native/transformers/xpu/flash_attn/sycltla/kernel/tile_scheduler_sdpa_fwd_bshd.h @@ -0,0 +1,263 @@ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.h" + +namespace cutlass::flash_attention { + +namespace kernel { + +struct XeFlashIndividualTileScheduler { + struct Params { + dim3 grid; + // FastDivmod divmod_num_heads; + }; + + bool valid_ = true; + Params params; + + CUTLASS_DEVICE + XeFlashIndividualTileScheduler(Params const& params) : params(params) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, + KernelHardwareInfo hw_info, + TileShape const& tile_shape) { + using namespace cute; + dim3 grid( + size(ceil_div( + shape<3>(problem_size), + shape<0>(tile_shape))), // seq_len_qo / 128 + size(shape<1>(problem_size)), // num_heads_q + size(shape<0>(problem_size))); // batch + return Params{grid}; + } + + template + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + return make_coord(BlockIdxX(), BlockIdxY(), BlockIdxZ()); + } + + CUTLASS_DEVICE + XeFlashIndividualTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +struct XeFlashDecodeIndividualTileScheduler { + struct Params { + dim3 grid; + FastDivmod divmod_num_heads; + }; + + bool valid_ = true; + Params params; + + CUTLASS_DEVICE + XeFlashDecodeIndividualTileScheduler(Params const& params) : params(params) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, + KernelHardwareInfo hw_info, + TileShape const& tile_shape) { + using namespace cute; + dim3 grid( + size(ceil_div(shape<7>(problem_size), shape<1>(tile_shape))), + size(ceil_div( + shape<3>(problem_size), + 8)), // we want to process only 8 tokens per workgroup + size(shape<0>(problem_size) * shape<1>(problem_size))); + return Params{grid, {shape<1>(problem_size)}}; + } + + template + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int block_decode = BlockIdxZ(); + int bidh; + params.divmod_num_heads(block_decode, bidh, block_decode); + return make_coord(BlockIdxX(), BlockIdxY(), block_decode, bidh); + } + + CUTLASS_DEVICE + XeFlashDecodeIndividualTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct XeFlashPersistentTileScheduler { + struct Params { + int num_blocks; + FastDivmod divmod_seq_len_block; + FastDivmod divmod_head_size_block; + FastDivmod divmod_num_heads; + + KernelHardwareInfo hw_info; + }; + + int block_idx = 0; + Params params; + + CUTLASS_DEVICE + XeFlashPersistentTileScheduler(Params const& params) + : block_idx(BlockIdxX()), params(params) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, + KernelHardwareInfo hw_info, + TileShape const& tile_shape) { + using namespace cute; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST( + " WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments " + "KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); + } + + CUTLASS_TRACE_HOST( + "to_underlying_arguments(): Setting persistent grid SM count to " + << sm_count); + hw_info.sm_count = sm_count; + + int num_head_size_blocks = + size(ceil_div(shape<6>(problem_size), shape<1>(tile_shape))); + int num_seq_len_blocks = + size(ceil_div(shape<3>(problem_size), shape<0>(tile_shape))); + int num_blocks = num_seq_len_blocks * num_head_size_blocks * + size(shape<0>(problem_size) * shape<1>(problem_size)); + + return Params{ + num_blocks, + {num_seq_len_blocks}, + {num_head_size_blocks}, + {shape<1>(problem_size)}, + hw_info}; + } + + template + static dim3 get_grid_shape(Params const& params) { + auto queue = compat::get_default_queue(); + auto dev = queue.get_device(); + const size_t maxSubgroups = + dev.template get_info(); + // TODO (Codeplay): revert this back to std::min(params.num_blocks, + // params.hw_info.sm_count) once performance issue is fixed. + dim3 grid( + std::min( + params.num_blocks, + ceil_div(params.hw_info.sm_count * maxSubgroups, Num_SGs)), + 1, + 1); + return grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return block_idx < params.num_blocks; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int block_decode = block_idx; + int seq_len_block, head_size_block, bidh; + params.divmod_head_size_block(block_decode, head_size_block, block_decode); + params.divmod_seq_len_block(block_decode, seq_len_block, block_decode); + params.divmod_num_heads(block_decode, bidh, block_decode); + return make_coord(head_size_block, seq_len_block, block_decode, bidh); + } + + CUTLASS_DEVICE + XeFlashPersistentTileScheduler& operator++() { + block_idx += GridDimX(); + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// +} // namespace kernel + +struct IndividualScheduler {}; +struct PersistentScheduler {}; +struct FlashDecodeIndividualScheduler {}; + +namespace detail { + +template +struct TileSchedulerSelector { + static_assert( + cutlass::detail::dependent_false, + "Could not select a tile scheduler for given parameters."); +}; + +// Default (void) maps to XeFlashIndividualTileScheduler +template +struct TileSchedulerSelector< + void, + ArchTag, + cute::enable_if_t>> { + using Scheduler = + typename TileSchedulerSelector::Scheduler; +}; + +template +struct TileSchedulerSelector< + IndividualScheduler, + ArchTag, + cute::enable_if_t>> { + using Scheduler = kernel::XeFlashIndividualTileScheduler; +}; + +template +struct TileSchedulerSelector< + PersistentScheduler, + ArchTag, + cute::enable_if_t>> { + using Scheduler = kernel::XeFlashPersistentTileScheduler; +}; + +template +struct TileSchedulerSelector< + FlashDecodeIndividualScheduler, + ArchTag, + cute::enable_if_t>> { + using Scheduler = kernel::XeFlashDecodeIndividualTileScheduler; +}; +} // namespace detail + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::flash_attention diff --git a/src/ATen/native/transformers/xpu/flash_attn/sycltla/kernel/xe_sdpa_fwd_bshd.h b/src/ATen/native/transformers/xpu/flash_attn/sycltla/kernel/xe_sdpa_fwd_bshd.h new file mode 100644 index 0000000000..f2f14eaffe --- /dev/null +++ b/src/ATen/native/transformers/xpu/flash_attn/sycltla/kernel/xe_sdpa_fwd_bshd.h @@ -0,0 +1,761 @@ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/kernel_hardware_info.hpp" + +#include "../collective/xe_flash_attn_prefill_mma_bshd.h" + +namespace cutlass::flash_attention::kernel { + +template < + class ProblemShape, + class CollectiveMainloop, + class CollectiveSoftmaxEpilogue_, + class CollectiveEpilogue, + class TileScheduler_ = void> +class FMHAPrefill; + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveSoftmaxEpilogue_, + class CollectiveEpilogue_, + class TileScheduler_> +class FMHAPrefill { + public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + + static_assert( + rank(ProblemShape{}) == 7, + "ProblemShape{} should be "); + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShapeQK = typename CollectiveMainloop::TileShapeQK; + using TileShapePV = typename CollectiveMainloop::TileShapePV; + using TiledMmaQK = typename CollectiveMainloop::TiledMmaQK; + using TiledMmaPV = typename CollectiveMainloop::TiledMmaPV; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementQ = typename CollectiveMainloop::ElementQ; + using StrideQ = typename CollectiveMainloop::StrideQ; + using ElementK = typename CollectiveMainloop::ElementK; + using StrideK = typename CollectiveMainloop::StrideK; + using ElementV = typename CollectiveMainloop::ElementV; + using StrideV = typename CollectiveMainloop::StrideV; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + + using CollectiveSoftmaxEpilogue = CollectiveSoftmaxEpilogue_; + using SoftmaxArguments = typename CollectiveSoftmaxEpilogue::Arguments; + using SoftmaxParams = typename CollectiveSoftmaxEpilogue::Params; + + static_assert( + cute::is_void_v or + cute::is_same_v or + cute::is_same_v, + "Unsupported TileScheduler for Intel Xe."); + using TileSchedulerTag = TileScheduler_; + using TileScheduler = typename detail:: + TileSchedulerSelector::Scheduler; + using TileSchedulerParams = typename TileScheduler::Params; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementO = typename CollectiveEpilogue::ElementO; + using StrideO = typename CollectiveEpilogue::StrideO; + using ElementLSE = typename CollectiveEpilogue::ElementLSE; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + using TileShapeOutput = typename CollectiveEpilogue::TileShapeOutput; + using TiledMmaOutput = typename CollectiveEpilogue::TiledMmaOutput; + + static_assert( + cute::is_same_v< + ElementAccumulator, + typename CollectiveEpilogue::ElementAccumulator>, + "Mainloop and epilogue do not agree on accumulator value type."); + // MSVC requires the cast to fix a warning-as-error. + static constexpr int SharedStorageSize = 0; + + static constexpr bool CausalMask = CollectiveMainloop::CausalMask; + static constexpr int SubgroupSize = + CollectiveMainloop::SubgroupSize; // sub_group size + static constexpr uint32_t MaxThreadsPerBlock = + CollectiveMainloop::MaxThreadsPerBlock; + using MmaAtomShape = typename CollectiveMainloop::MmaAtomShape; // 8,16,16 + + static constexpr int QK_BLK_M = CollectiveMainloop::QK_BLK_M; + static constexpr int QK_BLK_N = CollectiveMainloop::QK_BLK_N; + static constexpr int QK_BLK_K = CollectiveMainloop::QK_BLK_K; + + static constexpr int QK_ATOM_N = CollectiveMainloop::QK_ATOM_N; + static constexpr int QK_ATOM_K = CollectiveMainloop::QK_ATOM_K; + + static constexpr int QK_SG_M = CollectiveMainloop::QK_SG_M; + + static constexpr int Epilogue_BLK_N = get<1>(TileShapeOutput{}); + static constexpr int Epilogue_BLK_K = get<2>(TileShapeOutput{}); + + static constexpr int PV_ATOM_M = CollectiveMainloop::PV_ATOM_M; + static constexpr int PV_ATOM_N = CollectiveMainloop::PV_ATOM_N; + static constexpr int PV_ATOM_K = CollectiveMainloop::PV_ATOM_K; + + static constexpr auto Num_SGs = PV_ATOM_N * PV_ATOM_M * PV_ATOM_K; + static constexpr int Vec = CollectiveMainloop::Vec; + static constexpr int FragsM = CollectiveMainloop::FragsM; + // The FragsN here used for Creation of S matrix so we use the FragsN for S + // shape + static constexpr int FragsN = CollectiveMainloop::FragsNS; + + static constexpr int VSlicer = get<1>(TileShapeOutput{}) / + (get<1>(TileShapePV{}) * PV_ATOM_N); // ceil_div(FragsNOut,FragsNS); + using AccumeShape = decltype(make_shape( + Int{}, + Int{}, + get<1>(TileShapePV{}) / get<1>(MmaAtomShape()), + Int{})); + + static constexpr bool is_var_len = CollectiveMainloop::is_var_len; + // Kernel level shared memory storage + struct SharedStorage { + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + EpilogueTensorStorage epilogue; + }; + + // Device side arguments + struct Arguments { + gemm::GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + SoftmaxArguments softmax{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + float softmax_scale; + }; + + // Kernel entry point API + struct Params { + gemm::GemmUniversalMode mode; + ProblemShape problem_shape; + MainloopParams mainloop; + SoftmaxParams softmax; + EpilogueParams epilogue; + TileSchedulerParams scheduler; + float softmax_scale; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the + // aliased type. + static Params to_underlying_arguments( + Arguments const& args, + void* workspace) { + (void)workspace; + return { + args.mode, + args.problem_shape, + CollectiveMainloop::to_underlying_arguments( + args.problem_shape, args.mainloop, workspace), + CollectiveSoftmaxEpilogue::to_underlying_arguments(args.softmax), + CollectiveEpilogue::to_underlying_arguments( + args.problem_shape, args.epilogue, workspace), + TileScheduler::to_underlying_arguments( + args.problem_shape, args.hw_info, TileShapeOutput{}), + args.softmax_scale}; + } + + static bool can_implement(Arguments const& args) { + bool mode_implementable = args.mode == gemm::GemmUniversalMode::kGemm or + (args.mode == gemm::GemmUniversalMode::kBatched && + rank(ProblemShape{}) == 4); + return mode_implementable; + } + + static int get_workspace_size(Arguments const& args) { + return 0; + } + + static cutlass::Status initialize_workspace( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::template get_grid_shape(params.scheduler); + } + + static dim3 get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + Shape get_sequence_length_shape( + ProblemShape const& problem_shape, + int const& batch) { + if constexpr (is_var_len) { + return cutlass::fmha::collective::apply_variable_length( + select<3, 4>(problem_shape), batch); + } else { + return select<3, 4>(problem_shape); + } + } + + CUTLASS_DEVICE + void operator()(Params const& params, char* smem_buf) { + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + // Preconditions + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + // Separate out problem shape for convenience + auto& batch = get<0>(params.problem_shape); + auto& num_heads_q = get<1>(params.problem_shape); + auto& num_head_kv = get<2>(params.problem_shape); + auto group_heads_q = num_heads_q / num_head_kv; + auto& head_size_qk = get<5>(params.problem_shape); + auto& head_size_vo = get<6>(params.problem_shape); + auto& softmax_scale = params.softmax_scale; + + // Preconditions + static_assert( + cute::rank(StrideQ{}) == 3, + "StrideQ must be rank-3: [seq_len_qo, head_size_qk, batch * " + "num_heads_q]."); + static_assert( + cute::rank(StrideK{}) == 3, + "StrideK must be rank-3: [head_size_qk, seq_len_kv, batch * " + "num_heads_kv]."); + static_assert( + cute::rank(StrideV{}) == 3, + "StrideV must be rank-3: [seq_len_kv, head_size_vo, batch * " + "num_heads_kv]."); + + int thread_idx = int(ThreadIdxX()); + int sub_group_id = thread_idx / SubgroupSize; + + TileScheduler tile_scheduler{params.scheduler}; + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = + tile_scheduler + .get_block_coord(); // head_size_blk_idx, seq_len_blk_idx, + // batch_blk_idx, num_heads_blk_idx + + auto blk_m_coord = get<0>(blk_coord); // seq_len_blk_idx + auto q_head_coord = get<1>(blk_coord); // q_heads_idx + auto batch_coord = get<2>(blk_coord); // batch_blk_idx + auto blk_n_coord = 0; // nums_head_blk_idx - not defined in TileScheduler + + // For variable sequence length case, batch is considered to be 1 (same as + // group gemm). For fixed sequence length case, the l_coord is the + // weighted sum of both batch_coord and num_heads_coord. Flash Attention + // implementation combines batch and num_heads to calculate the total + // batch_size. iff is_var_len: batch_size = num_heads (as each batch would + // have it's own seq_len_qo and seq_len_kv) iff !is_var_len: batch_size = + // batch * num_heads + // auto blk_l_coord = is_var_len ? num_heads_coord : batch_coord * + // num_heads_q + num_heads_coord; + + // Get problem shape for the current batch_blk_idx. For variable sequence + // length, it loads the sequence length from Global memory for the given + // batch_blk_idx and returns the appropriate problem_shape. For fixed + // sequence length, sequence_length_shape == select<3, + // 4>(params.problem_shape). sequence_length_shape = [seq_len_qo, + // seq_len_kv] + auto sequence_length_shape = + get_sequence_length_shape(params.problem_shape, batch_coord); + + auto [seq_len_qo, seq_len_kv] = sequence_length_shape; + + // This is for the bottom right masking, which happens when training with + // speculative decoding. In that case, the `is_causal` masking behavior + // will be changed and we need to adjust the main loop to perform + // appropriate calculations + if (seq_len_qo > seq_len_kv && CausalMask) { + int first_non_masked_sequence = seq_len_qo - seq_len_kv; + + int seq_coord = cute::min( + seq_len_qo, + (blk_m_coord * QK_BLK_M + (sub_group_id / PV_ATOM_N) * QK_SG_M) % + seq_len_qo); + + // Calculate the seq_len_idx (blk_m_coord * get<0>(TileShapeOutput{})) + // and check if it is still within bounds of the actual seq_len_qo + // (get<0>(sequence_length_shape)). + if (blk_m_coord * get<0>(TileShapeOutput{}) >= seq_len_qo) { + continue; + } + + // calculate the last seq_len_qo of this subblock + int last_seq_coord = seq_coord + QK_SG_M - 1; // 5 + + if (last_seq_coord < + first_non_masked_sequence) { // no need to perform calculation as + // those sections are masked + continue; + } + + // The main idea is to calculate the longest non-masked elements for + // this subgroup It is calculated by leveraging the property of bottom + // right mask + + // Calculate the longest length of the non-masked sequences for this + // subblock. The sequence is always the last one of subblock. + int longest_non_masked_length = cute::min( + seq_len_kv, + cute::max(0, last_seq_coord - first_non_masked_sequence + 1)); + + const int seq_len = cute::min(seq_len_kv, longest_non_masked_length); + + const int nblock_limit = cute::ceil_div(seq_len, QK_BLK_N); + + Tensor mQ_mkl = cute::get_xe_tensor( + make_shape(seq_len_qo, head_size_qk, 1)); //(m,k,l) + Tensor mK_nkl = cute::get_xe_tensor( + make_shape(seq_len_kv, head_size_qk, 1)); //(n,k,l) + Tensor mV_nkl = cute::get_xe_tensor( + make_shape(head_size_vo, seq_len_kv, 1)); //(n,k,l) + Tensor mQ_mk = mQ_mkl(_, _, 0); + Tensor mK_nk = mK_nkl(_, _, 0); // (n,k) + Tensor mV_nk = mV_nkl(_, _, 0); + + auto gQ = local_tile( + mQ_mk, + TileShapeQK{}, + make_coord(blk_m_coord, _, _), + Step<_1, X, _1>{}); + auto gK = local_tile( + mK_nk, TileShapeQK{}, make_coord(_, _, _), Step{}); + auto gV = local_tile( + mV_nk, + TileShapeOutput{}, + make_coord(_, blk_n_coord, _), + Step{}); + + auto mainloop_params = CollectiveMainloop::get_updated_copies( + params.mainloop, + params.problem_shape, + sequence_length_shape, + batch_coord, + q_head_coord); + // we limit the horisontal size to two subgroup, the empirical resutls + // show that reading the two cacheline side by side in gives better + // performance and anything after that does not have an effect on + // performance. // (64 here for float b float when possible and loop + // over to cover all the data needed) + auto tiled_prefetch_q = cute::prefetch_selector< + Shape, Int>, + Num_SGs>(mainloop_params.gmem_tiled_copy_q); + auto tiled_prefetch_k = cute::prefetch_selector< + Shape, Int>, + Num_SGs>(mainloop_params.gmem_tiled_copy_k); + auto tiled_prefetch_v = cute::prefetch_selector< + Shape< + Int, + Int>, + Num_SGs>(mainloop_params.gmem_tiled_copy_v); + auto thr_prefetch_Q = tiled_prefetch_q.get_slice(thread_idx); + auto thr_prefetch_K = tiled_prefetch_k.get_slice(thread_idx); + auto thr_prefetch_V = tiled_prefetch_v.get_slice(thread_idx); + auto pQgQ = thr_prefetch_Q.partition_S(gQ); + auto pKgK = thr_prefetch_K.partition_S(gK); + auto pVgV = thr_prefetch_V.partition_S(gV); + + for (int i = 0; i < size<3>(pQgQ); i++) { + prefetch(tiled_prefetch_q, pQgQ(_, _, _, i)); + } + for (int j = 0; j < size<4>(pKgK); j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < DispatchPolicy::Stages; i++) { + prefetch(tiled_prefetch_k, pKgK(_, _, _, i, j)); + } + } + + // Allocate the tiled_mma and the accumulators for the (M,N) + // workgroup_shape + Tensor out_reg = make_tensor(AccumeShape{}); + + // There are 16 workitem and 16 max per subgroup, each worktime containt + // 1 max and cumulatively, they calculate the max per subgroup + ElementAccumulator max_reg{-INFINITY}; + + // The sum reg each contains a 2d tesnor for 8 x 2 This is number of + // sequence lenght process per subgroup + Tensor sum_reg = + make_tensor(Shape, Int>{}); + + clear(sum_reg); + clear(out_reg); + + // Perform the collective scoped MMA + CollectiveMainloop collective_mma; + // when causal mask is true. It is not possible to set the scope + // of the barrier to workgroup level as the number n block is + // different for each subgroup due to triangular nature of causal based + // operation + static constexpr int barrier_scope = CausalMask ? 3 : 2; + // MAIN LOOP: loop over K and V, perform fused attention + online + // softmax + for (int nblock = 0; + nblock < nblock_limit - static_cast(CausalMask); + nblock++) { + barrier_arrive(barrier_scope); + // 1) Load K (performed inside mmaQK) + // 2) Create Tensor S + Tensor tSr = make_tensor( + Shape, Int, Int>{}); + clear(tSr); + + // 3) Perform GEMM S = Q*K + collective_mma.mmaQK( + tSr, + gQ, + gK(_, _, nblock, _), + tSr, + ceil_div(head_size_qk, QK_BLK_K), + mainloop_params); + + // we only need one block ahead, there is enough gap to prefetch it + // while doing softmax. because the gap between the two MMA is big, + // prefetching it the same way as cutlass K matrix does not make sense + for (int i = 0; i < size<1>(pVgV); i++) { + prefetch(tiled_prefetch_v, pVgV(_, i, _, nblock)); + } + + CollectiveSoftmaxEpilogue softmax(params.softmax); + softmax(nblock == 0, tSr, max_reg, sum_reg, out_reg); + + collective_mma.template mmaPV( + out_reg, tSr, gV(_, _, nblock), out_reg, mainloop_params); + + // Prefetch the next K tile + // there is no need to gaurd it with if statememt as prefetch will + // ignore out of bound reading + for (int j = 0; j < size<4>(pKgK); j++) { + prefetch( + tiled_prefetch_k, + pKgK(_, _, _, nblock + DispatchPolicy::Stages, j)); + } + barrier_wait(barrier_scope); + } + + if constexpr (CausalMask) { + // BAND Matrix + // 1) Load K (performed inside mmaQK) + // 2) Create Tensor S + Tensor tSr = make_tensor( + Shape, Int, Int>{}); + clear(tSr); + // 3) Perform GEMM S = Q*K + collective_mma.mmaQK( + tSr, + gQ, + gK(_, _, nblock_limit - 1, _), + tSr, + ceil_div(head_size_qk, QK_BLK_K), + mainloop_params); + // we only need one block ahead, there is enough gap to prefetch it + // while doing softmax. because the gap between the two MMA is big, + // prefetching it the same way as cutlass K matrix does not make sense + for (int i = 0; i < size<1>(pVgV); i++) { + prefetch(tiled_prefetch_v, pVgV(_, i, _, nblock_limit - 1)); + } + // mask the elements of each tile where j > i + const int item_id = thread_idx % SubgroupSize; + int col_idx = item_id + (nblock_limit - 1) * QK_BLK_N; + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < FragsN; + n++, col_idx += get<1>(MmaAtomShape())) { // 4 + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < FragsM; m++) { // 2 + int row_idx = m * Vec + seq_coord; + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < Vec; row++, row_idx++) { // 8 + if (col_idx > row_idx - first_non_masked_sequence || + row_idx < first_non_masked_sequence) { + tSr(row, m, n) = ElementAccumulator{-INFINITY}; + } + } + } + } + + CollectiveSoftmaxEpilogue softmax(params.softmax); + softmax((nblock_limit - 1) == 0, tSr, max_reg, sum_reg, out_reg); + + collective_mma.template mmaPV( + out_reg, + tSr, + gV(_, _, nblock_limit - 1), + out_reg, + mainloop_params); + } + + auto epilogue_params = + CollectiveEpilogue::template get_updated_copies( + params.epilogue, + params.problem_shape, + sequence_length_shape, + batch_coord, + q_head_coord); + CollectiveEpilogue epilogue{epilogue_params, shared_storage.epilogue}; + auto blk_coord_mnkl = + make_coord(blk_m_coord, blk_n_coord, batch_coord, 0); + epilogue( + params.problem_shape, + sequence_length_shape, + blk_coord_mnkl, + out_reg, + max_reg, + sum_reg, + q_head_coord, + softmax_scale); + } + // seq_len_kv == seq_len_qo + else { + const int seq_coord = cute::min( + seq_len_qo, + (blk_m_coord * QK_BLK_M + (sub_group_id / PV_ATOM_N) * QK_SG_M) % + seq_len_qo); + + // Calculate the seq_len_idx (blk_m_coord * get<0>(TileShapeOutput{})) + // and check if it is still within bounds of the actual seq_len_qo + // (get<0>(sequence_length_shape)). + if (blk_m_coord * get<0>(TileShapeOutput{}) >= seq_len_qo) { + continue; + } + + auto offset = cute::min(seq_len_qo, seq_len_kv); //(2048, 1024) + auto discard_seq_coord = seq_len_qo - offset; // 1024 + auto full_tile_offset = seq_len_kv - offset; // 0 + + const int seq_len = CausalMask ? full_tile_offset + + cute::min(seq_len_kv, seq_coord - discard_seq_coord) + QK_SG_M + : seq_len_kv; + const int nblock_limit = cute::ceil_div(seq_len, QK_BLK_N); + if (CausalMask && seq_coord < discard_seq_coord) { // 1024 =0 + continue; + } + + Tensor mQ_mkl = cute::get_xe_tensor( + make_shape(seq_len_qo, head_size_qk, 1)); //(m,k,l) + Tensor mK_nkl = cute::get_xe_tensor( + make_shape(seq_len_kv, head_size_qk, 1)); //(n,k,l) + Tensor mV_nkl = cute::get_xe_tensor( + make_shape(head_size_vo, seq_len_kv, 1)); //(n,k,l) + Tensor mQ_mk = mQ_mkl(_, _, 0); + Tensor mK_nk = mK_nkl(_, _, 0); // (n,k) + Tensor mV_nk = mV_nkl(_, _, 0); + + auto gQ = local_tile( + mQ_mk, + TileShapeQK{}, + make_coord(blk_m_coord, _, _), + Step<_1, X, _1>{}); + auto gK = local_tile( + mK_nk, TileShapeQK{}, make_coord(_, _, _), Step{}); + auto gV = local_tile( + mV_nk, + TileShapeOutput{}, + make_coord(_, blk_n_coord, _), + Step{}); + + auto mainloop_params = CollectiveMainloop::get_updated_copies( + params.mainloop, + params.problem_shape, + sequence_length_shape, + batch_coord, + q_head_coord); + // we limit the horisontal size to two subgroup, the empirical resutls + // show that reading the two cacheline side by side in gives better + // performance and anything after that does not have an effect on + // performance. // (64 here for float b float when possible and loop + // over to cover all the data needed) + auto tiled_prefetch_q = cute::prefetch_selector< + Shape, Int>, + Num_SGs>(mainloop_params.gmem_tiled_copy_q); + auto tiled_prefetch_k = cute::prefetch_selector< + Shape, Int>, + Num_SGs>(mainloop_params.gmem_tiled_copy_k); + auto tiled_prefetch_v = cute::prefetch_selector< + Shape< + Int, + Int>, + Num_SGs>(mainloop_params.gmem_tiled_copy_v); + auto thr_prefetch_Q = tiled_prefetch_q.get_slice(thread_idx); + auto thr_prefetch_K = tiled_prefetch_k.get_slice(thread_idx); + auto thr_prefetch_V = tiled_prefetch_v.get_slice(thread_idx); + auto pQgQ = thr_prefetch_Q.partition_S(gQ); + auto pKgK = thr_prefetch_K.partition_S(gK); + auto pVgV = thr_prefetch_V.partition_S(gV); + + for (int i = 0; i < size<3>(pQgQ); i++) { + prefetch(tiled_prefetch_q, pQgQ(_, _, _, i)); + } + for (int j = 0; j < size<4>(pKgK); j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < DispatchPolicy::Stages; i++) { + prefetch(tiled_prefetch_k, pKgK(_, _, _, i, j)); + } + } + + // Allocate the tiled_mma and the accumulators for the (M,N) + // workgroup_shape + Tensor out_reg = make_tensor(AccumeShape{}); + + // There are 16 workitem and 16 max per subgroup, each worktime containt + // 1 max and cumulatively, they calculate the max per subgroup + ElementAccumulator max_reg{-INFINITY}; + + // The sum reg each contains a 2d tesnor for 8 x 2 This is number of + // sequence lenght process per subgroup + Tensor sum_reg = + make_tensor(Shape, Int>{}); + + clear(sum_reg); + clear(out_reg); + + // Perform the collective scoped MMA + CollectiveMainloop collective_mma; + // when causal mask is true. It is not possible to set the scope + // of the barrier to workgroup level as the number n block is + // different for each subgroup due to triangular nature of causal based + // operation + static constexpr int barrier_scope = CausalMask ? 3 : 2; + // MAIN LOOP: loop over K and V, perform fused attention + online + // softmax + for (int nblock = 0; + nblock < nblock_limit - static_cast(CausalMask); + nblock++) { + barrier_arrive(barrier_scope); + // 1) Load K (performed inside mmaQK) + // 2) Create Tensor S + Tensor tSr = make_tensor( + Shape, Int, Int>{}); + clear(tSr); + + // 3) Perform GEMM S = Q*K + collective_mma.mmaQK( + tSr, + gQ, + gK(_, _, nblock, _), + tSr, + ceil_div(head_size_qk, QK_BLK_K), + mainloop_params); + + // we only need one block ahead, there is enough gap to prefetch it + // while doing softmax. because the gap between the two MMA is big, + // prefetching it the same way as cutlass K matrix does not make sense + for (int i = 0; i < size<1>(pVgV); i++) { + prefetch(tiled_prefetch_v, pVgV(_, i, _, nblock)); + } + + CollectiveSoftmaxEpilogue softmax(params.softmax); + softmax(nblock == 0, tSr, max_reg, sum_reg, out_reg); + + collective_mma.template mmaPV( + out_reg, tSr, gV(_, _, nblock), out_reg, mainloop_params); + + // Prefetch the next K tile + // there is no need to gaurd it with if statememt as prefetch will + // ignore out of bound reading + for (int j = 0; j < size<4>(pKgK); j++) { + prefetch( + tiled_prefetch_k, + pKgK(_, _, _, nblock + DispatchPolicy::Stages, j)); + } + barrier_wait(barrier_scope); + } + + if constexpr (CausalMask) { + // BAND Matrix + // 1) Load K (performed inside mmaQK) + // 2) Create Tensor S + Tensor tSr = make_tensor( + Shape, Int, Int>{}); + clear(tSr); + // 3) Perform GEMM S = Q*K + collective_mma.mmaQK( + tSr, + gQ, + gK(_, _, nblock_limit - 1, _), + tSr, + ceil_div(head_size_qk, QK_BLK_K), + mainloop_params); + // we only need one block ahead, there is enough gap to prefetch it + // while doing softmax. because the gap between the two MMA is big, + // prefetching it the same way as cutlass K matrix does not make sense + for (int i = 0; i < size<1>(pVgV); i++) { + prefetch(tiled_prefetch_v, pVgV(_, i, _, nblock_limit - 1)); + } + // mask the elements of each tile where j > i + const int item_id = thread_idx % SubgroupSize; + int col_idx = item_id + (nblock_limit - 1) * QK_BLK_N; + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < FragsN; + n++, col_idx += get<1>(MmaAtomShape())) { // 4 + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < FragsM; m++) { // 2 + int row_idx = m * Vec + seq_coord; + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < Vec; row++, row_idx++) { // 8 + if (col_idx - full_tile_offset > row_idx - discard_seq_coord) { + tSr(row, m, n) = ElementAccumulator{-INFINITY}; + } + } + } + } + + CollectiveSoftmaxEpilogue softmax(params.softmax); + softmax((nblock_limit - 1) == 0, tSr, max_reg, sum_reg, out_reg); + + collective_mma.template mmaPV( + out_reg, + tSr, + gV(_, _, nblock_limit - 1), + out_reg, + mainloop_params); + } + + auto epilogue_params = + CollectiveEpilogue::template get_updated_copies( + params.epilogue, + params.problem_shape, + sequence_length_shape, + batch_coord, + q_head_coord); + CollectiveEpilogue epilogue{epilogue_params, shared_storage.epilogue}; + auto blk_coord_mnkl = + make_coord(blk_m_coord, blk_n_coord, batch_coord, 0); + epilogue( + params.problem_shape, + sequence_length_shape, + blk_coord_mnkl, + // out_reg, max_reg, sum_reg, q_head_coord, 0.125); + out_reg, + max_reg, + sum_reg, + q_head_coord, + softmax_scale); + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::flash_attention::kernel diff --git a/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp b/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp new file mode 100644 index 0000000000..d25117d846 --- /dev/null +++ b/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp @@ -0,0 +1,1541 @@ +#include +#include + +// batch, numhead_qo,numhead_kv,seqlen_qo,seqlen_kv,headsize_qk,headsize_vo +using ProblemShapeRegular = cute::tuple; + +namespace cute { + +template +auto convert_layout_2d_layout(Layout layout) { + auto l = + make_layout(make_layout(get<0>(layout), get<1>(layout)), get<2>(layout)); + return l; +} + +template +void compute_o_dot_do( + T& trait, + Param& param, + const int m_block, + const int bidb, + const int bidh) { + // The thread index. + constexpr int kBlockM = T::kBlockM; + constexpr int kBlockN = T::kBlockN; + constexpr int kHeadDim = T::kHeadDim; + constexpr int kNSGs = T::kNSGs; + constexpr int SubgroupSize = T::SubgroupSize; + using DType = typename T::DType; + using VType = typename T::VType; + + auto sg = compat::get_nd_item<1>().get_sub_group(); + auto group = compat::get_nd_item<1>().get_group(); + auto first_thread_in_sg_idx = sg.get_group_linear_id() * trait.SubgroupSize; + auto bofst = Boffset(param); + + const index_t o_offset = bofst.o_offset(bidb, bidh, m_block * kBlockM); + const index_t dq_offset = bofst.dq_offset(bidb, bidh, m_block * kBlockM); + const index_t dpsum_offset = bofst.lse_offset(bidb, bidh, m_block * kBlockM); + + using ShapeO = + Shape, int>, Int>; + using ShapeP = Shape, int>>; + ShapeO O_shape; + ShapeP dP_shape; + if constexpr (Is_even_M) { + O_shape = make_shape(Int{}, Int{}); + dP_shape = make_shape(Int{}); + } else { + O_shape = make_shape(param.tail_m, Int{}); + dP_shape = make_shape(param.tail_m); + } + Shape dQ_shape = make_shape(Int{}, Int{}); + + Tensor mdO = make_tensor( + make_gmem_ptr(param.do_ptr + o_offset), + make_layout(O_shape, make_stride(param.o_r_stride, _1{}))); + Tensor mO = make_tensor( + make_gmem_ptr(param.o_ptr + o_offset), + make_layout(O_shape, make_stride(param.o_r_stride, _1{}))); + Tensor mdQaccum = make_tensor( + make_gmem_ptr(param.dqaccum_ptr + dq_offset), + make_layout( + make_shape(Int{}, Int{}), + make_stride(param.dq_r_stride, _1{}))); + Tensor mdPsum = make_tensor( + make_gmem_ptr(param.odo_ptr + dpsum_offset), + make_layout(dP_shape, Stride<_1>{})); + + auto tileload_odo = make_tiled_copy( + Copy_Atom, DType>{}, + Layout< + Shape, Int>, + Stride, _1>>{}, + Layout>{}); + auto tileload_dq = make_tiled_copy( + Copy_Atom, VType>{}, + Layout, Int>>{}, + Layout>{}); + auto thr_load_odo = tileload_odo.get_thread_slice(ThreadIdxX()); + auto thr_load_dq = tileload_dq.get_thread_slice(ThreadIdxX()); + + Tensor thr_tile_do_S = thr_load_odo.partition_S(mdO); + Tensor thr_tile_o_S = thr_load_odo.partition_S(mO); + Tensor thr_tile_dq_D = thr_load_dq.partition_D(mdQaccum); + Tensor rdQ = make_fragment_like(thr_tile_dq_D); + Tensor rdO = make_fragment_like(rdQ); + Tensor rO = make_fragment_like(rdQ); + clear(rdQ); + copy(tileload_dq, rdQ, thr_tile_dq_D); + + Tensor cO = make_identity_tensor(dQ_shape); + Tensor tcO = thr_load_odo.partition_S(cO); + Tensor tcO_row = logical_divide(tcO, Shape<_1>{})(make_coord(0, 0), _, 0); + Tensor rdO_2d = + make_tensor(rdO.data(), convert_layout_2d_layout(rdO.layout())); + Tensor rO_2d = make_tensor(rO.data(), convert_layout_2d_layout(rO.layout())); + if constexpr (Is_even_M) { + copy(tileload_odo, thr_tile_do_S, rdO); + copy(tileload_odo, thr_tile_o_S, rO); + CUTLASS_PRAGMA_UNROLL + for (int mi = 0; mi < size<0>(rdO_2d); ++mi) { + float accum = 0.0f; + CUTLASS_PRAGMA_UNROLL + for (int ni = 0; ni < size<1>(rdO_2d); ++ni) { + accum = accum + (float)rdO_2d(mi, ni) * (float)rO_2d(mi, ni); + } + accum = sycl::reduce_over_group(sg, accum, sycl::plus<>()); + if (sg.get_local_id() == 0) { + mdPsum(get<0>(tcO_row(mi))) = accum; + } + } + } else { + for (int mi = 0; mi < size<0>(rdO_2d); ++mi) { + if (get<0>(tcO_row(mi)) < param.tail_m) { + copy(tileload_odo, thr_tile_do_S(_, mi, _), rdO(_, mi, _)); + copy(tileload_odo, thr_tile_o_S(_, mi, _), rO(_, mi, _)); + } + } + CUTLASS_PRAGMA_UNROLL + for (int mi = 0; mi < size<0>(rdO_2d); ++mi) { + float accum = 0.0f; + CUTLASS_PRAGMA_UNROLL + for (int ni = 0; ni < size<1>(rdO_2d); ++ni) { + accum = accum + (float)rdO_2d(mi, ni) * (float)rO_2d(mi, ni); + } + accum = sycl::reduce_over_group(sg, accum, sycl::plus<>()); + if (sg.get_local_id() == 0 and get<0>(tcO_row(mi)) < param.tail_m) + mdPsum(get<0>(tcO_row(mi))) = accum; + } + } +} + +template +void mha_dot_do_o(T trait, Param param) { + // The block index for the M dimension. + const int m_block = BlockIdxX(); + // The block index for the batch. + const int bidb = BlockIdxZ(); + // The block index for the head. + const int bidh = BlockIdxY(); + ; + if (m_block == param.m_block - 1 and param.tail_m > 0) { + compute_o_dot_do(trait, param, m_block, bidb, bidh); + } else { + compute_o_dot_do(trait, param, m_block, bidb, bidh); + } +} + +template < + typename Tensor0, + typename Tensor1, + typename Tensor2, + typename Tensor3, + typename Tensor4, + typename Tensor5, + typename Tensor6, + typename Tensor7, + typename Tensor8, + typename TiledMma, + typename TileMNK, + typename TiledCopyA, + typename TiledCopyB> +CUTLASS_DEVICE void gemm_ker( + Tensor0& tCrCmn, + Tensor1& tCrA, + Tensor2& tCrB, + Tensor3& tAgAmk, + Tensor4& tArA, + Tensor5& gA, + Tensor6& tBgBnk, + Tensor7& tBrB, + Tensor8& gB, + TiledMma& tiled_mma, + TileMNK& tile_mnk, + TiledCopyA& copy_a, + TiledCopyB& copy_b) { + constexpr int barrier_scope = 2; + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<3>(tAgAmk); ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<3>(tBgBnk); ++n) { + auto tCrC = tCrCmn(_, _, _, m, n); + auto tAgA = tAgAmk(_, _, _, m, _); + auto tBgB = tBgBnk(_, _, _, n, _); + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<3>(tAgA); ++k) { + barrier_arrive(barrier_scope); + cute::copy(copy_a, tAgA(_, _, _, k), tArA); + cute::copy(copy_b, tBgB(_, _, _, k), tBrB); + cute::gemm(tiled_mma, tCrA, tCrB, tCrC); + barrier_wait(barrier_scope); + } + } + } +} + +template < + typename Engine0, + typename Layout0, + typename Engine1, + typename Layout1> +CUTLASS_DEVICE void apply_mask_causal( + Tensor& tensor, + Tensor& rC, + int m_offset, + int n_offset) { + auto sg = compat::get_nd_item<1>().get_sub_group(); + auto group = compat::get_nd_item<1>().get_group(); + int sg_local_id = sg.get_local_id(); + int sg_group_id = sg.get_group_id(); + Tensor rC_2d = make_tensor(rC.data(), convert_layout_2d_layout(rC.layout())); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<1>(tensor); ++n) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<0>(tensor); ++m) { + int x = n_offset + get<1>(rC_2d(m, n)) + sg_local_id; + int y = m_offset + get<0>(rC_2d(m, n)); + if (x > y) { + tensor(m, n) = -INFINITY; + } + } + } + return; +} + +template < + bool Is_even_MN, + class TileCopy, + class Engine0, + class Layout0, + class Engine1, + class Layout1> +CUTLASS_DEVICE void mha_save( + TileCopy& tile_copy, + Tensor& src, + Tensor& dst) { + static_assert(Layout0::rank == 5, "Only support Tensor with 5 ranks"); + static_assert( + Layout0::rank == Layout1::rank, "Only support same rank Tensor"); + if constexpr (Is_even_MN) { + copy(tile_copy, src, dst); + } else { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<3>(dst); ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<4>(dst); ++n) { + auto src_block = src(_, _, _, m, n); + auto dst_block = dst(_, _, _, m, n); + copy(tile_copy, src_block, dst_block); + } + } + } +} + +template < + bool Is_even_MN, + class TileCopy, + class Engine0, + class Layout0, + class Engine1, + class Layout1> +CUTLASS_DEVICE void mha_load( + TileCopy& tile_copy, + Tensor& src, + Tensor& dst) { + static_assert(Layout0::rank == 5, "Only support Tensor with 5 ranks"); + static_assert( + Layout0::rank == Layout1::rank, "Only support same rank Tensor"); + if constexpr (Is_even_MN) { + copy(tile_copy, src, dst); + } else { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<3>(src); ++m) { + auto src_block = src(_, _, _, m, _); + auto dst_block = dst(_, _, _, m, _); + copy(tile_copy, src_block, dst_block); + } + } +} + +template +CUTLASS_DEVICE void load_1colvec( + Tensor0& reg, + Tensor1& mT, + Tensor2& coord_row, + int tail_m = 0) { + if constexpr (Is_even_M) { + CUTLASS_PRAGMA_UNROLL + for (int mi = 0; mi < size(reg); ++mi) { + reg(mi) = mT(get<0>(coord_row(mi))); + } + } else { + for (int mi = 0; mi < size(reg); ++mi) { + int row = get<0>(coord_row(mi)); + if (row < tail_m) { + reg(mi) = mT(row); + } + } + } +} + +template +CUTLASS_DEVICE auto convert_layout_acc_layout(Layout acc_layout) { + static_assert(decltype(size<0>(acc_layout))::value == 8); + static_assert(decltype(rank(acc_layout))::value == 5); + auto l = logical_divide(acc_layout, Shape<_1>{}); // ((2, 2), MMA_M, MMA_N) + auto l2 = make_layout( + make_layout(get<0, 1>(l), get<1>(l), get<3>(l)), + make_layout(get<0, 0>(l), get<4>(l))); + return l2; +} + +template +CUTLASS_DEVICE void scale_apply_exp2( + Tensor& tensor, + Tensor& max, + const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + CUTLASS_PRAGMA_UNROLL + for (int mi = 0; mi < size<0>(tensor); ++mi) { + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * M_LOG2E; + CUTLASS_PRAGMA_UNROLL + for (int ni = 0; ni < size<1>(tensor); ++ni) { + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + } + } +} + +template +CUTLASS_DEVICE void softmax_backward( + Tensor0& P, + Tensor1& dP_sum, + Tensor2& dP, + const float scale) { + CUTLASS_PRAGMA_UNROLL + for (int mi = 0; mi < size<0>(dP); ++mi) { + CUTLASS_PRAGMA_UNROLL + for (int mj = 0; mj < size<1>(dP); ++mj) { + dP(mi, mj) = P(mi, mj) * (dP(mi, mj) - dP_sum(mi)) * scale; + } + } +} + +template +CUTLASS_DEVICE auto convert_type(Tensor const& tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + auto frag = + convert_op(*reinterpret_cast*>( + tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); +} + +template +void dq_dk_dv_1colblock( + Trait& trait, + Param& param, + const int bidb, + const int bidh, + const int bidhkv, + const int n_block, + const int tail_n = 0) { + using T = typename Trait::DType; + using V = typename Trait::VType; + constexpr int kHeadDim = Trait::kHeadDim; + constexpr int kBlockM = Trait::kBlockM; + constexpr int kBlockN = Trait::kBlockN; + constexpr bool is_causal = Trait::is_causal; + constexpr int kNSGs = Trait::kNSGs; + constexpr int SubgroupSize = Trait::SubgroupSize; + auto sg = compat::get_nd_item<1>().get_sub_group(); + auto group = compat::get_nd_item<1>().get_group(); + auto first_thread_in_sg_idx = sg.get_group_linear_id() * trait.SubgroupSize; + auto bofst = Boffset(param); + + const index_t q_offset = bofst.q_offset(bidb, bidh, 0); + const index_t k_offset = bofst.k_offset(bidb, bidhkv, n_block * kBlockN); + const index_t v_offset = bofst.v_offset(bidb, bidhkv, n_block * kBlockN); + const index_t dk_offset = bofst.dk_offset(bidb, bidh, n_block * kBlockN); + const index_t dv_offset = bofst.dv_offset(bidb, bidh, n_block * kBlockN); + const index_t o_offset = bofst.o_offset(bidb, bidh, 0); + const index_t dq_offset = bofst.dq_offset(bidb, bidh, 0); + const index_t lse_offset = bofst.lse_offset(bidb, bidh, 0); + // buff offset + const index_t pb_offset = + bidb * param.num_head_q * param.seq_len_kv_pad * kBlockM + + bidh * param.seq_len_kv_pad * kBlockM + n_block * kBlockN * kBlockM; + + const index_t s_offset = bofst.ps_offset(bidb, bidh, 0, n_block * kBlockN); + + const auto block_n_dim = tail_n == 0 ? Int{} : tail_n; + using Shape1 = Shape< + std::conditional_t, int>, + Int, + Int<1>>; + using Shape2 = Shape< + Int, + std::conditional_t, int>, + Int<1>>; + Shape shapeQ = make_shape(kBlockM, Int{}, _1{}); + Shape shapedQ = Shape, Int, _1>{}; + Shape1 shapeKtV; + Shape2 shapeK; + if constexpr (Is_even_N) { + shapeKtV = make_shape(Int{}, Int{}, _1{}); + shapeK = make_shape(Int{}, Int{}, _1{}); + } else { + shapeKtV = make_shape(tail_n, Int{}, _1{}); + shapeK = make_shape(Int{}, tail_n, _1{}); + } + Shape shapeO = make_shape(kBlockM, Int{}, _1{}); + Shape shapeQtOt = make_shape(Int{}, kBlockM, _1{}); + + Shape shapeSP = make_shape(kBlockM, block_n_dim, _1{}); + + Shape shapePt = make_shape(block_n_dim, kBlockM, _1{}); + + Tensor mQ = make_tensor( + make_gmem_ptr(param.q_ptr + q_offset), + make_layout(shapeQ, make_stride(param.q_r_stride, _1{}, _1{}))); + Tensor mKt = make_tensor( + make_gmem_ptr(param.k_ptr + k_offset), + make_layout(shapeKtV, make_stride(param.k_r_stride, _1{}, _1{}))); + Tensor mV = make_tensor( + make_gmem_ptr(param.v_ptr + v_offset), + make_layout(shapeKtV, make_stride(param.v_r_stride, _1{}, _1{}))); + Tensor mdO = make_tensor( + make_gmem_ptr(param.do_ptr + o_offset), + make_layout(shapeO, make_stride(param.o_r_stride, _1{}, _1{}))); + // intermediate buffer + Tensor mP = make_tensor( + make_gmem_ptr(param.pb_ptr + pb_offset), + make_layout(shapeSP, make_stride(block_n_dim, _1{}, _1{}))); + Tensor mPt = make_tensor( + make_gmem_ptr(param.pb_ptr + pb_offset), + make_layout(shapePt, make_stride(_1{}, block_n_dim, _1{}))); + Tensor mdOt = make_tensor( + make_gmem_ptr(param.do_ptr + o_offset), + make_layout(shapeQtOt, make_stride(_1{}, param.o_r_stride, _1{}))); + Tensor mK = make_tensor( + make_gmem_ptr(param.k_ptr + k_offset), + make_layout(shapeK, make_stride(_1{}, param.k_r_stride, _1{}))); + Tensor mdPt = make_tensor( + make_gmem_ptr(param.pb_ptr + pb_offset), + make_layout(shapePt, make_stride(_1{}, block_n_dim, _1{}))); + Tensor mQt = make_tensor( + make_gmem_ptr(param.q_ptr + q_offset), + make_layout(shapeQtOt, make_stride(_1{}, param.q_r_stride, _1{}))); + + Tensor mLSE = make_tensor( + make_gmem_ptr(param.lse_ptr + lse_offset), + make_layout(Shape>{}, Stride<_1>{})); + Tensor mdPsum = make_tensor( + make_gmem_ptr(param.odo_ptr + lse_offset), + make_layout(Shape>{}, Stride<_1>{})); + + Tensor mdV = make_tensor( + make_gmem_ptr(param.dv_ptr + dv_offset), + make_layout(shapeKtV, make_stride(param.dv_r_stride, _1{}, _1{}))); + Tensor mdP = make_tensor( + make_gmem_ptr(param.pb_ptr + pb_offset), + make_layout(shapeSP, make_stride(block_n_dim, _1{}, _1{}))); + Tensor mdQaccum = make_tensor( + make_gmem_ptr(param.dqaccum_ptr + dq_offset), + make_layout(shapedQ, make_stride(param.dq_r_stride, _1{}, _1{}))); + Tensor mdK = make_tensor( + make_gmem_ptr(param.dk_ptr + dk_offset), + make_layout(shapeKtV, make_stride(param.dk_r_stride, _1{}, _1{}))); + + Tensor mS = make_tensor( + make_gmem_ptr(param.s_ptr + s_offset), + make_layout(shapeSP, make_stride(param.s_r_stride, _1{}, _1{}))); + Tensor mdPd = make_tensor( + make_gmem_ptr(param.dp_ptr + s_offset), + make_layout(shapeSP, make_stride(param.s_r_stride, _1{}, _1{}))); + + Shape tile_sdp = typename Trait::TileShapeSdP{}; + Shape tile_dkv = typename Trait::TileShapedKV{}; + Shape tile_dq = typename Trait::TileShapedQ{}; + + auto tileloadQ = typename Trait::TiledLoadQ{mQ}; + auto tileloadKt = typename Trait::TiledLoadKt{mKt}; + auto tileloaddO = typename Trait::TiledLoaddO{mdO}; + auto tileloadV = typename Trait::TiledLoadV{mV}; + auto tileloadPt = typename Trait::TiledLoadPt{mPt}; + auto tileloaddOt = + typename Trait::TiledLoaddOt{mdOt}; // load dO as operand B for dV=Pt*dO + auto tileloaddP = typename Trait::TiledLoaddP{mdP}; + auto tileloadK = typename Trait::TiledLoadK{mK}; + auto tileloaddQ = typename Trait::TiledLoaddQ{mdQaccum}; + auto tileloaddPt = typename Trait::TiledLoaddPt{mdPt}; + auto tileloadQt = typename Trait::TiledLoadQt{mQt}; + + auto tilesaveP = typename Trait::TiledSaveS{mP}; // to internal buffer + auto tilesavedV = typename Trait::TiledSavedV{mdV}; + auto tilesavedP = typename Trait::TiledSavedP{mdP}; + auto tilesavedQ = typename Trait::TiledSavedQ{mdQaccum}; + auto tilesavedK = typename Trait::TiledSavedK{mdK}; + + Tensor mQ_coord = cute::get_xe_tensor(shapeQ); + Tensor mdQ_coord = cute::get_xe_tensor(shapedQ); + Tensor mKtV_coord = cute::get_xe_tensor(shapeKtV); + Tensor mdO_coord = cute::get_xe_tensor(shapeO); + Tensor mQtdOt_coord = cute::get_xe_tensor(shapeQtOt); + Tensor mK_coord = cute::get_xe_tensor(shapeK); + + Tensor mSP_coord = cute::get_xe_tensor(shapeSP); + Tensor mPt_coord = cute::get_xe_tensor(shapePt); + + typename Trait::TiledMmaSdP tiled_mma_sdp; + typename Trait::TiledMmadKV tiled_mma_dkv; + typename Trait::TiledMmadQ tiled_mma_dq; + + auto thr_mma_sdp = tiled_mma_sdp.get_slice(first_thread_in_sg_idx); + auto thr_mma_dkv = tiled_mma_dkv.get_slice(first_thread_in_sg_idx); + auto thr_mma_dq = tiled_mma_dq.get_slice(first_thread_in_sg_idx); + + Tensor gQ = local_tile(mQ_coord, select<0, 2>(tile_sdp), make_coord(_, _, 0)); + Tensor gKtV = + local_tile(mKtV_coord, select<1, 2>(tile_sdp), make_coord(_, _, 0)); + Tensor gV = + local_tile(mKtV_coord, select<1, 2>(tile_sdp), make_coord(_, _, 0)); + Tensor gdO = + local_tile(mdO_coord, select<0, 2>(tile_sdp), make_coord(_, _, 0)); + Tensor gPt = local_tile( + mPt_coord, select<0, 2>(tile_dkv), make_coord(_, _, 0)); // load Pt + Tensor gdPa = local_tile( + mSP_coord, select<0, 2>(tile_dq), make_coord(_, _, 0)); // operand A dQ + Tensor gK = local_tile( + mK_coord, select<1, 2>(tile_dq), make_coord(_, _, 0)); // operand B dQ + Tensor gdPt = local_tile( + mPt_coord, select<0, 2>(tile_dkv), make_coord(_, _, 0)); // load dpt + Tensor gQtdOt = local_tile( + mQtdOt_coord, + select<1, 2>(tile_dkv), + make_coord(_, _, 0)); // load Q as operand B + Tensor gQtdOt2 = local_tile( + mQtdOt_coord, + select<1, 2>(tile_dkv), + make_coord(_, _, 0)); // load Q as operand B + + Tensor gSP = local_tile( + mSP_coord, select<0, 1>(tile_sdp), make_coord(_, _, 0)); // dump P + Tensor gdV = local_tile( + mKtV_coord, select<0, 1>(tile_dkv), make_coord(_, _, 0)); // dump dV + Tensor gdQ = local_tile( + mdQ_coord, select<0, 1>(tile_dq), make_coord(_, _, 0)); // dump dQ + Tensor gdK = local_tile( + mKtV_coord, select<0, 1>(tile_dkv), make_coord(_, _, 0)); // dump dK + + Tensor tSgQ = thr_mma_sdp.partition_A(gQ); + Tensor tSgKt = thr_mma_sdp.partition_B(gKtV); + Tensor tdPgdO = thr_mma_sdp.partition_A(gdO); + Tensor tdPgV = thr_mma_sdp.partition_B(gV); + Tensor tdVgPt = thr_mma_dkv.partition_A(gPt); + Tensor tdVgdOt = thr_mma_dkv.partition_B(gQtdOt2); + Tensor tdQgdP = thr_mma_dq.partition_A(gdPa); + Tensor tdQgK = thr_mma_dq.partition_B(gK); + Tensor tdKgdPt = thr_mma_dkv.partition_A(gdPt); + Tensor tdKgQt = thr_mma_dkv.partition_B(gQtdOt); + + Tensor tPgP = thr_mma_sdp.partition_C(gSP); // save P to internal buffer + Tensor tdVgdV = thr_mma_dkv.partition_C(gdV); // save to dv + Tensor tdQgdQ = thr_mma_dq.partition_C(gdQ); // save to dq + Tensor tdKgdK = thr_mma_dkv.partition_C(gdK); // save to dk + + Tensor tSrQ = make_tensor( + make_fragment_layout(tileloadQ, tSgQ(_, _, _, 0, 0).shape())); + Tensor tSrKt = make_tensor( + make_fragment_layout(tileloadKt, tSgKt(_, _, _, 0, 0).shape())); + Tensor tdPrdO = make_tensor( + make_fragment_layout(tileloaddO, tdPgdO(_, _, _, 0, 0).shape())); + Tensor tdPrV = make_tensor( + make_fragment_layout(tileloadV, tdPgV(_, _, _, 0, 0).shape())); + Tensor tdVrPt = make_tensor( + make_fragment_layout(tileloadPt, tdVgPt(_, _, _, 0, 0).shape())); + Tensor tdVrdOt = make_tensor( + make_fragment_layout(tileloaddOt, tdVgdOt(_, _, _, 0, 0).shape())); + Tensor tdQrdP = make_tensor( + make_fragment_layout(tileloaddP, tdQgdP(_, _, _, 0, 0).shape())); + Tensor tdQrK = make_tensor( + make_fragment_layout(tileloadK, tdQgK(_, _, _, 0, 0).shape())); + Tensor tdKrdPt = make_tensor( + make_fragment_layout(tileloaddPt, tdKgdPt(_, _, _, 0, 0).shape())); + Tensor tdKrQt = make_tensor( + make_fragment_layout(tileloadQt, tdKgQt(_, _, _, 0, 0).shape())); + + ThrCopy thr_copy_q = tileloadQ.get_slice(compat::local_id::x()); + ThrCopy thr_copy_kt = tileloadKt.get_slice(compat::local_id::x()); + ThrCopy thr_copy_do = tileloaddO.get_slice(compat::local_id::x()); + ThrCopy thr_copy_v = tileloadV.get_slice(compat::local_id::x()); + ThrCopy thr_copy_pt = tileloadPt.get_slice(compat::local_id::x()); + ThrCopy thr_copy_dot = tileloaddOt.get_slice(compat::local_id::x()); + ThrCopy thr_copy_dp = tileloaddP.get_slice(compat::local_id::x()); + ThrCopy thr_copy_k = tileloadK.get_slice(compat::local_id::x()); + ThrCopy thr_copy_dpt = tileloaddPt.get_slice(compat::local_id::x()); + ThrCopy thr_copy_qt = tileloadQt.get_slice(compat::local_id::x()); + + // Retile registers for copies + Tensor tQrQ = thr_copy_q.retile_D(tSrQ); + Tensor tKtrKt = thr_copy_kt.retile_D(tSrKt); + Tensor tdOrdO = thr_copy_do.retile_D(tdPrdO); + Tensor tVrV = thr_copy_v.retile_D(tdPrV); + Tensor tPtrPt = thr_copy_pt.retile_D(tdVrPt); + Tensor tdOtrdOt = thr_copy_dot.retile_D(tdVrdOt); + Tensor tdPrdPa = thr_copy_dp.retile_D(tdQrdP); + Tensor tKrK = thr_copy_k.retile_D(tdQrK); + Tensor tdPtrdPt = thr_copy_dpt.retile_D(tdKrdPt); + Tensor tQtrQt = thr_copy_qt.retile_D(tdKrQt); + + // Retile global counting tensors for copies + Tensor tQgQ = thr_copy_q.retile_S(tSgQ); + Tensor tKtgKt = thr_copy_kt.retile_S(tSgKt); + Tensor tdOgdO = thr_copy_do.retile_S(tdPgdO); + Tensor tVgV = thr_copy_v.retile_S(tdPgV); + Tensor tPtgPt = thr_copy_pt.retile_S(tdVgPt); + Tensor tdOtgdOt = thr_copy_dot.retile_S(tdVgdOt); + Tensor tdPgdPa = thr_copy_dp.retile_S(tdQgdP); + Tensor tKgK = thr_copy_k.retile_S(tdQgK); + Tensor tdPtgdPt = thr_copy_dpt.retile_S(tdKgdPt); + Tensor tQtgQt = thr_copy_qt.retile_S(tdKgQt); + + Tensor tSrS = partition_fragment_C( + tiled_mma_sdp, + make_shape( + get<0>(tile_sdp), + get<1>(tile_sdp), + ceil_div(Int{}, get<0>(tile_sdp)), + ceil_div(Int{}, get<1>(tile_sdp)))); + Tensor tdPrdP = partition_fragment_C( + tiled_mma_sdp, + make_shape( + get<0>(tile_sdp), + get<1>(tile_sdp), + ceil_div(Int{}, get<0>(tile_sdp)), + ceil_div(Int{}, get<1>(tile_sdp)))); + Tensor tdVrdV = partition_fragment_C( + tiled_mma_dkv, + make_shape( + get<0>(tile_dkv), + get<1>(tile_dkv), + ceil_div(Int{}, get<0>(tile_dkv)), + ceil_div(Int{}, get<1>(tile_dkv)))); + Tensor tdQrdQ = partition_fragment_C( + tiled_mma_dq, + make_shape( + get<0>(tile_dq), + get<1>(tile_dq), + ceil_div(Int{}, get<0>(tile_dq)), + ceil_div(Int{}, get<1>(tile_dq)))); + Tensor tdKrdK = partition_fragment_C( + tiled_mma_dkv, + make_shape( + get<0>(tile_dkv), + get<1>(tile_dkv), + ceil_div(Int{}, get<0>(tile_dkv)), + ceil_div(Int{}, get<1>(tile_dkv)))); + // for lse read + Tensor caccS = make_identity_tensor( + Shape, Int>{}); // same buffer as accS + Tensor taccScS = thr_mma_sdp.partition_C(caccS); + static_assert(decltype(size<0>(taccScS))::value == 8); + Tensor taccScS_rc = logical_divide(taccScS, Shape<_1>{}); + Tensor taccScS_row = + logical_divide(taccScS, Shape<_1>{})(make_coord(0, _), _, 0); + Tensor lse = make_tensor(Shape>{}); + // static_assert(size<0>(tSrS) * size<1>(tSrS) == size<0>(lse) && "row of acc + // and lse not match"); misc + + const int max_m_block = ceil_div(param.seq_len_q, kBlockM); + const int tail_m = param.seq_len_q % kBlockM; + + cutlass::NumericConverter converter; + // clear accumulator + clear(tdVrdV); + clear(tdKrdK); + for (int m_block = 0; m_block < max_m_block; ++m_block) { + const bool Is_even_M = not((m_block == max_m_block - 1) and (tail_m != 0)); + if (not Is_even_M) { + mQ = make_tensor( + make_gmem_ptr(mQ.data()), + make_layout( + make_shape(tail_m, Int{}, _1{}), + make_stride(param.q_r_stride, _1{}, _1{}))); + mdO = make_tensor( + make_gmem_ptr(mdO.data()), + make_layout( + make_shape(tail_m, Int{}, _1{}), + make_stride(param.o_r_stride, _1{}, _1{}))); + mdOt = make_tensor( + make_gmem_ptr(mdOt.data()), + make_layout( + make_shape(Int{}, tail_m, _1{}), + make_stride(_1{}, param.o_r_stride, _1{}))); + mdQaccum = make_tensor( + make_gmem_ptr(mdQaccum.data()), + make_layout(shapedQ, make_stride(param.dq_r_stride, _1{}, _1{}))); + mQt = make_tensor( + make_gmem_ptr(mQt.data()), + make_layout( + make_shape(Int{}, tail_m, _1{}), + make_stride(_1{}, param.q_r_stride, _1{}))); + + tileloadQ = typename Trait::TiledLoadQ{mQ}; + tileloaddO = typename Trait::TiledLoaddO{mdO}; + tileloaddOt = typename Trait::TiledLoaddOt{mdOt}; + tileloaddQ = typename Trait::TiledLoaddQ{mdQaccum}; + tileloadQt = typename Trait::TiledLoadQt{mQt}; + tilesavedQ = typename Trait::TiledSavedQ{mdQaccum}; + } + clear(tSrS); + // S=QKt + gemm_ker( + tSrS, + tSrQ, + tSrKt, + tQgQ, + tQrQ, + gQ, + tKtgKt, + tKtrKt, + gKtV, + tiled_mma_sdp, + tile_sdp, + tileloadQ, + tileloadKt); + Tensor scores = + make_tensor(tSrS.data(), convert_layout_acc_layout(tSrS.layout())); + + if constexpr (is_causal) + apply_mask_causal( + scores, taccScS_rc, m_block * kBlockM, n_block * kBlockN); + + if (Is_even_M) { + load_1colvec(lse, mLSE, taccScS_row); + } else { + load_1colvec(lse, mLSE, taccScS_row, tail_m); + } + + Tensor dP_sum = make_fragment_like(lse); + if (Is_even_M) + load_1colvec(dP_sum, mdPsum, taccScS_row); + else + load_1colvec(dP_sum, mdPsum, taccScS_row, tail_m); + + // P=softmax(S,lse) + scale_apply_exp2(scores, lse, param.scale_softmax_log2); + auto tSrSl = convert_type(tSrS); + mha_save(tilesaveP, tSrSl, tPgP); // save P to internal buffers + clear(tdPrdP); + // dP=dO*Vt + gemm_ker( + tdPrdP, + tdPrdO, + tdPrV, + tdOgdO, + tdOrdO, + gdO, + tVgV, + tVrV, + gKtV, + tiled_mma_sdp, + tile_sdp, + tileloaddO, + tileloadV); + Tensor dS = make_tensor(tdPrdP.data(), scores.layout()); + // dS=P(dP-sum_row(P))*scale + softmax_backward(scores, dP_sum, dS, param.scale_softmax); + auto tdPrdPl = convert_type(tdPrdP); + if (n_block > 0) // TODO: need actual prefetch here. yk + copy(tileloaddQ, tdQgdQ, tdQrdQ); + + // dV=Pt*dO + gemm_ker( + tdVrdV, + tdVrPt, + tdVrdOt, + tPtgPt, + tPtrPt, + gPt, + tdOtgdOt, + tdOtrdOt, + gQtdOt, + tiled_mma_dkv, + tile_dkv, + tileloadPt, + tileloaddOt); + sycl::group_barrier(group); + + mha_save( + tilesavedP, tdPrdPl, tPgP); // save dP to buffer after P used by dV + sycl::group_barrier(group); + + clear(tdQrdQ); + if (n_block > 0) { + if (Is_even_M) + mha_load(tileloaddQ, tdQgdQ, tdQrdQ); + else + mha_load(tileloaddQ, tdQgdQ, tdQrdQ); + } + // dQ=dP*K + gemm_ker( + tdQrdQ, + tdQrdP, + tdQrK, + tdPgdPa, + tdPrdPa, + gdPa, + tKgK, + tKrK, + gK, + tiled_mma_dq, + tile_dq, + tileloaddP, + tileloadK); + if (Is_even_M) + mha_save(tilesavedQ, tdQrdQ, tdQgdQ); + else + mha_save(tilesavedQ, tdQrdQ, tdQgdQ); + // dK=dPt*Q + gemm_ker( + tdKrdK, + tdKrdPt, + tdKrQt, + tdPtgdPt, + tdPtrdPt, + gdPt, + tQtgQt, + tQtrQt, + gQtdOt, + tiled_mma_dkv, + tile_dkv, + tileloaddPt, + tileloadQt); + // update ptr/atom copy + mQ.data() = mQ.data() + int(kBlockM * param.q_r_stride); + mdO.data() = mdO.data() + int(kBlockM * param.o_r_stride); + mdOt.data() = mdOt.data() + int(kBlockM * param.o_r_stride); + mdQaccum.data() = mdQaccum.data() + int(kBlockM * param.dq_r_stride); + mQt.data() = mQt.data() + int(kBlockM * param.q_r_stride); + mLSE.data() = mLSE.data() + int(kBlockM); + mdPsum.data() = mdPsum.data() + int(kBlockM); + + tileloadQ = typename Trait::TiledLoadQ{mQ}; + tileloaddO = typename Trait::TiledLoaddO{mdO}; + tileloaddOt = typename Trait::TiledLoaddOt{mdOt}; + tileloaddQ = typename Trait::TiledLoaddQ{mdQaccum}; + tileloadQt = typename Trait::TiledLoadQt{mQt}; + tilesavedQ = typename Trait::TiledSavedQ{mdQaccum}; + } + auto tdVrdVl = convert_type(tdVrdV); + mha_save(tilesavedV, tdVrdVl, tdVgdV); + auto tdKrdKl = convert_type(tdKrdK); + mha_save(tilesavedK, tdKrdKl, tdKgdK); +} + +template +void mha_backward(T trait, Param param) { + const int bidb = BlockIdxZ(); + const int bidhq = BlockIdxY(); + const int bidhkv = bidhq / param.num_qh_per_kvh; + for (int n_block = 0; n_block < param.n_block; ++n_block) + dq_dk_dv_1colblock(trait, param, bidb, bidhq, bidhkv, n_block); + if (param.tail_n > 0) + dq_dk_dv_1colblock( + trait, param, bidb, bidhq, bidhkv, param.n_block, param.tail_n); +} + +template +void convert_dq( + T& trait, + Param& param, + int m_block, + int bidb, + int bidh) { + constexpr int kBlockM = T::kBlockM; + constexpr int kBlockN = T::kBlockN; + constexpr int kHeadDim = T::kHeadDim; + using DType = typename T::DType; + using VType = typename T::VType; + auto sg = compat::get_nd_item<1>().get_sub_group(); + auto first_thread_in_sg_idx = sg.get_group_linear_id() * trait.SubgroupSize; + + auto bofst = Boffset(param); + const index_t dq_offset = bofst.dq_offset(bidb, bidh, m_block * kBlockM); + const index_t q_offset = bofst.q_offset(bidb, bidh, m_block * kBlockM); + using ShapeQ = Shape< + std::conditional_t, int>, + Int, + _1>; + ShapeQ shapeQ; + if constexpr (Is_even_M) { + shapeQ = make_shape(Int{}, Int{}, _1{}); + } else { + shapeQ = make_shape(param.tail_m, Int{}, _1{}); + } + + Tensor mdQaccum = make_tensor( + make_gmem_ptr(param.dqaccum_ptr + dq_offset), + make_layout( + Shape, Int>{}, + make_stride(param.dq_r_stride, _1{}, _1{}))); + Tensor mdQ = make_tensor( + make_gmem_ptr(param.dq_ptr + q_offset), + make_layout(shapeQ, make_stride(param.q_r_stride, _1{}, _1{}))); + + Shape tile_dq = typename T::TileShapedQ{}; + + auto tileloaddQ = typename T::TiledLoaddQ{mdQaccum}; + auto tilesavedQ = typename T::TiledSavedV{mdQ}; + + typename T::TiledMmadQ tiled_mma_dq; + auto thr_mma_dq = tiled_mma_dq.get_slice(first_thread_in_sg_idx); + + Tensor mQ_coord = cute::get_xe_tensor(shapeQ); + Tensor gdQ = local_tile( + mQ_coord, select<0, 1>(tile_dq), make_coord(_, _, 0)); // dump dQ + + Tensor tdQgdQ = thr_mma_dq.partition_C(gdQ); // save to dq + Tensor tdQrdQaccum = partition_fragment_C( + tiled_mma_dq, + make_shape( + get<0>(tile_dq), + get<1>(tile_dq), + ceil_div(Int{}, get<0>(tile_dq)), + ceil_div(Int{}, get<1>(tile_dq)))); + + Tensor tdQrdQ = make_fragment_like(tdQrdQaccum); + if constexpr (Is_even_M) { + mha_load(tileloaddQ, tdQgdQ, tdQrdQaccum); + } else { + mha_load(tileloaddQ, tdQgdQ, tdQrdQaccum); + } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tdQrdQ); ++i) { + tdQrdQ(i) = static_cast(tdQrdQaccum(i)); + } + if constexpr (Is_even_M) { + mha_save(tilesavedQ, tdQrdQ, tdQgdQ); + } else { + mha_save(tilesavedQ, tdQrdQ, tdQgdQ); + } +} + +template +void mhd_convert_dq(T trait, Param param) { + // The block index for the M dimension. + const int m_block = BlockIdxX(); + // The block index for the batch. + const int bidb = BlockIdxZ(); + // The block index for the head. + const int bidh = BlockIdxY(); + if (param.tail_m > 0 and m_block == param.m_block - 1) { + convert_dq(trait, param, m_block, bidb, bidh); + } else { + convert_dq(trait, param, m_block, bidb, bidh); + } +} + +template +class MhaDotDoOName; + +template +class MhaBackwardName; + +template +class MhdConvertDqName; + +template < + typename T, + typename ProblemShape, + int kBlockM, + int kBlockN, + int kHeadDim, + int kNSGs, + int AtomLayoutMSdP, + int AtomLayoutNdKV, + int AtomLayoutMdQ, + bool is_causal, + bool is_bhsd> +void run_mha_bwd_specialized( + sycl::queue& queue, + ProblemShape& problem_shape, + const T* do_d, + const T* o_d, + const T* q_d, + const T* k_d, + const T* v_d, + const float* lse_d, + float* odo_d, + float* dqaccum_d, + T* dq_d, + T* dk_d, + T* dv_d, + T* s_d, + T* dp_d, + T* pbuff, + int seq_len_q_pad, + int seq_len_kv_pad, + float scale) { + auto trait = FAKernel< + T, + kHeadDim, + kBlockM, + kBlockN, + kNSGs, + AtomLayoutMSdP, + AtomLayoutNdKV, + AtomLayoutMdQ, + is_causal>{}; + + const int BATCH = get<0>(problem_shape); + const int NUM_HEAD_Q = get<1>(problem_shape); + const int NUM_HEAD_KV = get<2>(problem_shape); + const int SEQ_LEN_Q = get<3>(problem_shape); + const int SEQ_LEN_KV = get<4>(problem_shape); + const int N_BLOCK = SEQ_LEN_KV / kBlockN; + const int tail_n = SEQ_LEN_KV % kBlockN; + const int M_BLOCK = ceil_div(SEQ_LEN_Q, kBlockM); + const int tail_m = SEQ_LEN_Q % kBlockM; + auto param = Param( + do_d, + o_d, + q_d, + k_d, + v_d, + lse_d, + odo_d, + dqaccum_d, + dq_d, + dk_d, + dv_d, + s_d, + dp_d, + pbuff, + scale); + param.batch = BATCH; + param.num_head_q = NUM_HEAD_Q; + param.num_head_kv = NUM_HEAD_KV; + param.num_qh_per_kvh = NUM_HEAD_Q / NUM_HEAD_KV; + param.seq_len_q = SEQ_LEN_Q; + param.seq_len_kv = SEQ_LEN_KV; + param.head_dim = kHeadDim; + param.n_block = N_BLOCK; + param.tail_n = tail_n; + param.m_block = M_BLOCK; + param.tail_m = tail_m; + param.seq_len_kv_pad = seq_len_kv_pad; + param.seq_len_q_pad = seq_len_q_pad; + if constexpr (is_bhsd) { + setup_bhsd_stride(param); + } else { + setup_bshd_stride(param); + } + auto dimGrid0 = + compat::dim3(size(M_BLOCK), size(param.num_head_q), size(param.batch)); + auto dimBlock0 = + compat::dim3(size(kNSGs * trait.SubgroupSize), size(1), size(1)); + compat::experimental::launch_properties launch_props0{}; + compat::experimental::kernel_properties kernel_props0{ + sycl::ext::oneapi::experimental::sub_group_size}; + compat::experimental::launch_policy policy0{ + dimGrid0, dimBlock0, launch_props0, kernel_props0}; + compat::experimental:: + launch, MhaDotDoOName>( + policy0, queue, trait, param); + + auto dimGrid1 = + compat::dim3(size(1), size(param.num_head_q), size(param.batch)); + auto dimBlock1 = + compat::dim3(size(kNSGs * trait.SubgroupSize), size(1), size(1)); + compat::experimental::launch_properties launch_props1{}; + compat::experimental::kernel_properties kernel_props1{ + sycl::ext::oneapi::experimental::sub_group_size}; + compat::experimental::launch_policy policy1{ + dimGrid1, dimBlock1, launch_props1, kernel_props1}; + compat::experimental:: + launch, MhaBackwardName>( + policy1, queue, trait, param); + + auto dimGrid2 = + compat::dim3(size(M_BLOCK), size(param.num_head_q), size(param.batch)); + auto dimBlock2 = + compat::dim3(size(kNSGs * trait.SubgroupSize), size(1), size(1)); + compat::experimental::launch_properties launch_props2{}; + compat::experimental::kernel_properties kernel_props2{ + sycl::ext::oneapi::experimental::sub_group_size}; + compat::experimental::launch_policy policy2{ + dimGrid2, dimBlock2, launch_props2, kernel_props2}; + auto event2 = compat::experimental::launch< + mhd_convert_dq, + MhdConvertDqName>(policy2, queue, trait, param); +} + +template < + typename T, + typename ProblemShape, + int kMPad, + int kNPad, + bool is_causal, + bool is_bhsd> +void run_mha_bwd_( + sycl::queue& queue, + ProblemShape& problem_shape, + const T* do_d, + const T* o_d, + const T* q_d, + const T* k_d, + const T* v_d, + const float* lse_d, + float* odo_d, + float* dqaccum_d, + T* dq_d, + T* dk_d, + T* dv_d, + T* s_d, + T* dp_d, + T* pbuff, + int seq_len_q_pad, + int seq_len_kv_pad, + float scale) { + const int headdim = get<5>(problem_shape); +#define RUN_MHA_BWD_SPECIALIZED() \ + run_mha_bwd_specialized< \ + T, \ + ProblemShape, \ + kBlockM, \ + kBlockN, \ + kHeadDim, \ + kNSGs, \ + AtomLayoutMSdP, \ + AtomLayoutNdKV, \ + AtomLayoutMdQ, \ + is_causal, \ + is_bhsd>( \ + queue, \ + problem_shape, \ + do_d, \ + o_d, \ + q_d, \ + k_d, \ + v_d, \ + lse_d, \ + odo_d, \ + dqaccum_d, \ + dq_d, \ + dk_d, \ + dv_d, \ + s_d, \ + dp_d, \ + pbuff, \ + seq_len_q_pad, \ + seq_len_kv_pad, \ + scale); + + if (headdim == 64) { + constexpr int kBlockM = 64; + constexpr int kBlockN = 32; + constexpr int kHeadDim = 64; + constexpr int kNSGs = 8; + constexpr int AtomLayoutMSdP = 4; + constexpr int AtomLayoutNdKV = 2; + constexpr int AtomLayoutMdQ = 2; + static_assert( + kBlockM <= kMPad, "kBlockM must be less than or equal to kMPad"); + static_assert( + kBlockN <= kNPad, "kBlockN must be less than or equal to kNPad"); + RUN_MHA_BWD_SPECIALIZED(); + } else if (headdim == 96) { + constexpr int kBlockM = 64; + constexpr int kBlockN = 64; + constexpr int kHeadDim = 96; + constexpr int kNSGs = 8; + constexpr int AtomLayoutMSdP = 2; + constexpr int AtomLayoutNdKV = 4; + constexpr int AtomLayoutMdQ = 4; + static_assert( + kBlockM <= kMPad, "kBlockM must be less than or equal to kMPad"); + static_assert( + kBlockN <= kNPad, "kBlockN must be less than or equal to kNPad"); + RUN_MHA_BWD_SPECIALIZED(); + } else if (headdim == 128) { + constexpr int kBlockM = 64; + constexpr int kBlockN = 32; + constexpr int kHeadDim = 128; + constexpr int kNSGs = 8; + constexpr int AtomLayoutMSdP = 4; + constexpr int AtomLayoutNdKV = 2; + constexpr int AtomLayoutMdQ = 2; + static_assert( + kBlockM <= kMPad, "kBlockM must be less than or equal to kMPad"); + static_assert( + kBlockN <= kNPad, "kBlockN must be less than or equal to kNPad"); + RUN_MHA_BWD_SPECIALIZED(); + } else if (headdim == 192) { + constexpr int kBlockM = 64; + constexpr int kBlockN = 32; + constexpr int kHeadDim = 192; + constexpr int kNSGs = 8; + constexpr int AtomLayoutMSdP = 4; + constexpr int AtomLayoutNdKV = 2; + constexpr int AtomLayoutMdQ = 2; + static_assert( + kBlockM <= kMPad, "kBlockM must be less than or equal to kMPad"); + static_assert( + kBlockN <= kNPad, "kBlockN must be less than or equal to kNPad"); + RUN_MHA_BWD_SPECIALIZED(); + } else if (headdim == 256) { + constexpr int kBlockM = 64; + constexpr int kBlockN = 32; + constexpr int kHeadDim = 256; + constexpr int kNSGs = 8; + constexpr int AtomLayoutMSdP = 4; + constexpr int AtomLayoutNdKV = 2; + constexpr int AtomLayoutMdQ = 2; + static_assert( + kBlockM <= kMPad, "kBlockM must be less than or equal to kMPad"); + static_assert( + kBlockN <= kNPad, "kBlockN must be less than or equal to kNPad"); + RUN_MHA_BWD_SPECIALIZED(); + } else { + TORCH_CHECK( + false, + "FlashAttentionBackwardXPU: unsupported head dimension: ", + headdim); + } +#undef RUN_MHA_BWD_SPECIALIZED +} + +template +void run_mha_bwd( + sycl::queue& queue, + ProblemShape& problem_shape, + const void* grad_out, + const void* out, + const void* query, + const void* key, + const void* value, + const void* logsumexp, + void* odo, + void* dqaccum, + void* grad_query, + void* grad_key, + void* grad_value, + void* s, + void* dp, + void* pbuff, + int seqlen_qo_pad, + int seqlen_kv_pad, + bool is_causal, + float scale, + at::ScalarType dtype, + sycltla::ATTN_TENSOR_LAYOUT layout) { + const bool is_bhsd = (layout == sycltla::ATTN_TENSOR_LAYOUT::BHSD); + FP16_SWITCH(dtype == at::kHalf, [&] { + BOOL_SWITCH(is_bhsd, IS_BSHD, [&] { + BOOL_SWITCH(is_causal, IS_CAUSAL, [&] { + run_mha_bwd_( + queue, + problem_shape, + static_cast(grad_out), + static_cast(out), + static_cast(query), + static_cast(key), + static_cast(value), + static_cast(logsumexp), + static_cast(odo), + static_cast(dqaccum), + static_cast(grad_query), + static_cast(grad_key), + static_cast(grad_value), + static_cast(s), + static_cast(dp), + static_cast(pbuff), + seqlen_qo_pad, + seqlen_kv_pad, + scale); + }); + }); + }); +} + +} // namespace cute + +namespace sycltla { + +std::tuple flash_attention_backward_sycltla( + const at::Tensor& grad_out, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& out, + const at::Tensor& logsumexp, + const at::Tensor& cumulative_sequence_length_q, + const at::Tensor& cumulative_sequence_length_k, + const int64_t max_seqlen_batch_q, + const int64_t max_seqlen_batch_k, + const double dropout, + const bool is_causal, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset, + const float scale) { + TORCH_CHECK( + dropout == 0.0, + "FlashAttentionBackwardXPU does not only support dropout > 0.0 yet"); + + CHECK_DEVICE(query); + CHECK_DEVICE(key); + CHECK_DEVICE(value); + CHECK_DEVICE(out); + CHECK_DEVICE(grad_out); + CHECK_DEVICE(logsumexp); + + TORCH_CHECK( + !query.is_nested() && !key.is_nested() && !value.is_nested() && + !out.is_nested() && !grad_out.is_nested() && !logsumexp.is_nested(), + "FlashAttentionBackwardXPU only support dense inputs"); + + auto dtype = query.scalar_type(); + TORCH_CHECK( + dtype == at::kHalf || dtype == at::kBFloat16, + "FlashAttentionBackwardXPU only support fp16 and bf16 data type"); + TORCH_CHECK( + logsumexp.scalar_type() == at::kFloat, + "FlashAttentionBackwardXPU: logsumexp must have the dtype float32"); + TORCH_CHECK( + key.scalar_type() == dtype, + "FlashAttentionBackwardXPU: query and key must have the same dtype"); + TORCH_CHECK( + value.scalar_type() == dtype, + "FlashAttentionBackwardXPU: query and value must have the same dtype"); + TORCH_CHECK( + out.scalar_type() == dtype, + "FlashAttentionBackwardXPU: query and out must have the same dtype"); + + TORCH_CHECK( + query.dim() == 4 && key.dim() == 4 && value.dim() == 4 && + out.dim() == 4 && grad_out.dim() == 4 && logsumexp.dim() == 3, + "FlashAttentionBackwardXPU requires query, key, value, out, grad_out to be 4 dimensional and logsumexp to be 3 dimensional"); + + const int batch_size = query.sizes()[0]; + const int numhead_qo = query.sizes()[1]; + const int numhead_kv = key.sizes()[1]; + const int seqlen_qo = query.sizes()[2]; + const int seqlen_kv = key.sizes()[2]; + const int headsize_qk = query.sizes()[3]; + const int headsize_vo = value.sizes()[3]; + CHECK_SHAPE(query, batch_size, numhead_qo, seqlen_qo, headsize_qk); + CHECK_SHAPE(key, batch_size, numhead_kv, seqlen_kv, headsize_qk); + CHECK_SHAPE(value, batch_size, numhead_kv, seqlen_kv, headsize_vo); + CHECK_SHAPE(out, batch_size, numhead_qo, seqlen_qo, headsize_vo); + CHECK_SHAPE(grad_out, batch_size, numhead_qo, seqlen_qo, headsize_vo); + CHECK_SHAPE(logsumexp, batch_size, numhead_qo, seqlen_qo); + TORCH_CHECK( + numhead_qo % numhead_kv == 0, + "FlashAttentionBackwardXPU: number of heads in key/value must divide number of heads in query"); + TORCH_CHECK( + headsize_qk == headsize_vo, + "FlashAttentionBackwardXPU: headsize_qk must be equal to headsize_vo"); + + TORCH_CHECK( + query.stride(-1) == 1, + "FlashAttentionBackwardXPU: input tensor must have contiguous last dimension"); + TORCH_CHECK( + key.stride(-1) == 1, + "FlashAttentionBackwardXPU: input tensor must have contiguous last dimension"); + TORCH_CHECK( + value.stride(-1) == 1, + "FlashAttentionBackwardXPU: input tensor must have contiguous last dimension"); + TORCH_CHECK( + out.stride(-1) == 1, + "FlashAttentionBackwardXPU: out tensor must have contiguous last dimension"); + TORCH_CHECK( + grad_out.stride(-1) == 1, + "FlashAttentionBackwardXPU: dout tensor must have contiguous last dimension"); + TORCH_CHECK( + logsumexp.stride(-1) == 1, + "FlashAttentionBackwardXPU: logsumexp tensor must have contiguous last dimension"); + + ATTN_TENSOR_LAYOUT layout = get_attn_tensor_layout(query); + if (layout == ATTN_TENSOR_LAYOUT::UNSUPPORTED) { + TORCH_CHECK( + false, + "FlashAttentionBackwardXPU: only support BHSD or BSHD layout, got query with shape ", + query.sizes(), + ", stride ", + query.strides()); + } + layout = fuse_attn_tensor_layout(layout, get_attn_tensor_layout(key)); + TORCH_CHECK( + ATTN_TENSOR_LAYOUT::UNSUPPORTED != layout, + "FlashAttentionBackwardXPU: query and key must have the same layout, got query with layout ", + to_string(layout), + ", key with layout ", + to_string(get_attn_tensor_layout(key))); + layout = fuse_attn_tensor_layout(layout, get_attn_tensor_layout(value)); + TORCH_CHECK( + ATTN_TENSOR_LAYOUT::UNSUPPORTED != layout, + "FlashAttentionBackwardXPU: query and value must have the same layout, got query with layout ", + to_string(layout), + ", value with layout ", + to_string(get_attn_tensor_layout(value))); + layout = fuse_attn_tensor_layout(layout, get_attn_tensor_layout(out)); + TORCH_CHECK( + ATTN_TENSOR_LAYOUT::UNSUPPORTED != layout, + "FlashAttentionBackwardXPU: query and out must have the same layout, got query with layout ", + to_string(layout), + ", out with layout ", + to_string(get_attn_tensor_layout(out))); + if (layout == ATTN_TENSOR_LAYOUT::BXD) { + layout = ATTN_TENSOR_LAYOUT::BHSD; + } + TORCH_CHECK(logsumexp.is_contiguous(), "logsumexp must have BHS layout"); + // grad_out is created by autograd, may not have standard layout + auto contiguous_grad_out = attn_tensor_to_layout(grad_out, layout); + + auto sycl_queue = at::xpu::getCurrentXPUStream().queue(); + auto device_architecture = + sycl_queue.get_device() + .get_info< + sycl::ext::oneapi::experimental::info::device::architecture>(); + constexpr auto supported_architectures = + std::array{ + sycl::ext::oneapi::experimental::architecture::intel_gpu_pvc, + sycl::ext::oneapi::experimental::architecture::intel_gpu_pvc_vg, + sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g21, + sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g31}; + if (std::find( + supported_architectures.begin(), + supported_architectures.end(), + device_architecture) == supported_architectures.end()) { + TORCH_CHECK( + false, + "XPU device architecture does not support flash attention backward. Supported architectures are: intel_gpu_pvc, intel_gpu_pvc_vg, intel_gpu_bmg_g21, intel_gpu_bmg_g31."); + } + + auto grad_query = at::empty_like(query); + auto grad_key = at::empty_like(key); + auto grad_value = at::empty_like(value); + + auto opts = query.options(); + + at::Tensor grad_key_expanded, grad_value_expanded; + if (numhead_kv != numhead_qo) { // MQA / GQA + if (layout == ATTN_TENSOR_LAYOUT::BHSD) { // BHSD + grad_key_expanded = + at::empty({batch_size, numhead_qo, seqlen_kv, headsize_qk}, opts); + grad_value_expanded = + at::empty({batch_size, numhead_qo, seqlen_kv, headsize_vo}, opts); + } else { // BSHD + grad_key_expanded = + at::empty({batch_size, seqlen_kv, numhead_qo, headsize_qk}, opts) + .transpose(1, 2); + grad_value_expanded = + at::empty({batch_size, seqlen_kv, numhead_qo, headsize_vo}, opts) + .transpose(1, 2); + } + } else { + grad_key_expanded = grad_key; + grad_value_expanded = grad_value; + } + + constexpr int kMPad = 128; + constexpr int kNPad = 128; + int seqlen_qo_pad = (seqlen_qo + kMPad - 1) / kMPad * kMPad; + int seqlen_kv_pad = (seqlen_kv + kNPad - 1) / kNPad * kNPad; + auto tensor_s = + at::empty({batch_size, numhead_qo, seqlen_qo_pad, seqlen_kv_pad}, opts); + auto tensor_odo = at::empty_like(out, opts.dtype(at::kFloat)); + auto tensor_dqaccum = at::empty( + {batch_size, numhead_qo, seqlen_qo_pad, headsize_qk}, + opts.dtype(at::kFloat)); + auto tensor_dp = + at::empty({batch_size, numhead_qo, seqlen_qo_pad, seqlen_kv_pad}, opts); + auto tensor_pbuff = + at::empty({batch_size, numhead_qo, seqlen_kv_pad, kMPad}, opts); + + auto problem_shape = ProblemShapeRegular( + batch_size, + numhead_qo, + numhead_kv, + seqlen_qo, + seqlen_kv, + headsize_qk, + headsize_vo); + + cute::run_mha_bwd( + sycl_queue, + problem_shape, + contiguous_grad_out.data_ptr(), + out.data_ptr(), + query.data_ptr(), + key.data_ptr(), + value.data_ptr(), + logsumexp.data_ptr(), + tensor_odo.data_ptr(), + tensor_dqaccum.data_ptr(), + grad_query.data_ptr(), + grad_key_expanded.data_ptr(), + grad_value_expanded.data_ptr(), + tensor_s.data_ptr(), + tensor_dp.data_ptr(), + tensor_pbuff.data_ptr(), + seqlen_qo_pad, + seqlen_kv_pad, + is_causal, + scale, + dtype, + layout); + + if (numhead_kv != numhead_qo) { + at::sum_out( + grad_key, + at::reshape( + grad_key_expanded, + {batch_size, + numhead_kv, + numhead_qo / numhead_kv, + seqlen_kv, + headsize_qk}), + {2}); + at::sum_out( + grad_value, + at::reshape( + grad_value_expanded, + {batch_size, + numhead_kv, + numhead_qo / numhead_kv, + seqlen_kv, + headsize_vo}), + {2}); + } + + return std::make_tuple( + std::move(grad_query), std::move(grad_key), std::move(grad_value)); +} +} // namespace sycltla diff --git a/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.h b/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.h new file mode 100644 index 0000000000..c8dac7074d --- /dev/null +++ b/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.h @@ -0,0 +1,457 @@ +#pragma once +#include +#include +#include +#include + +namespace cute { + +template < + class T_, + int kHeadDim_, + int kBlockM_, + int kBlockN_, + int kNSGs_, + int AtomLayoutMSdP_ = 2, + int AtomLayoutNdKV_ = 2, + int AtomLayoutMdQ_ = 2, + bool is_causal_ = false> +struct FAKernel { + /* + Q BATCH,NUM_HEAD_Q,SEQ_LEN_QO,HEAD_SIZE_QK + K BATCH,NUM_HEAD_KV,SEQ_LEN_KV,HEAD_SIZE_QK + V BATCH,NUM_HEAD_KV,SEQ_LEN_KV,HEAD_SIZE_VO + P BATCH,NUM_HEAD_Q,SEQ_LEN_QO,SEQ_LEN_KV + O BATCH,NUM_HEAD_Q,SEQ_LEN_QO,HEAD_SIZE_VO + */ + using DType = T_; + using VType = float; // accumulation + using MMA_Atom_ARCH = std::conditional_t< + std::is_same_v, + MMA_Atom, + MMA_Atom>; + static constexpr int kHeadDim = kHeadDim_; + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kNSGs = kNSGs_; + static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; + static constexpr int AtomLayoutNdKV = AtomLayoutNdKV_; + static constexpr int AtomLayoutMdQ = AtomLayoutMdQ_; + static constexpr bool is_causal = is_causal_; + using SubgroupLayoutSdP = + Layout, Int, _1>>; + using SubgroupLayoutdKV = + Layout, Int, _1>>; + using SubgroupLayoutdQ = + Layout, Int, _1>>; + + using TileShapeSdP = + Tile, Int<16 * kNSGs / AtomLayoutMSdP>, _16>; + static_assert( + size<0>(TileShapeSdP{}) <= kBlockM && + "tile size M must be smaller than or equal to kBlockM"); + static_assert( + kBlockM % size<0>(TileShapeSdP{}) == 0 && + "kBlockM must be dividable by tile size M"); + static_assert( + size<1>(TileShapeSdP{}) <= kBlockN && + "tile size N must be smaller than or equal to kBlockN"); + static_assert( + kBlockN % size<1>(TileShapeSdP{}) == 0 && + "kBlockN must be dividable by tile size N "); + + using TileShapedKV = + Tile, Int<16 * kNSGs / AtomLayoutNdKV>, _16>; + static_assert( + size<0>(TileShapedKV{}) <= kBlockN && + "tile size M must be smaller than or equal to kBlockN"); + static_assert( + kBlockN % size<0>(TileShapedKV{}) == 0 && + "kBlockN must be dividable by tile size M"); + static_assert( + size<1>(TileShapedKV{}) <= kHeadDim && + "tile size N must be smaller than or equal to kHeadDim"); + static_assert( + kHeadDim % size<1>(TileShapedKV{}) == 0 && + "kHeadDim must be dividable by tile size N"); + + using TileShapedQ = + Tile, Int<16 * kNSGs / AtomLayoutMdQ>, _16>; + static_assert( + size<0>(TileShapedQ{}) <= kBlockM && + "tile size M must be smaller than or equal to kBlockM"); + static_assert( + kBlockM % size<0>(TileShapedQ{}) == 0 && + "kBlockM must dividable by tile size M"); + static_assert( + size<1>(TileShapedQ{}) <= kHeadDim && + "tile size N must be smaller than or equal to kHeadDim"); + static_assert( + kHeadDim % size<1>(TileShapedQ{}) == 0 && + "kHeadDim must be dividable by tile size N"); + + using TiledMmaSdP = typename TiledMMAHelper< + MMA_Atom_ARCH, + Layout, + SubgroupLayoutSdP>::TiledMMA; + + using TiledMmadKV = typename TiledMMAHelper< + MMA_Atom_ARCH, + Layout, + SubgroupLayoutdKV>::TiledMMA; + + using TiledMmadQ = typename TiledMMAHelper< + MMA_Atom_ARCH, + Layout, + SubgroupLayoutdQ>::TiledMMA; + static constexpr auto bP = Int<2>{}; // Pipeline + + using StrideR = cute::tuple>; + using StrideC = cute::tuple, long>; + + // for load Q and Kt in S=QKt + using TiledLoadQ = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 k-major + Layout>{})); // Val layout 16x1 + using TiledLoadKt = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // Val layout 16x1 + + // for load dO and Vt in dP=dO*Vt + using TiledLoaddO = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 k-major + Layout>{})); // Val layout 16x1 + + using TiledLoadV = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // Val layout 16x1 + + // for load Pt and dO in dV=Pt*dO + using TiledLoadPt = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 m-major + Layout>{})); // // Val layout 8x1 + using TiledLoaddOt = decltype(make_tiled_copy( + Copy_Atom, DType>{}, // should + // be V + // here + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // val layout 16x1 + + // for load dP, K and dQ in dQ=dP*K + using TiledLoaddP = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 k-major + Layout>{})); // val layout 16x1 + using TiledLoadK = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // val layout 16x1 + + using TiledLoaddQ = decltype(make_tiled_copy( + Copy_Atom, VType>{}, + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // val layout 8x1 + + // for load dPt, Q in dK=dPt*Q + using TiledLoaddPt = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 k-major + Layout>{})); // Val layout 16x1 + using TiledLoadQt = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // Val layout 16x1 + + // for save S in S=QKt and P + using TiledSaveS = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // Val layout 8x1 + // for save dP in dP=dO*Vt + using TiledSavedP = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // Val layout 8x1 + // for save dV in dV=Pt*dO + using TiledSavedV = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // Val layout 8x1 + // for save dQ in dQ=dP*K + using TiledSavedQ = decltype(make_tiled_copy( + Copy_Atom, VType>{}, + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // val layout 8x1 + // for save dK=dPt*Q + using TiledSavedK = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + Layout>{}, // Thr layout 1x16 n-major + Layout>{})); // Val layout 8x1 + + static constexpr int SubgroupSize = 16; + static constexpr int smem_size = 0; + + FAKernel() {} +}; + +using index_t = uint64_t; + +template +struct Param { + Param( + const T* dO, + const T* o, + const T* q, + const T* k, + const T* v, + const float* lse, + float* odo, + float* dqaccum, + T* dq, + T* dk, + T* dv, + T* s, + T* dp, + T* pb, + const float softmax_scale) + : do_ptr(dO), + o_ptr(o), + q_ptr(q), + k_ptr(k), + v_ptr(v), + lse_ptr(lse), + odo_ptr(odo), + dqaccum_ptr(dqaccum), + dq_ptr(dq), + dk_ptr(dk), + dv_ptr(dv), + s_ptr(s), + dp_ptr(dp), + pb_ptr(pb), + scale_softmax(softmax_scale), + scale_softmax_log2(softmax_scale * M_LOG2E), + is_bhsd(true) {} + // read only + const T* do_ptr; + const T* o_ptr; + const T* q_ptr; + const T* k_ptr; + const T* v_ptr; + const float* lse_ptr; + const float scale_softmax; + const float scale_softmax_log2; + // write + float* odo_ptr; + float* dqaccum_ptr; + T* dq_ptr; + T* dk_ptr; + T* dv_ptr; + T* s_ptr; + T* dp_ptr; + T* pb_ptr; + + // const dimension + int batch; + int num_head_q; + int num_head_kv; + int seq_len_q; + int seq_len_q_pad; + int seq_len_kv; + int seq_len_kv_pad; + int head_dim; + int n_block; + int tail_n; + int m_block; + int tail_m; + int num_qh_per_kvh; + int q_r_stride; + int q_h_stride; + int q_b_stride; + + int k_r_stride; + int k_h_stride; + int k_b_stride; + + int dk_r_stride; + int dk_h_stride; + int dk_b_stride; + + int v_r_stride; + int v_h_stride; + int v_b_stride; + + int dv_r_stride; + int dv_h_stride; + int dv_b_stride; + + int o_r_stride; + int o_h_stride; + int o_b_stride; + + int s_r_stride; + int s_s_stride; + int s_b_stride; + + int dq_r_stride; + int dq_h_stride; + int dq_b_stride; + /* + * input output layout + * true batch, numhead, seqlen, headsize + * false batch, seqlen, numhead, headsize + */ + bool is_bhsd; +}; + +template +struct Boffset { + Boffset(Param& param_) : param(param_) {} + index_t q_offset(const index_t b_id, const index_t h_id, const index_t s_id) { + return b_id * param.q_b_stride + h_id * param.q_h_stride + + s_id * param.q_r_stride; + } + index_t k_offset(const index_t b_id, const index_t h_id, const index_t s_id) { + return b_id * param.k_b_stride + h_id * param.k_h_stride + + s_id * param.k_r_stride; + } + index_t v_offset(const index_t b_id, const index_t h_id, const index_t s_id) { + return b_id * param.v_b_stride + h_id * param.v_h_stride + + s_id * param.v_r_stride; + } + index_t dk_offset( + const index_t b_id, + const index_t h_id, + const index_t s_id) { + return b_id * param.dk_b_stride + h_id * param.dk_h_stride + + s_id * param.dk_r_stride; + } + index_t dv_offset( + const index_t b_id, + const index_t h_id, + const index_t s_id) { + return b_id * param.dv_b_stride + h_id * param.dv_h_stride + + s_id * param.dv_r_stride; + } + index_t ps_offset( + const index_t b_id, + const index_t h_id, + const index_t sq_id, + const index_t sk_id) { + return b_id * param.s_b_stride + h_id * param.s_s_stride + + sq_id * param.s_r_stride + sk_id; + } + index_t lse_offset( + const index_t b_id, + const index_t h_id, + const index_t s_id) { + return b_id * param.seq_len_q * param.num_head_q + h_id * param.seq_len_q + + s_id; + } + + index_t o_offset(const index_t b_id, const index_t h_id, const index_t s_id) { + return b_id * param.o_b_stride + h_id * param.o_h_stride + + s_id * param.o_r_stride; + } + + index_t dq_offset( + const index_t b_id, + const index_t h_id, + const index_t s_id) { + return b_id * param.dq_b_stride + h_id * param.dq_h_stride + + s_id * param.dq_r_stride; + } + Param& param; +}; + +template +void setup_bhsd_stride(Param& param) { + param.q_r_stride = param.head_dim; + param.q_h_stride = param.seq_len_q * param.head_dim; + param.q_b_stride = param.num_head_q * param.seq_len_q * param.head_dim; + + // param.dq_r_stride = param.head_dim; + // param.dq_h_stride = param.seq_len_q * param.head_dim; + // param.dq_b_stride = param.num_head_q * param.seq_len_q * param.head_dim; + + param.k_r_stride = param.head_dim; + param.k_h_stride = param.seq_len_kv * param.head_dim; + param.k_b_stride = param.num_head_kv * param.seq_len_kv * param.head_dim; + + param.dk_r_stride = param.head_dim; + param.dk_h_stride = param.seq_len_kv * param.head_dim; + param.dk_b_stride = param.num_head_q * param.seq_len_kv * param.head_dim; + + param.v_r_stride = param.head_dim; + param.v_h_stride = param.seq_len_kv * param.head_dim; + param.v_b_stride = param.num_head_kv * param.seq_len_kv * param.head_dim; + + param.dv_r_stride = param.head_dim; + param.dv_h_stride = param.seq_len_kv * param.head_dim; + param.dv_b_stride = param.num_head_q * param.seq_len_kv * param.head_dim; + + param.o_r_stride = param.head_dim; + param.o_h_stride = param.seq_len_q * param.head_dim; + param.o_b_stride = param.num_head_q * param.seq_len_q * param.head_dim; + + // param.do_r_stride = param.head_dim; + // param.do_h_stride = param.seq_len_q * param.head_dim; + // param.do_b_stride = param.num_head_q * param.seq_len_q * param.head_dim; + param.s_r_stride = param.seq_len_kv_pad; + param.s_s_stride = param.seq_len_q_pad * param.seq_len_kv_pad; + param.s_b_stride = + param.num_head_q * param.seq_len_q_pad * param.seq_len_kv_pad; + + param.dq_r_stride = param.head_dim; + param.dq_h_stride = param.seq_len_q_pad * param.head_dim; + param.dq_b_stride = param.num_head_q * param.seq_len_q_pad * param.head_dim; +} + +template +void setup_bshd_stride(Param& param) { + param.q_r_stride = param.num_head_q * param.head_dim; + param.q_h_stride = param.head_dim; + param.q_b_stride = param.num_head_q * param.seq_len_q * param.head_dim; + + // param.dq_r_stride = param.head_dim; + // param.dq_h_stride = param.seq_len_q * param.head_dim; + // param.dq_b_stride = param.num_head_q * param.seq_len_q * param.head_dim; + + param.k_r_stride = param.num_head_kv * param.head_dim; + param.k_h_stride = param.head_dim; + param.k_b_stride = param.num_head_kv * param.seq_len_kv * param.head_dim; + + param.dk_r_stride = param.num_head_q * param.head_dim; + param.dk_h_stride = param.head_dim; + param.dk_b_stride = param.num_head_q * param.seq_len_kv * param.head_dim; + + param.v_r_stride = param.num_head_kv * param.head_dim; + param.v_h_stride = param.head_dim; + param.v_b_stride = param.num_head_kv * param.seq_len_kv * param.head_dim; + + param.dv_r_stride = param.num_head_q * param.head_dim; + param.dv_h_stride = param.head_dim; + param.dv_b_stride = param.num_head_q * param.seq_len_kv * param.head_dim; + + param.o_r_stride = param.num_head_q * param.head_dim; + param.o_h_stride = param.head_dim; + param.o_b_stride = param.num_head_q * param.seq_len_q * param.head_dim; + + // param.do_r_stride = param.head_dim; + // param.do_h_stride = param.seq_len_q * param.head_dim; + // param.do_b_stride = param.num_head_q * param.seq_len_q * param.head_dim; + param.s_r_stride = param.seq_len_kv_pad; + param.s_s_stride = param.seq_len_q_pad * param.seq_len_kv_pad; + param.s_b_stride = + param.num_head_q * param.seq_len_q_pad * param.seq_len_kv_pad; + + param.dq_r_stride = param.num_head_q * param.head_dim; + param.dq_h_stride = param.head_dim; + param.dq_b_stride = param.num_head_q * param.seq_len_q_pad * param.head_dim; +} + +} // namespace cute diff --git a/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_common.h b/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_common.h new file mode 100644 index 0000000000..8e18912507 --- /dev/null +++ b/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_common.h @@ -0,0 +1,50 @@ +#pragma once +#include + +#include + +#include +#include +#include + +/** + * Panic wrapper for unwinding CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) \ + << " at: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +#define CHECK_DEVICE(x) TORCH_CHECK(x.is_xpu(), #x " must be on XPU") + +#define CHECK_SHAPE(x, ...) \ + TORCH_CHECK( \ + x.sizes() == at::IntArrayRef({__VA_ARGS__}), \ + #x " must have shape (" #__VA_ARGS__ ")") + +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +#define FP16_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + using elem_type = cute::half_t; \ + return __VA_ARGS__(); \ + } else { \ + using elem_type = cute::bfloat16_t; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.cpp b/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.cpp new file mode 100644 index 0000000000..6bd76660c7 --- /dev/null +++ b/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.cpp @@ -0,0 +1,547 @@ +#include "mha_fwd.h" +#include "mha_common.h" + +// batch, numhead_qo,numhead_kv,seqlen_qo,seqlen_kv,headsize_qk,headsize_vo +using ProblemShapeRegular = cute::tuple; + +namespace cute { + +template +class MhaName; + +template +struct FA2Runner { + using StrideQ = typename FMHAPrefillKernel::StrideQ; + using StrideK = typename FMHAPrefillKernel::StrideK; + using StrideV = typename FMHAPrefillKernel::StrideV; + using StrideO = typename FMHAPrefillKernel::StrideO; + + using ElementQ = typename FMHAPrefillKernel::ElementQ; + using ElementK = typename FMHAPrefillKernel::ElementK; + using ElementV = typename FMHAPrefillKernel::ElementV; + using ElementAcc = typename FMHAPrefillKernel::ElementAccumulator; + + using CollectiveEpilogue = typename FMHAPrefillKernel::CollectiveEpilogue; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename FMHAPrefillKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideQ stride_Q; + StrideK stride_K; + StrideV stride_V; + StrideO stride_O; + + // + // Methods + // + + // Note that the GemmUniversalAdapter currently doesn't support flash + // attention, which is why this secondary `run` function is required to launch + // the kernel. + void run(sycl::queue& queue, typename FMHAPrefillKernel::Params params) { + dim3 const block = FMHAPrefillKernel::get_block_shape(); + dim3 const grid = FMHAPrefillKernel::get_grid_shape(params); + + // configure smem size and carveout + int smem_size = FMHAPrefillKernel::SharedStorageSize; + + const auto sycl_block = compat::dim3(block.x, block.y, block.z); + const auto sycl_grid = compat::dim3(grid.x, grid.y, grid.z); + +// Launch parameters depend on whether SYCL compiler supports work-group scratch +// memory extension +#if !defined(SYCL_EXT_ONEAPI_WORK_GROUP_SCRATCH_MEMORY) + using namespace compat::experimental; + auto event = launch< + cutlass::device_kernel, + MhaName>( + launch_policy{ + sycl_grid, + sycl_block, + local_mem_size{static_cast(smem_size)}, + kernel_properties{sycl_exp::sub_group_size< + FMHAPrefillKernel::DispatchPolicy::SubgroupSize>}}, + queue, + params); +#else + compat::experimental::launch_properties launch_props{ + sycl::ext::oneapi::experimental::work_group_scratch_size(smem_size), + }; + compat::experimental::kernel_properties kernel_props{ + sycl::ext::oneapi::experimental::sub_group_size< + FMHAPrefillKernel::DispatchPolicy::SubgroupSize>}; + compat::experimental::launch_policy policy{ + sycl_grid, sycl_block, launch_props, kernel_props}; + auto event = compat::experimental::launch< + cutlass::device_kernel, + MhaName>(policy, queue, params); +#endif + } + + void run( + sycl::queue& queue, + ProblemShapeType problem_size, + const cutlass::KernelHardwareInfo& hw_info, + const ElementQ* inputQ, + const ElementK* inputK, + const ElementV* inputV, + ElementOutput* output, + float* logsumexp, + float softmax_scale) { + auto + [batch, + num_heads_q, + num_heads_kv, + seq_len_qo, + seq_len_kv, + head_size_qk, + head_size_vo] = problem_size; + + stride_Q = cutlass::make_cute_packed_stride( + StrideQ{}, + cute::make_shape(seq_len_qo, head_size_qk, batch * num_heads_q)); + stride_K = cutlass::make_cute_packed_stride( + StrideK{}, + cute::make_shape(seq_len_kv, head_size_qk, batch * num_heads_kv)); + stride_V = cutlass::make_cute_packed_stride( + StrideV{}, + cute::make_shape(head_size_vo, seq_len_kv, batch * num_heads_kv)); + stride_O = cutlass::make_cute_packed_stride( + StrideO{}, + cute::make_shape(seq_len_qo, head_size_vo, batch * num_heads_q)); + + typename FMHAPrefillKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {inputQ, stride_Q, inputK, stride_K, inputV, stride_V}, + {softmax_scale}, + {output, stride_O, logsumexp}, + hw_info, + softmax_scale}; + + // Define device-global scratch memory + size_t workspace_size = FMHAPrefillKernel::get_workspace_size(arguments); + at::Tensor workspace_tensor = at::empty( + {static_cast(workspace_size)}, + at::device(at::kXPU).dtype(at::kByte)); + + if (!FMHAPrefillKernel::can_implement(arguments)) { + TORCH_CHECK( + false, + "Invalid Problem Size", + batch, + "x", + num_heads_q, + "x", + seq_len_qo, + "x", + seq_len_kv, + "x", + head_size_qk, + "x", + head_size_vo); + return; + } + + // Initialize the workspace + CUTLASS_CHECK(FMHAPrefillKernel::initialize_workspace( + arguments, workspace_tensor.data_ptr())); + + // Convert host-side arguments to device-side arguments to be passed to the + // kernel + auto params = FMHAPrefillKernel::to_underlying_arguments( + arguments, workspace_tensor.data_ptr()); + + // Launch a SYCL kernel using scratch/shared memory + run(queue, params); + } +}; + +template < + typename T, + typename ProblemShape, + bool IS_CAUSAL, + typename TileShapeQK, + typename TileShapePV, + typename TileShapeOutPut, + typename SubgroupLayout, + int PipelineStages> +void run_mha_fwd_( + sycl::queue& queue, + ProblemShape& problem_shape, + const T* query, + const T* key, + const T* value, + T* out, + float* logsumexp, + float scale) { + cutlass::KernelHardwareInfo hw_info; + + using LayoutQ = cutlass::layout::RowMajor; + using LayoutK = cutlass::layout::ColumnMajor; + using LayoutV = cutlass::layout::RowMajor; + using LayoutO = cutlass::layout::RowMajor; + + using ElementInputQ = T; + using ElementInputKV = T; + using ElementOutput = T; + using ElementAccumulator = float; + using ElementComputeEpilogue = float; + + using MMAOperation = std::conditional_t< + std::is_same_v, + XE_8x16x16_F32BF16BF16F32_TT, + XE_8x16x16_F32F16F16F32_TT>; + using GmemTiledCopyQ = XE_2D_U16x8x32_LD_N; + using GmemTiledCopyK = + XE_2D_U16x16x16_LD_T; // _T designates a transposed block load operation + using GmemTiledCopyV = XE_2D_U16x16x32_LD_V; + using GmemTiledCopyStore = XE_2D_U16x8x16_ST_N; // Change to output BF16 + + using GEMMDispatchPolicy = + cutlass::gemm::MainloopIntelXeXMX16; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + using CollectiveEpilogue = + cutlass::flash_attention::collective::FlashPrefillEpilogue< + EpilogueDispatchPolicy, + MMAOperation, + TileShapeOutPut, + SubgroupLayout, + ElementComputeEpilogue, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + GmemTiledCopyStore>; + using CollectiveSoftmaxEpilogue = + cutlass::flash_attention::collective::FlashPrefillSoftmaxEpilogue< + IS_CAUSAL, + EpilogueDispatchPolicy, + ElementAccumulator>; + + using namespace cutlass::fmha::collective; + + using ProblemShapeType = ProblemShape; + + // Mainloop + using CollectiveMainloop = + cutlass::flash_attention::collective::FlashPrefillMma< + GEMMDispatchPolicy, + ProblemShapeType, + ElementInputQ, + cutlass::gemm::TagToStrideA_t, + ElementInputKV, + cutlass::gemm::TagToStrideB_t, + ElementInputKV, + cutlass::gemm::TagToStrideB_t, + MMAOperation, + TileShapeQK, + TileShapePV, + SubgroupLayout, + GmemTiledCopyQ, // Q + GmemTiledCopyK, // K + GmemTiledCopyV, // V, + IS_CAUSAL>; + using FMHAPrefillKernel = cutlass::flash_attention::kernel::FMHAPrefill< + ProblemShapeType, + CollectiveMainloop, + CollectiveSoftmaxEpilogue, + CollectiveEpilogue, + cutlass::flash_attention::IndividualScheduler>; + + FA2Runner runner; + runner.run( + queue, problem_shape, hw_info, query, key, value, out, logsumexp, scale); +} + +template +void run_mha_fwd_( + sycl::queue& queue, + ProblemShape& problem_shape, + const T* query, + const T* key, + const T* value, + T* out, + float* logsumexp, + float scale) { + const int headdim = get<5>(problem_shape); + +#define run_mha_fwd_specialized( \ + TileShapeQK_, \ + TileShapePV_, \ + TileShapeOutPut_, \ + SubgroupLayout_, \ + PipelineStages_) \ + run_mha_fwd_< \ + T, \ + ProblemShape, \ + IS_CAUSAL, \ + TileShapeQK_, \ + TileShapePV_, \ + TileShapeOutPut_, \ + SubgroupLayout_, \ + PipelineStages_>( \ + queue, problem_shape, query, key, value, out, logsumexp, scale); + + constexpr int PipelineStages = 2; + if (headdim == 64) { + using TileShapeQK = Shape<_128, _64, _64>; + using TileShapePV = Shape<_128, _32, _64>; + using TileShapeOutPut = Shape<_128, _64, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; + run_mha_fwd_specialized( + TileShapeQK, + TileShapePV, + TileShapeOutPut, + SubgroupLayout, + PipelineStages); + } else if (headdim == 96) { + using TileShapeQK = Shape<_128, _64, _32>; + using TileShapePV = Shape<_128, _32, _64>; + using TileShapeOutPut = Shape<_128, _96, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; + run_mha_fwd_specialized( + TileShapeQK, + TileShapePV, + TileShapeOutPut, + SubgroupLayout, + PipelineStages); + } else if (headdim == 128) { + using TileShapeQK = Shape<_256, _32, _64>; + using TileShapePV = Shape<_256, _32, _32>; + using TileShapeOutPut = Shape<_256, _128, _32>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; + run_mha_fwd_specialized( + TileShapeQK, + TileShapePV, + TileShapeOutPut, + SubgroupLayout, + PipelineStages); + } else if (headdim == 192) { + using TileShapeQK = Shape<_256, _64, _64>; + using TileShapePV = Shape<_256, _32, _64>; + using TileShapeOutPut = Shape<_256, _192, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; + run_mha_fwd_specialized( + TileShapeQK, + TileShapePV, + TileShapeOutPut, + SubgroupLayout, + PipelineStages); + } else { + TORCH_CHECK( + false, "FlashAttentionForwardXPU only support headdim 64,96,128,192"); + } +} + +template +void run_mha_fwd( + sycl::queue& queue, + ProblemShape& problem_shape, + const void* query, + const void* key, + const void* value, + void* out, + void* logsumexp, + bool is_causal, + float scale, + at::ScalarType dtype) { + FP16_SWITCH(dtype == at::kHalf, [&] { + BOOL_SWITCH(is_causal, IS_CAUSAL, [&] { + run_mha_fwd_( + queue, + problem_shape, + static_cast(query), + static_cast(key), + static_cast(value), + static_cast(out), + static_cast(logsumexp), + scale); + }); + }); +} +} // namespace cute + +namespace sycltla { + +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + c10::SymInt, + c10::SymInt, + at::Tensor, + at::Tensor> +flash_attention_forward_sycltla( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const double dropout, + const bool is_causal, + const float scale) { + TORCH_CHECK( + dropout == 0.0, + "FlashAttentionForwardXPU does not only support dropout > 0.0 yet"); + + CHECK_DEVICE(query); + CHECK_DEVICE(key); + CHECK_DEVICE(value); + + TORCH_CHECK( + !query.is_nested() && !key.is_nested() && !value.is_nested(), + "FlashAttentionForwardXPU only support dense inputs"); + + auto dtype = query.scalar_type(); + TORCH_CHECK( + dtype == at::kHalf || dtype == at::kBFloat16, + "FlashAttentionForwardXPU only support fp16 and bf16 data type"); + TORCH_CHECK( + key.scalar_type() == dtype, + "FlashAttentionForwardXPU: query and key must have the same dtype"); + TORCH_CHECK( + value.scalar_type() == dtype, + "FlashAttentionForwardXPU: query and value must have the same dtype"); + + TORCH_CHECK( + query.dim() == 4 && key.dim() == 4 && value.dim() == 4, + "FlashAttentionForwardXPU requires query, key, value to be 4 dimensional"); + + const int batch_size = query.sizes()[0]; + const int numhead_qo = query.sizes()[1]; + const int numhead_kv = key.sizes()[1]; + const int seqlen_qo = query.sizes()[2]; + const int seqlen_kv = key.sizes()[2]; + const int headsize_qk = query.sizes()[3]; + const int headsize_vo = value.sizes()[3]; + + CHECK_SHAPE(query, batch_size, numhead_qo, seqlen_qo, headsize_qk); + CHECK_SHAPE(key, batch_size, numhead_kv, seqlen_kv, headsize_qk); + CHECK_SHAPE(value, batch_size, numhead_kv, seqlen_kv, headsize_vo); + + TORCH_CHECK( + numhead_qo % numhead_kv == 0, + "FlashAttentionForwardXPU: numhead_qo must be divisible by numhead_kv"); + + TORCH_CHECK( + query.stride(-1) == 1, + "FlashAttentionForwardXPU: input tensor must have contiguous last dimension"); + TORCH_CHECK( + key.stride(-1) == 1, + "FlashAttentionForwardXPU: input tensor must have contiguous last dimension"); + TORCH_CHECK( + value.stride(-1) == 1, + "FlashAttentionForwardXPU: input tensor must have contiguous last dimension"); + + ATTN_TENSOR_LAYOUT layout = get_attn_tensor_layout(query); + if (layout == ATTN_TENSOR_LAYOUT::UNSUPPORTED) { + TORCH_CHECK( + false, + "FlashAttentionForwardXPU: only support BHSD or BSHD layout, got query with shape ", + query.sizes(), + ", stride ", + query.strides()); + } + layout = fuse_attn_tensor_layout(layout, get_attn_tensor_layout(key)); + TORCH_CHECK( + ATTN_TENSOR_LAYOUT::UNSUPPORTED != layout, + "FlashAttentionBackwardXPU: query and key must have the same layout, got query with layout ", + to_string(layout), + ", key with layout ", + to_string(get_attn_tensor_layout(key))); + layout = fuse_attn_tensor_layout(layout, get_attn_tensor_layout(value)); + TORCH_CHECK( + ATTN_TENSOR_LAYOUT::UNSUPPORTED != layout, + "FlashAttentionBackwardXPU: query and value must have the same layout, got query with layout ", + to_string(layout), + ", value with layout ", + to_string(get_attn_tensor_layout(value))); + if (layout == ATTN_TENSOR_LAYOUT::BXD) { + layout = ATTN_TENSOR_LAYOUT::BSHD; + } + TORCH_CHECK( + layout == ATTN_TENSOR_LAYOUT::BSHD, + "FlashAttentionBackwardXPU: currently only support BSHD layout"); + + auto opts = query.options(); + at::Tensor out; + if (layout == ATTN_TENSOR_LAYOUT::BHSD) { + out = at::empty({batch_size, numhead_qo, seqlen_qo, headsize_vo}, opts); + } else if (layout == ATTN_TENSOR_LAYOUT::BSHD) { + out = at::empty({batch_size, seqlen_qo, numhead_qo, headsize_vo}, opts) + .permute({0, 2, 1, 3}); + } else { + TORCH_CHECK( + false, "FlashAttentionForwardXPU: only support BHSD or BSHD layout"); + } + + at::Tensor logsumexp = + at::empty({batch_size, numhead_qo, seqlen_qo}, opts.dtype(at::kFloat)); + + auto sycl_queue = at::xpu::getCurrentXPUStream().queue(); + auto device_architecture = + sycl_queue.get_device() + .get_info< + sycl::ext::oneapi::experimental::info::device::architecture>(); + constexpr auto supported_architectures = + std::array{ + sycl::ext::oneapi::experimental::architecture::intel_gpu_pvc, + sycl::ext::oneapi::experimental::architecture::intel_gpu_pvc_vg, + sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g21, + sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g31}; + if (std::find( + supported_architectures.begin(), + supported_architectures.end(), + device_architecture) == supported_architectures.end()) { + TORCH_CHECK( + false, + "XPU device architecture does not support flash attention. Supported architectures are: intel_gpu_pvc, intel_gpu_pvc_vg, intel_gpu_bmg_g21, intel_gpu_bmg_g31."); + } + + auto problem_shape = ProblemShapeRegular( + batch_size, + numhead_qo, + numhead_kv, + seqlen_qo, + seqlen_kv, + headsize_qk, + headsize_vo); + + cute::run_mha_fwd( + sycl_queue, + problem_shape, + query.data_ptr(), + key.data_ptr(), + value.data_ptr(), + out.data_ptr(), + logsumexp.data_ptr(), + is_causal, + scale, + dtype); + + return std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + c10::SymInt, + c10::SymInt, + at::Tensor, + at::Tensor>{ + out, + logsumexp, + /* cumulative_sequence_length_q */ at::Tensor(), + /* cumulative_sequence_length_k */ at::Tensor(), + /* max_seqlen_batch_q */ c10::SymInt(0), + /* max_seqlen_batch_k */ c10::SymInt(0), + /* philox_seed */ at::empty({}, at::dtype(at::kLong)), + /* philox_offset */ at::empty({}, at::dtype(at::kLong))}; +} + +} // namespace sycltla \ No newline at end of file diff --git a/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.h b/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.h new file mode 100644 index 0000000000..175883dd4b --- /dev/null +++ b/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.h @@ -0,0 +1,13 @@ +#pragma once +#include +#include +#include +#include +#include + +#include "collective/xe_flash_attn_prefill_mma_bshd.h" +#include "collective/xe_flash_attn_sdpa_fwd_bshd_epilogue.h" +#include "collective/xe_flash_attn_sdpa_fwd_bshd_softmax_epilogue.h" +#include "flash_attention_v2/collective/fmha_fusion.hpp" +#include "kernel/tile_scheduler_sdpa_fwd_bshd.h" +#include "kernel/xe_sdpa_fwd_bshd.h" \ No newline at end of file diff --git a/src/ATen/native/transformers/xpu/flash_attn/utils.h b/src/ATen/native/transformers/xpu/flash_attn/utils.h new file mode 100644 index 0000000000..4a42ffad8f --- /dev/null +++ b/src/ATen/native/transformers/xpu/flash_attn/utils.h @@ -0,0 +1,135 @@ +#pragma once + +#include +#include + +namespace sycltla { + +enum class ATTN_TENSOR_LAYOUT { + BHSD, // batchsize, headnum, seqlen, headdim + BSHD, // batchsize, seqlen, headnum, headdim + BXD, // in case headnum==1 or seqlen==1, which is compatible with BHSD/BSHD + UNSUPPORTED +}; + +inline std::string to_string(ATTN_TENSOR_LAYOUT layout) { + switch (layout) { + case ATTN_TENSOR_LAYOUT::BHSD: + return "BHSD"; + case ATTN_TENSOR_LAYOUT::BSHD: + return "BSHD"; + case ATTN_TENSOR_LAYOUT::BXD: + return "BXD"; + case ATTN_TENSOR_LAYOUT::UNSUPPORTED: + return "UNSUPPORTED"; + default: + return "UNKNOWN"; + } +} + +inline ATTN_TENSOR_LAYOUT get_attn_tensor_layout(const at::Tensor& t) { + // sdpa's tensor shape are in BHSD format + if (t.is_contiguous(at::MemoryFormat::Contiguous)) { + if (t.size(1) == 1 || t.size(2) == 1) { + return ATTN_TENSOR_LAYOUT::BXD; + } + return ATTN_TENSOR_LAYOUT::BHSD; + } else if (t.transpose(1, 2).is_contiguous(at::MemoryFormat::Contiguous)) { + if (t.size(1) == 1 || t.size(2) == 1) { + return ATTN_TENSOR_LAYOUT::BXD; + } + return ATTN_TENSOR_LAYOUT::BSHD; + } else { + return ATTN_TENSOR_LAYOUT::UNSUPPORTED; + } +} + +inline ATTN_TENSOR_LAYOUT fuse_attn_tensor_layout( + ATTN_TENSOR_LAYOUT layout1, + ATTN_TENSOR_LAYOUT layout2) { + if (layout1 == ATTN_TENSOR_LAYOUT::UNSUPPORTED || + layout2 == ATTN_TENSOR_LAYOUT::UNSUPPORTED) { + return ATTN_TENSOR_LAYOUT::UNSUPPORTED; + } + if (layout1 == layout2) { + return layout1; + } + // if one is BXD, return the other one + if (layout1 == ATTN_TENSOR_LAYOUT::BXD) { + return layout2; + } + if (layout2 == ATTN_TENSOR_LAYOUT::BXD) { + return layout1; + } + // otherwise, incompatible + return ATTN_TENSOR_LAYOUT::UNSUPPORTED; +} + +inline at::Tensor attn_tensor_to_layout( + const at::Tensor& t, + ATTN_TENSOR_LAYOUT target_layout) { + if (target_layout == ATTN_TENSOR_LAYOUT::UNSUPPORTED || + target_layout == ATTN_TENSOR_LAYOUT::BXD) { + TORCH_CHECK( + false, "FlashAttentionXPU: only support BHSD or BSHD as target layout"); + } + + ATTN_TENSOR_LAYOUT layout = get_attn_tensor_layout(t); + at::Tensor output = t; + if (layout == ATTN_TENSOR_LAYOUT::UNSUPPORTED || + layout == ATTN_TENSOR_LAYOUT::BXD || layout != target_layout) { + if (target_layout == ATTN_TENSOR_LAYOUT::BHSD) { + // convert to BHSD + output = t.contiguous(at::MemoryFormat::Contiguous); + layout = ATTN_TENSOR_LAYOUT::BHSD; + } else { + // convert to BSHD + output = t.permute({0, 2, 1, 3}) + .contiguous(at::MemoryFormat::Contiguous) + .permute({0, 2, 1, 3}); + layout = ATTN_TENSOR_LAYOUT::BSHD; + } + } + + return output; +} + +inline bool check_flash_attention_bshd_layout( + sdp::sdp_params const& params, + bool debug) { + sycltla::ATTN_TENSOR_LAYOUT layout = + sycltla::get_attn_tensor_layout(params.query); + if (layout == sycltla::ATTN_TENSOR_LAYOUT::UNSUPPORTED) { + if (debug) { + TORCH_WARN("FlashAttentionXPU requires query to be in BSHD layout."); + } + return false; + } + layout = fuse_attn_tensor_layout( + layout, sycltla::get_attn_tensor_layout(params.key)); + if (layout == sycltla::ATTN_TENSOR_LAYOUT::UNSUPPORTED) { + if (debug) { + TORCH_WARN("FlashAttentionXPU requires key to be in BSHD layout."); + } + return false; + } + layout = fuse_attn_tensor_layout( + layout, sycltla::get_attn_tensor_layout(params.value)); + if (layout == sycltla::ATTN_TENSOR_LAYOUT::UNSUPPORTED) { + if (debug) { + TORCH_WARN("FlashAttentionXPU requires value to be in BSHD layout."); + } + return false; + } + if (layout != sycltla::ATTN_TENSOR_LAYOUT::BSHD && + layout != sycltla::ATTN_TENSOR_LAYOUT::BXD) { + if (debug) { + TORCH_WARN( + "FlashAttentionXPU requires query, key, and value to be in BSHD layout."); + } + return false; + } + return true; +} + +} // namespace sycltla From 76aec4e5142dc93dd3029b6db4a1a5acdffc80bd Mon Sep 17 00:00:00 2001 From: "fengqing.lu" Date: Wed, 12 Nov 2025 21:52:10 -0800 Subject: [PATCH 2/7] install header --- src/ATen/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ATen/CMakeLists.txt b/src/ATen/CMakeLists.txt index 961a5065b1..959c6162fb 100644 --- a/src/ATen/CMakeLists.txt +++ b/src/ATen/CMakeLists.txt @@ -40,7 +40,7 @@ install_xpu_headers("native/quantized/xpu/sycl") install_xpu_headers("native/sparse/xpu") install_xpu_headers("native/sparse/xpu/sycl") install_xpu_headers("native/transformers/xpu") -install_xpu_headers("native/transformers/xpu/sycl") +install_xpu_headers("native/transformers/xpu/flash_attn") if(xpu_ops_generated_headers) install(FILES ${xpu_ops_generated_headers} DESTINATION ${AT_INSTALL_INCLUDE_DIR}/ATen/ops) From d33b6a5785ee211bcacda8674fef292591f226bd Mon Sep 17 00:00:00 2001 From: "fengqing.lu" Date: Thu, 13 Nov 2025 22:34:06 -0800 Subject: [PATCH 3/7] fix build warning --- .../xpu/flash_attn/sycltla/mha_bwd.cpp | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp b/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp index d25117d846..867e3520a5 100644 --- a/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp +++ b/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp @@ -22,7 +22,7 @@ void compute_o_dot_do( const int bidh) { // The thread index. constexpr int kBlockM = T::kBlockM; - constexpr int kBlockN = T::kBlockN; + // constexpr int kBlockN = T::kBlockN; constexpr int kHeadDim = T::kHeadDim; constexpr int kNSGs = T::kNSGs; constexpr int SubgroupSize = T::SubgroupSize; @@ -31,7 +31,8 @@ void compute_o_dot_do( auto sg = compat::get_nd_item<1>().get_sub_group(); auto group = compat::get_nd_item<1>().get_group(); - auto first_thread_in_sg_idx = sg.get_group_linear_id() * trait.SubgroupSize; + // auto first_thread_in_sg_idx = sg.get_group_linear_id() * + // trait.SubgroupSize; auto bofst = Boffset(param); const index_t o_offset = bofst.o_offset(bidb, bidh, m_block * kBlockM); @@ -208,7 +209,7 @@ CUTLASS_DEVICE void apply_mask_causal( auto sg = compat::get_nd_item<1>().get_sub_group(); auto group = compat::get_nd_item<1>().get_group(); int sg_local_id = sg.get_local_id(); - int sg_group_id = sg.get_group_id(); + // int sg_group_id = sg.get_group_id(); Tensor rC_2d = make_tensor(rC.data(), convert_layout_2d_layout(rC.layout())); CUTLASS_PRAGMA_UNROLL for (int n = 0; n < size<1>(tensor); ++n) { @@ -370,8 +371,8 @@ void dq_dk_dv_1colblock( constexpr int kBlockM = Trait::kBlockM; constexpr int kBlockN = Trait::kBlockN; constexpr bool is_causal = Trait::is_causal; - constexpr int kNSGs = Trait::kNSGs; - constexpr int SubgroupSize = Trait::SubgroupSize; + // constexpr int kNSGs = Trait::kNSGs; + // constexpr int SubgroupSize = Trait::SubgroupSize; auto sg = compat::get_nd_item<1>().get_sub_group(); auto group = compat::get_nd_item<1>().get_group(); auto first_thread_in_sg_idx = sg.get_group_linear_id() * trait.SubgroupSize; @@ -674,7 +675,8 @@ void dq_dk_dv_1colblock( const int max_m_block = ceil_div(param.seq_len_q, kBlockM); const int tail_m = param.seq_len_q % kBlockM; - cutlass::NumericConverter converter; + // cutlass::NumericConverter converter; + // clear accumulator clear(tdVrdV); clear(tdKrdK); @@ -878,7 +880,7 @@ void convert_dq( int bidb, int bidh) { constexpr int kBlockM = T::kBlockM; - constexpr int kBlockN = T::kBlockN; + // constexpr int kBlockN = T::kBlockN; constexpr int kHeadDim = T::kHeadDim; using DType = typename T::DType; using VType = typename T::VType; From 54a5ca54fe502f0473a0c87d459c59747f2c2d26 Mon Sep 17 00:00:00 2001 From: "fengqing.lu" Date: Thu, 13 Nov 2025 22:35:30 -0800 Subject: [PATCH 4/7] rebase forwardkernel --- .../sycltla/kernel/xe_sdpa_fwd_bshd.h | 652 ++++++------------ 1 file changed, 210 insertions(+), 442 deletions(-) diff --git a/src/ATen/native/transformers/xpu/flash_attn/sycltla/kernel/xe_sdpa_fwd_bshd.h b/src/ATen/native/transformers/xpu/flash_attn/sycltla/kernel/xe_sdpa_fwd_bshd.h index f2f14eaffe..db08bd29e8 100644 --- a/src/ATen/native/transformers/xpu/flash_attn/sycltla/kernel/xe_sdpa_fwd_bshd.h +++ b/src/ATen/native/transformers/xpu/flash_attn/sycltla/kernel/xe_sdpa_fwd_bshd.h @@ -286,472 +286,240 @@ class FMHAPrefill { // speculative decoding. In that case, the `is_causal` masking behavior // will be changed and we need to adjust the main loop to perform // appropriate calculations - if (seq_len_qo > seq_len_kv && CausalMask) { - int first_non_masked_sequence = seq_len_qo - seq_len_kv; - - int seq_coord = cute::min( - seq_len_qo, - (blk_m_coord * QK_BLK_M + (sub_group_id / PV_ATOM_N) * QK_SG_M) % - seq_len_qo); - - // Calculate the seq_len_idx (blk_m_coord * get<0>(TileShapeOutput{})) - // and check if it is still within bounds of the actual seq_len_qo - // (get<0>(sequence_length_shape)). - if (blk_m_coord * get<0>(TileShapeOutput{}) >= seq_len_qo) { - continue; - } - - // calculate the last seq_len_qo of this subblock - int last_seq_coord = seq_coord + QK_SG_M - 1; // 5 + int first_non_masked_sequence = seq_len_qo - seq_len_kv; + + int seq_coord = cute::min( + seq_len_qo, + (blk_m_coord * QK_BLK_M + (sub_group_id / PV_ATOM_N) * QK_SG_M) % + seq_len_qo); + + // Calculate the seq_len_idx (blk_m_coord * get<0>(TileShapeOutput{})) and + // check if it is still within bounds of the actual seq_len_qo + // (get<0>(sequence_length_shape)). + if (blk_m_coord * get<0>(TileShapeOutput{}) >= seq_len_qo) { + continue; + } - if (last_seq_coord < - first_non_masked_sequence) { // no need to perform calculation as - // those sections are masked - continue; - } + // calculate the last seq_len_qo of this subgroup + int last_seq_coord = seq_coord + QK_SG_M - 1; - // The main idea is to calculate the longest non-masked elements for - // this subgroup It is calculated by leveraging the property of bottom - // right mask - - // Calculate the longest length of the non-masked sequences for this - // subblock. The sequence is always the last one of subblock. - int longest_non_masked_length = cute::min( - seq_len_kv, - cute::max(0, last_seq_coord - first_non_masked_sequence + 1)); - - const int seq_len = cute::min(seq_len_kv, longest_non_masked_length); - - const int nblock_limit = cute::ceil_div(seq_len, QK_BLK_N); - - Tensor mQ_mkl = cute::get_xe_tensor( - make_shape(seq_len_qo, head_size_qk, 1)); //(m,k,l) - Tensor mK_nkl = cute::get_xe_tensor( - make_shape(seq_len_kv, head_size_qk, 1)); //(n,k,l) - Tensor mV_nkl = cute::get_xe_tensor( - make_shape(head_size_vo, seq_len_kv, 1)); //(n,k,l) - Tensor mQ_mk = mQ_mkl(_, _, 0); - Tensor mK_nk = mK_nkl(_, _, 0); // (n,k) - Tensor mV_nk = mV_nkl(_, _, 0); - - auto gQ = local_tile( - mQ_mk, - TileShapeQK{}, - make_coord(blk_m_coord, _, _), - Step<_1, X, _1>{}); - auto gK = local_tile( - mK_nk, TileShapeQK{}, make_coord(_, _, _), Step{}); - auto gV = local_tile( - mV_nk, - TileShapeOutput{}, - make_coord(_, blk_n_coord, _), - Step{}); - - auto mainloop_params = CollectiveMainloop::get_updated_copies( - params.mainloop, - params.problem_shape, - sequence_length_shape, - batch_coord, - q_head_coord); - // we limit the horisontal size to two subgroup, the empirical resutls - // show that reading the two cacheline side by side in gives better - // performance and anything after that does not have an effect on - // performance. // (64 here for float b float when possible and loop - // over to cover all the data needed) - auto tiled_prefetch_q = cute::prefetch_selector< - Shape, Int>, - Num_SGs>(mainloop_params.gmem_tiled_copy_q); - auto tiled_prefetch_k = cute::prefetch_selector< - Shape, Int>, - Num_SGs>(mainloop_params.gmem_tiled_copy_k); - auto tiled_prefetch_v = cute::prefetch_selector< - Shape< - Int, - Int>, - Num_SGs>(mainloop_params.gmem_tiled_copy_v); - auto thr_prefetch_Q = tiled_prefetch_q.get_slice(thread_idx); - auto thr_prefetch_K = tiled_prefetch_k.get_slice(thread_idx); - auto thr_prefetch_V = tiled_prefetch_v.get_slice(thread_idx); - auto pQgQ = thr_prefetch_Q.partition_S(gQ); - auto pKgK = thr_prefetch_K.partition_S(gK); - auto pVgV = thr_prefetch_V.partition_S(gV); - - for (int i = 0; i < size<3>(pQgQ); i++) { - prefetch(tiled_prefetch_q, pQgQ(_, _, _, i)); - } - for (int j = 0; j < size<4>(pKgK); j++) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < DispatchPolicy::Stages; i++) { - prefetch(tiled_prefetch_k, pKgK(_, _, _, i, j)); - } - } - - // Allocate the tiled_mma and the accumulators for the (M,N) - // workgroup_shape - Tensor out_reg = make_tensor(AccumeShape{}); - - // There are 16 workitem and 16 max per subgroup, each worktime containt - // 1 max and cumulatively, they calculate the max per subgroup - ElementAccumulator max_reg{-INFINITY}; - - // The sum reg each contains a 2d tesnor for 8 x 2 This is number of - // sequence lenght process per subgroup - Tensor sum_reg = - make_tensor(Shape, Int>{}); - - clear(sum_reg); - clear(out_reg); - - // Perform the collective scoped MMA - CollectiveMainloop collective_mma; - // when causal mask is true. It is not possible to set the scope - // of the barrier to workgroup level as the number n block is - // different for each subgroup due to triangular nature of causal based - // operation - static constexpr int barrier_scope = CausalMask ? 3 : 2; - // MAIN LOOP: loop over K and V, perform fused attention + online - // softmax - for (int nblock = 0; - nblock < nblock_limit - static_cast(CausalMask); - nblock++) { - barrier_arrive(barrier_scope); - // 1) Load K (performed inside mmaQK) - // 2) Create Tensor S - Tensor tSr = make_tensor( - Shape, Int, Int>{}); - clear(tSr); - - // 3) Perform GEMM S = Q*K - collective_mma.mmaQK( - tSr, - gQ, - gK(_, _, nblock, _), - tSr, - ceil_div(head_size_qk, QK_BLK_K), - mainloop_params); - - // we only need one block ahead, there is enough gap to prefetch it - // while doing softmax. because the gap between the two MMA is big, - // prefetching it the same way as cutlass K matrix does not make sense - for (int i = 0; i < size<1>(pVgV); i++) { - prefetch(tiled_prefetch_v, pVgV(_, i, _, nblock)); - } + if (CausalMask && + last_seq_coord < + first_non_masked_sequence) { // no need to perform calculation as + // the whole subblock is masked + continue; + } - CollectiveSoftmaxEpilogue softmax(params.softmax); - softmax(nblock == 0, tSr, max_reg, sum_reg, out_reg); + // The main idea is to calculate the longest non-masked elements for this + // subgroup It is calculated by leveraging the property of bottom right + // mask - collective_mma.template mmaPV( - out_reg, tSr, gV(_, _, nblock), out_reg, mainloop_params); + // Calculate the longest length of the non-masked sequences for this + // subgroup. The sequence is always the last sequence in that subblock. + int longest_non_masked_length = cute::min( + seq_len_kv, + cute::max(0, last_seq_coord - first_non_masked_sequence + 1)); + int seq_len = seq_len_kv; - // Prefetch the next K tile - // there is no need to gaurd it with if statememt as prefetch will - // ignore out of bound reading - for (int j = 0; j < size<4>(pKgK); j++) { - prefetch( - tiled_prefetch_k, - pKgK(_, _, _, nblock + DispatchPolicy::Stages, j)); - } - barrier_wait(barrier_scope); - } - - if constexpr (CausalMask) { - // BAND Matrix - // 1) Load K (performed inside mmaQK) - // 2) Create Tensor S - Tensor tSr = make_tensor( - Shape, Int, Int>{}); - clear(tSr); - // 3) Perform GEMM S = Q*K - collective_mma.mmaQK( - tSr, - gQ, - gK(_, _, nblock_limit - 1, _), - tSr, - ceil_div(head_size_qk, QK_BLK_K), - mainloop_params); - // we only need one block ahead, there is enough gap to prefetch it - // while doing softmax. because the gap between the two MMA is big, - // prefetching it the same way as cutlass K matrix does not make sense - for (int i = 0; i < size<1>(pVgV); i++) { - prefetch(tiled_prefetch_v, pVgV(_, i, _, nblock_limit - 1)); - } - // mask the elements of each tile where j > i - const int item_id = thread_idx % SubgroupSize; - int col_idx = item_id + (nblock_limit - 1) * QK_BLK_N; - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < FragsN; - n++, col_idx += get<1>(MmaAtomShape())) { // 4 - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < FragsM; m++) { // 2 - int row_idx = m * Vec + seq_coord; - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < Vec; row++, row_idx++) { // 8 - if (col_idx > row_idx - first_non_masked_sequence || - row_idx < first_non_masked_sequence) { - tSr(row, m, n) = ElementAccumulator{-INFINITY}; - } - } - } - } - - CollectiveSoftmaxEpilogue softmax(params.softmax); - softmax((nblock_limit - 1) == 0, tSr, max_reg, sum_reg, out_reg); + if (CausalMask) { + seq_len = cute::min(seq_len_kv, longest_non_masked_length); + } - collective_mma.template mmaPV( - out_reg, - tSr, - gV(_, _, nblock_limit - 1), - out_reg, - mainloop_params); + int nblock_limit = cute::ceil_div(seq_len, QK_BLK_N); + + Tensor mQ_mkl = cute::get_xe_tensor( + make_shape(seq_len_qo, head_size_qk, 1)); //(m,k,l) + Tensor mK_nkl = cute::get_xe_tensor( + make_shape(seq_len_kv, head_size_qk, 1)); //(n,k,l) + Tensor mV_nkl = cute::get_xe_tensor( + make_shape(head_size_vo, seq_len_kv, 1)); //(n,k,l) + Tensor mQ_mk = mQ_mkl(_, _, 0); + Tensor mK_nk = mK_nkl(_, _, 0); // (n,k) + Tensor mV_nk = mV_nkl(_, _, 0); + + auto gQ = local_tile( + mQ_mk, + TileShapeQK{}, + make_coord(blk_m_coord, _, _), + Step<_1, X, _1>{}); + auto gK = local_tile( + mK_nk, TileShapeQK{}, make_coord(_, _, _), Step{}); + auto gV = local_tile( + mV_nk, + TileShapeOutput{}, + make_coord(_, blk_n_coord, _), + Step{}); + + auto mainloop_params = CollectiveMainloop::get_updated_copies( + params.mainloop, + params.problem_shape, + sequence_length_shape, + batch_coord, + q_head_coord); + // we limit the horisontal size to two subgroup, the empirical resutls + // show that reading the two cacheline side by side in gives better + // performance and anything after that does not have an effect on + // performance. // (64 here for float b float when possible and loop over + // to cover all the data needed) + auto tiled_prefetch_q = cute::prefetch_selector< + Shape, Int>, + Num_SGs>(mainloop_params.gmem_tiled_copy_q); + auto tiled_prefetch_k = cute::prefetch_selector< + Shape, Int>, + Num_SGs>(mainloop_params.gmem_tiled_copy_k); + auto tiled_prefetch_v = cute::prefetch_selector< + Shape< + Int, + Int>, + Num_SGs>(mainloop_params.gmem_tiled_copy_v); + auto thr_prefetch_Q = tiled_prefetch_q.get_slice(thread_idx); + auto thr_prefetch_K = tiled_prefetch_k.get_slice(thread_idx); + auto thr_prefetch_V = tiled_prefetch_v.get_slice(thread_idx); + auto pQgQ = thr_prefetch_Q.partition_S(gQ); + auto pKgK = thr_prefetch_K.partition_S(gK); + auto pVgV = thr_prefetch_V.partition_S(gV); + + for (int i = 0; i < size<3>(pQgQ); i++) { + prefetch(tiled_prefetch_q, pQgQ(_, _, _, i)); + } + for (int j = 0; j < size<4>(pKgK); j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < DispatchPolicy::Stages; i++) { + prefetch(tiled_prefetch_k, pKgK(_, _, _, i, j)); } - - auto epilogue_params = - CollectiveEpilogue::template get_updated_copies( - params.epilogue, - params.problem_shape, - sequence_length_shape, - batch_coord, - q_head_coord); - CollectiveEpilogue epilogue{epilogue_params, shared_storage.epilogue}; - auto blk_coord_mnkl = - make_coord(blk_m_coord, blk_n_coord, batch_coord, 0); - epilogue( - params.problem_shape, - sequence_length_shape, - blk_coord_mnkl, - out_reg, - max_reg, - sum_reg, - q_head_coord, - softmax_scale); } - // seq_len_kv == seq_len_qo - else { - const int seq_coord = cute::min( - seq_len_qo, - (blk_m_coord * QK_BLK_M + (sub_group_id / PV_ATOM_N) * QK_SG_M) % - seq_len_qo); - - // Calculate the seq_len_idx (blk_m_coord * get<0>(TileShapeOutput{})) - // and check if it is still within bounds of the actual seq_len_qo - // (get<0>(sequence_length_shape)). - if (blk_m_coord * get<0>(TileShapeOutput{}) >= seq_len_qo) { - continue; + + // Allocate the tiled_mma and the accumulators for the (M,N) + // workgroup_shape + Tensor out_reg = make_tensor(AccumeShape{}); + + // There are 16 workitem and 16 max per subgroup, each worktime containt 1 + // max and cumulatively, they calculate the max per subgroup + ElementAccumulator max_reg{-INFINITY}; + + // The sum reg each contains a 2d tesnor for 8 x 2 This is number of + // sequence lenght process per subgroup + Tensor sum_reg = + make_tensor(Shape, Int>{}); + + clear(sum_reg); + clear(out_reg); + + // Perform the collective scoped MMA + CollectiveMainloop collective_mma; + // when causal mask is true. It is not possible to set the scope + // of the barrier to workgroup level as the number n block is + // different for each subgroup due to triangular nature of causal based + // operation + static constexpr int barrier_scope = CausalMask ? 3 : 2; + // MAIN LOOP: loop over K and V, perform fused attention + online softmax + for (int nblock = 0; nblock < nblock_limit - static_cast(CausalMask); + nblock++) { + barrier_arrive(barrier_scope); + // 1) Load K (performed inside mmaQK) + // 2) Create Tensor S + Tensor tSr = make_tensor( + Shape, Int, Int>{}); + clear(tSr); + + // 3) Perform GEMM S = Q*K + collective_mma.mmaQK( + tSr, + gQ, + gK(_, _, nblock, _), + tSr, + ceil_div(head_size_qk, QK_BLK_K), + mainloop_params); + + // we only need one block ahead, there is enough gap to prefetch it + // while doing softmax. because the gap between the two MMA is big, + // prefetching it the same way as cutlass K matrix does not make sense + for (int i = 0; i < size<1>(pVgV); i++) { + prefetch(tiled_prefetch_v, pVgV(_, i, _, nblock)); } - auto offset = cute::min(seq_len_qo, seq_len_kv); //(2048, 1024) - auto discard_seq_coord = seq_len_qo - offset; // 1024 - auto full_tile_offset = seq_len_kv - offset; // 0 + CollectiveSoftmaxEpilogue softmax(params.softmax); + softmax(nblock == 0, tSr, max_reg, sum_reg, out_reg); - const int seq_len = CausalMask ? full_tile_offset + - cute::min(seq_len_kv, seq_coord - discard_seq_coord) + QK_SG_M - : seq_len_kv; - const int nblock_limit = cute::ceil_div(seq_len, QK_BLK_N); - if (CausalMask && seq_coord < discard_seq_coord) { // 1024 =0 - continue; - } + collective_mma.template mmaPV( + out_reg, tSr, gV(_, _, nblock), out_reg, mainloop_params); - Tensor mQ_mkl = cute::get_xe_tensor( - make_shape(seq_len_qo, head_size_qk, 1)); //(m,k,l) - Tensor mK_nkl = cute::get_xe_tensor( - make_shape(seq_len_kv, head_size_qk, 1)); //(n,k,l) - Tensor mV_nkl = cute::get_xe_tensor( - make_shape(head_size_vo, seq_len_kv, 1)); //(n,k,l) - Tensor mQ_mk = mQ_mkl(_, _, 0); - Tensor mK_nk = mK_nkl(_, _, 0); // (n,k) - Tensor mV_nk = mV_nkl(_, _, 0); - - auto gQ = local_tile( - mQ_mk, - TileShapeQK{}, - make_coord(blk_m_coord, _, _), - Step<_1, X, _1>{}); - auto gK = local_tile( - mK_nk, TileShapeQK{}, make_coord(_, _, _), Step{}); - auto gV = local_tile( - mV_nk, - TileShapeOutput{}, - make_coord(_, blk_n_coord, _), - Step{}); - - auto mainloop_params = CollectiveMainloop::get_updated_copies( - params.mainloop, - params.problem_shape, - sequence_length_shape, - batch_coord, - q_head_coord); - // we limit the horisontal size to two subgroup, the empirical resutls - // show that reading the two cacheline side by side in gives better - // performance and anything after that does not have an effect on - // performance. // (64 here for float b float when possible and loop - // over to cover all the data needed) - auto tiled_prefetch_q = cute::prefetch_selector< - Shape, Int>, - Num_SGs>(mainloop_params.gmem_tiled_copy_q); - auto tiled_prefetch_k = cute::prefetch_selector< - Shape, Int>, - Num_SGs>(mainloop_params.gmem_tiled_copy_k); - auto tiled_prefetch_v = cute::prefetch_selector< - Shape< - Int, - Int>, - Num_SGs>(mainloop_params.gmem_tiled_copy_v); - auto thr_prefetch_Q = tiled_prefetch_q.get_slice(thread_idx); - auto thr_prefetch_K = tiled_prefetch_k.get_slice(thread_idx); - auto thr_prefetch_V = tiled_prefetch_v.get_slice(thread_idx); - auto pQgQ = thr_prefetch_Q.partition_S(gQ); - auto pKgK = thr_prefetch_K.partition_S(gK); - auto pVgV = thr_prefetch_V.partition_S(gV); - - for (int i = 0; i < size<3>(pQgQ); i++) { - prefetch(tiled_prefetch_q, pQgQ(_, _, _, i)); - } + // Prefetch the next K tile + // there is no need to gaurd it with if statememt as prefetch will + // ignore out of bound reading for (int j = 0; j < size<4>(pKgK); j++) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < DispatchPolicy::Stages; i++) { - prefetch(tiled_prefetch_k, pKgK(_, _, _, i, j)); - } + prefetch( + tiled_prefetch_k, + pKgK(_, _, _, nblock + DispatchPolicy::Stages, j)); } + barrier_wait(barrier_scope); + } - // Allocate the tiled_mma and the accumulators for the (M,N) - // workgroup_shape - Tensor out_reg = make_tensor(AccumeShape{}); - - // There are 16 workitem and 16 max per subgroup, each worktime containt - // 1 max and cumulatively, they calculate the max per subgroup - ElementAccumulator max_reg{-INFINITY}; - - // The sum reg each contains a 2d tesnor for 8 x 2 This is number of - // sequence lenght process per subgroup - Tensor sum_reg = - make_tensor(Shape, Int>{}); - - clear(sum_reg); - clear(out_reg); - - // Perform the collective scoped MMA - CollectiveMainloop collective_mma; - // when causal mask is true. It is not possible to set the scope - // of the barrier to workgroup level as the number n block is - // different for each subgroup due to triangular nature of causal based - // operation - static constexpr int barrier_scope = CausalMask ? 3 : 2; - // MAIN LOOP: loop over K and V, perform fused attention + online - // softmax - for (int nblock = 0; - nblock < nblock_limit - static_cast(CausalMask); - nblock++) { - barrier_arrive(barrier_scope); - // 1) Load K (performed inside mmaQK) - // 2) Create Tensor S - Tensor tSr = make_tensor( - Shape, Int, Int>{}); - clear(tSr); - - // 3) Perform GEMM S = Q*K - collective_mma.mmaQK( - tSr, - gQ, - gK(_, _, nblock, _), - tSr, - ceil_div(head_size_qk, QK_BLK_K), - mainloop_params); - - // we only need one block ahead, there is enough gap to prefetch it - // while doing softmax. because the gap between the two MMA is big, - // prefetching it the same way as cutlass K matrix does not make sense - for (int i = 0; i < size<1>(pVgV); i++) { - prefetch(tiled_prefetch_v, pVgV(_, i, _, nblock)); - } - - CollectiveSoftmaxEpilogue softmax(params.softmax); - softmax(nblock == 0, tSr, max_reg, sum_reg, out_reg); - - collective_mma.template mmaPV( - out_reg, tSr, gV(_, _, nblock), out_reg, mainloop_params); - - // Prefetch the next K tile - // there is no need to gaurd it with if statememt as prefetch will - // ignore out of bound reading - for (int j = 0; j < size<4>(pKgK); j++) { - prefetch( - tiled_prefetch_k, - pKgK(_, _, _, nblock + DispatchPolicy::Stages, j)); - } - barrier_wait(barrier_scope); + if constexpr (CausalMask) { + // BAND Matrix + // 1) Load K (performed inside mmaQK) + // 2) Create Tensor S + Tensor tSr = make_tensor( + Shape, Int, Int>{}); + clear(tSr); + // 3) Perform GEMM S = Q*K + collective_mma.mmaQK( + tSr, + gQ, + gK(_, _, nblock_limit - 1, _), + tSr, + ceil_div(head_size_qk, QK_BLK_K), + mainloop_params); + // we only need one block ahead, there is enough gap to prefetch it + // while doing softmax. because the gap between the two MMA is big, + // prefetching it the same way as cutlass K matrix does not make sense + for (int i = 0; i < size<1>(pVgV); i++) { + prefetch(tiled_prefetch_v, pVgV(_, i, _, nblock_limit - 1)); } - - if constexpr (CausalMask) { - // BAND Matrix - // 1) Load K (performed inside mmaQK) - // 2) Create Tensor S - Tensor tSr = make_tensor( - Shape, Int, Int>{}); - clear(tSr); - // 3) Perform GEMM S = Q*K - collective_mma.mmaQK( - tSr, - gQ, - gK(_, _, nblock_limit - 1, _), - tSr, - ceil_div(head_size_qk, QK_BLK_K), - mainloop_params); - // we only need one block ahead, there is enough gap to prefetch it - // while doing softmax. because the gap between the two MMA is big, - // prefetching it the same way as cutlass K matrix does not make sense - for (int i = 0; i < size<1>(pVgV); i++) { - prefetch(tiled_prefetch_v, pVgV(_, i, _, nblock_limit - 1)); - } - // mask the elements of each tile where j > i - const int item_id = thread_idx % SubgroupSize; - int col_idx = item_id + (nblock_limit - 1) * QK_BLK_N; + // mask the elements of each tile using the bottom right masking + const int item_id = thread_idx % SubgroupSize; + int col_idx = item_id + (nblock_limit - 1) * QK_BLK_N; + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < FragsN; + n++, col_idx += get<1>(MmaAtomShape())) { // 4 CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < FragsN; - n++, col_idx += get<1>(MmaAtomShape())) { // 4 + for (int m = 0; m < FragsM; m++) { // 2 + int row_idx = m * Vec + seq_coord; CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < FragsM; m++) { // 2 - int row_idx = m * Vec + seq_coord; - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < Vec; row++, row_idx++) { // 8 - if (col_idx - full_tile_offset > row_idx - discard_seq_coord) { - tSr(row, m, n) = ElementAccumulator{-INFINITY}; - } + for (int row = 0; row < Vec; row++, row_idx++) { // 8 + if (row_idx < first_non_masked_sequence || + col_idx > row_idx - first_non_masked_sequence) { + tSr(row, m, n) = ElementAccumulator{-INFINITY}; } } } - - CollectiveSoftmaxEpilogue softmax(params.softmax); - softmax((nblock_limit - 1) == 0, tSr, max_reg, sum_reg, out_reg); - - collective_mma.template mmaPV( - out_reg, - tSr, - gV(_, _, nblock_limit - 1), - out_reg, - mainloop_params); } - auto epilogue_params = - CollectiveEpilogue::template get_updated_copies( - params.epilogue, - params.problem_shape, - sequence_length_shape, - batch_coord, - q_head_coord); - CollectiveEpilogue epilogue{epilogue_params, shared_storage.epilogue}; - auto blk_coord_mnkl = - make_coord(blk_m_coord, blk_n_coord, batch_coord, 0); - epilogue( - params.problem_shape, - sequence_length_shape, - blk_coord_mnkl, - // out_reg, max_reg, sum_reg, q_head_coord, 0.125); - out_reg, - max_reg, - sum_reg, - q_head_coord, - softmax_scale); + CollectiveSoftmaxEpilogue softmax(params.softmax); + softmax((nblock_limit - 1) == 0, tSr, max_reg, sum_reg, out_reg); + + collective_mma.template mmaPV( + out_reg, tSr, gV(_, _, nblock_limit - 1), out_reg, mainloop_params); } + auto epilogue_params = + CollectiveEpilogue::template get_updated_copies( + params.epilogue, + params.problem_shape, + sequence_length_shape, + batch_coord, + q_head_coord); + CollectiveEpilogue epilogue{epilogue_params, shared_storage.epilogue}; + auto blk_coord_mnkl = + make_coord(blk_m_coord, blk_n_coord, batch_coord, 0); + epilogue( + params.problem_shape, + sequence_length_shape, + blk_coord_mnkl, + out_reg, + max_reg, + sum_reg, + q_head_coord, + softmax_scale); } } }; From 991ee97c46dc9787a2d716cd77e965e05edbc3ee Mon Sep 17 00:00:00 2001 From: "fengqing.lu" Date: Thu, 13 Nov 2025 23:56:30 -0800 Subject: [PATCH 5/7] fix CI build error --- cmake/SYCLTLA.cmake | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cmake/SYCLTLA.cmake b/cmake/SYCLTLA.cmake index 3d44d82923..d70073bdc9 100644 --- a/cmake/SYCLTLA.cmake +++ b/cmake/SYCLTLA.cmake @@ -3,6 +3,8 @@ macro(replace_cmake_build_flags) set(CMAKE_CXX_FLAGS_BK "${CMAKE_CXX_FLAGS}") string(REPLACE "-Werror=format" "" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") string(REPLACE "-Werror=format" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") + string(REPLACE "-Werror=unused-variable" "" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") + string(REPLACE "-Werror=unused-variable" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") endmacro() macro(restore_cmake_build_flags) From 89c6a49e712a5d79b365b950bd38ad0d8cd1aa5a Mon Sep 17 00:00:00 2001 From: "fengqing.lu" Date: Tue, 18 Nov 2025 00:27:20 -0800 Subject: [PATCH 6/7] rebase to latest --- .../sycltla/kernel/xe_sdpa_fwd_bshd.h | 34 +++++++++++++++++-- .../xpu/flash_attn/sycltla/mha_bwd.cpp | 16 ++++----- 2 files changed, 40 insertions(+), 10 deletions(-) diff --git a/src/ATen/native/transformers/xpu/flash_attn/sycltla/kernel/xe_sdpa_fwd_bshd.h b/src/ATen/native/transformers/xpu/flash_attn/sycltla/kernel/xe_sdpa_fwd_bshd.h index db08bd29e8..a1f86f9af1 100644 --- a/src/ATen/native/transformers/xpu/flash_attn/sycltla/kernel/xe_sdpa_fwd_bshd.h +++ b/src/ATen/native/transformers/xpu/flash_attn/sycltla/kernel/xe_sdpa_fwd_bshd.h @@ -164,6 +164,7 @@ class FMHAPrefill { Arguments const& args, void* workspace) { (void)workspace; + return { args.mode, args.problem_shape, @@ -438,6 +439,29 @@ class FMHAPrefill { prefetch(tiled_prefetch_v, pVgV(_, i, _, nblock)); } + // Prevnt numerical errors when seq_len_kv is not fully divisible by + // QK_BLK_N + const int item_id = thread_idx % SubgroupSize; + if (seq_len_kv % QK_BLK_N != 0) { + int col_idx = item_id + nblock * QK_BLK_N; + int remainder = seq_len_kv % QK_BLK_N; + int cutoff = (seq_len_kv / QK_BLK_N) * QK_BLK_N + remainder; + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < FragsN; n++, col_idx += get<1>(MmaAtomShape())) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < FragsM; m++) { + int row_idx = m * Vec + seq_coord; + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < Vec; row++, row_idx++) { + if (col_idx >= cutoff) { + tSr(row, m, n) = ElementAccumulator{-INFINITY}; + } + } + } + } + } + CollectiveSoftmaxEpilogue softmax(params.softmax); softmax(nblock == 0, tSr, max_reg, sum_reg, out_reg); @@ -479,6 +503,8 @@ class FMHAPrefill { // mask the elements of each tile using the bottom right masking const int item_id = thread_idx % SubgroupSize; int col_idx = item_id + (nblock_limit - 1) * QK_BLK_N; + int remainder = seq_len_kv % QK_BLK_N; + int cutoff = (seq_len_kv / QK_BLK_N) * QK_BLK_N + remainder; CUTLASS_PRAGMA_UNROLL for (int n = 0; n < FragsN; n++, col_idx += get<1>(MmaAtomShape())) { // 4 @@ -487,8 +513,12 @@ class FMHAPrefill { int row_idx = m * Vec + seq_coord; CUTLASS_PRAGMA_UNROLL for (int row = 0; row < Vec; row++, row_idx++) { // 8 - if (row_idx < first_non_masked_sequence || - col_idx > row_idx - first_non_masked_sequence) { + if (row_idx < first_non_masked_sequence || // for the sequence + // that is fully masked + col_idx > row_idx - + first_non_masked_sequence || // for the bottom right + // triangular masking + col_idx >= cutoff) { // for seq_len_kv not fully divisible tSr(row, m, n) = ElementAccumulator{-INFINITY}; } } diff --git a/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp b/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp index 867e3520a5..c8fe118466 100644 --- a/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp +++ b/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp @@ -22,7 +22,7 @@ void compute_o_dot_do( const int bidh) { // The thread index. constexpr int kBlockM = T::kBlockM; - // constexpr int kBlockN = T::kBlockN; + constexpr int kBlockN = T::kBlockN; constexpr int kHeadDim = T::kHeadDim; constexpr int kNSGs = T::kNSGs; constexpr int SubgroupSize = T::SubgroupSize; @@ -31,8 +31,8 @@ void compute_o_dot_do( auto sg = compat::get_nd_item<1>().get_sub_group(); auto group = compat::get_nd_item<1>().get_group(); - // auto first_thread_in_sg_idx = sg.get_group_linear_id() * - // trait.SubgroupSize; + auto first_thread_in_sg_idx = sg.get_group_linear_id() * trait.SubgroupSize; + auto bofst = Boffset(param); const index_t o_offset = bofst.o_offset(bidb, bidh, m_block * kBlockM); @@ -209,7 +209,7 @@ CUTLASS_DEVICE void apply_mask_causal( auto sg = compat::get_nd_item<1>().get_sub_group(); auto group = compat::get_nd_item<1>().get_group(); int sg_local_id = sg.get_local_id(); - // int sg_group_id = sg.get_group_id(); + int sg_group_id = sg.get_group_id(); Tensor rC_2d = make_tensor(rC.data(), convert_layout_2d_layout(rC.layout())); CUTLASS_PRAGMA_UNROLL for (int n = 0; n < size<1>(tensor); ++n) { @@ -371,8 +371,8 @@ void dq_dk_dv_1colblock( constexpr int kBlockM = Trait::kBlockM; constexpr int kBlockN = Trait::kBlockN; constexpr bool is_causal = Trait::is_causal; - // constexpr int kNSGs = Trait::kNSGs; - // constexpr int SubgroupSize = Trait::SubgroupSize; + constexpr int kNSGs = Trait::kNSGs; + constexpr int SubgroupSize = Trait::SubgroupSize; auto sg = compat::get_nd_item<1>().get_sub_group(); auto group = compat::get_nd_item<1>().get_group(); auto first_thread_in_sg_idx = sg.get_group_linear_id() * trait.SubgroupSize; @@ -675,7 +675,7 @@ void dq_dk_dv_1colblock( const int max_m_block = ceil_div(param.seq_len_q, kBlockM); const int tail_m = param.seq_len_q % kBlockM; - // cutlass::NumericConverter converter; + cutlass::NumericConverter converter; // clear accumulator clear(tdVrdV); @@ -880,7 +880,7 @@ void convert_dq( int bidb, int bidh) { constexpr int kBlockM = T::kBlockM; - // constexpr int kBlockN = T::kBlockN; + constexpr int kBlockN = T::kBlockN; constexpr int kHeadDim = T::kHeadDim; using DType = typename T::DType; using VType = typename T::VType; From b61325e02122d5c030bfaf3935e271cb70d445c7 Mon Sep 17 00:00:00 2001 From: "fengqing.lu" Date: Thu, 20 Nov 2025 02:29:39 -0800 Subject: [PATCH 7/7] pad input tensors if headdim is not multiple of 64 --- .../xpu/flash_attn/sycltla/mha_bwd.cpp | 16 ++++------- .../xpu/flash_attn/sycltla/mha_fwd.cpp | 28 +++++++++++++++---- .../transformers/xpu/flash_attn/utils.h | 2 +- 3 files changed, 29 insertions(+), 17 deletions(-) diff --git a/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp b/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp index c8fe118466..c818f86bb2 100644 --- a/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp +++ b/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp @@ -1406,19 +1406,15 @@ std::tuple flash_attention_backward_sycltla( to_string(layout), ", value with layout ", to_string(get_attn_tensor_layout(value))); - layout = fuse_attn_tensor_layout(layout, get_attn_tensor_layout(out)); - TORCH_CHECK( - ATTN_TENSOR_LAYOUT::UNSUPPORTED != layout, - "FlashAttentionBackwardXPU: query and out must have the same layout, got query with layout ", - to_string(layout), - ", out with layout ", - to_string(get_attn_tensor_layout(out))); if (layout == ATTN_TENSOR_LAYOUT::BXD) { layout = ATTN_TENSOR_LAYOUT::BHSD; } TORCH_CHECK(logsumexp.is_contiguous(), "logsumexp must have BHS layout"); // grad_out is created by autograd, may not have standard layout - auto contiguous_grad_out = attn_tensor_to_layout(grad_out, layout); + auto grad_out_ = attn_tensor_to_layout(grad_out, layout); + // TODO: This code block is temporary WA. Remove it after fwd supporting BHSD + // layouts + auto out_ = attn_tensor_to_layout(out, layout); auto sycl_queue = at::xpu::getCurrentXPUStream().queue(); auto device_architecture = @@ -1493,8 +1489,8 @@ std::tuple flash_attention_backward_sycltla( cute::run_mha_bwd( sycl_queue, problem_shape, - contiguous_grad_out.data_ptr(), - out.data_ptr(), + grad_out_.data_ptr(), + out_.data_ptr(), query.data_ptr(), key.data_ptr(), value.data_ptr(), diff --git a/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.cpp b/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.cpp index 6bd76660c7..4ff7294ede 100644 --- a/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.cpp +++ b/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.cpp @@ -451,23 +451,39 @@ flash_attention_forward_sycltla( layout = fuse_attn_tensor_layout(layout, get_attn_tensor_layout(key)); TORCH_CHECK( ATTN_TENSOR_LAYOUT::UNSUPPORTED != layout, - "FlashAttentionBackwardXPU: query and key must have the same layout, got query with layout ", + "FlashAttentionForwardXPU: query and key must have the same layout, got query with layout ", to_string(layout), ", key with layout ", to_string(get_attn_tensor_layout(key))); layout = fuse_attn_tensor_layout(layout, get_attn_tensor_layout(value)); TORCH_CHECK( ATTN_TENSOR_LAYOUT::UNSUPPORTED != layout, - "FlashAttentionBackwardXPU: query and value must have the same layout, got query with layout ", + "FlashAttentionForwardXPU: query and value must have the same layout, got query with layout ", to_string(layout), ", value with layout ", to_string(get_attn_tensor_layout(value))); if (layout == ATTN_TENSOR_LAYOUT::BXD) { layout = ATTN_TENSOR_LAYOUT::BSHD; } + + at::Tensor query_ = query, key_ = key, value_ = value; + { + // Currently fwd only supports BSHD layout. + // However, input headdim may be padded when headdim is not multiple of 64. + // The pad op will make input tensor become BHSD contiguous. + // TODO: This code block is temporary WA. Remove it after supporting BHSD + // layouts. + if (layout != ATTN_TENSOR_LAYOUT::BSHD) { + query_ = attn_tensor_to_layout(query, ATTN_TENSOR_LAYOUT::BSHD); + key_ = attn_tensor_to_layout(key, ATTN_TENSOR_LAYOUT::BSHD); + value_ = attn_tensor_to_layout(value, ATTN_TENSOR_LAYOUT::BSHD); + layout = ATTN_TENSOR_LAYOUT::BSHD; + } + } + TORCH_CHECK( layout == ATTN_TENSOR_LAYOUT::BSHD, - "FlashAttentionBackwardXPU: currently only support BSHD layout"); + "FlashAttentionForwardXPU: currently only support BSHD layout"); auto opts = query.options(); at::Tensor out; @@ -516,9 +532,9 @@ flash_attention_forward_sycltla( cute::run_mha_fwd( sycl_queue, problem_shape, - query.data_ptr(), - key.data_ptr(), - value.data_ptr(), + query_.data_ptr(), + key_.data_ptr(), + value_.data_ptr(), out.data_ptr(), logsumexp.data_ptr(), is_causal, diff --git a/src/ATen/native/transformers/xpu/flash_attn/utils.h b/src/ATen/native/transformers/xpu/flash_attn/utils.h index 4a42ffad8f..dd174aa0f0 100644 --- a/src/ATen/native/transformers/xpu/flash_attn/utils.h +++ b/src/ATen/native/transformers/xpu/flash_attn/utils.h @@ -94,7 +94,7 @@ inline at::Tensor attn_tensor_to_layout( return output; } -inline bool check_flash_attention_bshd_layout( +inline bool check_flash_attention_layout( sdp::sdp_params const& params, bool debug) { sycltla::ATTN_TENSOR_LAYOUT layout =