From efb04327a7bcf5c88bb939835632de6e123e3667 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Date: Wed, 13 Nov 2024 10:22:45 -0800 Subject: [PATCH 01/15] corrected types for strides in triton FA (#274) (#276) Co-authored-by: Aleksandr Malyshev (cherry picked from commit 9a46e97c1e63cbb5223a10a86705063b00e55576) --- vllm/attention/backends/rocm_flash_attn.py | 3 +- vllm/attention/ops/triton_flash_attention.py | 40 ++++++++++---------- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 7d2d87176800..e5df445d8449 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -619,7 +619,8 @@ def forward( # QKV for prefill. query = query[:num_prefill_tokens] - if key is not None and value is not None: + if key is not None and value is not None \ + and attn_type != AttentionType.ENCODER_DECODER: key = key[:num_prefill_tokens] value = value[:num_prefill_tokens] diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index f94211116a74..2019ed184e5a 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -314,26 +314,26 @@ def attn_fwd( sm_scale, L, Out, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vk, - stride_vn, - stride_oz, - stride_oh, - stride_om, - stride_on, - stride_bz, - stride_bh, - stride_bm, - stride_bn, + stride_qz: tl.int64, + stride_qh: tl.int64, + stride_qm: tl.int64, + stride_qk: tl.int64, + stride_kz: tl.int64, + stride_kh: tl.int64, + stride_kn: tl.int64, + stride_kk: tl.int64, + stride_vz: tl.int64, + stride_vh: tl.int64, + stride_vk: tl.int64, + stride_vn: tl.int64, + stride_oz: tl.int64, + stride_oh: tl.int64, + stride_om: tl.int64, + stride_on: tl.int64, + stride_bz: tl.int64, + stride_bh: tl.int64, + stride_bm: tl.int64, + stride_bn: tl.int64, cu_seqlens_q, cu_seqlens_k, dropout_p, From d291770df08a29e14b616d9ce1538b00ba09a432 Mon Sep 17 00:00:00 2001 From: dhonnappa-amd Date: Wed, 4 Dec 2024 15:46:32 -0600 Subject: [PATCH 02/15] Update test-template.j2 (#283) Adding build only k8s node and queue names update --- .buildkite/test-template.j2 | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 index e7b24268ba39..ce448836a827 100644 --- a/.buildkite/test-template.j2 +++ b/.buildkite/test-template.j2 @@ -19,7 +19,7 @@ steps: - exit_status: -10 # Agent was lost limit: 5 agents: - queue: amd + queue: amd-cpu {% for step in steps %} {% if step.mirror_hardwares and "amd" in step.mirror_hardwares %} @@ -27,7 +27,7 @@ steps: depends_on: - "amd-build" agents: - queue: amd + queue: amd_gpu commands: - bash .buildkite/run-amd-test.sh "cd {{ (step.working_dir or default_working_dir) | safe }} ; {{ step.command or (step.commands | join(" && ")) | safe }}" env: From 679a15cbdb64d1806286ecc0dd317768c58cdbf9 Mon Sep 17 00:00:00 2001 From: Hosang <156028780+hyoon1@users.noreply.github.com> Date: Mon, 9 Dec 2024 12:30:40 -0500 Subject: [PATCH 03/15] Fix max_seqlens_q/k initialization for Navi GPUs (#310) - max_seqlens_q/k variables were not correctly initialized for Navi GPUs leading to incorrect outputs. - ensure that the correct values are passed to the attn_fwd kernel based on the GPU type. --- vllm/attention/ops/triton_flash_attention.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index a49df831b46e..3671c2f91e3b 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -912,9 +912,8 @@ def check_and_convert(t, scale): p_descale = 1.0 / p_scale o_descale = 1.0 / o_scale - if is_navi(): - max_seqlens_q = 0 - max_seqlens_k = 0 + arg_max_seqlens_q = 0 if is_navi() else max_seqlens_q + arg_max_seqlens_k = 0 if is_navi() else max_seqlens_k attn_fwd[grid]( q, @@ -944,8 +943,8 @@ def check_and_convert(t, scale): HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, - MAX_SEQLENS_Q=max_seqlens_q, - MAX_SEQLENS_K=max_seqlens_k, + MAX_SEQLENS_Q=arg_max_seqlens_q, + MAX_SEQLENS_K=arg_max_seqlens_k, IS_CAUSAL=causal, VARLEN=True, BLOCK_DMODEL=padded_d_model, From 22f9066285861cc7cdb49d5caad995582ae3cd36 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Mon, 9 Dec 2024 17:22:28 -0500 Subject: [PATCH 04/15] Setting the value for the scpecilative decoding worker class on rocm platform (#313) Signed-off-by: Gregory Shtrasberg --- vllm/platforms/rocm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 00efa056f7ef..d2f7cd40e25b 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -150,6 +150,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: elif vllm_config.speculative_config: parallel_config.worker_cls = \ "vllm.spec_decode.spec_decode_worker.create_spec_worker" + parallel_config.sd_worker_cls = \ + "vllm.worker.worker.Worker" else: parallel_config.worker_cls = "vllm.worker.worker.Worker" From f57aa621f45b4868374c30f0e270f91be5befa16 Mon Sep 17 00:00:00 2001 From: sanyalington Date: Tue, 24 Dec 2024 15:34:00 +0000 Subject: [PATCH 05/15] fp8 cpa --- .../kernels/benchmark_paged_attention.py | 8 +- csrc/rocm/attention.cu | 1265 ++++++++++++++++- 2 files changed, 1220 insertions(+), 53 deletions(-) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 483584dd804e..5c4643dfe9b4 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -10,7 +10,7 @@ create_kv_caches_with_random) NUM_BLOCKS = 1024 * 1024 -PARTITION_SIZE = 512 +PARTITION_SIZE = 256 @torch.inference_mode() @@ -161,6 +161,8 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: kv_cache_dtype, k_scale, v_scale, + None, + PARTITION_SIZE ) else: raise ValueError(f"Invalid version: {version}") @@ -174,13 +176,13 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: # Warmup. print("Warming up...") run_benchmark = run_cuda_benchmark - run_benchmark(num_iters=3, profile=False) + run_benchmark(num_iters=500, profile=False) # Benchmark. if do_profile: latency = run_benchmark(num_iters=1, profile=True) else: - latency = run_benchmark(num_iters=1000, profile=False) + latency = run_benchmark(num_iters=10000, profile=False) print(f"Kernel running time: {latency * 1000000:.3f} us") diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index efda714f53c6..5a7d048018f1 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -29,6 +29,11 @@ #define __HIP__MI300_MI250__ #endif +#if defined(__HIPCC__) && (defined(__gfx940__) || \ + defined(__gfx941__) || defined(__gfx942__)) + #define __HIP__MI300__ +#endif + #if defined(NDEBUG) #undef NDEBUG #include @@ -51,6 +56,9 @@ using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; using float16x4 = __attribute__((__vector_size__(4 * sizeof(_Float16)))) _Float16; typedef float16x4 _Half4; +using float16x2 = + __attribute__((__vector_size__(2 * sizeof(_Float16)))) _Float16; +typedef float16x2 _Half2; typedef struct _Half8 { _Half4 xy[2]; } _Half8; @@ -63,8 +71,13 @@ typedef struct _B16x8 { } _B16x8; using _B8x8 = uint2; +using _B8x4 = int32_t; //used in builtins using bit8_t = uint8_t; +typedef struct _B8x16 { + _B8x8 xy[2]; +} _B8x16; + ////// Non temporal load stores /////// template @@ -92,6 +105,21 @@ __device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA, } } +template +__device__ __forceinline__ floatx4 gcn_mfma16x16x16_instr(const _B16x4& inpA, + const _B16x4& inpB, + const floatx4& inpC) { + if constexpr (std::is_same::value) { + return __builtin_amdgcn_mfma_f32_16x16x16f16(inpA, inpB, inpC, absz, cbid, + blgp); + } else if constexpr (std::is_same::value) { + return __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(inpA, inpB, inpC, absz, cbid, + blgp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + template __device__ __forceinline__ float to_float(const T& inp) { if constexpr (std::is_same::value) { @@ -139,23 +167,49 @@ __device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) { __hip_bfloat16 b; } t16; _B16x4 ret; +#if 0 + #pragma unroll + for (int i = 0; i < 4; i++) { + t16.f = (_Float16)inp[i]; + ret[i] = t16.u; + } + return ret; +#else if constexpr (std::is_same::value) { +#if 0 #pragma unroll for (int i = 0; i < 4; i++) { t16.f = (_Float16)inp[i]; ret[i] = t16.u; } return ret; +#else + union h2cvt { + __half2 h2[2]; + _B16x4 b16x4; + } u; + u.h2[0] = __float22half2_rn(make_float2(inp[0],inp[1])); + u.h2[1] = __float22half2_rn(make_float2(inp[2],inp[3])); + return u.b16x4; +#endif } else if constexpr (std::is_same::value) { #pragma unroll for (int i = 0; i < 4; i++) { - t16.b = __float2bfloat16(inp[i]); - ret[i] = t16.u; + union fcvt { + uint32_t u32; + float f32; + } u; + u.f32 = inp[i]; + u.u32 += 0x7fff + ((u.u32 >> 16) & 1); //RNE with no nan/inf check + ret[i] = uint16_t(u.u32 >> 16); + //t16.b = __float2bfloat16(inp[i]); + //ret[i] = t16.u; } return ret; } else { static_assert(false, "unsupported 16b dtype"); } +#endif } template @@ -167,7 +221,7 @@ __device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, __hip_bfloat16 b; } t1, t2, res; _B16x4 ret; - if constexpr (std::is_same::value) { +#if 0 #pragma unroll for (int i = 0; i < 4; i++) { t1.u = inp1[i]; @@ -176,18 +230,49 @@ __device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, ret[i] = res.u; } return ret; - } else if constexpr (std::is_same::value) { +#else + if constexpr (std::is_same::value) { +#if 0 #pragma unroll for (int i = 0; i < 4; i++) { t1.u = inp1[i]; t2.u = inp2[i]; - res.b = t1.b + t2.b; + res.f = t1.f + t2.f; ret[i] = res.u; } return ret; +#else + union h2cvt { + _B16x4 b16x4; + __half2 h2[2]; + } u1,u2,s; + u1.b16x4 = inp1; + u2.b16x4 = inp2; + s.h2[0] = u1.h2[0] + u2.h2[0]; + s.h2[1] = u1.h2[1] + u2.h2[1]; + return s.b16x4; +#endif + } else if constexpr (std::is_same::value) { + #pragma unroll + for (int i = 0; i < 4; i++) { + union fcvt { + float f32; + uint32_t i32; + } u1,u2,s; + u1.i32 = uint32_t(inp1[i])<<16; + u2.i32 = uint32_t(inp2[i])<<16; + s.f32 = u1.f32 + u2.f32; + ret[i] = uint16_t(s.i32>>16); + //t1.u = inp1[i]; + //t2.u = inp2[i]; + //res.b = t1.b + t2.b; + //ret[i] = res.u; + } + return ret; } else { static_assert(false, "unsupported 16b dtype"); } +#endif } template @@ -210,8 +295,737 @@ __device__ __forceinline__ _B16x8 scaled_convert_b8x8(const _B8x8 input, } } +template +__device__ __forceinline__ _B16x8 scaled_convert_b8x8_custom(const _B8x8 input, + const float scale) { + union { + floatx4 f32x4[2]; + vllm::Float8_ f32x8; + } tmpf8; + tmpf8.f32x8 = vllm::fp8::vec_conversion(*reinterpret_cast(&input)); + + tmpf8.f32x4[0] *= scale; + tmpf8.f32x4[1] *= scale; + + _B16x8 ret; + ret.xy[0] = from_floatx4(tmpf8.f32x4[0]); + ret.xy[1] = from_floatx4(tmpf8.f32x4[1]); + return ret; +} + +__device__ __forceinline__ floatx4 to_float_fp8x4(const _B8x4& inp) { + const auto f0 = __builtin_amdgcn_cvt_pk_f32_fp8(inp, false); + const auto f1 = __builtin_amdgcn_cvt_pk_f32_fp8(inp, true); + floatx4 ret; + ret[0] = f0[0]; + ret[1] = f0[1]; + ret[2] = f1[0]; + ret[3] = f1[1]; + return ret; +} + +template +__device__ __forceinline__ _B16x4 from_floatx4_rtz(const floatx4& inp) { + _B16x4 ret; + if constexpr (std::is_same::value) { + union h2cvt { + _Half2 h2[2]; + _B16x4 b16x4; + } u; + u.h2[0] = __builtin_amdgcn_cvt_pkrtz(inp[0],inp[1]); + u.h2[1] = __builtin_amdgcn_cvt_pkrtz(inp[2],inp[3]); + return u.b16x4; + } else if constexpr (std::is_same::value) { + for (int i = 0; i < 4; i++) { + union fcvt { + uint32_t i32; + float f32; + } u; + u.f32 = inp[i]; + ret[i] = uint16_t(u.i32 >> 16); + } + return ret; + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ _B16x4 from_floatx4_trunc(const floatx4& inp) { + _B16x4 ret; + if constexpr (std::is_same::value) { + int32_t tmpf8; + tmpf8 = __builtin_amdgcn_cvt_pk_fp8_f32(inp[0], inp[1], tmpf8, false); + tmpf8 = __builtin_amdgcn_cvt_pk_fp8_f32(inp[2], inp[3], tmpf8, true); + const auto f0 = __builtin_amdgcn_cvt_pk_f32_fp8(tmpf8, false); + const auto f1 = __builtin_amdgcn_cvt_pk_f32_fp8(tmpf8, true); + union h2cvt { + _Half2 h2[2]; + _B16x4 b16x4; + } u; + u.h2[0] = __builtin_amdgcn_cvt_pkrtz(f0[0],f0[1]); + u.h2[1] = __builtin_amdgcn_cvt_pkrtz(f1[0],f1[1]); + return u.b16x4; + } else if constexpr (std::is_same::value) { + int32_t tmpf8; + tmpf8 = __builtin_amdgcn_cvt_pk_fp8_f32(inp[0], inp[1], tmpf8, false); + tmpf8 = __builtin_amdgcn_cvt_pk_fp8_f32(inp[2], inp[3], tmpf8, true); + const auto f0 = __builtin_amdgcn_cvt_pk_f32_fp8(tmpf8, false); + const auto f1 = __builtin_amdgcn_cvt_pk_f32_fp8(tmpf8, true); + floatx4 tmpf; + tmpf[0] = f0[0]; + tmpf[1] = f0[1]; + tmpf[2] = f1[0]; + tmpf[3] = f1[1]; + for (int i = 0; i < 4; i++) { + union fcvt { + uint32_t i32; + float f32; + } u; + u.f32 = tmpf[i]; + ret[i] = uint16_t(u.i32 >> 16); + } + return ret; + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + + + +template +__device__ __forceinline__ _B16x8 convert_b8x8_custom(const _B8x8 input) { +#if 0 + union { + floatx4 f32x4[2]; + vllm::Float8_ f32x8; + _B8x8 b8x8[2]; + } tmpf8; + tmpf8.f32x8 = vllm::fp8::vec_conversion(*reinterpret_cast(&input)); + //tmpf8.b8x8[0] = input; + //tmpf8.b8x8[1] = input; +#endif + union { + _B8x8 b8x8; + _B8x4 b8x4[2]; + } tmp; + tmp.b8x8 = input; + _B16x8 ret; + for (int i=0; i<2; i++) { + ret.xy[i] = from_floatx4_rtz( to_float_fp8x4(tmp.b8x4[i]) ); + } + //ret.xy[0] = from_floatx4(tmpf8.f32x4[0]); + //ret.xy[1] = from_floatx4(tmpf8.f32x4[1]); + return ret; +} /////////////////////////////////////// +// grid (num_seqs, num_partitions,num_heads/gqa_ratio) +// block (partition size) +template +__global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma16_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, float k_scale, float v_scale, + const float* __restrict__ fp8_out_scale_ptr) { + constexpr int NWARPS = NUM_THREADS / WARP_SIZE; + const int warpid = threadIdx.x / WARP_SIZE; + const int laneid = threadIdx.x % WARP_SIZE; + const int lane4id = laneid % 4; + const int lane16id = laneid % 16; + const int rowid = laneid / 16; + + const int seq_idx = blockIdx.x; + const int partition_idx = blockIdx.y; + + constexpr int T_PAR_SIZE = 256; //partition size set to 256 TODO move to template param + //const int partition_size = 256; //blockDim.x; //TODO this could be head_size or partition_size + + const int max_num_partitions = gridDim.y; + + const int context_len = context_lens[seq_idx]; + + const int partition_start_token_idx = partition_idx * T_PAR_SIZE; //partition_size; + // exit if partition is out of context for seq + if (partition_start_token_idx >= context_len) { + return; + } + + constexpr int GQA_RATIO4 = DIVIDE_ROUND_UP(GQA_RATIO,4); + + __shared__ float shared_qk_max[NWARPS][16 + 1]; + __shared__ float shared_exp_sum[NWARPS][16 + 1]; + //shared_logits is used for multiple purposes + //__shared__ _B16x4 shared_logits[NWARPS][4][16][4 + 1]; + __shared__ _B16x4 shared_logits[NWARPS][4][16][4]; + + //for QK mfma16x16, layout is QHead/Tokenx16 across every 16 lanes, 16 Bytes HeadElements in each lane, 4x16B HeadElements across 4 rows of warp + constexpr int ROWS_PER_WARP = WARP_SIZE / 16; //rows refers to 16 lanes; refer dpp terminology + constexpr int CONTIGUOUS_KV_ELEMS_16B_LOAD = 16 / sizeof(cache_t); //8 for 16 bit cache type, 16 for 8 bit types + constexpr int QKHE_PER_FETCH = CONTIGUOUS_KV_ELEMS_16B_LOAD * ROWS_PER_WARP; //each fetch across a warp fetches these many elements + constexpr int QK_SIZE_RATIO = sizeof(scalar_t) / sizeof(cache_t); //1 for 16bit types, 2 for 8bit types + constexpr int QKHELOOP = HEAD_SIZE / QKHE_PER_FETCH; //4xQKHE_16B across warp + + _B16x8 Qlocal[QKHELOOP][QK_SIZE_RATIO]; //note that 16 contiguous elements of Q should be fetched per lane for 8 bit cache types : QK_SIZE_RATIO changes for this + + constexpr int CONTIGUOUS_SCALAR_ELEMS_16B = 16 / sizeof(scalar_t); + //constexpr int x = CONTIGUOUS_SCALAR_ELEMS_16B; //x is defined by vLLM as 16Bytes + + //constexpr int TLOOP1 = CONTIGUOUS_KV_ELEMS_16B_LOAD / 4; //mfma16x16x16 outputs 4 elements per lane: will be moved to match layout for V dwordx4 loads + //constexpr int TOKENS_PER_WARP1 = 16 * TLOOP1; //16 tokens across lanes * TLOOP factor + //constexpr int T_PAR_LOOP = T_PAR_SIZE / TOKENS_PER_WARP1 / NWARPS; + constexpr int TOKENS_PER_WARP = T_PAR_SIZE / NWARPS; //sub partition of tokens per warp for qk calculation + constexpr int TLOOP = TOKENS_PER_WARP / 16; //each mfma16x16x16 instruction processes 16 tokens + + _B16x8 Klocal[TLOOP][QKHELOOP]; //this could be B8x16 too + + const int wg_start_head_idx = blockIdx.z * GQA_RATIO; + const int wg_start_kv_head_idx = blockIdx.z; + const int total_num_heads = gridDim.z * GQA_RATIO; + + //for QK mfma, tokens in multiples of TOKENS_PER_WARP are spread across warps + //each mfma takes QH16xT16x16HE across warp + //repeat mfmas across QKHELOOP dimension + //output layout from QKmfma : QH16xT4x4 16 qheads across 16 lanes, 16 tokens across 4 rowsx4 tokens per lane + + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int last_ctx_block = num_context_blocks - 1; + + const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq; + + int kphysical_block_number[TLOOP]; + + //fetch k physical block numbers + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int klocal_token_idx = TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; + const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; + const int kblock_idx = (kglobal_token_idx < context_len) + ? kglobal_token_idx / BLOCK_SIZE + : last_ctx_block; + kphysical_block_number[token_depth] = block_table_seq[kblock_idx]; + } + +#if 0 //fetch Q into registers + + const int local_qhead_idx = lane16id % GQA_RATIO; + const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; + const int64_t seq_idx64 = static_cast(seq_idx); + const scalar_t* q_ptr = q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE + rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD; + + if (lane16id < GQA_RATIO) { + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + const scalar_t* q_ptr2 = q_ptr + qkhe_depth * QKHE_PER_FETCH; + for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { + const scalar_t* q_fetch_ptr = q_ptr2 + qkratio * CONTIGUOUS_SCALAR_ELEMS_16B; + const _B16x8* q_fetch_ptr_16B = reinterpret_cast(q_fetch_ptr); + Qlocal[qkhe_depth][qkratio] = *q_fetch_ptr_16B; + } + } + } else { + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { + Qlocal[qkhe_depth][qkratio].xy[0] = {0}; + Qlocal[qkhe_depth][qkratio].xy[1] = {0}; + } + } + } +#else //fetch Q in shared + const int local_qhead_idx = 4 * warpid + rowid; + const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; + const int64_t seq_idx64 = static_cast(seq_idx); + const scalar_t* q_ptr = q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE; //+ rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD; + + if (local_qhead_idx < GQA_RATIO) { + const scalar_t* q_fetch_ptr = q_ptr + lane16id * CONTIGUOUS_SCALAR_ELEMS_16B; //this works for head size 128 : 16 lanes x 8 elems = 128 elems + const _B16x8* q_fetch_ptr_16B = reinterpret_cast(q_fetch_ptr); + _B16x8 tmp = *q_fetch_ptr_16B; + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + const int offset1 = lane16id/4; //16 contiguous chunks of head elems are spread across 4x4lanes + shared_logits[offset1][lane4id][local_qhead_idx][0] = tmp.xy[0]; + shared_logits[offset1][lane4id][local_qhead_idx][1] = tmp.xy[1]; + } else { + for (int i=0; i<2; i++) { + const int head_elem = lane16id * 2 + i; //element id in _B16x4 terms + const int offset3 = head_elem % 4; + const int offset2 = (head_elem / 4) % 4; + const int offset1 = head_elem /4/4; + shared_logits[offset1][offset2][local_qhead_idx][offset3] = tmp.xy[i]; + } + } + } + __syncthreads(); + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { + for (int i=0; i<2; i++) { + Qlocal[qkhe_depth][qkratio].xy[i] = shared_logits[qkhe_depth][rowid][lane16id % GQA_RATIO][2*qkratio + i]; + } + } + } +#endif + + constexpr int KX = 16 / sizeof(cache_t); + const cache_t* k_ptr = k_cache + wg_start_kv_head_idx * kv_head_stride; + + const int row_head_elem = rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD; + + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int64_t kblock_number = static_cast(kphysical_block_number[token_depth]); + const cache_t* k_ptr2 = k_ptr + kblock_number * kv_block_stride; + const int klocal_token_idx = TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; + const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; + const int kphysical_block_offset = klocal_token_idx % BLOCK_SIZE; + const cache_t* k_ptr3 = k_ptr2 + kphysical_block_offset * KX; + + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + const int head_elem = row_head_elem + qkhe_depth * QKHE_PER_FETCH; + const int offset1 = head_elem / KX; + const int offset2 = head_elem % KX; + const cache_t* k_fetch_ptr = k_ptr3 + offset1 * BLOCK_SIZE * KX + offset2; + const _B16x8* k_fetch_ptr_16B = reinterpret_cast(k_fetch_ptr); + Klocal[token_depth][qkhe_depth] = *k_fetch_ptr_16B; + } + } + + constexpr int VTOKENS_PER_LANE = TOKENS_PER_WARP / ROWS_PER_WARP;// 16 * T_PAR_SIZE / 256; + constexpr int VBLOCKS_PER_LANE = DIVIDE_ROUND_UP(VTOKENS_PER_LANE,BLOCK_SIZE); + constexpr int VTLOOP = NWARPS; //was * TOKENS_PER_WARP / ROWS_PER_WARP / VTOKENS_PER_LANE; + constexpr int VTLANELOOP = DIVIDE_ROUND_UP(VTOKENS_PER_LANE , CONTIGUOUS_KV_ELEMS_16B_LOAD); //optimized for 16B fetches; assumes minimum block size is 16 + constexpr int VHELOOP = HEAD_SIZE / 16 / NWARPS; + + int vphysical_block_number[VTLOOP][VBLOCKS_PER_LANE]; + + //fetch v physical block numbers + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE; vblock_depth++) { + const int vlocal_token_idx = vtoken_depth * VTOKENS_PER_LANE * ROWS_PER_WARP + rowid * VTOKENS_PER_LANE + vblock_depth * BLOCK_SIZE; + const int vglobal_token_idx = partition_start_token_idx + vlocal_token_idx; + const int vblock_idx = (vglobal_token_idx < context_len) + ? vglobal_token_idx / BLOCK_SIZE + : last_ctx_block; + vphysical_block_number[vtoken_depth][vblock_depth] = + block_table_seq[vblock_idx]; + } + } + + _B16x8 Vlocal[VTLOOP][VHELOOP][VTLANELOOP]; //this could be B8x16 too + + const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; + + //v fetches are 16head elems across lanes x 16 tokens per lane + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + const int vhead_elem = vhe_depth * NWARPS * 16 + warpid * 16 + lane16id; + const cache_t* v_ptr2 = v_ptr + vhead_elem * BLOCK_SIZE; + + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + const int vblock_depth = vfetch_depth * CONTIGUOUS_KV_ELEMS_16B_LOAD / BLOCK_SIZE; + //const int token_depth = vtoken_depth * VBLOCKS_PER_LANE + vblock_depth; + const int64_t vblock_number = static_cast(vphysical_block_number[vtoken_depth][vblock_depth]); + const cache_t* v_ptr3 = v_ptr2 + (vblock_number * kv_block_stride); + + const cache_t* v_fetch_ptr = v_ptr3 + vfetch_depth * CONTIGUOUS_KV_ELEMS_16B_LOAD; + const _B16x8* v_fetch_ptr_16B = reinterpret_cast(v_fetch_ptr); + Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = *v_fetch_ptr_16B; + } + } + } + + //__syncthreads(); //if using shared Q + float scale2 = scale; + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + scale2 *= k_scale; + } + + floatx4 dout[TLOOP]; +#if 1 //Q stored in registers + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + dout[token_depth] = {0}; + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { + for (int i=0; i<2; i++) { + dout[token_depth] = gcn_mfma16x16x16_instr(Klocal[token_depth][qkhe_depth].xy[i], + Qlocal[qkhe_depth][qkratio].xy[i], dout[token_depth]); + } + } + } else { //kv cache dtype fp8 + auto Ktmp = Klocal[token_depth][qkhe_depth]; + _B8x16 Ktmp8x16 = *reinterpret_cast<_B8x16*>(&Ktmp); + for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { + _B8x8 Ktmp8x8 = Ktmp8x16.xy[qkratio]; + _B16x8 Klocaltmp = convert_b8x8_custom(Ktmp8x8); + for (int i=0; i<2; i++) { + dout[token_depth] = gcn_mfma16x16x16_instr(Klocaltmp.xy[i], + Qlocal[qkhe_depth][qkratio].xy[i], dout[token_depth]); + } + } + } + } + dout[token_depth] *= scale2; + } + +#else //Q in shared + _B16x4 tmpQ[QKHELOOP][2]; + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + tmpQ[qkhe_depth][0] = shared_logits[qkhe_depth][rowid][lane16id][0]; + tmpQ[qkhe_depth][1] = shared_logits[qkhe_depth][rowid][lane16id][1]; + } + + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + dout[token_depth] = {0}; + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + //for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { + for (int i=0; i<2; i++) { + dout[token_depth] = gcn_mfma16x16x16_instr(Klocal[token_depth][qkhe_depth].xy[i], + tmpQ[qkhe_depth][i], //shared_logits[qkhe_depth][rowid][lane16id][i], + dout[token_depth]); + } + //} + } + dout[token_depth] *= scale; + } +#endif + +#if 0 //DEBUG ONLY qk * scale + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + auto qkout_ptr2 = qkout_ptr + warpid * TLOOP * 16 + token_depth * 16 + rowid * 4; + auto qkout_write_ptr = reinterpret_cast<_B16x4 *>(qkout_ptr2); + auto tmp = from_floatx4(dout[token_depth]); + *qkout_write_ptr = tmp; + } +#endif + + float qk_max = -FLT_MAX; + float exp_sum = 0.0f; + + const int qkout_token_idx = partition_start_token_idx + TOKENS_PER_WARP * warpid + rowid * 4; + + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + for (int i=0; i<4; i++) { + const float tmp = (local_token_idx + i < context_len) ? dout[token_depth][i] : -FLT_MAX; + qk_max = fmaxf(qk_max, tmp); + } + } + + for (int mask = WARP_SIZE/2; mask >= 16; mask/=2) { + qk_max = fmaxf(qk_max, __shfl_xor(qk_max,mask)); + } + + + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + for (int i=0; i<4; i++) { + const float tmp = (local_token_idx + i < context_len) ? __expf(dout[token_depth][i] - qk_max) : 0.0f; + dout[token_depth][i] = tmp; + exp_sum += tmp; + } + } + + for (int mask = WARP_SIZE/2; mask >= 16; mask/=2) { + exp_sum += __shfl_xor(exp_sum,mask); + } + + __syncthreads(); //sync before writing to shared mem + + float* shared_mem = reinterpret_cast(shared_logits); + if (laneid < 16) { + //shared_qk_max[warpid][lane16id] = qk_max; + //shared_exp_sum[warpid][lane16id] = exp_sum; + const int qk_max_offset = warpid*16 + lane16id; + shared_mem[qk_max_offset] = qk_max; + const int exp_sum_offset = NWARPS*16 + qk_max_offset; + shared_mem[exp_sum_offset] = exp_sum; + } + +#if 0 //DEBUG ONLY + //scalar_t* qkout_ptr = out + + // seq_idx * total_num_heads * T_PAR_SIZE + lane16id * T_PAR_SIZE; + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + //auto qkout_ptr2 = qkout_ptr + warpid * TLOOP * 16 + token_depth * 16 + rowid * 4; + //auto qkout_write_ptr = reinterpret_cast<_B16x4 *>(qkout_ptr2); + auto tmp = from_floatx4(dout[token_depth]); + shared_tokens[warpid][token_depth][lane16id][rowid] = tmp; + //*qkout_write_ptr = tmp; + } +#endif + __syncthreads(); + + float partition_qk_max = -FLT_MAX; + float warp_qk_max_exp[NWARPS]; + float partition_exp_sum = 0.0f; + + for (int w=0; w(dout[token_depth]); + } else { + shared_logits[warpid][token_depth][lane16id][rowid] = from_floatx4_rtz(dout[token_depth]); + } + } + + if (threadIdx.x < GQA_RATIO) { + const int qhead_idx = lane16id; + const int offset = seq_idx * total_num_heads * max_num_partitions + (wg_start_head_idx + qhead_idx) * max_num_partitions + partition_idx; + max_logits[offset] = partition_qk_max; + exp_sums[offset] = partition_exp_sum; + } + + __syncthreads(); + +#if 0 //DEBUG ONLY + scalar_t* qkout_ptr = out + + seq_idx * total_num_heads * T_PAR_SIZE + lane16id * T_PAR_SIZE; + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + auto qkout_ptr2 = qkout_ptr + warpid * TLOOP * 16 + token_depth * 16 + rowid * 4; + auto qkout_write_ptr = reinterpret_cast<_B16x4 *>(qkout_ptr2); + //dout[token_depth] *= inv_sum_scale[warpid]; + //auto tmp = from_floatx4(dout[token_depth]); + auto tmp = shared_tokens[warpid][token_depth][lane16id][rowid]; + *qkout_write_ptr = tmp; + } +#endif +#if 0 + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + _B16x8 Vtmp = Vlocal[vtoken_depth][vhe_depth][vfetch_depth]; + _B8x16 Vtmp8x16 = *reinterpret_cast<_B8x16*>(&Vtmp); + for (int j=0; j<2; j++) { + _B8x8 Vtmp8x8 = Vtmp8x16.xy[j]; + _B16x8 Vlocaltmp = convert_b8x8_custom(Vtmp8x8); + for (int i=0; i<2; i++) { + const int offset = 4*rowid + 2*j + i; + const int offset1 = offset % 4; + const int offset2 = offset / 4; + tmp_out = gcn_mfma16x16x16_instr(Vlocaltmp.xy[i], + shared_logits[vtoken_depth][offset2][lane16id][offset1], + tmp_out); + } + } + } + } +#endif + _B16x4 outelems[VHELOOP]; + _B16x4 S_local[VTLOOP][2][2]; + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + //for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + for (int j=0; j<2; j++) { + for (int i=0; i<2; i++) { + const int offset = 4*rowid + 2*j + i; + const int offset1 = offset % 4; + const int offset2 = offset / 4; + S_local[vtoken_depth][j][i] = shared_logits[vtoken_depth][offset2][lane16id][offset1]; + } + } + //} + } + } + //v layout: 16he across lanes x 16 tokens per lane + + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + floatx4 tmp_out = {0}; + + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + for (int i=0; i<2; i++) { + //TODO generalize this for 8 bit dtypes: each lane needs 2*vfetch_depth + 2 _B16x4 K/token dimension elems; each row is multiplied by a factor of 4 + //layout: lane in depth dimension | row across -> + //0 4 8 12 + //1 5 9 13 + //2 6 10 14 + //3 7 11 15 + const int offset = rowid * VTLANELOOP * 2 + 2*vfetch_depth + i; + const int offset1 = offset % 4; //4 corresponds to ROWS_PER_WARP + const int offset2 = offset / 4; +#if 0 + //if output format is 16 head elems across 16 lanes, 16 qheads spread across 4 rows + tmp_out = gcn_mfma16x16x16_instr(shared_logits[vtoken_depth][offset2][lane16id][offset1], + Vlocal[vtoken_depth][vhe_depth][vfetch_depth].xy[i], tmp_out); +#else + //if output format is 16 qheads across 16 lanes, 16 head elems spread across 4 rows + tmp_out = gcn_mfma16x16x16_instr(Vlocal[vtoken_depth][vhe_depth][vfetch_depth].xy[i], + shared_logits[vtoken_depth][offset2][lane16id][offset1], + tmp_out); +#endif + } + } + } else { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + _B16x8 Vtmp = Vlocal[vtoken_depth][vhe_depth][vfetch_depth]; + _B8x16 Vtmp8x16 = *reinterpret_cast<_B8x16*>(&Vtmp); + for (int j=0; j<2; j++) { + _B8x8 Vtmp8x8 = Vtmp8x16.xy[j]; + _B16x8 Vlocaltmp = convert_b8x8_custom(Vtmp8x8); + for (int i=0; i<2; i++) { + const int offset = 4*rowid + 2*j + i; + const int offset1 = offset % 4; + const int offset2 = offset / 4; + tmp_out = gcn_mfma16x16x16_instr(Vlocaltmp.xy[i], + S_local[vtoken_depth][j][i], + tmp_out); + //shared_logits[vtoken_depth][offset2][lane16id][offset1], + //tmp_out); + } + } + } + + } + } + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + tmp_out *= v_scale; + } + outelems[vhe_depth] = from_floatx4(tmp_out); + } +#if 1 + __syncthreads(); + + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + shared_logits[warpid][vhe_depth][lane16id][rowid] = outelems[vhe_depth]; //lane16 id head dimension; rowid head element dimension + } + + __syncthreads(); + + if (warpid == 0) { + _B16x8 vout[GQA_RATIO4]; + for (int h = 0; h < GQA_RATIO4; h++) { + const int local_head_idx = 4 * h + rowid; + const int head_elem_idx = lane16id * 8; + const int offset1 = (head_elem_idx / 16)%4; + const int offset2 = head_elem_idx / 16 / NWARPS; + const int offset3 = (head_elem_idx / 4)%4; + for (int i=0; i<2; i++) { + vout[h].xy[i] = shared_logits[offset1][offset2][local_head_idx][offset3+i]; + } + } + + const int hsz_maxp_mult = HEAD_SIZE * max_num_partitions; + scalar_t* out_ptr = out + + seq_idx * total_num_heads * hsz_maxp_mult + partition_idx * HEAD_SIZE; + for (int h = 0; h < GQA_RATIO4; h++) { + const int local_head_idx = 4 * h + rowid; + if (local_head_idx < GQA_RATIO) { + const int out_head_idx = wg_start_head_idx + local_head_idx; + scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult; + const int head_elem_idx = lane16id * 8; + scalar_t* out_ptr3 = out_ptr2 + head_elem_idx; + _B16x8* out_ptr_B16x8 = reinterpret_cast<_B16x8*>(out_ptr3); + *out_ptr_B16x8 = vout[h]; + } + } + } + +#endif + +#if 0 + //if output format is 16 he across 16 lanes, 16 qheads spread across 4 rows + const int hsz_maxp_mult = HEAD_SIZE * max_num_partitions; + scalar_t* out_ptr = out + + seq_idx * total_num_heads * hsz_maxp_mult + partition_idx * HEAD_SIZE; + + const int vhe_offset = warpid * 16 + lane16id; + + #pragma unroll + for (int i=0; i<4; i++) { + const int local_head_idx = 4*rowid + i; + if (local_head_idx < GQA_RATIO) { + const int out_head_idx = wg_start_head_idx + local_head_idx; + scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult; + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + const int vhead_elem = vhe_depth * NWARPS * 16 + vhe_offset; + scalar_t* out_ptr3 = out_ptr2 + vhead_elem; + bit16_t* out_ptr_b16 = reinterpret_cast(out_ptr3); + *out_ptr_b16 = outelems[vhe_depth][i]; + } + } + } +#endif +#if 0 + //if output format is 16 qheads across 16 lanes, 16 he spread across 4 rows + if (lane16id < GQA_RATIO) { + const int hsz_maxp_mult = HEAD_SIZE * max_num_partitions; + scalar_t* out_ptr = out + + seq_idx * total_num_heads * hsz_maxp_mult + partition_idx * HEAD_SIZE; + const int local_head_idx = lane16id; + const int out_head_idx = wg_start_head_idx + local_head_idx; + scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult; + const int vhe_offset = warpid * 16 + rowid * 4; + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + const int vhead_elem = vhe_depth * NWARPS * 16 + vhe_offset; + scalar_t* out_ptr3 = out_ptr2 + vhead_elem; + _B16x4* out_ptr_B16x4 = reinterpret_cast<_B16x4*>(out_ptr3); + *out_ptr_B16x4 = outelems[vhe_depth]; + } + } +#endif +#if 0 //DEBUG ONLY + floatx4 partition_out[VHELOOP]; + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + partition_out[vhe_depth] = {0}; + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + partition_out[vhe_depth] += inv_sum_scale[vtoken_depth] * vout[vhe_depth][vtoken_depth]; + } + } +#endif +#if 0 //DEBUG ONLY + if (laneid < GQA_RATIO) { + auto* exp_sums_ptr = exp_sums + seq_idx * 8 * max_num_partitions + partition_idx; + floatx4 tmp = {0}; + //for (int t=0; t(from_floatx4(tmp), shared_tokens[warpid][lane4id][lane16id][rowid]); + + float2 tmpf = *reinterpret_cast(&tmp16); + *exp_sums_ptr = laneid%2 == 0 ? tmpf.x : tmpf.y; + } +#endif +} +///////////////////////////////////////////////////////////// // grid (num_seqs, num_partitions,num_heads/gqa_ratio) // block (partition size) template 16/2 // 8xtokens + constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE; + int vphysical_blocks[VBLOCKS]; _B16x8 Vlocal[VHELOOP][VTLOOP]; _B8x8 Vlocalb8[VHELOOP][VTLOOP]; floatx4 dout[QHLOOP]; float qk_max[QHLOOP]; + __shared__ _B16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1]; #pragma unroll for (int h = 0; h < QHLOOP; h++) { dout[h] = {0}; @@ -312,8 +1129,6 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( static_cast(block_table[block_idx]); // fetch vphysical block numbers up front - constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE; - int vphysical_blocks[VBLOCKS]; const int warp_start_block_idx = warp_start_token_idx / BLOCK_SIZE; if constexpr (GQA_RATIO < 12) { @@ -370,6 +1185,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } +#if 1 float alibi_slope[QHLOOP]; if (alibi_slopes != nullptr) { #pragma unroll @@ -380,6 +1196,26 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( : 0.f; } } +#endif +#if 0 + float alibi_slope; + const int lane16id = laneid % 16; + if (alibi_slopes != nullptr) { + alibi_slope = (lane16id < GQA_RATIO) + ? alibi_slopes[wg_start_head_idx + lane16id] + : 0.f; + //#pragma unroll + // for (int h = 0; h < QHLOOP; h++) { + // for (int i=0; i<4; i++) { + // const int qhead_idx = h * 4 + i; + // alibi_slope[qhead_idx] = (qhead_idx < GQA_RATIO) + // ? alibi_slopes[wg_start_head_idx + qhead_idx] + // : 0.f; + // } + //} + //} + } +#endif // fetch vphysical block numbers up front if constexpr (GQA_RATIO >= 12) { @@ -392,6 +1228,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } +#if 1 //fetch vcache in normal case const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { const _B16x8* v_ptrh8 = reinterpret_cast(v_ptr); @@ -416,7 +1253,10 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } } - } else { + } //if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) +#endif +#if 1 //fetch vcache in fp8 case + else { // if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) const _B8x8* v_ptrh8 = reinterpret_cast(v_ptr); // iterate over each v block #pragma unroll @@ -435,23 +1275,73 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( // iterate over all velems within block #pragma unroll for (int d = 0; d < BLOCK_SIZE / 8; d++) { - // Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; - const _B8x8 Vlocalb8 = v_ptrh8be[d]; - Vlocal[h][b * BLOCK_SIZE / 8 + d] = - scaled_convert_b8x8(Vlocalb8, v_scale); + Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; + //const _B8x8 Vlocalb8 = v_ptrh8be[d]; + //Vlocal[h][b * BLOCK_SIZE / 8 + d] = + // scaled_convert_b8x8(Vlocalb8, v_scale); } } } } - +#endif +#if 0 //cvt kf8 to kf/bf16 up front if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { #pragma unroll for (int d = 0; d < KHELOOP; d++) { Klocal[d] = - scaled_convert_b8x8(Klocalb8[d], k_scale); + //scaled_convert_b8x8(Klocalb8[d], k_scale); + convert_b8x8_custom(Klocalb8[d]); } } +#endif + + /*Klocal[x] = scaled_convert_b8x8(Klocalb8[x], k_scale); \*/ + /*Klocal[x] = scaled_convert_b8x8_custom(Klocalb8[x], k_scale); \*/ +#define QK_mfma(x) \ + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { \ + Klocal[x] = convert_b8x8_custom(Klocalb8[x]); \ + } \ + for (int h = 0; h < QHLOOP; h++) { \ + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], \ + Klocal[x].xy[0], dout[h]);\ + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], \ + Klocal[x].xy[1], dout[h]);\ + } + //#pragma unroll + //for (int h = 0; h < QHLOOP; h++) { + QK_mfma(0); + QK_mfma(1); + QK_mfma(2); + QK_mfma(3); + QK_mfma(4); + QK_mfma(5); + QK_mfma(6); + QK_mfma(7); + if constexpr (KHELOOP > 8) { + QK_mfma(8); + QK_mfma(9); + QK_mfma(10); + QK_mfma(11); + QK_mfma(12); + QK_mfma(13); + QK_mfma(14); + QK_mfma(15); + } + //} +#undef QK_mfma + float scale2 = scale; + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + scale2 *= k_scale; + } + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + dout[h] *= scale2; + //if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + // dout[h] *= k_scale; + //} + } +#if 0 #pragma unroll for (int h = 0; h < QHLOOP; h++) { dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], @@ -522,10 +1412,39 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } // KHELOOP>8 dout[h] *= scale; } +#endif + +#if 0 + if (alibi_slopes != nullptr) { + float alibi_slope_local[GQA_RATIO]; +#define DPP_BCAST_ASM(id) asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:id " : "=v"(alibi_slope_local[id]) : "v"(alibi_slope)); + //for (int head=0; head < 16; head++) { + //DPP_BCAST_ASM(0); + if constexpr(GQA_RATIO>0) { asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:0 " : "=v"(alibi_slope_local[0]) : "v"(alibi_slope));} + if constexpr(GQA_RATIO>1) { asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:1 " : "=v"(alibi_slope_local[1]) : "v"(alibi_slope));} + if constexpr(GQA_RATIO>2) { asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:2 " : "=v"(alibi_slope_local[2]) : "v"(alibi_slope));} + if constexpr(GQA_RATIO>3) { asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:3 " : "=v"(alibi_slope_local[3]) : "v"(alibi_slope));} + if constexpr(GQA_RATIO>4) { asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:4 " : "=v"(alibi_slope_local[4]) : "v"(alibi_slope));} + if constexpr(GQA_RATIO>5) { asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:5 " : "=v"(alibi_slope_local[5]) : "v"(alibi_slope));} + if constexpr(GQA_RATIO>6) { asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:6 " : "=v"(alibi_slope_local[6]) : "v"(alibi_slope));} + if constexpr(GQA_RATIO>7) { asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:7 " : "=v"(alibi_slope_local[7]) : "v"(alibi_slope));} + //} + + const int alibi_offset = global_token_idx - context_len + 1; + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + #pragma unroll + for (int i = 0; i < 4; i++) { + dout[h][i] += alibi_slope_local[4*h+i] * alibi_offset; + } + } + } +#endif // transpose dout so that 4 token ids are in each lane, and 4 heads are across // 4 lanes #pragma unroll for (int h = 0; h < QHLOOP; h++) { +#if 1 floatx4 tmp = {0}; #pragma unroll for (int i = 0; i < 4; i++) { @@ -535,9 +1454,38 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( // tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(A, B, tmp, 0, 0, 0); } dout[h] = tmp; +#endif +#if 0 + asm("s_nop 0\n\t v_mov_b32_dpp %0, %1 quad_perm:[1,0,3,2] " : "=v"(dout[h][1]) : "v"(dout[h][1]) ); + asm("s_nop 0\n\t v_mov_b32_dpp %0, %1 quad_perm:[2,3,0,1] " : "=v"(dout[h][2]) : "v"(dout[h][2]) ); + asm("s_nop 0\n\t v_mov_b32_dpp %0, %1 quad_perm:[3,2,1,0] " : "=v"(dout[h][3]) : "v"(dout[h][3]) ); + + bool mask = (lane4id % 2) == 1; + float tmp = dout[h][1]; + dout[h][1] = mask ? dout[h][0] : dout[h][1]; + dout[h][0] = mask ? tmp : dout[h][0]; + tmp = dout[h][3]; + dout[h][3] = mask ? dout[h][2] : dout[h][3]; + dout[h][2] = mask ? tmp : dout[h][2]; + + mask = (lane4id>>1) == 1; + tmp = dout[h][2]; + dout[h][2] = mask ? dout[h][0] : dout[h][2]; + dout[h][0] = mask ? tmp : dout[h][0]; + tmp = dout[h][3]; + dout[h][3] = mask ? dout[h][1] : dout[h][3]; + dout[h][1] = mask ? tmp : dout[h][1]; + + + asm("s_nop 0\n\t v_mov_b32_dpp %0, %1 quad_perm:[1,0,3,2] " : "=v"(dout[h][1]) : "v"(dout[h][1]) ); + asm("s_nop 0\n\t v_mov_b32_dpp %0, %1 quad_perm:[2,3,0,1] " : "=v"(dout[h][2]) : "v"(dout[h][2]) ); + asm("s_nop 0\n\t v_mov_b32_dpp %0, %1 quad_perm:[3,2,1,0] " : "=v"(dout[h][3]) : "v"(dout[h][3]) ); + +#endif } const int lane4_token_idx = 4 * (global_token_idx >> 2); +#if 1 //alibi after transpose const int alibi_offset = lane4_token_idx - context_len + 1; if (alibi_slopes != nullptr) { #pragma unroll @@ -548,6 +1496,9 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } } +#endif + + const int bpermute_mask = 4*(16*((laneid>>2)%4) + lane4id); #pragma unroll for (int h = 0; h < QHLOOP; h++) { @@ -559,11 +1510,22 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( : qk_max[h]; } #pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { + for (int mask = WARP_SIZE / 2; mask >= 64; mask /= 2) { qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h], mask)); } + asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:4" : "=v"(qk_max[h]) : "v"(qk_max[h]), "v"(qk_max[h]) ); + asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:8" : "=v"(qk_max[h]) : "v"(qk_max[h]), "v"(qk_max[h]) ); + + //asm("v_nop\n v_nop\n ds_bpermute_b32 %0, %1, %2 \n s_waitcnt lgkmcnt(0)" : "=v"(qk_max[h]) : "v"(bpermute_mask), "v"(qk_max[h]) ); + + //qk_max[h] = __builtin_amdgcn_ds_bpermute(bpermute_mask, qk_max[h]); + auto tmp = __builtin_amdgcn_ds_bpermute(bpermute_mask, *reinterpret_cast(&qk_max[h])); + qk_max[h] = *reinterpret_cast(&tmp); + asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:4" : "=v"(qk_max[h]) : "v"(qk_max[h]), "v"(qk_max[h]) ); + asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:8" : "=v"(qk_max[h]) : "v"(qk_max[h]), "v"(qk_max[h]) ); } + float exp_sum[QHLOOP]; #pragma unroll for (int h = 0; h < QHLOOP; h++) { @@ -576,17 +1538,28 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( exp_sum[h] += dout[h][i]; } #pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { + for (int mask = WARP_SIZE / 2; mask >= 64; mask /= 2) { exp_sum[h] += __shfl_xor(exp_sum[h], mask); } + asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:4" : "=v"(exp_sum[h]) : "v"(exp_sum[h]), "v"(exp_sum[h]) ); + asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:8" : "=v"(exp_sum[h]) : "v"(exp_sum[h]), "v"(exp_sum[h]) ); + + //asm("v_nop\n v_nop\n ds_bpermute_b32 %0, %1, %2 \n s_waitcnt lgkmcnt(0)" : "=v"(exp_sum[h]) : "v"(bpermute_mask), "v"(exp_sum[h]) ); + //exp_sum[h] = __builtin_amdgcn_ds_bpermute(bpermute_mask, exp_sum[h]); + auto tmp = __builtin_amdgcn_ds_bpermute(bpermute_mask, *reinterpret_cast(&exp_sum[h])); + exp_sum[h] = *reinterpret_cast(&tmp); + asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:4" : "=v"(exp_sum[h]) : "v"(exp_sum[h]), "v"(exp_sum[h]) ); + asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:8" : "=v"(exp_sum[h]) : "v"(exp_sum[h]), "v"(exp_sum[h]) ); } + if (laneid<4) { #pragma unroll for (int h = 0; h < QHLOOP; h++) { const int head_idx = 4 * h + lane4id; shared_qk_max[warpid][head_idx] = qk_max[h]; shared_exp_sum[warpid][head_idx] = exp_sum[h]; } + } } // warp within context __syncthreads(); @@ -630,7 +1603,6 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( logits[h] = from_floatx4(dout[h]); } - __shared__ _B16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1]; if (warp_start_token_idx >= context_len) { // warp out of context #pragma unroll @@ -641,6 +1613,139 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } } else { // warp in context +#if 0 //fetch v cache + const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + const _B16x8* v_ptrh8 = reinterpret_cast(v_ptr); + // iterate over each v block + #pragma unroll + for (int b = 0; b < VBLOCKS; b++) { + // int32 physical_block_number leads to overflow when multiplied with + // kv_block_stride + const int64_t vphysical_block_number = + static_cast(vphysical_blocks[b]); + const _B16x8* v_ptrh8b = + v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; + // iterate over each head elem (within head_size) + #pragma unroll + for (int h = 0; h < VHELOOP; h++) { + const int head_size_elem = h * WARP_SIZE + laneid; + const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; + // iterate over all velems within block + #pragma unroll + for (int d = 0; d < BLOCK_SIZE / 8; d++) { + Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; + } + } + } + } //if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) + + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + const _B8x8* v_ptrh8 = reinterpret_cast(v_ptr); + // iterate over each v block + #pragma unroll + for (int b = 0; b < VBLOCKS; b++) { + // int32 physical_block_number leads to overflow when multiplied with + // kv_block_stride + const int64_t vphysical_block_number = + static_cast(vphysical_blocks[b]); + const _B8x8* v_ptrh8b = + v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; + // iterate over each head elem (within head_size) + #pragma unroll + for (int h = 0; h < VHELOOP; h++) { + const int head_size_elem = h * WARP_SIZE + laneid; + const _B8x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; + // iterate over all velems within block + #pragma unroll + for (int d = 0; d < BLOCK_SIZE / 8; d++) { + Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; + //const _B8x8 Vlocalb8 = v_ptrh8be[d]; + //Vlocal[h][b * BLOCK_SIZE / 8 + d] = + // scaled_convert_b8x8(Vlocalb8, v_scale); + } + } + } + } +#endif +#if 0 //cvt vf8 ->f16/bf16 up front + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + for (int vh = 0; vh < VHELOOP; vh++) { + for (int b=0; b < VTLOOP; b++) { + //Vlocal[vh][b] = scaled_convert_b8x8(Vlocalb8[vh][b], v_scale); + Vlocal[vh][b] = convert_b8x8_custom(Vlocalb8[vh][b]); + } + } + } +#endif + + /*Vlocal[vh][x] = scaled_convert_b8x8(Vlocalb8[vh][x], v_scale);\*/ + /*Vlocal[vh][x] = scaled_convert_b8x8_custom(Vlocalb8[vh][x], v_scale);\*/ + #define SV_mfma(x) \ + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) {\ + Vlocal[vh][x] = convert_b8x8_custom(Vlocalb8[vh][x]);\ + }\ + for (int qh = 0; qh < QHLOOP; qh++) { \ + acc[qh] = gcn_mfma_instr(logits[qh], Vlocal[vh][x].xy[0], \ + acc[qh]); \ + acc[qh] = gcn_mfma_instr(logits[qh], Vlocal[vh][x].xy[1], \ + acc[qh]); \ + } +#if 0 + floatx4 acc[QHLOOP][VHELOOP]; + for (int qh = 0; qh < QHLOOP; qh++) { + for (int vh = 0; vh < VHELOOP; vh++) { + acc[qh][vh] = {0}; + } + } +#endif + //#pragma unroll + // for (int qh = 0; qh < QHLOOP; qh++) { + // iterate over each v head elem (within head_size) + //#pragma unroll + for (int vh = 0; vh < VHELOOP; vh++) { + floatx4 acc[QHLOOP]; + for (int qh = 0; qh < QHLOOP; qh++) { + acc[qh] = {0}; + } + // iterate over tokens + SV_mfma(0); + SV_mfma(1); + SV_mfma(2); + SV_mfma(3); + SV_mfma(4); + SV_mfma(5); + SV_mfma(6); + SV_mfma(7); +#if 0 + SV_mfma(8); + SV_mfma(9); + SV_mfma(10); + SV_mfma(11); + SV_mfma(12); + SV_mfma(13); + SV_mfma(14); + SV_mfma(15); +#endif + for (int qh = 0; qh < QHLOOP; qh++) { + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + acc[qh] *= v_scale; + } + vout_shared[qh][vh][laneid][warpid] = from_floatx4(acc[qh]); + } + } + //} + +#if 0 + for (int qh = 0; qh < QHLOOP; qh++) { + for (int vh = 0; vh < VHELOOP; vh++) { + vout_shared[qh][vh][laneid][warpid] = from_floatx4(acc[qh][vh]); + } + } +#endif + +#undef SV_mfma +#if 0 // iterate across heads #pragma unroll for (int qh = 0; qh < QHLOOP; qh++) { @@ -684,6 +1789,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( vout_shared[qh][vh][laneid][warpid] = from_floatx4(acc); } } +#endif } // warp in context __syncthreads(); @@ -787,12 +1893,13 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( const int seq_idx = blockIdx.y; const int context_len = context_lens[seq_idx]; const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); +#if 0 //disable this as mfma16 kernel does not support this optimization yet if (num_partitions == 1) { // if num_partitions==1, main kernel will write to out directly, no work in // reduction kernel return; } - +#endif constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; const int warpid = threadIdx.x / WARP_SIZE; const int laneid = threadIdx.x % WARP_SIZE; @@ -973,6 +2080,33 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +template +__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, float k_scale, float v_scale, + const float* __restrict__ fp8_out_scale_ptr) { + UNREACHABLE_CODE +} + template \ + <<>>( \ + query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ + block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ + exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \ + k_scale, v_scale, fp8_out_scale_ptr); + #define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \ paged_attention_ll4mi_QKV_kernel \ @@ -1036,7 +2180,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( context_lens_ptr, max_num_partitions, fp8_out_scale_ptr); template + int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD> void paged_attention_custom_launcher( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, @@ -1076,65 +2220,82 @@ void paged_attention_custom_launcher( OUTT* out_ptr = reinterpret_cast(out.data_ptr()); const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); + constexpr int PARTITION_SIZE = 256; const int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); const int gqa_ratio = num_heads / num_kv_heads; assert(num_heads % num_kv_heads == 0); assert(head_size == HEAD_SIZE); - constexpr int NTHR = PARTITION_SIZE; + constexpr int NTHR = 256; //PARTITION_SIZE; dim3 grid(num_seqs, max_num_partitions, num_kv_heads); dim3 block(NTHR); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); switch (gqa_ratio) { case 1: - LAUNCH_CUSTOM_ATTENTION(1); + //LAUNCH_CUSTOM_ATTENTION(1); + LAUNCH_CUSTOM_ATTENTION_MFMA16(1); break; case 2: - LAUNCH_CUSTOM_ATTENTION(2); + //LAUNCH_CUSTOM_ATTENTION(2); + LAUNCH_CUSTOM_ATTENTION_MFMA16(2); break; case 3: - LAUNCH_CUSTOM_ATTENTION(3); + //LAUNCH_CUSTOM_ATTENTION(3); + LAUNCH_CUSTOM_ATTENTION_MFMA16(3); break; case 4: - LAUNCH_CUSTOM_ATTENTION(4); + //LAUNCH_CUSTOM_ATTENTION(4); + LAUNCH_CUSTOM_ATTENTION_MFMA16(4); break; case 5: - LAUNCH_CUSTOM_ATTENTION(5); + //LAUNCH_CUSTOM_ATTENTION(5); + LAUNCH_CUSTOM_ATTENTION_MFMA16(5); break; case 6: - LAUNCH_CUSTOM_ATTENTION(6); + //LAUNCH_CUSTOM_ATTENTION(6); + LAUNCH_CUSTOM_ATTENTION_MFMA16(6); break; case 7: - LAUNCH_CUSTOM_ATTENTION(7); + //LAUNCH_CUSTOM_ATTENTION(7); + LAUNCH_CUSTOM_ATTENTION_MFMA16(7); break; case 8: - LAUNCH_CUSTOM_ATTENTION(8); + //LAUNCH_CUSTOM_ATTENTION(8); + LAUNCH_CUSTOM_ATTENTION_MFMA16(8); break; case 9: - LAUNCH_CUSTOM_ATTENTION(9); + //LAUNCH_CUSTOM_ATTENTION(9); + LAUNCH_CUSTOM_ATTENTION_MFMA16(9); break; case 10: - LAUNCH_CUSTOM_ATTENTION(10); + //LAUNCH_CUSTOM_ATTENTION(10); + LAUNCH_CUSTOM_ATTENTION_MFMA16(10); break; case 11: - LAUNCH_CUSTOM_ATTENTION(11); + //LAUNCH_CUSTOM_ATTENTION(11); + LAUNCH_CUSTOM_ATTENTION_MFMA16(11); break; case 12: - LAUNCH_CUSTOM_ATTENTION(12); + //LAUNCH_CUSTOM_ATTENTION(12); + LAUNCH_CUSTOM_ATTENTION_MFMA16(12); break; case 13: - LAUNCH_CUSTOM_ATTENTION(13); + //LAUNCH_CUSTOM_ATTENTION(13); + LAUNCH_CUSTOM_ATTENTION_MFMA16(13); break; case 14: - LAUNCH_CUSTOM_ATTENTION(14); + //LAUNCH_CUSTOM_ATTENTION(14); + LAUNCH_CUSTOM_ATTENTION_MFMA16(14); break; case 15: - LAUNCH_CUSTOM_ATTENTION(15); + //LAUNCH_CUSTOM_ATTENTION(15); + LAUNCH_CUSTOM_ATTENTION_MFMA16(15); break; case 16: - LAUNCH_CUSTOM_ATTENTION(16); + //LAUNCH_CUSTOM_ATTENTION(16); + LAUNCH_CUSTOM_ATTENTION_MFMA16(16); break; default: TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio); @@ -1146,11 +2307,14 @@ void paged_attention_custom_launcher( // note there are cases with graphing where max_context_len is the max // supported by graphing, not the actual max among all the sequences: in that // case reduction kernel will still run but return immediately - if (max_context_len > PARTITION_SIZE) { + + //below optimization is not yet implemented in mfma16 kernel + //if (max_context_len > PARTITION_SIZE) { dim3 reduce_grid(num_heads, num_seqs); dim3 reduce_block(head_size); const int npar_loops = DIVIDE_ROUND_UP(max_num_partitions, WARP_SIZE); // support upto 8*64*256=128K context length +#if 1 switch (npar_loops) { case 1: LAUNCH_CUSTOM_REDUCTION(1); @@ -1180,7 +2344,8 @@ void paged_attention_custom_launcher( TORCH_CHECK(false, "Unsupported npar_loops: ", npar_loops); break; } - } +#endif + //} //if max_context_len > partition_size } #define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \ @@ -1197,14 +2362,12 @@ void paged_attention_custom_launcher( case 256: \ CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, 256); \ break; \ - case 512: \ - CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, 512); \ - break; \ default: \ TORCH_CHECK(false, "Unsupported partition size: ", partition_size); \ break; \ } - +/* +*/ #if defined(__HIPCC__) && defined(__gfx90a__) #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ if (fp8_out_scale) { \ @@ -1226,19 +2389,17 @@ void paged_attention_custom_launcher( case 16: \ CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \ break; \ - case 32: \ - CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \ - break; \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ break; \ } - +/* + case 32: \ + CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \ + break; \ +*/ #define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \ switch (head_size) { \ - case 64: \ - CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64); \ - break; \ case 128: \ CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128); \ break; \ @@ -1246,7 +2407,11 @@ void paged_attention_custom_launcher( TORCH_CHECK(false, "Unsupported head size: ", head_size); \ break; \ } - +/* + case 64: \ + CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64); \ + break; \ +*/ void paged_attention( torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] From 4145bae5edaa3204872372ad921a7fe47b2d5926 Mon Sep 17 00:00:00 2001 From: sanyalington Date: Thu, 26 Dec 2024 10:18:23 +0000 Subject: [PATCH 06/15] load nt optional for kv --- csrc/rocm/attention.cu | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 5a7d048018f1..0ddb751d135e 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -90,6 +90,27 @@ __device__ __forceinline__ void store(T value, T* addr) { addr[0] = value; } + +template +__device__ __forceinline__ T loadnt(T* addr) { + return __builtin_nontemporal_load(addr); +} + +__device__ __forceinline__ _B16x8 load_ntmprl_16Byte(const _B16x8* addr) { + auto addr_alias = reinterpret_cast(addr); + auto dat0 = loadnt(addr_alias); + auto dat1 = loadnt(addr_alias + 1); + auto dat2 = loadnt(addr_alias + 2); + auto dat3 = loadnt(addr_alias + 3); + //auto dat0 = *(addr_alias); + //auto dat1 = *(addr_alias+1); + //auto dat2 = *(addr_alias+2); + //auto dat3 = *(addr_alias+3); + auto res = make_float4(dat0,dat1,dat2,dat3); + return *reinterpret_cast<_B16x8*>(&res); +} + + template __device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA, const _B16x4& inpB, @@ -600,6 +621,7 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 const cache_t* k_fetch_ptr = k_ptr3 + offset1 * BLOCK_SIZE * KX + offset2; const _B16x8* k_fetch_ptr_16B = reinterpret_cast(k_fetch_ptr); Klocal[token_depth][qkhe_depth] = *k_fetch_ptr_16B; + //Klocal[token_depth][qkhe_depth] = load_ntmprl_16Byte(k_fetch_ptr_16B); } } @@ -643,6 +665,7 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 const cache_t* v_fetch_ptr = v_ptr3 + vfetch_depth * CONTIGUOUS_KV_ELEMS_16B_LOAD; const _B16x8* v_fetch_ptr_16B = reinterpret_cast(v_fetch_ptr); Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = *v_fetch_ptr_16B; + //Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = load_ntmprl_16Byte(v_fetch_ptr_16B); } } } From 7bb4a481dedd4b75413ca315387eb0ae4a93daee Mon Sep 17 00:00:00 2001 From: shsanyal Date: Tue, 14 Jan 2025 18:28:20 +0000 Subject: [PATCH 07/15] enable alibi; fix gfx90a compile --- csrc/rocm/attention.cu | 74 +++++++++++++++++++++++++++--------------- 1 file changed, 48 insertions(+), 26 deletions(-) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 0ddb751d135e..45b2fb114091 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -29,11 +29,6 @@ #define __HIP__MI300_MI250__ #endif -#if defined(__HIPCC__) && (defined(__gfx940__) || \ - defined(__gfx941__) || defined(__gfx942__)) - #define __HIP__MI300__ -#endif - #if defined(NDEBUG) #undef NDEBUG #include @@ -335,6 +330,10 @@ __device__ __forceinline__ _B16x8 scaled_convert_b8x8_custom(const _B8x8 input, } __device__ __forceinline__ floatx4 to_float_fp8x4(const _B8x4& inp) { +#if defined(__gfx90a__) + float4 f32x4 = vllm::fp8::vec_conversion(*reinterpret_cast(&inp)); + return *reinterpret_cast(&f32x4); +#else //MI3xx+ optimized builtins const auto f0 = __builtin_amdgcn_cvt_pk_f32_fp8(inp, false); const auto f1 = __builtin_amdgcn_cvt_pk_f32_fp8(inp, true); floatx4 ret; @@ -343,6 +342,7 @@ __device__ __forceinline__ floatx4 to_float_fp8x4(const _B8x4& inp) { ret[2] = f1[0]; ret[3] = f1[1]; return ret; +#endif } template @@ -444,7 +444,7 @@ __device__ __forceinline__ _B16x8 convert_b8x8_custom(const _B8x8 input) { // block (partition size) template __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma16_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] @@ -625,6 +625,12 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 } } + float alibi_slope; + if constexpr(ALIBI_ENABLED) { + const int alibi_head_idx = wg_start_head_idx + lane16id; + alibi_slope = (lane16id < GQA_RATIO) ? alibi_slopes[alibi_head_idx] : 0.f; + } + constexpr int VTOKENS_PER_LANE = TOKENS_PER_WARP / ROWS_PER_WARP;// 16 * T_PAR_SIZE / 256; constexpr int VBLOCKS_PER_LANE = DIVIDE_ROUND_UP(VTOKENS_PER_LANE,BLOCK_SIZE); constexpr int VTLOOP = NWARPS; //was * TOKENS_PER_WARP / ROWS_PER_WARP / VTOKENS_PER_LANE; @@ -735,11 +741,23 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 } #endif + const int qkout_token_idx = partition_start_token_idx + TOKENS_PER_WARP * warpid + rowid * 4; + + if constexpr(ALIBI_ENABLED) { + //const int alibi_head_idx = wg_start_head_idx + lane16id; + //const float alibi_slope = (lane16id < GQA_RATIO) ? alibi_slopes[alibi_head_idx] : 0.f; + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + const int alibi_offset = local_token_idx - context_len + 1; + for (int i=0; i<4; i++) { + dout[token_depth][i] += alibi_slope * (alibi_offset + i); + } + } + } + float qk_max = -FLT_MAX; float exp_sum = 0.0f; - const int qkout_token_idx = partition_start_token_idx + TOKENS_PER_WARP * warpid + rowid * 4; - for (int token_depth = 0; token_depth < TLOOP; token_depth++) { const int local_token_idx = qkout_token_idx + token_depth * 16; for (int i=0; i<4; i++) { @@ -2105,7 +2123,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( template __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] @@ -2177,7 +2195,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( #define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \ paged_attention_ll4mi_QKV_mfma16_kernel \ + HEAD_SIZE, NTHR, ALIBI_ENABLED, GQA_RATIO> \ <<>>( \ query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ @@ -2203,7 +2221,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( context_lens_ptr, max_num_partitions, fp8_out_scale_ptr); template + int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD, bool ALIBI_ENABLED> void paged_attention_custom_launcher( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, @@ -2372,25 +2390,32 @@ void paged_attention_custom_launcher( } #define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \ - PSIZE) \ + PSIZE, ALIBI_ENABLED) \ paged_attention_custom_launcher( \ + PSIZE, ALIBI_ENABLED>( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ num_kv_heads, scale, block_tables, context_lens, max_context_len, \ alibi_slopes, k_scale, v_scale, fp8_out_scale); +#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ + OUTT, PSIZE) \ + if (alibi_slopes) { \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, true); \ + } else { \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, false); \ + } + #define CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ OUTT) \ switch (partition_size) { \ case 256: \ - CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, 256); \ + CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, 256); \ break; \ default: \ TORCH_CHECK(false, "Unsupported partition size: ", partition_size); \ break; \ } -/* -*/ + #if defined(__HIPCC__) && defined(__gfx90a__) #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ if (fp8_out_scale) { \ @@ -2412,17 +2437,19 @@ void paged_attention_custom_launcher( case 16: \ CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \ break; \ + case 32: \ + CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \ + break; \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ break; \ } -/* - case 32: \ - CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \ - break; \ -*/ + #define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \ switch (head_size) { \ + case 64: \ + CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64); \ + break; \ case 128: \ CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128); \ break; \ @@ -2430,11 +2457,6 @@ void paged_attention_custom_launcher( TORCH_CHECK(false, "Unsupported head size: ", head_size); \ break; \ } -/* - case 64: \ - CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64); \ - break; \ -*/ void paged_attention( torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] From bca50774ca573341180d2f01123af29f2f769833 Mon Sep 17 00:00:00 2001 From: shsanyal Date: Wed, 15 Jan 2025 18:00:57 +0000 Subject: [PATCH 08/15] checkpoint with head size 64 supported --- csrc/rocm/attention.cu | 304 +++++++---------------------------------- 1 file changed, 46 insertions(+), 258 deletions(-) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 45b2fb114091..6b7ecf0665d4 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -571,10 +571,11 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 const int local_qhead_idx = 4 * warpid + rowid; const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; const int64_t seq_idx64 = static_cast(seq_idx); - const scalar_t* q_ptr = q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE; //+ rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD; + const scalar_t* q_ptr = q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE; - if (local_qhead_idx < GQA_RATIO) { - const scalar_t* q_fetch_ptr = q_ptr + lane16id * CONTIGUOUS_SCALAR_ELEMS_16B; //this works for head size 128 : 16 lanes x 8 elems = 128 elems + const int qhead_element = lane16id * CONTIGUOUS_SCALAR_ELEMS_16B; + if ( (local_qhead_idx < GQA_RATIO) && (qhead_element(q_fetch_ptr); _B16x8 tmp = *q_fetch_ptr_16B; if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { @@ -683,7 +684,7 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 } floatx4 dout[TLOOP]; -#if 1 //Q stored in registers + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { dout[token_depth] = {0}; for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { @@ -710,28 +711,6 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 dout[token_depth] *= scale2; } -#else //Q in shared - _B16x4 tmpQ[QKHELOOP][2]; - for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { - tmpQ[qkhe_depth][0] = shared_logits[qkhe_depth][rowid][lane16id][0]; - tmpQ[qkhe_depth][1] = shared_logits[qkhe_depth][rowid][lane16id][1]; - } - - for (int token_depth = 0; token_depth < TLOOP; token_depth++) { - dout[token_depth] = {0}; - for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { - //for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { - for (int i=0; i<2; i++) { - dout[token_depth] = gcn_mfma16x16x16_instr(Klocal[token_depth][qkhe_depth].xy[i], - tmpQ[qkhe_depth][i], //shared_logits[qkhe_depth][rowid][lane16id][i], - dout[token_depth]); - } - //} - } - dout[token_depth] *= scale; - } -#endif - #if 0 //DEBUG ONLY qk * scale for (int token_depth = 0; token_depth < TLOOP; token_depth++) { auto qkout_ptr2 = qkout_ptr + warpid * TLOOP * 16 + token_depth * 16 + rowid * 4; @@ -835,6 +814,8 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { shared_logits[warpid][token_depth][lane16id][rowid] = from_floatx4_rtz(dout[token_depth]); } else { + _B16x4* shared_logits_ptr = reinterpret_cast<_B16x4*>(shared_logits); + //shared_logits_ptr[warpid*4*16*4 + token_depth*16*4 + lane16id*4 + rowid] = from_floatx4_rtz(dout[token_depth]); shared_logits[warpid][token_depth][lane16id][rowid] = from_floatx4_rtz(dout[token_depth]); } } @@ -883,21 +864,6 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 } #endif _B16x4 outelems[VHELOOP]; - _B16x4 S_local[VTLOOP][2][2]; - if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { - for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { - //for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { - for (int j=0; j<2; j++) { - for (int i=0; i<2; i++) { - const int offset = 4*rowid + 2*j + i; - const int offset1 = offset % 4; - const int offset2 = offset / 4; - S_local[vtoken_depth][j][i] = shared_logits[vtoken_depth][offset2][lane16id][offset1]; - } - } - //} - } - } //v layout: 16he across lanes x 16 tokens per lane for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { @@ -940,11 +906,11 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 const int offset = 4*rowid + 2*j + i; const int offset1 = offset % 4; const int offset2 = offset / 4; + _B16x4* shared_logits_ptr = reinterpret_cast<_B16x4*>(shared_logits); tmp_out = gcn_mfma16x16x16_instr(Vlocaltmp.xy[i], - S_local[vtoken_depth][j][i], + shared_logits[vtoken_depth][offset2][lane16id][offset1], + //shared_logits_ptr[vtoken_depth*4*16*4 + offset2*16*4 + lane16id*4 + offset1], tmp_out); - //shared_logits[vtoken_depth][offset2][lane16id][offset1], - //tmp_out); } } } @@ -967,10 +933,12 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 __syncthreads(); if (warpid == 0) { - _B16x8 vout[GQA_RATIO4]; + _B16x8 vout[GQA_RATIO4]; + //each lane writes out 16Bytes of tmp_out along head elem dimension + const int head_elem_idx = lane16id * 8; + if (head_elem_idx < HEAD_SIZE) { for (int h = 0; h < GQA_RATIO4; h++) { const int local_head_idx = 4 * h + rowid; - const int head_elem_idx = lane16id * 8; const int offset1 = (head_elem_idx / 16)%4; const int offset2 = head_elem_idx / 16 / NWARPS; const int offset3 = (head_elem_idx / 4)%4; @@ -987,12 +955,12 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 if (local_head_idx < GQA_RATIO) { const int out_head_idx = wg_start_head_idx + local_head_idx; scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult; - const int head_elem_idx = lane16id * 8; scalar_t* out_ptr3 = out_ptr2 + head_elem_idx; _B16x8* out_ptr_B16x8 = reinterpret_cast<_B16x8*>(out_ptr3); *out_ptr_B16x8 = vout[h]; } } + } } #endif @@ -1071,7 +1039,7 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 // block (partition size) template __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] @@ -1110,8 +1078,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( return; } constexpr int QHLOOP = - DIVIDE_ROUND_UP(GQA_RATIO, 4); // each 4 lanes fetch 4 different qheads, - // total qheads =8, so qhloop is 2 + DIVIDE_ROUND_UP(GQA_RATIO, 4); // each 4 lanes fetch 4 different qheads constexpr int GQA_RATIO4 = 4 * QHLOOP; __shared__ float shared_qk_max[NWARPS][GQA_RATIO4 + 1]; __shared__ float shared_exp_sum[NWARPS][GQA_RATIO4 + 1]; @@ -1172,14 +1139,11 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( // fetch vphysical block numbers up front const int warp_start_block_idx = warp_start_token_idx / BLOCK_SIZE; - if constexpr (GQA_RATIO < 12) { - #pragma unroll - for (int b = 0; b < VBLOCKS; b++) { + for (int b = 0; b < VBLOCKS; b++) { const int vblock_idx = warp_start_block_idx + b; const int vblock_idx_ctx = (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block; vphysical_blocks[b] = block_table[vblock_idx_ctx]; - } } // each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems @@ -1228,8 +1192,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( #if 1 float alibi_slope[QHLOOP]; - if (alibi_slopes != nullptr) { - #pragma unroll + if constexpr(ALIBI_ENABLED) { for (int h = 0; h < QHLOOP; h++) { const int qhead_idx = h * 4 + lane4id; alibi_slope[h] = (qhead_idx < GQA_RATIO) @@ -1258,17 +1221,6 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } #endif - // fetch vphysical block numbers up front - if constexpr (GQA_RATIO >= 12) { - #pragma unroll - for (int b = 0; b < VBLOCKS; b++) { - const int vblock_idx = warp_start_block_idx + b; - const int vblock_idx_ctx = - (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block; - vphysical_blocks[b] = block_table[vblock_idx_ctx]; - } - } - #if 1 //fetch vcache in normal case const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { @@ -1349,8 +1301,6 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( Klocal[x].xy[1], dout[h]);\ } - //#pragma unroll - //for (int h = 0; h < QHLOOP; h++) { QK_mfma(0); QK_mfma(1); QK_mfma(2); @@ -1369,7 +1319,6 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( QK_mfma(14); QK_mfma(15); } - //} #undef QK_mfma float scale2 = scale; if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { @@ -1378,83 +1327,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( #pragma unroll for (int h = 0; h < QHLOOP; h++) { dout[h] *= scale2; - //if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { - // dout[h] *= k_scale; - //} } -#if 0 - #pragma unroll - for (int h = 0; h < QHLOOP; h++) { - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[0].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[0].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[1].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[1].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[2].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[2].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[3].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[3].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[4].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[4].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[5].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[5].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[6].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[6].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[7].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[7].xy[1], dout[h]); - if constexpr (KHELOOP > 8) { - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[8].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[8].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[9].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[9].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[10].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[10].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[11].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[11].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[12].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[12].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[13].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[13].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[14].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[14].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[15].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[15].xy[1], dout[h]); - } // KHELOOP>8 - dout[h] *= scale; - } -#endif - #if 0 if (alibi_slopes != nullptr) { float alibi_slope_local[GQA_RATIO]; @@ -1527,11 +1400,9 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( const int lane4_token_idx = 4 * (global_token_idx >> 2); #if 1 //alibi after transpose - const int alibi_offset = lane4_token_idx - context_len + 1; - if (alibi_slopes != nullptr) { - #pragma unroll + if constexpr(ALIBI_ENABLED) { + const int alibi_offset = lane4_token_idx - context_len + 1; for (int h = 0; h < QHLOOP; h++) { - #pragma unroll for (int i = 0; i < 4; i++) { dout[h][i] += alibi_slope[h] * (alibi_offset + i); } @@ -1641,7 +1512,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( _B16x4 logits[QHLOOP]; #pragma unroll for (int h = 0; h < QHLOOP; h++) { - logits[h] = from_floatx4(dout[h]); + logits[h] = from_floatx4_rtz(dout[h]); } @@ -1786,51 +1657,6 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( #endif #undef SV_mfma -#if 0 - // iterate across heads - #pragma unroll - for (int qh = 0; qh < QHLOOP; qh++) { - // iterate over each v head elem (within head_size) - #pragma unroll - for (int vh = 0; vh < VHELOOP; vh++) { - floatx4 acc = {0}; - // iterate over tokens - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][0].xy[0], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][0].xy[1], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][1].xy[0], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][1].xy[1], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][2].xy[0], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][2].xy[1], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][3].xy[0], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][3].xy[1], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][4].xy[0], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][4].xy[1], - acc); - acc = gcn_mfma_instr(logits[qh], - Vlocal[vh][5].xy[0], acc); - acc = gcn_mfma_instr(logits[qh], - Vlocal[vh][5].xy[1], acc); - acc = gcn_mfma_instr(logits[qh], - Vlocal[vh][6].xy[0], acc); - acc = gcn_mfma_instr(logits[qh], - Vlocal[vh][6].xy[1], acc); - acc = gcn_mfma_instr(logits[qh], - Vlocal[vh][7].xy[0], acc); - acc = gcn_mfma_instr(logits[qh], - Vlocal[vh][7].xy[1], acc); - vout_shared[qh][vh][laneid][warpid] = from_floatx4(acc); - } - } -#endif } // warp in context __syncthreads(); @@ -1856,58 +1682,20 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } - if (context_len > partition_size) { - scalar_t* out_ptr = out + - seq_idx * num_heads * max_num_partitions * HEAD_SIZE + - partition_idx * HEAD_SIZE; - const int out_num_partitions = max_num_partitions; - bit16_t* out_ptr_b16 = reinterpret_cast(out_ptr); - #pragma unroll - for (int qh = 0; qh < QHLOOP; qh++) { - #pragma unroll - for (int vh = 0; vh < VHELOOP; vh++) { - const int head_size_elem = vh * WARP_SIZE + laneid; - #pragma unroll - for (int i = 0; i < 4; i++) { - const int head_idx = 4 * qh + i; - if (head_idx < GQA_RATIO) { - out_ptr_b16[(wg_start_head_idx + head_idx) * out_num_partitions * - HEAD_SIZE + - head_size_elem] = vout[qh][vh][i]; - } - } - } - } - } // context_len > partition_size - else { - bit8_t* final_out_ptr_b8; - bit16_t* final_out_ptr_b16; - if constexpr (std::is_same::value) { - final_out_ptr_b8 = final_out + seq_idx * num_heads * HEAD_SIZE; - } else { - OUTT* out_ptr = final_out + seq_idx * num_heads * HEAD_SIZE; - final_out_ptr_b16 = reinterpret_cast(out_ptr); - } - #pragma unroll - for (int qh = 0; qh < QHLOOP; qh++) { - #pragma unroll - for (int vh = 0; vh < VHELOOP; vh++) { - const int head_size_elem = vh * WARP_SIZE + laneid; - #pragma unroll - for (int i = 0; i < 4; i++) { - const int head_idx = 4 * qh + i; - if (head_idx < GQA_RATIO) { - if constexpr (std::is_same::value) { - const float tmpf = - out_scale * to_float_b16(vout[qh][vh][i]); - const OUTT tmp = hip_fp8(tmpf).data; - final_out_ptr_b8[(wg_start_head_idx + head_idx) * HEAD_SIZE + - head_size_elem] = tmp; - } else { - final_out_ptr_b16[(wg_start_head_idx + head_idx) * HEAD_SIZE + - head_size_elem] = vout[qh][vh][i]; - } - } + scalar_t* out_ptr = out + + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + partition_idx * HEAD_SIZE; + const int out_num_partitions = max_num_partitions; + bit16_t* out_ptr_b16 = reinterpret_cast(out_ptr); + for (int qh = 0; qh < QHLOOP; qh++) { + for (int vh = 0; vh < VHELOOP; vh++) { + const int head_size_elem = vh * WARP_SIZE + laneid; + for (int i = 0; i < 4; i++) { + const int head_idx = 4 * qh + i; + if (head_idx < GQA_RATIO) { + out_ptr_b16[(wg_start_head_idx + head_idx) * out_num_partitions * + HEAD_SIZE + + head_size_elem] = vout[qh][vh][i]; } } } @@ -2150,7 +1938,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_ template __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] @@ -2205,7 +1993,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( #define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \ paged_attention_ll4mi_QKV_kernel \ + HEAD_SIZE, NTHR, ALIBI_ENABLED, GQA_RATIO> \ <<>>( \ query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ @@ -2275,20 +2063,20 @@ void paged_attention_custom_launcher( const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); switch (gqa_ratio) { case 1: - //LAUNCH_CUSTOM_ATTENTION(1); - LAUNCH_CUSTOM_ATTENTION_MFMA16(1); + LAUNCH_CUSTOM_ATTENTION(1); + //LAUNCH_CUSTOM_ATTENTION_MFMA16(1); break; case 2: - //LAUNCH_CUSTOM_ATTENTION(2); - LAUNCH_CUSTOM_ATTENTION_MFMA16(2); + LAUNCH_CUSTOM_ATTENTION(2); + //LAUNCH_CUSTOM_ATTENTION_MFMA16(2); break; case 3: - //LAUNCH_CUSTOM_ATTENTION(3); - LAUNCH_CUSTOM_ATTENTION_MFMA16(3); + LAUNCH_CUSTOM_ATTENTION(3); + //LAUNCH_CUSTOM_ATTENTION_MFMA16(3); break; case 4: - //LAUNCH_CUSTOM_ATTENTION(4); - LAUNCH_CUSTOM_ATTENTION_MFMA16(4); + LAUNCH_CUSTOM_ATTENTION(4); + //LAUNCH_CUSTOM_ATTENTION_MFMA16(4); break; case 5: //LAUNCH_CUSTOM_ATTENTION(5); From e2e7d19cf9bc54a9ca5aca2a67240c2bdcdfa51b Mon Sep 17 00:00:00 2001 From: shsanyal Date: Wed, 15 Jan 2025 19:36:28 +0000 Subject: [PATCH 09/15] block size 32 fix --- csrc/rocm/attention.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 6b7ecf0665d4..2c4078bb9526 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -633,7 +633,7 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 } constexpr int VTOKENS_PER_LANE = TOKENS_PER_WARP / ROWS_PER_WARP;// 16 * T_PAR_SIZE / 256; - constexpr int VBLOCKS_PER_LANE = DIVIDE_ROUND_UP(VTOKENS_PER_LANE,BLOCK_SIZE); + constexpr int VBLOCKS_PER_LANE = 1; //was DIVIDE_ROUND_UP(VTOKENS_PER_LANE,BLOCK_SIZE); constexpr int VTLOOP = NWARPS; //was * TOKENS_PER_WARP / ROWS_PER_WARP / VTOKENS_PER_LANE; constexpr int VTLANELOOP = DIVIDE_ROUND_UP(VTOKENS_PER_LANE , CONTIGUOUS_KV_ELEMS_16B_LOAD); //optimized for 16B fetches; assumes minimum block size is 16 constexpr int VHELOOP = HEAD_SIZE / 16 / NWARPS; @@ -655,7 +655,7 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 _B16x8 Vlocal[VTLOOP][VHELOOP][VTLANELOOP]; //this could be B8x16 too - const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; + const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride + ((rowid * VTOKENS_PER_LANE)%BLOCK_SIZE); //v fetches are 16head elems across lanes x 16 tokens per lane for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { @@ -664,7 +664,7 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { - const int vblock_depth = vfetch_depth * CONTIGUOUS_KV_ELEMS_16B_LOAD / BLOCK_SIZE; + const int vblock_depth = 0; //was vfetch_depth * CONTIGUOUS_KV_ELEMS_16B_LOAD / BLOCK_SIZE; //const int token_depth = vtoken_depth * VBLOCKS_PER_LANE + vblock_depth; const int64_t vblock_number = static_cast(vphysical_block_number[vtoken_depth][vblock_depth]); const cache_t* v_ptr3 = v_ptr2 + (vblock_number * kv_block_stride); From 2a669e44a470c33bfefdbdce5bba8e31f644df29 Mon Sep 17 00:00:00 2001 From: shsanyal Date: Thu, 16 Jan 2025 07:57:59 +0000 Subject: [PATCH 10/15] clean up --- csrc/rocm/attention.cu | 67 ++++++++++-------------------------------- 1 file changed, 16 insertions(+), 51 deletions(-) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 2c4078bb9526..7f3984d84e80 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -440,8 +440,8 @@ __device__ __forceinline__ _B16x8 convert_b8x8_custom(const _B8x8 input) { return ret; } /////////////////////////////////////// -// grid (num_seqs, num_partitions,num_heads/gqa_ratio) -// block (partition size) +// grid (num_seqs, num_partitions,num_kv_heads) +// block (256) template (seq_idx); @@ -602,7 +596,7 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 } #endif - constexpr int KX = 16 / sizeof(cache_t); + constexpr int KX = 16 / sizeof(cache_t); //vLLM defines x as 16 Bytes of kv cache elements const cache_t* k_ptr = k_cache + wg_start_kv_head_idx * kv_head_stride; const int row_head_elem = rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD; @@ -632,9 +626,9 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 alibi_slope = (lane16id < GQA_RATIO) ? alibi_slopes[alibi_head_idx] : 0.f; } - constexpr int VTOKENS_PER_LANE = TOKENS_PER_WARP / ROWS_PER_WARP;// 16 * T_PAR_SIZE / 256; - constexpr int VBLOCKS_PER_LANE = 1; //was DIVIDE_ROUND_UP(VTOKENS_PER_LANE,BLOCK_SIZE); - constexpr int VTLOOP = NWARPS; //was * TOKENS_PER_WARP / ROWS_PER_WARP / VTOKENS_PER_LANE; + constexpr int VTOKENS_PER_LANE = TOKENS_PER_WARP / ROWS_PER_WARP; //16 tokens per lane + constexpr int VBLOCKS_PER_LANE = 1; //assumes block size >=16, each lane can correspond to 1 block only + constexpr int VTLOOP = NWARPS; //corresponds to tokens across warps constexpr int VTLANELOOP = DIVIDE_ROUND_UP(VTOKENS_PER_LANE , CONTIGUOUS_KV_ELEMS_16B_LOAD); //optimized for 16B fetches; assumes minimum block size is 16 constexpr int VHELOOP = HEAD_SIZE / 16 / NWARPS; @@ -664,8 +658,7 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { - const int vblock_depth = 0; //was vfetch_depth * CONTIGUOUS_KV_ELEMS_16B_LOAD / BLOCK_SIZE; - //const int token_depth = vtoken_depth * VBLOCKS_PER_LANE + vblock_depth; + const int vblock_depth = 0; const int64_t vblock_number = static_cast(vphysical_block_number[vtoken_depth][vblock_depth]); const cache_t* v_ptr3 = v_ptr2 + (vblock_number * kv_block_stride); @@ -677,9 +670,10 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 } } - //__syncthreads(); //if using shared Q + //__syncthreads(); //if using shared Q (deprecated) float scale2 = scale; if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + //multiply by k_scale if fp8 kv cache scale2 *= k_scale; } @@ -723,8 +717,6 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 const int qkout_token_idx = partition_start_token_idx + TOKENS_PER_WARP * warpid + rowid * 4; if constexpr(ALIBI_ENABLED) { - //const int alibi_head_idx = wg_start_head_idx + lane16id; - //const float alibi_slope = (lane16id < GQA_RATIO) ? alibi_slopes[alibi_head_idx] : 0.f; for (int token_depth = 0; token_depth < TLOOP; token_depth++) { const int local_token_idx = qkout_token_idx + token_depth * 16; const int alibi_offset = local_token_idx - context_len + 1; @@ -767,8 +759,6 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 float* shared_mem = reinterpret_cast(shared_logits); if (laneid < 16) { - //shared_qk_max[warpid][lane16id] = qk_max; - //shared_exp_sum[warpid][lane16id] = exp_sum; const int qk_max_offset = warpid*16 + lane16id; shared_mem[qk_max_offset] = qk_max; const int exp_sum_offset = NWARPS*16 + qk_max_offset; @@ -793,28 +783,25 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 float partition_exp_sum = 0.0f; for (int w=0; w(dout[token_depth]); } else { - _B16x4* shared_logits_ptr = reinterpret_cast<_B16x4*>(shared_logits); + //_B16x4* shared_logits_ptr = reinterpret_cast<_B16x4*>(shared_logits); //shared_logits_ptr[warpid*4*16*4 + token_depth*16*4 + lane16id*4 + rowid] = from_floatx4_rtz(dout[token_depth]); shared_logits[warpid][token_depth][lane16id][rowid] = from_floatx4_rtz(dout[token_depth]); } @@ -840,28 +827,6 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 auto tmp = shared_tokens[warpid][token_depth][lane16id][rowid]; *qkout_write_ptr = tmp; } -#endif -#if 0 - if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { - for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { - for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { - for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { - _B16x8 Vtmp = Vlocal[vtoken_depth][vhe_depth][vfetch_depth]; - _B8x16 Vtmp8x16 = *reinterpret_cast<_B8x16*>(&Vtmp); - for (int j=0; j<2; j++) { - _B8x8 Vtmp8x8 = Vtmp8x16.xy[j]; - _B16x8 Vlocaltmp = convert_b8x8_custom(Vtmp8x8); - for (int i=0; i<2; i++) { - const int offset = 4*rowid + 2*j + i; - const int offset1 = offset % 4; - const int offset2 = offset / 4; - tmp_out = gcn_mfma16x16x16_instr(Vlocaltmp.xy[i], - shared_logits[vtoken_depth][offset2][lane16id][offset1], - tmp_out); - } - } - } - } #endif _B16x4 outelems[VHELOOP]; //v layout: 16he across lanes x 16 tokens per lane @@ -874,7 +839,7 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { for (int i=0; i<2; i++) { - //TODO generalize this for 8 bit dtypes: each lane needs 2*vfetch_depth + 2 _B16x4 K/token dimension elems; each row is multiplied by a factor of 4 + //generalize this for 8 bit dtypes: each lane needs 2*vfetch_depth + 2 _B16x4 K/token dimension elems; each row is multiplied by a factor of 4 //layout: lane in depth dimension | row across -> //0 4 8 12 //1 5 9 13 @@ -906,7 +871,7 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 const int offset = 4*rowid + 2*j + i; const int offset1 = offset % 4; const int offset2 = offset / 4; - _B16x4* shared_logits_ptr = reinterpret_cast<_B16x4*>(shared_logits); + //_B16x4* shared_logits_ptr = reinterpret_cast<_B16x4*>(shared_logits); tmp_out = gcn_mfma16x16x16_instr(Vlocaltmp.xy[i], shared_logits[vtoken_depth][offset2][lane16id][offset1], //shared_logits_ptr[vtoken_depth*4*16*4 + offset2*16*4 + lane16id*4 + offset1], From 51623be13fee8d65dfbdadb05d1aa73732772b8e Mon Sep 17 00:00:00 2001 From: shsanyal Date: Thu, 16 Jan 2025 18:14:56 +0000 Subject: [PATCH 11/15] further clean up and comments --- csrc/rocm/attention.cu | 671 ++++++----------------------------------- 1 file changed, 91 insertions(+), 580 deletions(-) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 7f3984d84e80..c3de8efb07c4 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -73,19 +73,7 @@ typedef struct _B8x16 { _B8x8 xy[2]; } _B8x16; -////// Non temporal load stores /////// - -template -__device__ __forceinline__ T load(T* addr) { - return addr[0]; -} - -template -__device__ __forceinline__ void store(T value, T* addr) { - addr[0] = value; -} - - +////// Non temporal loads /////// template __device__ __forceinline__ T loadnt(T* addr) { return __builtin_nontemporal_load(addr); @@ -97,17 +85,14 @@ __device__ __forceinline__ _B16x8 load_ntmprl_16Byte(const _B16x8* addr) { auto dat1 = loadnt(addr_alias + 1); auto dat2 = loadnt(addr_alias + 2); auto dat3 = loadnt(addr_alias + 3); - //auto dat0 = *(addr_alias); - //auto dat1 = *(addr_alias+1); - //auto dat2 = *(addr_alias+2); - //auto dat3 = *(addr_alias+3); auto res = make_float4(dat0,dat1,dat2,dat3); return *reinterpret_cast<_B16x8*>(&res); } +/////////////////////////////////// template -__device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA, +__device__ __forceinline__ floatx4 gcn_mfma4x4x4_instr(const _B16x4& inpA, const _B16x4& inpB, const floatx4& inpC) { if constexpr (std::is_same::value) { @@ -183,23 +168,7 @@ __device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) { __hip_bfloat16 b; } t16; _B16x4 ret; -#if 0 - #pragma unroll - for (int i = 0; i < 4; i++) { - t16.f = (_Float16)inp[i]; - ret[i] = t16.u; - } - return ret; -#else if constexpr (std::is_same::value) { -#if 0 - #pragma unroll - for (int i = 0; i < 4; i++) { - t16.f = (_Float16)inp[i]; - ret[i] = t16.u; - } - return ret; -#else union h2cvt { __half2 h2[2]; _B16x4 b16x4; @@ -207,25 +176,20 @@ __device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) { u.h2[0] = __float22half2_rn(make_float2(inp[0],inp[1])); u.h2[1] = __float22half2_rn(make_float2(inp[2],inp[3])); return u.b16x4; -#endif } else if constexpr (std::is_same::value) { - #pragma unroll for (int i = 0; i < 4; i++) { union fcvt { uint32_t u32; float f32; } u; u.f32 = inp[i]; - u.u32 += 0x7fff + ((u.u32 >> 16) & 1); //RNE with no nan/inf check + u.u32 += 0x7fff + ((u.u32 >> 16) & 1); //BF16 RNE with no nan/inf check ret[i] = uint16_t(u.u32 >> 16); - //t16.b = __float2bfloat16(inp[i]); - //ret[i] = t16.u; } return ret; } else { static_assert(false, "unsupported 16b dtype"); } -#endif } template @@ -237,27 +201,7 @@ __device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, __hip_bfloat16 b; } t1, t2, res; _B16x4 ret; -#if 0 - #pragma unroll - for (int i = 0; i < 4; i++) { - t1.u = inp1[i]; - t2.u = inp2[i]; - res.f = t1.f + t2.f; - ret[i] = res.u; - } - return ret; -#else if constexpr (std::is_same::value) { -#if 0 - #pragma unroll - for (int i = 0; i < 4; i++) { - t1.u = inp1[i]; - t2.u = inp2[i]; - res.f = t1.f + t2.f; - ret[i] = res.u; - } - return ret; -#else union h2cvt { _B16x4 b16x4; __half2 h2[2]; @@ -267,9 +211,7 @@ __device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, s.h2[0] = u1.h2[0] + u2.h2[0]; s.h2[1] = u1.h2[1] + u2.h2[1]; return s.b16x4; -#endif } else if constexpr (std::is_same::value) { - #pragma unroll for (int i = 0; i < 4; i++) { union fcvt { float f32; @@ -279,16 +221,11 @@ __device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, u2.i32 = uint32_t(inp2[i])<<16; s.f32 = u1.f32 + u2.f32; ret[i] = uint16_t(s.i32>>16); - //t1.u = inp1[i]; - //t2.u = inp2[i]; - //res.b = t1.b + t2.b; - //ret[i] = res.u; } return ret; } else { static_assert(false, "unsupported 16b dtype"); } -#endif } template @@ -371,61 +308,8 @@ __device__ __forceinline__ _B16x4 from_floatx4_rtz(const floatx4& inp) { } } -template -__device__ __forceinline__ _B16x4 from_floatx4_trunc(const floatx4& inp) { - _B16x4 ret; - if constexpr (std::is_same::value) { - int32_t tmpf8; - tmpf8 = __builtin_amdgcn_cvt_pk_fp8_f32(inp[0], inp[1], tmpf8, false); - tmpf8 = __builtin_amdgcn_cvt_pk_fp8_f32(inp[2], inp[3], tmpf8, true); - const auto f0 = __builtin_amdgcn_cvt_pk_f32_fp8(tmpf8, false); - const auto f1 = __builtin_amdgcn_cvt_pk_f32_fp8(tmpf8, true); - union h2cvt { - _Half2 h2[2]; - _B16x4 b16x4; - } u; - u.h2[0] = __builtin_amdgcn_cvt_pkrtz(f0[0],f0[1]); - u.h2[1] = __builtin_amdgcn_cvt_pkrtz(f1[0],f1[1]); - return u.b16x4; - } else if constexpr (std::is_same::value) { - int32_t tmpf8; - tmpf8 = __builtin_amdgcn_cvt_pk_fp8_f32(inp[0], inp[1], tmpf8, false); - tmpf8 = __builtin_amdgcn_cvt_pk_fp8_f32(inp[2], inp[3], tmpf8, true); - const auto f0 = __builtin_amdgcn_cvt_pk_f32_fp8(tmpf8, false); - const auto f1 = __builtin_amdgcn_cvt_pk_f32_fp8(tmpf8, true); - floatx4 tmpf; - tmpf[0] = f0[0]; - tmpf[1] = f0[1]; - tmpf[2] = f1[0]; - tmpf[3] = f1[1]; - for (int i = 0; i < 4; i++) { - union fcvt { - uint32_t i32; - float f32; - } u; - u.f32 = tmpf[i]; - ret[i] = uint16_t(u.i32 >> 16); - } - return ret; - } else { - static_assert(false, "unsupported 16b dtype"); - } -} - - - template __device__ __forceinline__ _B16x8 convert_b8x8_custom(const _B8x8 input) { -#if 0 - union { - floatx4 f32x4[2]; - vllm::Float8_ f32x8; - _B8x8 b8x8[2]; - } tmpf8; - tmpf8.f32x8 = vllm::fp8::vec_conversion(*reinterpret_cast(&input)); - //tmpf8.b8x8[0] = input; - //tmpf8.b8x8[1] = input; -#endif union { _B8x8 b8x8; _B8x4 b8x4[2]; @@ -435,10 +319,9 @@ __device__ __forceinline__ _B16x8 convert_b8x8_custom(const _B8x8 input) { for (int i=0; i<2; i++) { ret.xy[i] = from_floatx4_rtz( to_float_fp8x4(tmp.b8x4[i]) ); } - //ret.xy[0] = from_floatx4(tmpf8.f32x4[0]); - //ret.xy[1] = from_floatx4(tmpf8.f32x4[1]); return ret; } + /////////////////////////////////////// // grid (num_seqs, num_partitions,num_kv_heads) // block (256) @@ -537,31 +420,7 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 kphysical_block_number[token_depth] = block_table_seq[kblock_idx]; } -#if 0 //fetch Q into registers (deprecated) - - const int local_qhead_idx = lane16id % GQA_RATIO; - const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; - const int64_t seq_idx64 = static_cast(seq_idx); - const scalar_t* q_ptr = q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE + rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD; - - if (lane16id < GQA_RATIO) { - for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { - const scalar_t* q_ptr2 = q_ptr + qkhe_depth * QKHE_PER_FETCH; - for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { - const scalar_t* q_fetch_ptr = q_ptr2 + qkratio * CONTIGUOUS_SCALAR_ELEMS_16B; - const _B16x8* q_fetch_ptr_16B = reinterpret_cast(q_fetch_ptr); - Qlocal[qkhe_depth][qkratio] = *q_fetch_ptr_16B; - } - } - } else { - for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { - for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { - Qlocal[qkhe_depth][qkratio].xy[0] = {0}; - Qlocal[qkhe_depth][qkratio].xy[1] = {0}; - } - } - } -#else //fetch Q in shared and then write to registers + //fetch Q in shared across warps and then write to registers const int local_qhead_idx = 4 * warpid + rowid; const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; const int64_t seq_idx64 = static_cast(seq_idx); @@ -594,13 +453,15 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 } } } -#endif + + //set to true to enable non temporal kv loads: has some benefit in very high batch size cases + constexpr bool NT_KV_LOAD = false; constexpr int KX = 16 / sizeof(cache_t); //vLLM defines x as 16 Bytes of kv cache elements const cache_t* k_ptr = k_cache + wg_start_kv_head_idx * kv_head_stride; const int row_head_elem = rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD; - + //fetch K values for (int token_depth = 0; token_depth < TLOOP; token_depth++) { const int64_t kblock_number = static_cast(kphysical_block_number[token_depth]); const cache_t* k_ptr2 = k_ptr + kblock_number * kv_block_stride; @@ -615,8 +476,11 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 const int offset2 = head_elem % KX; const cache_t* k_fetch_ptr = k_ptr3 + offset1 * BLOCK_SIZE * KX + offset2; const _B16x8* k_fetch_ptr_16B = reinterpret_cast(k_fetch_ptr); - Klocal[token_depth][qkhe_depth] = *k_fetch_ptr_16B; - //Klocal[token_depth][qkhe_depth] = load_ntmprl_16Byte(k_fetch_ptr_16B); + if constexpr(NT_KV_LOAD) { + Klocal[token_depth][qkhe_depth] = load_ntmprl_16Byte(k_fetch_ptr_16B); + } else { + Klocal[token_depth][qkhe_depth] = *k_fetch_ptr_16B; + } } } @@ -664,13 +528,16 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 const cache_t* v_fetch_ptr = v_ptr3 + vfetch_depth * CONTIGUOUS_KV_ELEMS_16B_LOAD; const _B16x8* v_fetch_ptr_16B = reinterpret_cast(v_fetch_ptr); - Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = *v_fetch_ptr_16B; - //Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = load_ntmprl_16Byte(v_fetch_ptr_16B); + if constexpr(NT_KV_LOAD) { + Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = load_ntmprl_16Byte(v_fetch_ptr_16B); + } else { + Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = *v_fetch_ptr_16B; + } } } } - - //__syncthreads(); //if using shared Q (deprecated) + + //calculate post qk mfma scale float scale2 = scale; if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { //multiply by k_scale if fp8 kv cache @@ -678,7 +545,7 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 } floatx4 dout[TLOOP]; - + //qk mfma for (int token_depth = 0; token_depth < TLOOP; token_depth++) { dout[token_depth] = {0}; for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { @@ -705,17 +572,9 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 dout[token_depth] *= scale2; } -#if 0 //DEBUG ONLY qk * scale - for (int token_depth = 0; token_depth < TLOOP; token_depth++) { - auto qkout_ptr2 = qkout_ptr + warpid * TLOOP * 16 + token_depth * 16 + rowid * 4; - auto qkout_write_ptr = reinterpret_cast<_B16x4 *>(qkout_ptr2); - auto tmp = from_floatx4(dout[token_depth]); - *qkout_write_ptr = tmp; - } -#endif - const int qkout_token_idx = partition_start_token_idx + TOKENS_PER_WARP * warpid + rowid * 4; - + + //apply alibi if constexpr(ALIBI_ENABLED) { for (int token_depth = 0; token_depth < TLOOP; token_depth++) { const int local_token_idx = qkout_token_idx + token_depth * 16; @@ -725,7 +584,8 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 } } } - + + //calculate qk_max and exp_sum per warp and write to shared memory float qk_max = -FLT_MAX; float exp_sum = 0.0f; @@ -765,19 +625,9 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 shared_mem[exp_sum_offset] = exp_sum; } -#if 0 //DEBUG ONLY - //scalar_t* qkout_ptr = out + - // seq_idx * total_num_heads * T_PAR_SIZE + lane16id * T_PAR_SIZE; - for (int token_depth = 0; token_depth < TLOOP; token_depth++) { - //auto qkout_ptr2 = qkout_ptr + warpid * TLOOP * 16 + token_depth * 16 + rowid * 4; - //auto qkout_write_ptr = reinterpret_cast<_B16x4 *>(qkout_ptr2); - auto tmp = from_floatx4(dout[token_depth]); - shared_tokens[warpid][token_depth][lane16id][rowid] = tmp; - //*qkout_write_ptr = tmp; - } -#endif __syncthreads(); - + + //calculate partition qk_max and exp_sum float partition_qk_max = -FLT_MAX; float warp_qk_max_exp[NWARPS]; float partition_exp_sum = 0.0f; @@ -795,18 +645,14 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 const float inv_sum_scale = __fdividef(1.f, partition_exp_sum + 1e-6f) * warp_qk_max_exp[warpid]; __syncthreads(); - + + //write logits to shared mem for (int token_depth = 0; token_depth < TLOOP; token_depth++) { dout[token_depth] *= inv_sum_scale; - if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { - shared_logits[warpid][token_depth][lane16id][rowid] = from_floatx4_rtz(dout[token_depth]); - } else { - //_B16x4* shared_logits_ptr = reinterpret_cast<_B16x4*>(shared_logits); - //shared_logits_ptr[warpid*4*16*4 + token_depth*16*4 + lane16id*4 + rowid] = from_floatx4_rtz(dout[token_depth]); - shared_logits[warpid][token_depth][lane16id][rowid] = from_floatx4_rtz(dout[token_depth]); - } + //use rtz conversion for performance, with no visible impact on accuracy + shared_logits[warpid][token_depth][lane16id][rowid] = from_floatx4_rtz(dout[token_depth]); } - + //write out partition max_logits and exp_sum if (threadIdx.x < GQA_RATIO) { const int qhead_idx = lane16id; const int offset = seq_idx * total_num_heads * max_num_partitions + (wg_start_head_idx + qhead_idx) * max_num_partitions + partition_idx; @@ -816,21 +662,9 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 __syncthreads(); -#if 0 //DEBUG ONLY - scalar_t* qkout_ptr = out + - seq_idx * total_num_heads * T_PAR_SIZE + lane16id * T_PAR_SIZE; - for (int token_depth = 0; token_depth < TLOOP; token_depth++) { - auto qkout_ptr2 = qkout_ptr + warpid * TLOOP * 16 + token_depth * 16 + rowid * 4; - auto qkout_write_ptr = reinterpret_cast<_B16x4 *>(qkout_ptr2); - //dout[token_depth] *= inv_sum_scale[warpid]; - //auto tmp = from_floatx4(dout[token_depth]); - auto tmp = shared_tokens[warpid][token_depth][lane16id][rowid]; - *qkout_write_ptr = tmp; - } -#endif _B16x4 outelems[VHELOOP]; + //Softmax V mfma //v layout: 16he across lanes x 16 tokens per lane - for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { floatx4 tmp_out = {0}; @@ -848,16 +682,10 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 const int offset = rowid * VTLANELOOP * 2 + 2*vfetch_depth + i; const int offset1 = offset % 4; //4 corresponds to ROWS_PER_WARP const int offset2 = offset / 4; -#if 0 - //if output format is 16 head elems across 16 lanes, 16 qheads spread across 4 rows - tmp_out = gcn_mfma16x16x16_instr(shared_logits[vtoken_depth][offset2][lane16id][offset1], - Vlocal[vtoken_depth][vhe_depth][vfetch_depth].xy[i], tmp_out); -#else - //if output format is 16 qheads across 16 lanes, 16 head elems spread across 4 rows + //output format is 16 qheads across 16 lanes, 16 head elems spread across 4 rows tmp_out = gcn_mfma16x16x16_instr(Vlocal[vtoken_depth][vhe_depth][vfetch_depth].xy[i], shared_logits[vtoken_depth][offset2][lane16id][offset1], tmp_out); -#endif } } } else { @@ -871,10 +699,9 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 const int offset = 4*rowid + 2*j + i; const int offset1 = offset % 4; const int offset2 = offset / 4; - //_B16x4* shared_logits_ptr = reinterpret_cast<_B16x4*>(shared_logits); + //output format is 16 qheads across 16 lanes, 16 head elems spread across 4 rows tmp_out = gcn_mfma16x16x16_instr(Vlocaltmp.xy[i], shared_logits[vtoken_depth][offset2][lane16id][offset1], - //shared_logits_ptr[vtoken_depth*4*16*4 + offset2*16*4 + lane16id*4 + offset1], tmp_out); } } @@ -882,21 +709,24 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 } } + //apply post Softmax V mfma v_scale if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { tmp_out *= v_scale; } outelems[vhe_depth] = from_floatx4(tmp_out); } -#if 1 __syncthreads(); - + + //store Softmax-V mfma output to shared mem for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { - shared_logits[warpid][vhe_depth][lane16id][rowid] = outelems[vhe_depth]; //lane16 id head dimension; rowid head element dimension + //lane16 id head dimension; rowid head element dimension + shared_logits[warpid][vhe_depth][lane16id][rowid] = outelems[vhe_depth]; } __syncthreads(); - + + //write to tmp_out with coalesced writes if (warpid == 0) { _B16x8 vout[GQA_RATIO4]; //each lane writes out 16Bytes of tmp_out along head elem dimension @@ -927,86 +757,16 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 } } } - -#endif - -#if 0 - //if output format is 16 he across 16 lanes, 16 qheads spread across 4 rows - const int hsz_maxp_mult = HEAD_SIZE * max_num_partitions; - scalar_t* out_ptr = out + - seq_idx * total_num_heads * hsz_maxp_mult + partition_idx * HEAD_SIZE; - - const int vhe_offset = warpid * 16 + lane16id; - - #pragma unroll - for (int i=0; i<4; i++) { - const int local_head_idx = 4*rowid + i; - if (local_head_idx < GQA_RATIO) { - const int out_head_idx = wg_start_head_idx + local_head_idx; - scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult; - for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { - const int vhead_elem = vhe_depth * NWARPS * 16 + vhe_offset; - scalar_t* out_ptr3 = out_ptr2 + vhead_elem; - bit16_t* out_ptr_b16 = reinterpret_cast(out_ptr3); - *out_ptr_b16 = outelems[vhe_depth][i]; - } - } - } -#endif -#if 0 - //if output format is 16 qheads across 16 lanes, 16 he spread across 4 rows - if (lane16id < GQA_RATIO) { - const int hsz_maxp_mult = HEAD_SIZE * max_num_partitions; - scalar_t* out_ptr = out + - seq_idx * total_num_heads * hsz_maxp_mult + partition_idx * HEAD_SIZE; - const int local_head_idx = lane16id; - const int out_head_idx = wg_start_head_idx + local_head_idx; - scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult; - const int vhe_offset = warpid * 16 + rowid * 4; - for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { - const int vhead_elem = vhe_depth * NWARPS * 16 + vhe_offset; - scalar_t* out_ptr3 = out_ptr2 + vhead_elem; - _B16x4* out_ptr_B16x4 = reinterpret_cast<_B16x4*>(out_ptr3); - *out_ptr_B16x4 = outelems[vhe_depth]; - } - } -#endif -#if 0 //DEBUG ONLY - floatx4 partition_out[VHELOOP]; - for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { - partition_out[vhe_depth] = {0}; - for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { - partition_out[vhe_depth] += inv_sum_scale[vtoken_depth] * vout[vhe_depth][vtoken_depth]; - } - } -#endif -#if 0 //DEBUG ONLY - if (laneid < GQA_RATIO) { - auto* exp_sums_ptr = exp_sums + seq_idx * 8 * max_num_partitions + partition_idx; - floatx4 tmp = {0}; - //for (int t=0; t(from_floatx4(tmp), shared_tokens[warpid][lane4id][lane16id][rowid]); - - float2 tmpf = *reinterpret_cast(&tmp16); - *exp_sums_ptr = laneid%2 == 0 ? tmpf.x : tmpf.y; - } -#endif } + ///////////////////////////////////////////////////////////// -// grid (num_seqs, num_partitions,num_heads/gqa_ratio) -// block (partition size) +// grid (num_seqs, num_partitions, num_kv_heads) +// block (256 : partition size) template -__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( +__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, // head_size/x, block_size, x] @@ -1052,9 +812,10 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( constexpr int KHELOOP = HEAD_SIZE / x; _B16x8 Klocal[KHELOOP]; _B8x8 Klocalb8[KHELOOP]; + // v head_size dimension is distributed across warp constexpr int VHELOOP = HEAD_SIZE / - WARP_SIZE; // v head_size dimension is distributed across lanes + WARP_SIZE; constexpr int VTLOOP = 8; // 16 separate 4xtokens across warp -> 16/2 // 8xtokens constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE; @@ -1063,8 +824,9 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( _B8x8 Vlocalb8[VHELOOP][VTLOOP]; floatx4 dout[QHLOOP]; float qk_max[QHLOOP]; + __shared__ _B16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1]; - #pragma unroll + for (int h = 0; h < QHLOOP; h++) { dout[h] = {0}; qk_max[h] = -FLT_MAX; @@ -1092,17 +854,16 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( const int local_token_idx = threadIdx.x; const int global_token_idx = partition_start_token_idx + local_token_idx; + // fetch block number for k const int block_idx = (global_token_idx < context_len) ? global_token_idx / BLOCK_SIZE : last_ctx_block; - // fetch block number for q and k // int32 physical_block_number leads to overflow when multiplied with // kv_block_stride const int64_t physical_block_number = static_cast(block_table[block_idx]); // fetch vphysical block numbers up front - const int warp_start_block_idx = warp_start_token_idx / BLOCK_SIZE; for (int b = 0; b < VBLOCKS; b++) { const int vblock_idx = warp_start_block_idx + b; @@ -1116,7 +877,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE; const _B16x8* q_ptrh8 = reinterpret_cast(q_ptr); const int qhead_elemh8 = laneid / 4; - #pragma unroll + for (int h = 0; h < QHLOOP - 1; h++) { const int qhead_idx = h * 4 + lane4id; Qlocal[h] = q_ptrh8[qhead_idx * HEAD_SIZE / 8 + qhead_elemh8]; @@ -1138,14 +899,13 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( // is already cast as _H8 if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { const _B16x8* k_ptrh8 = reinterpret_cast(k_ptr); - #pragma unroll for (int d = 0; d < KHELOOP; d++) { Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset]; } } else { + //vllm defines X as 16 Bytes of elements of cache_t constexpr int X = 16 / sizeof(cache_t); const cache_t* k_ptr2 = k_ptr + physical_block_offset * X; - #pragma unroll for (int d = 0; d < KHELOOP; d++) { const int head_elem = d * 8; const int offset1 = head_elem / X; @@ -1155,7 +915,6 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } -#if 1 float alibi_slope[QHLOOP]; if constexpr(ALIBI_ENABLED) { for (int h = 0; h < QHLOOP; h++) { @@ -1165,33 +924,12 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( : 0.f; } } -#endif -#if 0 - float alibi_slope; - const int lane16id = laneid % 16; - if (alibi_slopes != nullptr) { - alibi_slope = (lane16id < GQA_RATIO) - ? alibi_slopes[wg_start_head_idx + lane16id] - : 0.f; - //#pragma unroll - // for (int h = 0; h < QHLOOP; h++) { - // for (int i=0; i<4; i++) { - // const int qhead_idx = h * 4 + i; - // alibi_slope[qhead_idx] = (qhead_idx < GQA_RATIO) - // ? alibi_slopes[wg_start_head_idx + qhead_idx] - // : 0.f; - // } - //} - //} - } -#endif -#if 1 //fetch vcache in normal case const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; + //fetch vcache in kv cache auto case if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { const _B16x8* v_ptrh8 = reinterpret_cast(v_ptr); // iterate over each v block - #pragma unroll for (int b = 0; b < VBLOCKS; b++) { // int32 physical_block_number leads to overflow when multiplied with // kv_block_stride @@ -1200,24 +938,20 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( const _B16x8* v_ptrh8b = v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; // iterate over each head elem (within head_size) - #pragma unroll for (int h = 0; h < VHELOOP; h++) { const int head_size_elem = h * WARP_SIZE + laneid; const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; // iterate over all velems within block - #pragma unroll for (int d = 0; d < BLOCK_SIZE / 8; d++) { Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; } } } } //if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) -#endif -#if 1 //fetch vcache in fp8 case + //fetch vcache in fp8 case else { // if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) const _B8x8* v_ptrh8 = reinterpret_cast(v_ptr); // iterate over each v block - #pragma unroll for (int b = 0; b < VBLOCKS; b++) { // int32 physical_block_number leads to overflow when multiplied with // kv_block_stride @@ -1226,46 +960,28 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( const _B8x8* v_ptrh8b = v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; // iterate over each head elem (within head_size) - #pragma unroll for (int h = 0; h < VHELOOP; h++) { const int head_size_elem = h * WARP_SIZE + laneid; const _B8x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; // iterate over all velems within block - #pragma unroll for (int d = 0; d < BLOCK_SIZE / 8; d++) { Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; - //const _B8x8 Vlocalb8 = v_ptrh8be[d]; - //Vlocal[h][b * BLOCK_SIZE / 8 + d] = - // scaled_convert_b8x8(Vlocalb8, v_scale); } } } } -#endif -#if 0 //cvt kf8 to kf/bf16 up front - if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { - #pragma unroll - for (int d = 0; d < KHELOOP; d++) { - Klocal[d] = - //scaled_convert_b8x8(Klocalb8[d], k_scale); - convert_b8x8_custom(Klocalb8[d]); - } - } -#endif - /*Klocal[x] = scaled_convert_b8x8(Klocalb8[x], k_scale); \*/ - /*Klocal[x] = scaled_convert_b8x8_custom(Klocalb8[x], k_scale); \*/ #define QK_mfma(x) \ if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { \ Klocal[x] = convert_b8x8_custom(Klocalb8[x]); \ } \ for (int h = 0; h < QHLOOP; h++) { \ - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], \ + dout[h] = gcn_mfma4x4x4_instr(Qlocal[h].xy[0], \ Klocal[x].xy[0], dout[h]);\ - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], \ + dout[h] = gcn_mfma4x4x4_instr(Qlocal[h].xy[1], \ Klocal[x].xy[1], dout[h]);\ } - + //QK mfma QK_mfma(0); QK_mfma(1); QK_mfma(2); @@ -1274,6 +990,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( QK_mfma(5); QK_mfma(6); QK_mfma(7); + //below only needed for head size 128 if constexpr (KHELOOP > 8) { QK_mfma(8); QK_mfma(9); @@ -1285,86 +1002,29 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( QK_mfma(15); } #undef QK_mfma + float scale2 = scale; if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { scale2 *= k_scale; } - #pragma unroll + for (int h = 0; h < QHLOOP; h++) { dout[h] *= scale2; } -#if 0 - if (alibi_slopes != nullptr) { - float alibi_slope_local[GQA_RATIO]; -#define DPP_BCAST_ASM(id) asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:id " : "=v"(alibi_slope_local[id]) : "v"(alibi_slope)); - //for (int head=0; head < 16; head++) { - //DPP_BCAST_ASM(0); - if constexpr(GQA_RATIO>0) { asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:0 " : "=v"(alibi_slope_local[0]) : "v"(alibi_slope));} - if constexpr(GQA_RATIO>1) { asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:1 " : "=v"(alibi_slope_local[1]) : "v"(alibi_slope));} - if constexpr(GQA_RATIO>2) { asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:2 " : "=v"(alibi_slope_local[2]) : "v"(alibi_slope));} - if constexpr(GQA_RATIO>3) { asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:3 " : "=v"(alibi_slope_local[3]) : "v"(alibi_slope));} - if constexpr(GQA_RATIO>4) { asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:4 " : "=v"(alibi_slope_local[4]) : "v"(alibi_slope));} - if constexpr(GQA_RATIO>5) { asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:5 " : "=v"(alibi_slope_local[5]) : "v"(alibi_slope));} - if constexpr(GQA_RATIO>6) { asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:6 " : "=v"(alibi_slope_local[6]) : "v"(alibi_slope));} - if constexpr(GQA_RATIO>7) { asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:7 " : "=v"(alibi_slope_local[7]) : "v"(alibi_slope));} - //} - const int alibi_offset = global_token_idx - context_len + 1; - #pragma unroll - for (int h = 0; h < QHLOOP; h++) { - #pragma unroll - for (int i = 0; i < 4; i++) { - dout[h][i] += alibi_slope_local[4*h+i] * alibi_offset; - } - } - } -#endif // transpose dout so that 4 token ids are in each lane, and 4 heads are across // 4 lanes - #pragma unroll for (int h = 0; h < QHLOOP; h++) { -#if 1 floatx4 tmp = {0}; - #pragma unroll for (int i = 0; i < 4; i++) { const float B = (lane4id == i) ? 1.0f : 0.0f; - // const float A = (global_token_idx < context_len) ? dout[h][i] : 0.0f; tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(dout[h][i], B, tmp, 0, 0, 0); - // tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(A, B, tmp, 0, 0, 0); } dout[h] = tmp; -#endif -#if 0 - asm("s_nop 0\n\t v_mov_b32_dpp %0, %1 quad_perm:[1,0,3,2] " : "=v"(dout[h][1]) : "v"(dout[h][1]) ); - asm("s_nop 0\n\t v_mov_b32_dpp %0, %1 quad_perm:[2,3,0,1] " : "=v"(dout[h][2]) : "v"(dout[h][2]) ); - asm("s_nop 0\n\t v_mov_b32_dpp %0, %1 quad_perm:[3,2,1,0] " : "=v"(dout[h][3]) : "v"(dout[h][3]) ); - - bool mask = (lane4id % 2) == 1; - float tmp = dout[h][1]; - dout[h][1] = mask ? dout[h][0] : dout[h][1]; - dout[h][0] = mask ? tmp : dout[h][0]; - tmp = dout[h][3]; - dout[h][3] = mask ? dout[h][2] : dout[h][3]; - dout[h][2] = mask ? tmp : dout[h][2]; - - mask = (lane4id>>1) == 1; - tmp = dout[h][2]; - dout[h][2] = mask ? dout[h][0] : dout[h][2]; - dout[h][0] = mask ? tmp : dout[h][0]; - tmp = dout[h][3]; - dout[h][3] = mask ? dout[h][1] : dout[h][3]; - dout[h][1] = mask ? tmp : dout[h][1]; - - - asm("s_nop 0\n\t v_mov_b32_dpp %0, %1 quad_perm:[1,0,3,2] " : "=v"(dout[h][1]) : "v"(dout[h][1]) ); - asm("s_nop 0\n\t v_mov_b32_dpp %0, %1 quad_perm:[2,3,0,1] " : "=v"(dout[h][2]) : "v"(dout[h][2]) ); - asm("s_nop 0\n\t v_mov_b32_dpp %0, %1 quad_perm:[3,2,1,0] " : "=v"(dout[h][3]) : "v"(dout[h][3]) ); - -#endif } const int lane4_token_idx = 4 * (global_token_idx >> 2); -#if 1 //alibi after transpose + if constexpr(ALIBI_ENABLED) { const int alibi_offset = lane4_token_idx - context_len + 1; for (int h = 0; h < QHLOOP; h++) { @@ -1373,29 +1033,24 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } } -#endif const int bpermute_mask = 4*(16*((laneid>>2)%4) + lane4id); - #pragma unroll for (int h = 0; h < QHLOOP; h++) { qk_max[h] = -FLT_MAX; - #pragma unroll for (int i = 0; i < 4; i++) { qk_max[h] = (lane4_token_idx + i < context_len) ? fmaxf(qk_max[h], dout[h][i]) : qk_max[h]; } - #pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 64; mask /= 2) { - qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h], mask)); - } + + //for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { + // qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h], mask)); + //} + //faster version of above code with dpp asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:4" : "=v"(qk_max[h]) : "v"(qk_max[h]), "v"(qk_max[h]) ); asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:8" : "=v"(qk_max[h]) : "v"(qk_max[h]), "v"(qk_max[h]) ); - //asm("v_nop\n v_nop\n ds_bpermute_b32 %0, %1, %2 \n s_waitcnt lgkmcnt(0)" : "=v"(qk_max[h]) : "v"(bpermute_mask), "v"(qk_max[h]) ); - - //qk_max[h] = __builtin_amdgcn_ds_bpermute(bpermute_mask, qk_max[h]); auto tmp = __builtin_amdgcn_ds_bpermute(bpermute_mask, *reinterpret_cast(&qk_max[h])); qk_max[h] = *reinterpret_cast(&tmp); asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:4" : "=v"(qk_max[h]) : "v"(qk_max[h]), "v"(qk_max[h]) ); @@ -1404,25 +1059,21 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( float exp_sum[QHLOOP]; - #pragma unroll for (int h = 0; h < QHLOOP; h++) { exp_sum[h] = 0.0f; - #pragma unroll for (int i = 0; i < 4; i++) { dout[h][i] = (lane4_token_idx + i < context_len) ? __expf(dout[h][i] - qk_max[h]) : 0.0f; exp_sum[h] += dout[h][i]; } - #pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 64; mask /= 2) { - exp_sum[h] += __shfl_xor(exp_sum[h], mask); - } + //for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { + // exp_sum[h] += __shfl_xor(exp_sum[h], mask); + //} + //faster version of above code with dpp asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:4" : "=v"(exp_sum[h]) : "v"(exp_sum[h]), "v"(exp_sum[h]) ); asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:8" : "=v"(exp_sum[h]) : "v"(exp_sum[h]), "v"(exp_sum[h]) ); - //asm("v_nop\n v_nop\n ds_bpermute_b32 %0, %1, %2 \n s_waitcnt lgkmcnt(0)" : "=v"(exp_sum[h]) : "v"(bpermute_mask), "v"(exp_sum[h]) ); - //exp_sum[h] = __builtin_amdgcn_ds_bpermute(bpermute_mask, exp_sum[h]); auto tmp = __builtin_amdgcn_ds_bpermute(bpermute_mask, *reinterpret_cast(&exp_sum[h])); exp_sum[h] = *reinterpret_cast(&tmp); asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:4" : "=v"(exp_sum[h]) : "v"(exp_sum[h]), "v"(exp_sum[h]) ); @@ -1430,7 +1081,6 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } if (laneid<4) { - #pragma unroll for (int h = 0; h < QHLOOP; h++) { const int head_idx = 4 * h + lane4id; shared_qk_max[warpid][head_idx] = qk_max[h]; @@ -1446,18 +1096,15 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( max_logits + seq_idx * num_heads * max_num_partitions + partition_idx; float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + partition_idx; - #pragma unroll for (int h = 0; h < QHLOOP; h++) { float global_qk_max = -FLT_MAX; float warp_qk_max[NWARPS]; const int head_idx = 4 * h + lane4id; - #pragma unroll for (int w = 0; w < NWARPS; w++) { warp_qk_max[w] = shared_qk_max[w][head_idx]; global_qk_max = fmaxf(global_qk_max, warp_qk_max[w]); } float global_exp_sum = 0.0f; - #pragma unroll for (int w = 0; w < NWARPS; w++) { global_exp_sum += shared_exp_sum[w][head_idx] * __expf(warp_qk_max[w] - global_qk_max); @@ -1475,111 +1122,29 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( // logits[h] -> every 4 lanes hold 4 heads, each lane holds 4 tokens, there // are 4x16 tokens across warp _B16x4 logits[QHLOOP]; - #pragma unroll for (int h = 0; h < QHLOOP; h++) { logits[h] = from_floatx4_rtz(dout[h]); } if (warp_start_token_idx >= context_len) { // warp out of context - #pragma unroll for (int qh = 0; qh < QHLOOP; qh++) { - #pragma unroll for (int vh = 0; vh < VHELOOP; vh++) { vout_shared[qh][vh][laneid][warpid] = {0}; } } } else { // warp in context -#if 0 //fetch v cache - const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; - if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { - const _B16x8* v_ptrh8 = reinterpret_cast(v_ptr); - // iterate over each v block - #pragma unroll - for (int b = 0; b < VBLOCKS; b++) { - // int32 physical_block_number leads to overflow when multiplied with - // kv_block_stride - const int64_t vphysical_block_number = - static_cast(vphysical_blocks[b]); - const _B16x8* v_ptrh8b = - v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; - // iterate over each head elem (within head_size) - #pragma unroll - for (int h = 0; h < VHELOOP; h++) { - const int head_size_elem = h * WARP_SIZE + laneid; - const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; - // iterate over all velems within block - #pragma unroll - for (int d = 0; d < BLOCK_SIZE / 8; d++) { - Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; - } - } - } - } //if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) - - if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { - const _B8x8* v_ptrh8 = reinterpret_cast(v_ptr); - // iterate over each v block - #pragma unroll - for (int b = 0; b < VBLOCKS; b++) { - // int32 physical_block_number leads to overflow when multiplied with - // kv_block_stride - const int64_t vphysical_block_number = - static_cast(vphysical_blocks[b]); - const _B8x8* v_ptrh8b = - v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; - // iterate over each head elem (within head_size) - #pragma unroll - for (int h = 0; h < VHELOOP; h++) { - const int head_size_elem = h * WARP_SIZE + laneid; - const _B8x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; - // iterate over all velems within block - #pragma unroll - for (int d = 0; d < BLOCK_SIZE / 8; d++) { - Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; - //const _B8x8 Vlocalb8 = v_ptrh8be[d]; - //Vlocal[h][b * BLOCK_SIZE / 8 + d] = - // scaled_convert_b8x8(Vlocalb8, v_scale); - } - } - } - } -#endif -#if 0 //cvt vf8 ->f16/bf16 up front - if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { - for (int vh = 0; vh < VHELOOP; vh++) { - for (int b=0; b < VTLOOP; b++) { - //Vlocal[vh][b] = scaled_convert_b8x8(Vlocalb8[vh][b], v_scale); - Vlocal[vh][b] = convert_b8x8_custom(Vlocalb8[vh][b]); - } - } - } -#endif - - /*Vlocal[vh][x] = scaled_convert_b8x8(Vlocalb8[vh][x], v_scale);\*/ - /*Vlocal[vh][x] = scaled_convert_b8x8_custom(Vlocalb8[vh][x], v_scale);\*/ #define SV_mfma(x) \ if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) {\ Vlocal[vh][x] = convert_b8x8_custom(Vlocalb8[vh][x]);\ }\ for (int qh = 0; qh < QHLOOP; qh++) { \ - acc[qh] = gcn_mfma_instr(logits[qh], Vlocal[vh][x].xy[0], \ + acc[qh] = gcn_mfma4x4x4_instr(logits[qh], Vlocal[vh][x].xy[0], \ acc[qh]); \ - acc[qh] = gcn_mfma_instr(logits[qh], Vlocal[vh][x].xy[1], \ + acc[qh] = gcn_mfma4x4x4_instr(logits[qh], Vlocal[vh][x].xy[1], \ acc[qh]); \ } -#if 0 - floatx4 acc[QHLOOP][VHELOOP]; - for (int qh = 0; qh < QHLOOP; qh++) { - for (int vh = 0; vh < VHELOOP; vh++) { - acc[qh][vh] = {0}; - } - } -#endif - //#pragma unroll - // for (int qh = 0; qh < QHLOOP; qh++) { - // iterate over each v head elem (within head_size) - //#pragma unroll + for (int vh = 0; vh < VHELOOP; vh++) { floatx4 acc[QHLOOP]; for (int qh = 0; qh < QHLOOP; qh++) { @@ -1594,16 +1159,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( SV_mfma(5); SV_mfma(6); SV_mfma(7); -#if 0 - SV_mfma(8); - SV_mfma(9); - SV_mfma(10); - SV_mfma(11); - SV_mfma(12); - SV_mfma(13); - SV_mfma(14); - SV_mfma(15); -#endif + for (int qh = 0; qh < QHLOOP; qh++) { if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { acc[qh] *= v_scale; @@ -1611,15 +1167,6 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( vout_shared[qh][vh][laneid][warpid] = from_floatx4(acc[qh]); } } - //} - -#if 0 - for (int qh = 0; qh < QHLOOP; qh++) { - for (int vh = 0; vh < VHELOOP; vh++) { - vout_shared[qh][vh][laneid][warpid] = from_floatx4(acc[qh][vh]); - } - } -#endif #undef SV_mfma } // warp in context @@ -1627,19 +1174,14 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( __syncthreads(); if (warpid == 0) { - // const float out_scale = (fp8_out_scale_ptr != nullptr) ? - // __fdividef(1.0f,(*fp8_out_scale_ptr)) : 1.0f; const float out_scale = (fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f; _B16x4 vout[QHLOOP][VHELOOP]; // iterate across heads - #pragma unroll for (int qh = 0; qh < QHLOOP; qh++) { // iterate over each v head elem (within head_size) - #pragma unroll for (int vh = 0; vh < VHELOOP; vh++) { vout[qh][vh] = {0}; - #pragma unroll for (int w = 0; w < NWARPS; w++) { vout[qh][vh] = addx4(vout[qh][vh], vout_shared[qh][vh][laneid][w]); @@ -1687,13 +1229,6 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( const int seq_idx = blockIdx.y; const int context_len = context_lens[seq_idx]; const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); -#if 0 //disable this as mfma16 kernel does not support this optimization yet - if (num_partitions == 1) { - // if num_partitions==1, main kernel will write to out directly, no work in - // reduction kernel - return; - } -#endif constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; const int warpid = threadIdx.x / WARP_SIZE; const int laneid = threadIdx.x % WARP_SIZE; @@ -1858,8 +1393,6 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( const float inv_global_exp_sum = __fdividef(1.0f, shared_global_exp_sum + 1e-6f); - // const float out_scale = (fp8_out_scale_ptr != nullptr) ? - // __fdividef(1.0f,(*fp8_out_scale_ptr)) : 1.0f; const float out_scale = (fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f; acc *= inv_global_exp_sum; @@ -1905,7 +1438,7 @@ template -__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( +__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, // head_size/x, block_size, x] @@ -1956,8 +1489,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \ k_scale, v_scale, fp8_out_scale_ptr); -#define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \ - paged_attention_ll4mi_QKV_kernel \ <<>>( \ query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ @@ -2014,6 +1547,9 @@ void paged_attention_custom_launcher( OUTT* out_ptr = reinterpret_cast(out.data_ptr()); const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); + + //partition size is fixed at 256 since both mfma4 and mfma16 kernels support it + //mfma4 kernel also supports partition size 512 constexpr int PARTITION_SIZE = 256; const int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); @@ -2021,74 +1557,60 @@ void paged_attention_custom_launcher( assert(num_heads % num_kv_heads == 0); assert(head_size == HEAD_SIZE); - constexpr int NTHR = 256; //PARTITION_SIZE; + constexpr int NTHR = 256; dim3 grid(num_seqs, max_num_partitions, num_kv_heads); dim3 block(NTHR); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + //mfma4 kernel is faster than mfma16 for gqa_ratio <= 4 switch (gqa_ratio) { case 1: - LAUNCH_CUSTOM_ATTENTION(1); - //LAUNCH_CUSTOM_ATTENTION_MFMA16(1); + LAUNCH_CUSTOM_ATTENTION_MFMA4(1); break; case 2: - LAUNCH_CUSTOM_ATTENTION(2); - //LAUNCH_CUSTOM_ATTENTION_MFMA16(2); + LAUNCH_CUSTOM_ATTENTION_MFMA4(2); break; case 3: - LAUNCH_CUSTOM_ATTENTION(3); - //LAUNCH_CUSTOM_ATTENTION_MFMA16(3); + LAUNCH_CUSTOM_ATTENTION_MFMA4(3); break; case 4: - LAUNCH_CUSTOM_ATTENTION(4); - //LAUNCH_CUSTOM_ATTENTION_MFMA16(4); + LAUNCH_CUSTOM_ATTENTION_MFMA4(4); break; case 5: - //LAUNCH_CUSTOM_ATTENTION(5); LAUNCH_CUSTOM_ATTENTION_MFMA16(5); break; case 6: - //LAUNCH_CUSTOM_ATTENTION(6); LAUNCH_CUSTOM_ATTENTION_MFMA16(6); break; case 7: - //LAUNCH_CUSTOM_ATTENTION(7); LAUNCH_CUSTOM_ATTENTION_MFMA16(7); break; case 8: - //LAUNCH_CUSTOM_ATTENTION(8); LAUNCH_CUSTOM_ATTENTION_MFMA16(8); break; case 9: - //LAUNCH_CUSTOM_ATTENTION(9); LAUNCH_CUSTOM_ATTENTION_MFMA16(9); break; case 10: - //LAUNCH_CUSTOM_ATTENTION(10); LAUNCH_CUSTOM_ATTENTION_MFMA16(10); break; case 11: - //LAUNCH_CUSTOM_ATTENTION(11); LAUNCH_CUSTOM_ATTENTION_MFMA16(11); break; case 12: - //LAUNCH_CUSTOM_ATTENTION(12); LAUNCH_CUSTOM_ATTENTION_MFMA16(12); break; case 13: - //LAUNCH_CUSTOM_ATTENTION(13); LAUNCH_CUSTOM_ATTENTION_MFMA16(13); break; case 14: - //LAUNCH_CUSTOM_ATTENTION(14); LAUNCH_CUSTOM_ATTENTION_MFMA16(14); break; case 15: - //LAUNCH_CUSTOM_ATTENTION(15); LAUNCH_CUSTOM_ATTENTION_MFMA16(15); break; case 16: - //LAUNCH_CUSTOM_ATTENTION(16); LAUNCH_CUSTOM_ATTENTION_MFMA16(16); break; default: @@ -2096,19 +1618,10 @@ void paged_attention_custom_launcher( break; } - // reduction kernel is only required if max_context_len > partition size, - // otherwise main kernel writes directly to final output - // note there are cases with graphing where max_context_len is the max - // supported by graphing, not the actual max among all the sequences: in that - // case reduction kernel will still run but return immediately - - //below optimization is not yet implemented in mfma16 kernel - //if (max_context_len > PARTITION_SIZE) { dim3 reduce_grid(num_heads, num_seqs); dim3 reduce_block(head_size); const int npar_loops = DIVIDE_ROUND_UP(max_num_partitions, WARP_SIZE); - // support upto 8*64*256=128K context length -#if 1 + //reduction kernel supports upto 8 NPAR *64*256 = 128K context length switch (npar_loops) { case 1: LAUNCH_CUSTOM_REDUCTION(1); @@ -2138,8 +1651,6 @@ void paged_attention_custom_launcher( TORCH_CHECK(false, "Unsupported npar_loops: ", npar_loops); break; } -#endif - //} //if max_context_len > partition_size } #define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \ From e52eb1f233f1d786b7354fdbbbc3778e257179dd Mon Sep 17 00:00:00 2001 From: vllmellm Date: Fri, 17 Jan 2025 07:01:36 +0000 Subject: [PATCH 12/15] kernel bug fixes and code cleaning. adjusted attention kernel unit test. --- csrc/rocm/attention.cu | 475 ++------------------------------ tests/kernels/test_attention.py | 47 ++-- 2 files changed, 48 insertions(+), 474 deletions(-) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 920a5da40458..dd45b15694ec 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -97,10 +97,6 @@ __device__ __forceinline__ _B16x8 load_ntmprl_16Byte(const _B16x8* addr) { auto dat1 = loadnt(addr_alias + 1); auto dat2 = loadnt(addr_alias + 2); auto dat3 = loadnt(addr_alias + 3); - //auto dat0 = *(addr_alias); - //auto dat1 = *(addr_alias+1); - //auto dat2 = *(addr_alias+2); - //auto dat3 = *(addr_alias+3); auto res = make_float4(dat0,dat1,dat2,dat3); return *reinterpret_cast<_B16x8*>(&res); } @@ -183,23 +179,7 @@ __device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) { __hip_bfloat16 b; } t16; _B16x4 ret; -#if 0 - #pragma unroll - for (int i = 0; i < 4; i++) { - t16.f = (_Float16)inp[i]; - ret[i] = t16.u; - } - return ret; -#else if constexpr (std::is_same::value) { -#if 0 - #pragma unroll - for (int i = 0; i < 4; i++) { - t16.f = (_Float16)inp[i]; - ret[i] = t16.u; - } - return ret; -#else union h2cvt { __half2 h2[2]; _B16x4 b16x4; @@ -207,7 +187,6 @@ __device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) { u.h2[0] = __float22half2_rn(make_float2(inp[0],inp[1])); u.h2[1] = __float22half2_rn(make_float2(inp[2],inp[3])); return u.b16x4; -#endif } else if constexpr (std::is_same::value) { #pragma unroll for (int i = 0; i < 4; i++) { @@ -218,14 +197,11 @@ __device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) { u.f32 = inp[i]; u.u32 += 0x7fff + ((u.u32 >> 16) & 1); //RNE with no nan/inf check ret[i] = uint16_t(u.u32 >> 16); - //t16.b = __float2bfloat16(inp[i]); - //ret[i] = t16.u; } return ret; } else { static_assert(false, "unsupported 16b dtype"); } -#endif } template @@ -237,27 +213,7 @@ __device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, __hip_bfloat16 b; } t1, t2, res; _B16x4 ret; -#if 0 - #pragma unroll - for (int i = 0; i < 4; i++) { - t1.u = inp1[i]; - t2.u = inp2[i]; - res.f = t1.f + t2.f; - ret[i] = res.u; - } - return ret; -#else if constexpr (std::is_same::value) { -#if 0 - #pragma unroll - for (int i = 0; i < 4; i++) { - t1.u = inp1[i]; - t2.u = inp2[i]; - res.f = t1.f + t2.f; - ret[i] = res.u; - } - return ret; -#else union h2cvt { _B16x4 b16x4; __half2 h2[2]; @@ -267,7 +223,6 @@ __device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, s.h2[0] = u1.h2[0] + u2.h2[0]; s.h2[1] = u1.h2[1] + u2.h2[1]; return s.b16x4; -#endif } else if constexpr (std::is_same::value) { #pragma unroll for (int i = 0; i < 4; i++) { @@ -279,16 +234,11 @@ __device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, u2.i32 = uint32_t(inp2[i])<<16; s.f32 = u1.f32 + u2.f32; ret[i] = uint16_t(s.i32>>16); - //t1.u = inp1[i]; - //t2.u = inp2[i]; - //res.b = t1.b + t2.b; - //ret[i] = res.u; } return ret; } else { static_assert(false, "unsupported 16b dtype"); } -#endif } template @@ -416,16 +366,6 @@ __device__ __forceinline__ _B16x4 from_floatx4_trunc(const floatx4& inp) { template __device__ __forceinline__ _B16x8 convert_b8x8_custom(const _B8x8 input) { -#if 0 - union { - floatx4 f32x4[2]; - vllm::Float8_ f32x8; - _B8x8 b8x8[2]; - } tmpf8; - tmpf8.f32x8 = vllm::fp8::vec_conversion(*reinterpret_cast(&input)); - //tmpf8.b8x8[0] = input; - //tmpf8.b8x8[1] = input; -#endif union { _B8x8 b8x8; _B8x4 b8x4[2]; @@ -435,10 +375,9 @@ __device__ __forceinline__ _B16x8 convert_b8x8_custom(const _B8x8 input) { for (int i=0; i<2; i++) { ret.xy[i] = from_floatx4_rtz( to_float_fp8x4(tmp.b8x4[i]) ); } - //ret.xy[0] = from_floatx4(tmpf8.f32x4[0]); - //ret.xy[1] = from_floatx4(tmpf8.f32x4[1]); return ret; } + /////////////////////////////////////// // grid (num_seqs, num_partitions,num_heads/gqa_ratio) // block (partition size) @@ -464,7 +403,7 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, // head_size] OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] - int max_ctx_blocks, float k_scale, float v_scale, + int max_ctx_blocks, const float* k_scale_ptr, const float* v_scale_ptr, const float* __restrict__ fp8_out_scale_ptr) { constexpr int NWARPS = NUM_THREADS / WARP_SIZE; const int warpid = threadIdx.x / WARP_SIZE; @@ -543,31 +482,7 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 kphysical_block_number[token_depth] = block_table_seq[kblock_idx]; } -#if 0 //fetch Q into registers - - const int local_qhead_idx = lane16id % GQA_RATIO; - const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; - const int64_t seq_idx64 = static_cast(seq_idx); - const scalar_t* q_ptr = q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE + rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD; - - if (lane16id < GQA_RATIO) { - for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { - const scalar_t* q_ptr2 = q_ptr + qkhe_depth * QKHE_PER_FETCH; - for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { - const scalar_t* q_fetch_ptr = q_ptr2 + qkratio * CONTIGUOUS_SCALAR_ELEMS_16B; - const _B16x8* q_fetch_ptr_16B = reinterpret_cast(q_fetch_ptr); - Qlocal[qkhe_depth][qkratio] = *q_fetch_ptr_16B; - } - } - } else { - for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { - for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { - Qlocal[qkhe_depth][qkratio].xy[0] = {0}; - Qlocal[qkhe_depth][qkratio].xy[1] = {0}; - } - } - } -#else //fetch Q in shared + //fetch Q in shared const int local_qhead_idx = 4 * warpid + rowid; const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; const int64_t seq_idx64 = static_cast(seq_idx); @@ -600,7 +515,6 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 } } } -#endif constexpr int KX = 16 / sizeof(cache_t); const cache_t* k_ptr = k_cache + wg_start_kv_head_idx * kv_head_stride; @@ -622,7 +536,6 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 const cache_t* k_fetch_ptr = k_ptr3 + offset1 * BLOCK_SIZE * KX + offset2; const _B16x8* k_fetch_ptr_16B = reinterpret_cast(k_fetch_ptr); Klocal[token_depth][qkhe_depth] = *k_fetch_ptr_16B; - //Klocal[token_depth][qkhe_depth] = load_ntmprl_16Byte(k_fetch_ptr_16B); } } @@ -672,15 +585,13 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 const cache_t* v_fetch_ptr = v_ptr3 + vfetch_depth * CONTIGUOUS_KV_ELEMS_16B_LOAD; const _B16x8* v_fetch_ptr_16B = reinterpret_cast(v_fetch_ptr); Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = *v_fetch_ptr_16B; - //Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = load_ntmprl_16Byte(v_fetch_ptr_16B); } } } - //__syncthreads(); //if using shared Q float scale2 = scale; if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { - scale2 *= k_scale; + scale2 *= *k_scale_ptr; } floatx4 dout[TLOOP]; @@ -711,15 +622,6 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 dout[token_depth] *= scale2; } -#if 0 //DEBUG ONLY qk * scale - for (int token_depth = 0; token_depth < TLOOP; token_depth++) { - auto qkout_ptr2 = qkout_ptr + warpid * TLOOP * 16 + token_depth * 16 + rowid * 4; - auto qkout_write_ptr = reinterpret_cast<_B16x4 *>(qkout_ptr2); - auto tmp = from_floatx4(dout[token_depth]); - *qkout_write_ptr = tmp; - } -#endif - const int qkout_token_idx = partition_start_token_idx + TOKENS_PER_WARP * warpid + rowid * 4; if constexpr(ALIBI_ENABLED) { @@ -749,7 +651,6 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 qk_max = fmaxf(qk_max, __shfl_xor(qk_max,mask)); } - for (int token_depth = 0; token_depth < TLOOP; token_depth++) { const int local_token_idx = qkout_token_idx + token_depth * 16; for (int i=0; i<4; i++) { @@ -774,18 +675,6 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 const int exp_sum_offset = NWARPS*16 + qk_max_offset; shared_mem[exp_sum_offset] = exp_sum; } - -#if 0 //DEBUG ONLY - //scalar_t* qkout_ptr = out + - // seq_idx * total_num_heads * T_PAR_SIZE + lane16id * T_PAR_SIZE; - for (int token_depth = 0; token_depth < TLOOP; token_depth++) { - //auto qkout_ptr2 = qkout_ptr + warpid * TLOOP * 16 + token_depth * 16 + rowid * 4; - //auto qkout_write_ptr = reinterpret_cast<_B16x4 *>(qkout_ptr2); - auto tmp = from_floatx4(dout[token_depth]); - shared_tokens[warpid][token_depth][lane16id][rowid] = tmp; - //*qkout_write_ptr = tmp; - } -#endif __syncthreads(); float partition_qk_max = -FLT_MAX; @@ -829,40 +718,6 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 __syncthreads(); -#if 0 //DEBUG ONLY - scalar_t* qkout_ptr = out + - seq_idx * total_num_heads * T_PAR_SIZE + lane16id * T_PAR_SIZE; - for (int token_depth = 0; token_depth < TLOOP; token_depth++) { - auto qkout_ptr2 = qkout_ptr + warpid * TLOOP * 16 + token_depth * 16 + rowid * 4; - auto qkout_write_ptr = reinterpret_cast<_B16x4 *>(qkout_ptr2); - //dout[token_depth] *= inv_sum_scale[warpid]; - //auto tmp = from_floatx4(dout[token_depth]); - auto tmp = shared_tokens[warpid][token_depth][lane16id][rowid]; - *qkout_write_ptr = tmp; - } -#endif -#if 0 - if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { - for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { - for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { - for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { - _B16x8 Vtmp = Vlocal[vtoken_depth][vhe_depth][vfetch_depth]; - _B8x16 Vtmp8x16 = *reinterpret_cast<_B8x16*>(&Vtmp); - for (int j=0; j<2; j++) { - _B8x8 Vtmp8x8 = Vtmp8x16.xy[j]; - _B16x8 Vlocaltmp = convert_b8x8_custom(Vtmp8x8); - for (int i=0; i<2; i++) { - const int offset = 4*rowid + 2*j + i; - const int offset1 = offset % 4; - const int offset2 = offset / 4; - tmp_out = gcn_mfma16x16x16_instr(Vlocaltmp.xy[i], - shared_logits[vtoken_depth][offset2][lane16id][offset1], - tmp_out); - } - } - } - } -#endif _B16x4 outelems[VHELOOP]; //v layout: 16he across lanes x 16 tokens per lane @@ -883,16 +738,11 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 const int offset = rowid * VTLANELOOP * 2 + 2*vfetch_depth + i; const int offset1 = offset % 4; //4 corresponds to ROWS_PER_WARP const int offset2 = offset / 4; -#if 0 - //if output format is 16 head elems across 16 lanes, 16 qheads spread across 4 rows - tmp_out = gcn_mfma16x16x16_instr(shared_logits[vtoken_depth][offset2][lane16id][offset1], - Vlocal[vtoken_depth][vhe_depth][vfetch_depth].xy[i], tmp_out); -#else + //if output format is 16 qheads across 16 lanes, 16 head elems spread across 4 rows tmp_out = gcn_mfma16x16x16_instr(Vlocal[vtoken_depth][vhe_depth][vfetch_depth].xy[i], shared_logits[vtoken_depth][offset2][lane16id][offset1], tmp_out); -#endif } } } else { @@ -918,12 +768,11 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 } } if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { - tmp_out *= v_scale; + tmp_out *= *v_scale_ptr; } outelems[vhe_depth] = from_floatx4(tmp_out); } -#if 1 __syncthreads(); for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { @@ -962,77 +811,6 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 } } } - -#endif - -#if 0 - //if output format is 16 he across 16 lanes, 16 qheads spread across 4 rows - const int hsz_maxp_mult = HEAD_SIZE * max_num_partitions; - scalar_t* out_ptr = out + - seq_idx * total_num_heads * hsz_maxp_mult + partition_idx * HEAD_SIZE; - - const int vhe_offset = warpid * 16 + lane16id; - - #pragma unroll - for (int i=0; i<4; i++) { - const int local_head_idx = 4*rowid + i; - if (local_head_idx < GQA_RATIO) { - const int out_head_idx = wg_start_head_idx + local_head_idx; - scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult; - for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { - const int vhead_elem = vhe_depth * NWARPS * 16 + vhe_offset; - scalar_t* out_ptr3 = out_ptr2 + vhead_elem; - bit16_t* out_ptr_b16 = reinterpret_cast(out_ptr3); - *out_ptr_b16 = outelems[vhe_depth][i]; - } - } - } -#endif -#if 0 - //if output format is 16 qheads across 16 lanes, 16 he spread across 4 rows - if (lane16id < GQA_RATIO) { - const int hsz_maxp_mult = HEAD_SIZE * max_num_partitions; - scalar_t* out_ptr = out + - seq_idx * total_num_heads * hsz_maxp_mult + partition_idx * HEAD_SIZE; - const int local_head_idx = lane16id; - const int out_head_idx = wg_start_head_idx + local_head_idx; - scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult; - const int vhe_offset = warpid * 16 + rowid * 4; - for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { - const int vhead_elem = vhe_depth * NWARPS * 16 + vhe_offset; - scalar_t* out_ptr3 = out_ptr2 + vhead_elem; - _B16x4* out_ptr_B16x4 = reinterpret_cast<_B16x4*>(out_ptr3); - *out_ptr_B16x4 = outelems[vhe_depth]; - } - } -#endif -#if 0 //DEBUG ONLY - floatx4 partition_out[VHELOOP]; - for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { - partition_out[vhe_depth] = {0}; - for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { - partition_out[vhe_depth] += inv_sum_scale[vtoken_depth] * vout[vhe_depth][vtoken_depth]; - } - } -#endif -#if 0 //DEBUG ONLY - if (laneid < GQA_RATIO) { - auto* exp_sums_ptr = exp_sums + seq_idx * 8 * max_num_partitions + partition_idx; - floatx4 tmp = {0}; - //for (int t=0; t(from_floatx4(tmp), shared_tokens[warpid][lane4id][lane16id][rowid]); - - float2 tmpf = *reinterpret_cast(&tmp16); - *exp_sums_ptr = laneid%2 == 0 ? tmpf.x : tmpf.y; - } -#endif } ///////////////////////////////////////////////////////////// // grid (num_seqs, num_partitions,num_heads/gqa_ratio) @@ -1190,7 +968,6 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } -#if 1 float alibi_slope[QHLOOP]; if constexpr(ALIBI_ENABLED) { for (int h = 0; h < QHLOOP; h++) { @@ -1200,28 +977,8 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( : 0.f; } } -#endif -#if 0 - float alibi_slope; - const int lane16id = laneid % 16; - if (alibi_slopes != nullptr) { - alibi_slope = (lane16id < GQA_RATIO) - ? alibi_slopes[wg_start_head_idx + lane16id] - : 0.f; - //#pragma unroll - // for (int h = 0; h < QHLOOP; h++) { - // for (int i=0; i<4; i++) { - // const int qhead_idx = h * 4 + i; - // alibi_slope[qhead_idx] = (qhead_idx < GQA_RATIO) - // ? alibi_slopes[wg_start_head_idx + qhead_idx] - // : 0.f; - // } - //} - //} - } -#endif -#if 1 //fetch vcache in normal case + //fetch vcache in normal case const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { const _B16x8* v_ptrh8 = reinterpret_cast(v_ptr); @@ -1247,8 +1004,8 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } } //if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) -#endif -#if 1 //fetch vcache in fp8 case + + //fetch vcache in fp8 case else { // if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) const _B8x8* v_ptrh8 = reinterpret_cast(v_ptr); // iterate over each v block @@ -1269,24 +1026,10 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( #pragma unroll for (int d = 0; d < BLOCK_SIZE / 8; d++) { Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; - //const _B8x8 Vlocalb8 = v_ptrh8be[d]; - //Vlocal[h][b * BLOCK_SIZE / 8 + d] = - // scaled_convert_b8x8(Vlocalb8, v_scale); } } } } -#endif -#if 0 //cvt kf8 to kf/bf16 up front - if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { - #pragma unroll - for (int d = 0; d < KHELOOP; d++) { - Klocal[d] = - //scaled_convert_b8x8(Klocalb8[d], k_scale); - convert_b8x8_custom(Klocalb8[d]); - } - } -#endif /*Klocal[x] = scaled_convert_b8x8(Klocalb8[x], k_scale); \*/ /*Klocal[x] = scaled_convert_b8x8_custom(Klocalb8[x], k_scale); \*/ @@ -1322,43 +1065,16 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( #undef QK_mfma float scale2 = scale; if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { - scale2 *= k_scale; + scale2 *= *k_scale_ptr; } #pragma unroll for (int h = 0; h < QHLOOP; h++) { dout[h] *= scale2; } -#if 0 - if (alibi_slopes != nullptr) { - float alibi_slope_local[GQA_RATIO]; -#define DPP_BCAST_ASM(id) asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:id " : "=v"(alibi_slope_local[id]) : "v"(alibi_slope)); - //for (int head=0; head < 16; head++) { - //DPP_BCAST_ASM(0); - if constexpr(GQA_RATIO>0) { asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:0 " : "=v"(alibi_slope_local[0]) : "v"(alibi_slope));} - if constexpr(GQA_RATIO>1) { asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:1 " : "=v"(alibi_slope_local[1]) : "v"(alibi_slope));} - if constexpr(GQA_RATIO>2) { asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:2 " : "=v"(alibi_slope_local[2]) : "v"(alibi_slope));} - if constexpr(GQA_RATIO>3) { asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:3 " : "=v"(alibi_slope_local[3]) : "v"(alibi_slope));} - if constexpr(GQA_RATIO>4) { asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:4 " : "=v"(alibi_slope_local[4]) : "v"(alibi_slope));} - if constexpr(GQA_RATIO>5) { asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:5 " : "=v"(alibi_slope_local[5]) : "v"(alibi_slope));} - if constexpr(GQA_RATIO>6) { asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:6 " : "=v"(alibi_slope_local[6]) : "v"(alibi_slope));} - if constexpr(GQA_RATIO>7) { asm("s_nop 0\n\tv_mov_b32_dpp %0, %1 row_newbcast:7 " : "=v"(alibi_slope_local[7]) : "v"(alibi_slope));} - //} - - const int alibi_offset = global_token_idx - context_len + 1; - #pragma unroll - for (int h = 0; h < QHLOOP; h++) { - #pragma unroll - for (int i = 0; i < 4; i++) { - dout[h][i] += alibi_slope_local[4*h+i] * alibi_offset; - } - } - } -#endif // transpose dout so that 4 token ids are in each lane, and 4 heads are across // 4 lanes #pragma unroll for (int h = 0; h < QHLOOP; h++) { -#if 1 floatx4 tmp = {0}; #pragma unroll for (int i = 0; i < 4; i++) { @@ -1368,38 +1084,10 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( // tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(A, B, tmp, 0, 0, 0); } dout[h] = tmp; -#endif -#if 0 - asm("s_nop 0\n\t v_mov_b32_dpp %0, %1 quad_perm:[1,0,3,2] " : "=v"(dout[h][1]) : "v"(dout[h][1]) ); - asm("s_nop 0\n\t v_mov_b32_dpp %0, %1 quad_perm:[2,3,0,1] " : "=v"(dout[h][2]) : "v"(dout[h][2]) ); - asm("s_nop 0\n\t v_mov_b32_dpp %0, %1 quad_perm:[3,2,1,0] " : "=v"(dout[h][3]) : "v"(dout[h][3]) ); - - bool mask = (lane4id % 2) == 1; - float tmp = dout[h][1]; - dout[h][1] = mask ? dout[h][0] : dout[h][1]; - dout[h][0] = mask ? tmp : dout[h][0]; - tmp = dout[h][3]; - dout[h][3] = mask ? dout[h][2] : dout[h][3]; - dout[h][2] = mask ? tmp : dout[h][2]; - - mask = (lane4id>>1) == 1; - tmp = dout[h][2]; - dout[h][2] = mask ? dout[h][0] : dout[h][2]; - dout[h][0] = mask ? tmp : dout[h][0]; - tmp = dout[h][3]; - dout[h][3] = mask ? dout[h][1] : dout[h][3]; - dout[h][1] = mask ? tmp : dout[h][1]; - - - asm("s_nop 0\n\t v_mov_b32_dpp %0, %1 quad_perm:[1,0,3,2] " : "=v"(dout[h][1]) : "v"(dout[h][1]) ); - asm("s_nop 0\n\t v_mov_b32_dpp %0, %1 quad_perm:[2,3,0,1] " : "=v"(dout[h][2]) : "v"(dout[h][2]) ); - asm("s_nop 0\n\t v_mov_b32_dpp %0, %1 quad_perm:[3,2,1,0] " : "=v"(dout[h][3]) : "v"(dout[h][3]) ); - -#endif } const int lane4_token_idx = 4 * (global_token_idx >> 2); -#if 1 //alibi after transpose + //alibi after transpose if constexpr(ALIBI_ENABLED) { const int alibi_offset = lane4_token_idx - context_len + 1; for (int h = 0; h < QHLOOP; h++) { @@ -1408,7 +1096,6 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } } -#endif const int bpermute_mask = 4*(16*((laneid>>2)%4) + lane4id); @@ -1428,9 +1115,6 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:4" : "=v"(qk_max[h]) : "v"(qk_max[h]), "v"(qk_max[h]) ); asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:8" : "=v"(qk_max[h]) : "v"(qk_max[h]), "v"(qk_max[h]) ); - //asm("v_nop\n v_nop\n ds_bpermute_b32 %0, %1, %2 \n s_waitcnt lgkmcnt(0)" : "=v"(qk_max[h]) : "v"(bpermute_mask), "v"(qk_max[h]) ); - - //qk_max[h] = __builtin_amdgcn_ds_bpermute(bpermute_mask, qk_max[h]); auto tmp = __builtin_amdgcn_ds_bpermute(bpermute_mask, *reinterpret_cast(&qk_max[h])); qk_max[h] = *reinterpret_cast(&tmp); asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:4" : "=v"(qk_max[h]) : "v"(qk_max[h]), "v"(qk_max[h]) ); @@ -1456,8 +1140,6 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:4" : "=v"(exp_sum[h]) : "v"(exp_sum[h]), "v"(exp_sum[h]) ); asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:8" : "=v"(exp_sum[h]) : "v"(exp_sum[h]), "v"(exp_sum[h]) ); - //asm("v_nop\n v_nop\n ds_bpermute_b32 %0, %1, %2 \n s_waitcnt lgkmcnt(0)" : "=v"(exp_sum[h]) : "v"(bpermute_mask), "v"(exp_sum[h]) ); - //exp_sum[h] = __builtin_amdgcn_ds_bpermute(bpermute_mask, exp_sum[h]); auto tmp = __builtin_amdgcn_ds_bpermute(bpermute_mask, *reinterpret_cast(&exp_sum[h])); exp_sum[h] = *reinterpret_cast(&tmp); asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:4" : "=v"(exp_sum[h]) : "v"(exp_sum[h]), "v"(exp_sum[h]) ); @@ -1525,72 +1207,6 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } } else { // warp in context -#if 0 //fetch v cache - const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; - if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { - const _B16x8* v_ptrh8 = reinterpret_cast(v_ptr); - // iterate over each v block - #pragma unroll - for (int b = 0; b < VBLOCKS; b++) { - // int32 physical_block_number leads to overflow when multiplied with - // kv_block_stride - const int64_t vphysical_block_number = - static_cast(vphysical_blocks[b]); - const _B16x8* v_ptrh8b = - v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; - // iterate over each head elem (within head_size) - #pragma unroll - for (int h = 0; h < VHELOOP; h++) { - const int head_size_elem = h * WARP_SIZE + laneid; - const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; - // iterate over all velems within block - #pragma unroll - for (int d = 0; d < BLOCK_SIZE / 8; d++) { - Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; - } - } - } - } //if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) - - if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { - const _B8x8* v_ptrh8 = reinterpret_cast(v_ptr); - // iterate over each v block - #pragma unroll - for (int b = 0; b < VBLOCKS; b++) { - // int32 physical_block_number leads to overflow when multiplied with - // kv_block_stride - const int64_t vphysical_block_number = - static_cast(vphysical_blocks[b]); - const _B8x8* v_ptrh8b = - v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; - // iterate over each head elem (within head_size) - #pragma unroll - for (int h = 0; h < VHELOOP; h++) { - const int head_size_elem = h * WARP_SIZE + laneid; - const _B8x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; - // iterate over all velems within block - #pragma unroll - for (int d = 0; d < BLOCK_SIZE / 8; d++) { - Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; - //const _B8x8 Vlocalb8 = v_ptrh8be[d]; - //Vlocal[h][b * BLOCK_SIZE / 8 + d] = - // scaled_convert_b8x8(Vlocalb8, v_scale); - } - } - } - } -#endif -#if 0 //cvt vf8 ->f16/bf16 up front - if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { - for (int vh = 0; vh < VHELOOP; vh++) { - for (int b=0; b < VTLOOP; b++) { - //Vlocal[vh][b] = scaled_convert_b8x8(Vlocalb8[vh][b], v_scale); - Vlocal[vh][b] = convert_b8x8_custom(Vlocalb8[vh][b]); - } - } - } -#endif - /*Vlocal[vh][x] = scaled_convert_b8x8(Vlocalb8[vh][x], v_scale);\*/ /*Vlocal[vh][x] = scaled_convert_b8x8_custom(Vlocalb8[vh][x], v_scale);\*/ #define SV_mfma(x) \ @@ -1603,18 +1219,8 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( acc[qh] = gcn_mfma_instr(logits[qh], Vlocal[vh][x].xy[1], \ acc[qh]); \ } -#if 0 - floatx4 acc[QHLOOP][VHELOOP]; - for (int qh = 0; qh < QHLOOP; qh++) { - for (int vh = 0; vh < VHELOOP; vh++) { - acc[qh][vh] = {0}; - } - } -#endif - //#pragma unroll - // for (int qh = 0; qh < QHLOOP; qh++) { + // iterate over each v head elem (within head_size) - //#pragma unroll for (int vh = 0; vh < VHELOOP; vh++) { floatx4 acc[QHLOOP]; for (int qh = 0; qh < QHLOOP; qh++) { @@ -1629,33 +1235,13 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( SV_mfma(5); SV_mfma(6); SV_mfma(7); -#if 0 - SV_mfma(8); - SV_mfma(9); - SV_mfma(10); - SV_mfma(11); - SV_mfma(12); - SV_mfma(13); - SV_mfma(14); - SV_mfma(15); -#endif for (int qh = 0; qh < QHLOOP; qh++) { if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { - acc[qh] *= v_scale; + acc[qh] *= *v_scale_ptr; } vout_shared[qh][vh][laneid][warpid] = from_floatx4(acc[qh]); } } - //} - -#if 0 - for (int qh = 0; qh < QHLOOP; qh++) { - for (int vh = 0; vh < VHELOOP; vh++) { - vout_shared[qh][vh][laneid][warpid] = from_floatx4(acc[qh][vh]); - } - } -#endif - #undef SV_mfma } // warp in context @@ -1722,13 +1308,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( const int seq_idx = blockIdx.y; const int context_len = context_lens[seq_idx]; const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); -#if 0 //disable this as mfma16 kernel does not support this optimization yet - if (num_partitions == 1) { - // if num_partitions==1, main kernel will write to out directly, no work in - // reduction kernel - return; - } -#endif + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; const int warpid = threadIdx.x / WARP_SIZE; const int laneid = threadIdx.x % WARP_SIZE; @@ -1893,8 +1473,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( const float inv_global_exp_sum = __fdividef(1.0f, shared_global_exp_sum + 1e-6f); - // const float out_scale = (fp8_out_scale_ptr != nullptr) ? - // __fdividef(1.0f,(*fp8_out_scale_ptr)) : 1.0f; + const float out_scale = (fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f; acc *= inv_global_exp_sum; @@ -1931,7 +1510,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_ scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, // head_size] OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] - int max_ctx_blocks, float k_scale, float v_scale, + int max_ctx_blocks, const float* k_scale_ptr, const float* v_scale_ptr, const float* __restrict__ fp8_out_scale_ptr) { UNREACHABLE_CODE } @@ -1989,7 +1568,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \ - k_scale, v_scale, fp8_out_scale_ptr); + k_scale_ptr, v_scale_ptr, fp8_out_scale_ptr); #define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \ paged_attention_ll4mi_QKV_kernel partition_size } @@ -2223,6 +1784,7 @@ void paged_attention_custom_launcher( CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T); \ } #endif + #define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \ switch (block_size) { \ case 16: \ @@ -2248,6 +1810,7 @@ void paged_attention_custom_launcher( TORCH_CHECK(false, "Unsupported head size: ", head_size); \ break; \ } + void paged_attention( torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 13c92dadcd4e..ebe41ba2d40f 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -19,31 +19,36 @@ # This will change depending on the compute capability. # - 512 as a buffer MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 + # There may not be enough gpu memory due to large NUM_BLOCKS. # Reduce NUM_BLOCKS when it happens. +NUM_BLOCKS = 128*1024+4321 # Arbitrary values for testing NUM_BLOCKS = 4321 # Arbitrary values for testing PARTITION_SIZE = 512 +PARTITION_SIZE_ROCM = 256 # flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16} DTYPES = [ torch.half, torch.bfloat16, torch.float -] if not current_platform.is_rocm() else [torch.half, torch.bfloat16] -NUM_GEN_SEQS = [7] # Arbitrary values for testing +] if not current_platform.is_rocm() else [torch.half,torch.bfloat16] +NUM_GEN_SEQS = [17] # Arbitrary values for testing NUM_PREFILL_SEQS = [3] # Arbitrary values for testing +# NUM_HEADS = [(64, 8), (26,2), (16,1), (32,32)] # Arbitrary values for testing NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing # FlashAttention forward only supports head dimension at most 128 # https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 -HEAD_SIZES = [64, 80, 120, 256] +HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256] +if current_platform.is_rocm(): + HEAD_SIZES = [128, 64] BLOCK_SIZES = [16, 32] -USE_ALIBI = [False, True] +USE_ALIBI = [False] KV_CACHE_DTYPE = ["auto", "fp8"] SEEDS = [0] CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] - def ref_masked_attention( query: torch.Tensor, key: torch.Tensor, @@ -117,7 +122,7 @@ def ref_single_query_cached_kv_attention( @pytest.mark.parametrize( "version", - ["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"]) + ["v1", "v2"] if not current_platform.is_rocm() else ["rocm"]) @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -182,7 +187,7 @@ def test_paged_attention( key_cache, value_cache = key_caches[0], value_caches[0] # Using default kv_scale - k_scale = v_scale = 1.0 + k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32) # Call the paged attention kernel. output = torch.empty_like(query) @@ -213,7 +218,7 @@ def test_paged_attention( elif version in ("v2", "rocm"): if current_platform.is_rocm(): - PARTITION_SIZE = 1024 if version == "v2" else 512 + PARTITION_SIZE = 256 if version == "v2" else PARTITION_SIZE_ROCM num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) assert PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape @@ -248,13 +253,13 @@ def test_paged_attention( v_scale, ) - opcheck(torch.ops._C.paged_attention_v2, + '''opcheck(torch.ops._C.paged_attention_v2, (output, exp_sums, max_logits, tmp_output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), cond=(head_size == HEAD_SIZES[0] - and block_size == BLOCK_SIZES[0])) + and block_size == BLOCK_SIZES[0]))''' else: ops.paged_attention_rocm( @@ -275,15 +280,17 @@ def test_paged_attention( kv_cache_dtype, k_scale, v_scale, + None, + PARTITION_SIZE, ) - opcheck(torch.ops._rocm_C.paged_attention, + '''opcheck(torch.ops._rocm_C.paged_attention, (output, exp_sums, max_logits, tmp_output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, k_scale, v_scale), + kv_cache_dtype, k_scale, v_scale, None, PARTITION_SIZE), cond=(head_size == HEAD_SIZES[0] - and block_size == BLOCK_SIZES[0])) + and block_size == BLOCK_SIZES[0]))''' else: raise AssertionError(f"Unknown version: {version}") @@ -298,14 +305,14 @@ def test_paged_attention( dtype=dtype, device=device) ops.convert_fp8(dequantized_key_cache, key_cache) - key_cache = dequantized_key_cache + key_cache = k_scale * dequantized_key_cache value_cache_shape = value_cache.shape dequantized_value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device) ops.convert_fp8(dequantized_value_cache, value_cache) - value_cache = dequantized_value_cache + value_cache = v_scale * dequantized_value_cache ref_output = torch.empty_like(query) ref_single_query_cached_kv_attention( @@ -328,9 +335,13 @@ def test_paged_attention( # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, # so we use a relaxed tolerance for the test. - atol, rtol = 1e-3, 1e-5 + atol, rtol = 1e-4, 1e-5 if kv_cache_dtype == "fp8": - atol, rtol = 1e-2, 1e-5 + atol, rtol = 5e-4, 1e-5 + #bf16 rounding is handled via truncation in new kernel, this increses error + if dtype == torch.bfloat16: + atol = 1e-3 + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) @@ -433,4 +444,4 @@ def test_multi_query_kv_attention( ) atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3 rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5 - torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) \ No newline at end of file From 7267062d004dabb2333e0a4a704e7ad090051f8b Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 20 Jan 2025 06:43:54 +0000 Subject: [PATCH 13/15] fix unit test for rocm custom attention kernel --- tests/kernels/test_attention.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 1667b458170e..18443daf44c8 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -42,7 +42,7 @@ HEAD_SIZES = [128, 64] BLOCK_SIZES = [16, 32] -USE_ALIBI = [False] +USE_ALIBI = [False, True] KV_CACHE_DTYPE = ["auto", "fp8"] SEEDS = [0] CUDA_DEVICES = [ @@ -253,13 +253,13 @@ def test_paged_attention( v_scale, ) - '''opcheck(torch.ops._C.paged_attention_v2, + opcheck(torch.ops._C.paged_attention_v2, (output, exp_sums, max_logits, tmp_output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), cond=(head_size == HEAD_SIZES[0] - and block_size == BLOCK_SIZES[0]))''' + and block_size == BLOCK_SIZES[0])) else: ops.paged_attention_rocm( @@ -284,13 +284,13 @@ def test_paged_attention( PARTITION_SIZE, ) - '''opcheck(torch.ops._rocm_C.paged_attention, + opcheck(torch.ops._rocm_C.paged_attention, (output, exp_sums, max_logits, tmp_output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, None, PARTITION_SIZE), cond=(head_size == HEAD_SIZES[0] - and block_size == BLOCK_SIZES[0]))''' + and block_size == BLOCK_SIZES[0])) else: raise AssertionError(f"Unknown version: {version}") @@ -335,12 +335,8 @@ def test_paged_attention( # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, # so we use a relaxed tolerance for the test. - atol, rtol = 1e-4, 1e-5 if kv_cache_dtype == "fp8": - atol, rtol = 5e-4, 1e-5 - #bf16 rounding is handled via truncation in new kernel, this increses error - if dtype == torch.bfloat16: - atol = 1e-3 + atol, rtol = 1e-2, 1e-5 torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) From b8e66a904abc22c805b3870a386ce497559cdba4 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Wed, 22 Jan 2025 03:54:53 +0000 Subject: [PATCH 14/15] fix benchmark paged attention --- .../kernels/benchmark_paged_attention.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 5c4643dfe9b4..4ce6955dee60 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -9,8 +9,9 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, create_kv_caches_with_random) -NUM_BLOCKS = 1024 * 1024 -PARTITION_SIZE = 256 +NUM_BLOCKS = 128 * 1024 +PARTITION_SIZE = 512 +PARTITION_SIZE_ROCM = 256 @torch.inference_mode() @@ -78,9 +79,12 @@ def main( # Prepare for the paged attention kernel. output = torch.empty_like(query) if version == "v2": - if current_platform.is_rocm() and not args.custom_paged_attn: + if current_platform.is_rocm(): global PARTITION_SIZE - PARTITION_SIZE = 1024 + if not args.custom_paged_attn: + PARTITION_SIZE = 1024 + else: + PARTITION_SIZE = PARTITION_SIZE_ROCM num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) tmp_output = torch.empty( size=(num_seqs, num_query_heads, num_partitions, head_size), @@ -101,7 +105,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: start_time = time.perf_counter() # Using default kv_scale - k_scale = v_scale = 1.0 + k_scale = v_scale = torch.tensor(1.0, dtype=torch.float, device=device) for _ in range(num_iters): if version == "v1": @@ -162,7 +166,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: k_scale, v_scale, None, - PARTITION_SIZE + PARTITION_SIZE, ) else: raise ValueError(f"Invalid version: {version}") @@ -238,4 +242,4 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: seed=args.seed, do_profile=args.profile, kv_cache_dtype=args.kv_cache_dtype, - ) + ) \ No newline at end of file From 0f6ff75d424a2ac725f3cbcbff9b31033cdc3959 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Thu, 23 Jan 2025 07:17:32 +0000 Subject: [PATCH 15/15] [Bugfix]: fix v1/v2 paged attention kernel unit test. --- tests/kernels/test_attention.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 18443daf44c8..5bbf7e3cfff4 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -7,7 +7,7 @@ from tests.kernels.utils import opcheck from vllm import _custom_ops as ops from vllm.platforms import current_platform -from vllm.utils import get_max_shared_memory_bytes +from vllm.utils import get_max_shared_memory_bytes, is_navi from .allclose_default import get_default_atol, get_default_rtol @@ -122,7 +122,7 @@ def ref_single_query_cached_kv_attention( @pytest.mark.parametrize( "version", - ["v1", "v2"] if not current_platform.is_rocm() else ["rocm"]) + ["v1", "v2"] if not current_platform.is_rocm() else ["v1","v2","rocm"]) @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -189,6 +189,10 @@ def test_paged_attention( # Using default kv_scale k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32) + # additional argument for v1/v2 pa kernel + num_threads = 1024 if current_platform.is_rocm() \ + and not is_navi() else 128 + # Call the paged attention kernel. output = torch.empty_like(query) if version == "v1": @@ -212,7 +216,7 @@ def test_paged_attention( opcheck(torch.ops._C.paged_attention_v1, (output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), + kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0, num_threads), cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0])) @@ -257,7 +261,7 @@ def test_paged_attention( (output, exp_sums, max_logits, tmp_output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), + kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0, num_threads), cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]))