From 42d4c5c1faf1d95a6acf8822653b3deb75ba88a6 Mon Sep 17 00:00:00 2001 From: shsanyal Date: Fri, 17 Jan 2025 18:25:36 +0000 Subject: [PATCH 1/8] integrate new cpa kernel, update tests and benchmark --- .../kernels/benchmark_paged_attention.py | 18 +- csrc/rocm/attention.cu | 1109 ++++++++++++----- tests/kernels/test_attention.py | 14 +- 3 files changed, 811 insertions(+), 330 deletions(-) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 483584dd804e..66d257497f91 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -9,8 +9,9 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, create_kv_caches_with_random) -NUM_BLOCKS = 1024 * 1024 +NUM_BLOCKS = 128 * 1024 PARTITION_SIZE = 512 +PARTITION_SIZE_ROCM = 256 @torch.inference_mode() @@ -78,9 +79,12 @@ def main( # Prepare for the paged attention kernel. output = torch.empty_like(query) if version == "v2": - if current_platform.is_rocm() and not args.custom_paged_attn: + if current_platform.is_rocm(): global PARTITION_SIZE - PARTITION_SIZE = 1024 + if not args.custom_paged_attn: + PARTITION_SIZE = 1024 + else: + PARTITION_SIZE = PARTITION_SIZE_ROCM num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) tmp_output = torch.empty( size=(num_seqs, num_query_heads, num_partitions, head_size), @@ -101,7 +105,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: start_time = time.perf_counter() # Using default kv_scale - k_scale = v_scale = 1.0 + k_scale = v_scale = torch.tensor(1.0, dtype=torch.float, device=device) for _ in range(num_iters): if version == "v1": @@ -161,6 +165,8 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: kv_cache_dtype, k_scale, v_scale, + None, + PARTITION_SIZE, ) else: raise ValueError(f"Invalid version: {version}") @@ -174,13 +180,13 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: # Warmup. print("Warming up...") run_benchmark = run_cuda_benchmark - run_benchmark(num_iters=3, profile=False) + run_benchmark(num_iters=500, profile=False) # Benchmark. if do_profile: latency = run_benchmark(num_iters=1, profile=True) else: - latency = run_benchmark(num_iters=1000, profile=False) + latency = run_benchmark(num_iters=10000, profile=False) print(f"Kernel running time: {latency * 1000000:.3f} us") diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index ab8edd6d0f57..4e4f5d7fb41e 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -51,6 +51,9 @@ using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; using float16x4 = __attribute__((__vector_size__(4 * sizeof(_Float16)))) _Float16; typedef float16x4 _Half4; +using float16x2 = + __attribute__((__vector_size__(2 * sizeof(_Float16)))) _Float16; +typedef float16x2 _Half2; typedef struct _Half8 { _Half4 xy[2]; } _Half8; @@ -63,22 +66,33 @@ typedef struct _B16x8 { } _B16x8; using _B8x8 = uint2; +using _B8x4 = int32_t; //used in builtins using bit8_t = uint8_t; -////// Non temporal load stores /////// +typedef struct _B8x16 { + _B8x8 xy[2]; +} _B8x16; +////// Non temporal loads /////// template -__device__ __forceinline__ T load(T* addr) { - return addr[0]; +__device__ __forceinline__ T loadnt(T* addr) { + return __builtin_nontemporal_load(addr); } -template -__device__ __forceinline__ void store(T value, T* addr) { - addr[0] = value; +__device__ __forceinline__ _B16x8 load_ntmprl_16Byte(const _B16x8* addr) { + auto addr_alias = reinterpret_cast(addr); + auto dat0 = loadnt(addr_alias); + auto dat1 = loadnt(addr_alias + 1); + auto dat2 = loadnt(addr_alias + 2); + auto dat3 = loadnt(addr_alias + 3); + auto res = make_float4(dat0,dat1,dat2,dat3); + return *reinterpret_cast<_B16x8*>(&res); } +/////////////////////////////////// + template -__device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA, +__device__ __forceinline__ floatx4 gcn_mfma4x4x4_instr(const _B16x4& inpA, const _B16x4& inpB, const floatx4& inpC) { if constexpr (std::is_same::value) { @@ -92,6 +106,21 @@ __device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA, } } +template +__device__ __forceinline__ floatx4 gcn_mfma16x16x16_instr(const _B16x4& inpA, + const _B16x4& inpB, + const floatx4& inpC) { + if constexpr (std::is_same::value) { + return __builtin_amdgcn_mfma_f32_16x16x16f16(inpA, inpB, inpC, absz, cbid, + blgp); + } else if constexpr (std::is_same::value) { + return __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(inpA, inpB, inpC, absz, cbid, + blgp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + template __device__ __forceinline__ float to_float(const T& inp) { if constexpr (std::is_same::value) { @@ -140,17 +169,22 @@ __device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) { } t16; _B16x4 ret; if constexpr (std::is_same::value) { - #pragma unroll - for (int i = 0; i < 4; i++) { - t16.f = (_Float16)inp[i]; - ret[i] = t16.u; - } - return ret; + union h2cvt { + __half2 h2[2]; + _B16x4 b16x4; + } u; + u.h2[0] = __float22half2_rn(make_float2(inp[0],inp[1])); + u.h2[1] = __float22half2_rn(make_float2(inp[2],inp[3])); + return u.b16x4; } else if constexpr (std::is_same::value) { - #pragma unroll for (int i = 0; i < 4; i++) { - t16.b = __float2bfloat16(inp[i]); - ret[i] = t16.u; + union fcvt { + uint32_t u32; + float f32; + } u; + u.f32 = inp[i]; + u.u32 += 0x7fff + ((u.u32 >> 16) & 1); //BF16 RNE with no nan/inf check + ret[i] = uint16_t(u.u32 >> 16); } return ret; } else { @@ -168,21 +202,25 @@ __device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, } t1, t2, res; _B16x4 ret; if constexpr (std::is_same::value) { - #pragma unroll - for (int i = 0; i < 4; i++) { - t1.u = inp1[i]; - t2.u = inp2[i]; - res.f = t1.f + t2.f; - ret[i] = res.u; - } - return ret; + union h2cvt { + _B16x4 b16x4; + __half2 h2[2]; + } u1,u2,s; + u1.b16x4 = inp1; + u2.b16x4 = inp2; + s.h2[0] = u1.h2[0] + u2.h2[0]; + s.h2[1] = u1.h2[1] + u2.h2[1]; + return s.b16x4; } else if constexpr (std::is_same::value) { - #pragma unroll for (int i = 0; i < 4; i++) { - t1.u = inp1[i]; - t2.u = inp2[i]; - res.b = t1.b + t2.b; - ret[i] = res.u; + union fcvt { + float f32; + uint32_t i32; + } u1,u2,s; + u1.i32 = uint32_t(inp1[i])<<16; + u2.i32 = uint32_t(inp2[i])<<16; + s.f32 = u1.f32 + u2.f32; + ret[i] = uint16_t(s.i32>>16); } return ret; } else { @@ -210,15 +248,525 @@ __device__ __forceinline__ _B16x8 scaled_convert_b8x8(const _B8x8 input, } } +template +__device__ __forceinline__ _B16x8 scaled_convert_b8x8_custom(const _B8x8 input, + const float scale) { + union { + floatx4 f32x4[2]; + vllm::Float8_ f32x8; + } tmpf8; + tmpf8.f32x8 = vllm::fp8::vec_conversion(*reinterpret_cast(&input)); + + tmpf8.f32x4[0] *= scale; + tmpf8.f32x4[1] *= scale; + + _B16x8 ret; + ret.xy[0] = from_floatx4(tmpf8.f32x4[0]); + ret.xy[1] = from_floatx4(tmpf8.f32x4[1]); + return ret; +} + +__device__ __forceinline__ floatx4 to_float_fp8x4(const _B8x4& inp) { +#if defined(__gfx90a__) + float4 f32x4 = vllm::fp8::vec_conversion(*reinterpret_cast(&inp)); + return *reinterpret_cast(&f32x4); +#else //MI3xx+ optimized builtins + const auto f0 = __builtin_amdgcn_cvt_pk_f32_fp8(inp, false); + const auto f1 = __builtin_amdgcn_cvt_pk_f32_fp8(inp, true); + floatx4 ret; + ret[0] = f0[0]; + ret[1] = f0[1]; + ret[2] = f1[0]; + ret[3] = f1[1]; + return ret; +#endif +} + +template +__device__ __forceinline__ _B16x4 from_floatx4_rtz(const floatx4& inp) { + _B16x4 ret; + if constexpr (std::is_same::value) { + union h2cvt { + _Half2 h2[2]; + _B16x4 b16x4; + } u; + u.h2[0] = __builtin_amdgcn_cvt_pkrtz(inp[0],inp[1]); + u.h2[1] = __builtin_amdgcn_cvt_pkrtz(inp[2],inp[3]); + return u.b16x4; + } else if constexpr (std::is_same::value) { + for (int i = 0; i < 4; i++) { + union fcvt { + uint32_t i32; + float f32; + } u; + u.f32 = inp[i]; + ret[i] = uint16_t(u.i32 >> 16); + } + return ret; + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ _B16x8 convert_b8x8_custom(const _B8x8 input) { + union { + _B8x8 b8x8; + _B8x4 b8x4[2]; + } tmp; + tmp.b8x8 = input; + _B16x8 ret; + for (int i=0; i<2; i++) { + ret.xy[i] = from_floatx4_rtz( to_float_fp8x4(tmp.b8x4[i]) ); + } + return ret; +} + /////////////////////////////////////// +// grid (num_seqs, num_partitions,num_kv_heads) +// block (256) +template +__global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma16_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, const float* k_scale_ptr, const float* v_scale_ptr, + const float* __restrict__ fp8_out_scale_ptr) { + constexpr int NWARPS = NUM_THREADS / WARP_SIZE; + const int warpid = threadIdx.x / WARP_SIZE; + const int laneid = threadIdx.x % WARP_SIZE; + const int lane4id = laneid % 4; + const int lane16id = laneid % 16; + const int rowid = laneid / 16; + + const int seq_idx = blockIdx.x; + const int partition_idx = blockIdx.y; + + constexpr int T_PAR_SIZE = 256; //token partition size set to 256 + + const int max_num_partitions = gridDim.y; + + const int context_len = context_lens[seq_idx]; + + const int partition_start_token_idx = partition_idx * T_PAR_SIZE; //partition_size; + // exit if partition is out of context for seq + if (partition_start_token_idx >= context_len) { + return; + } + + constexpr int GQA_RATIO4 = DIVIDE_ROUND_UP(GQA_RATIO,4); + + __shared__ float shared_qk_max[NWARPS][16 + 1]; + __shared__ float shared_exp_sum[NWARPS][16 + 1]; + //shared_logits is used for multiple purposes + __shared__ _B16x4 shared_logits[NWARPS][4][16][4]; + + //for QK mfma16x16, layout is QHead/Tokenx16 across every 16 lanes, 16 Bytes HeadElements in each lane, 4x16B HeadElements across 4 rows of warp + constexpr int ROWS_PER_WARP = WARP_SIZE / 16; //rows refers to 16 lanes; refer dpp terminology + constexpr int CONTIGUOUS_KV_ELEMS_16B_LOAD = 16 / sizeof(cache_t); //8 for 16 bit cache type, 16 for 8 bit types + constexpr int QKHE_PER_FETCH = CONTIGUOUS_KV_ELEMS_16B_LOAD * ROWS_PER_WARP; //each fetch across a warp fetches these many elements + constexpr int QK_SIZE_RATIO = sizeof(scalar_t) / sizeof(cache_t); //1 for 16bit types, 2 for 8bit types + constexpr int QKHELOOP = HEAD_SIZE / QKHE_PER_FETCH; //4xQKHE_16B across warp + + _B16x8 Qlocal[QKHELOOP][QK_SIZE_RATIO]; //note that 16 contiguous elements of Q should be fetched per lane for 8 bit cache types : QK_SIZE_RATIO changes for this + + constexpr int CONTIGUOUS_SCALAR_ELEMS_16B = 16 / sizeof(scalar_t); + + constexpr int TOKENS_PER_WARP = T_PAR_SIZE / NWARPS; //sub partition of tokens per warp for qk calculation + constexpr int TLOOP = TOKENS_PER_WARP / 16; //each mfma16x16x16 instruction processes 16 tokens + + _B16x8 Klocal[TLOOP][QKHELOOP]; //this could be B8x16 too + + const int wg_start_head_idx = blockIdx.z * GQA_RATIO; + const int wg_start_kv_head_idx = blockIdx.z; + const int total_num_heads = gridDim.z * GQA_RATIO; + + //for QK mfma, tokens in multiples of TOKENS_PER_WARP are spread across warps + //each mfma takes QH16xT16x16HE across warp + //repeat mfmas across QKHELOOP dimension + //output layout from QKmfma : QH16xT4x4 16 qheads across 16 lanes, 16 tokens across 4 rowsx4 tokens per lane + + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int last_ctx_block = num_context_blocks - 1; + + const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq; + + int kphysical_block_number[TLOOP]; + + //fetch k physical block numbers + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int klocal_token_idx = TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; + const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; + const int kblock_idx = (kglobal_token_idx < context_len) + ? kglobal_token_idx / BLOCK_SIZE + : last_ctx_block; + kphysical_block_number[token_depth] = block_table_seq[kblock_idx]; + } + + //fetch Q in shared across warps and then write to registers + const int local_qhead_idx = 4 * warpid + rowid; + const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; + const int64_t seq_idx64 = static_cast(seq_idx); + const scalar_t* q_ptr = q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE; + + const int qhead_element = lane16id * CONTIGUOUS_SCALAR_ELEMS_16B; + if ( (local_qhead_idx < GQA_RATIO) && (qhead_element(q_fetch_ptr); + _B16x8 tmp = *q_fetch_ptr_16B; + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + const int offset1 = lane16id/4; //16 contiguous chunks of head elems are spread across 4x4lanes + shared_logits[offset1][lane4id][local_qhead_idx][0] = tmp.xy[0]; + shared_logits[offset1][lane4id][local_qhead_idx][1] = tmp.xy[1]; + } else { + for (int i=0; i<2; i++) { + const int head_elem = lane16id * 2 + i; //element id in _B16x4 terms + const int offset3 = head_elem % 4; + const int offset2 = (head_elem / 4) % 4; + const int offset1 = head_elem /4/4; + shared_logits[offset1][offset2][local_qhead_idx][offset3] = tmp.xy[i]; + } + } + } + __syncthreads(); + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { + for (int i=0; i<2; i++) { + Qlocal[qkhe_depth][qkratio].xy[i] = shared_logits[qkhe_depth][rowid][lane16id % GQA_RATIO][2*qkratio + i]; + } + } + } + + //set to true to enable non temporal kv loads: has some benefit in very high batch size cases + constexpr bool NT_KV_LOAD = false; + + constexpr int KX = 16 / sizeof(cache_t); //vLLM defines x as 16 Bytes of kv cache elements + const cache_t* k_ptr = k_cache + wg_start_kv_head_idx * kv_head_stride; + + const int row_head_elem = rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD; + //fetch K values + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int64_t kblock_number = static_cast(kphysical_block_number[token_depth]); + const cache_t* k_ptr2 = k_ptr + kblock_number * kv_block_stride; + const int klocal_token_idx = TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; + const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; + const int kphysical_block_offset = klocal_token_idx % BLOCK_SIZE; + const cache_t* k_ptr3 = k_ptr2 + kphysical_block_offset * KX; + + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + const int head_elem = row_head_elem + qkhe_depth * QKHE_PER_FETCH; + const int offset1 = head_elem / KX; + const int offset2 = head_elem % KX; + const cache_t* k_fetch_ptr = k_ptr3 + offset1 * BLOCK_SIZE * KX + offset2; + const _B16x8* k_fetch_ptr_16B = reinterpret_cast(k_fetch_ptr); + if constexpr(NT_KV_LOAD) { + Klocal[token_depth][qkhe_depth] = load_ntmprl_16Byte(k_fetch_ptr_16B); + } else { + Klocal[token_depth][qkhe_depth] = *k_fetch_ptr_16B; + } + } + } + + float alibi_slope; + if constexpr(ALIBI_ENABLED) { + const int alibi_head_idx = wg_start_head_idx + lane16id; + alibi_slope = (lane16id < GQA_RATIO) ? alibi_slopes[alibi_head_idx] : 0.f; + } + + constexpr int VTOKENS_PER_LANE = TOKENS_PER_WARP / ROWS_PER_WARP; //16 tokens per lane + constexpr int VBLOCKS_PER_LANE = 1; //assumes block size >=16, each lane can correspond to 1 block only + constexpr int VTLOOP = NWARPS; //corresponds to tokens across warps + constexpr int VTLANELOOP = DIVIDE_ROUND_UP(VTOKENS_PER_LANE , CONTIGUOUS_KV_ELEMS_16B_LOAD); //optimized for 16B fetches; assumes minimum block size is 16 + constexpr int VHELOOP = HEAD_SIZE / 16 / NWARPS; + + int vphysical_block_number[VTLOOP][VBLOCKS_PER_LANE]; + + //fetch v physical block numbers + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE; vblock_depth++) { + const int vlocal_token_idx = vtoken_depth * VTOKENS_PER_LANE * ROWS_PER_WARP + rowid * VTOKENS_PER_LANE + vblock_depth * BLOCK_SIZE; + const int vglobal_token_idx = partition_start_token_idx + vlocal_token_idx; + const int vblock_idx = (vglobal_token_idx < context_len) + ? vglobal_token_idx / BLOCK_SIZE + : last_ctx_block; + vphysical_block_number[vtoken_depth][vblock_depth] = + block_table_seq[vblock_idx]; + } + } + + _B16x8 Vlocal[VTLOOP][VHELOOP][VTLANELOOP]; //this could be B8x16 too + + const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride + ((rowid * VTOKENS_PER_LANE)%BLOCK_SIZE); + + //v fetches are 16head elems across lanes x 16 tokens per lane + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + const int vhead_elem = vhe_depth * NWARPS * 16 + warpid * 16 + lane16id; + const cache_t* v_ptr2 = v_ptr + vhead_elem * BLOCK_SIZE; + + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + const int vblock_depth = 0; + const int64_t vblock_number = static_cast(vphysical_block_number[vtoken_depth][vblock_depth]); + const cache_t* v_ptr3 = v_ptr2 + (vblock_number * kv_block_stride); + + const cache_t* v_fetch_ptr = v_ptr3 + vfetch_depth * CONTIGUOUS_KV_ELEMS_16B_LOAD; + const _B16x8* v_fetch_ptr_16B = reinterpret_cast(v_fetch_ptr); + if constexpr(NT_KV_LOAD) { + Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = load_ntmprl_16Byte(v_fetch_ptr_16B); + } else { + Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = *v_fetch_ptr_16B; + } + } + } + } + + //calculate post qk mfma scale + float scale2 = scale; + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + //multiply by k_scale if fp8 kv cache + scale2 *= *k_scale_ptr; + } + + floatx4 dout[TLOOP]; + //qk mfma + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + dout[token_depth] = {0}; + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { + for (int i=0; i<2; i++) { + dout[token_depth] = gcn_mfma16x16x16_instr(Klocal[token_depth][qkhe_depth].xy[i], + Qlocal[qkhe_depth][qkratio].xy[i], dout[token_depth]); + } + } + } else { //kv cache dtype fp8 + auto Ktmp = Klocal[token_depth][qkhe_depth]; + _B8x16 Ktmp8x16 = *reinterpret_cast<_B8x16*>(&Ktmp); + for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { + _B8x8 Ktmp8x8 = Ktmp8x16.xy[qkratio]; + _B16x8 Klocaltmp = convert_b8x8_custom(Ktmp8x8); + for (int i=0; i<2; i++) { + dout[token_depth] = gcn_mfma16x16x16_instr(Klocaltmp.xy[i], + Qlocal[qkhe_depth][qkratio].xy[i], dout[token_depth]); + } + } + } + } + dout[token_depth] *= scale2; + } + + const int qkout_token_idx = partition_start_token_idx + TOKENS_PER_WARP * warpid + rowid * 4; + + //apply alibi + if constexpr(ALIBI_ENABLED) { + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + const int alibi_offset = local_token_idx - context_len + 1; + for (int i=0; i<4; i++) { + dout[token_depth][i] += alibi_slope * (alibi_offset + i); + } + } + } + + //calculate qk_max and exp_sum per warp and write to shared memory + float qk_max = -FLT_MAX; + float exp_sum = 0.0f; + + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + for (int i=0; i<4; i++) { + const float tmp = (local_token_idx + i < context_len) ? dout[token_depth][i] : -FLT_MAX; + qk_max = fmaxf(qk_max, tmp); + } + } + + for (int mask = WARP_SIZE/2; mask >= 16; mask/=2) { + qk_max = fmaxf(qk_max, __shfl_xor(qk_max,mask)); + } + + + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + for (int i=0; i<4; i++) { + const float tmp = (local_token_idx + i < context_len) ? __expf(dout[token_depth][i] - qk_max) : 0.0f; + dout[token_depth][i] = tmp; + exp_sum += tmp; + } + } + + for (int mask = WARP_SIZE/2; mask >= 16; mask/=2) { + exp_sum += __shfl_xor(exp_sum,mask); + } + + __syncthreads(); //sync before writing to shared mem + + float* shared_mem = reinterpret_cast(shared_logits); + if (laneid < 16) { + const int qk_max_offset = warpid*16 + lane16id; + shared_mem[qk_max_offset] = qk_max; + const int exp_sum_offset = NWARPS*16 + qk_max_offset; + shared_mem[exp_sum_offset] = exp_sum; + } + + __syncthreads(); + + //calculate partition qk_max and exp_sum + float partition_qk_max = -FLT_MAX; + float warp_qk_max_exp[NWARPS]; + float partition_exp_sum = 0.0f; + + for (int w=0; w(dout[token_depth]); + } + //write out partition max_logits and exp_sum + if (threadIdx.x < GQA_RATIO) { + const int qhead_idx = lane16id; + const int offset = seq_idx * total_num_heads * max_num_partitions + (wg_start_head_idx + qhead_idx) * max_num_partitions + partition_idx; + max_logits[offset] = partition_qk_max; + exp_sums[offset] = partition_exp_sum; + } + + __syncthreads(); + + _B16x4 outelems[VHELOOP]; + //Softmax V mfma + //v layout: 16he across lanes x 16 tokens per lane + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + floatx4 tmp_out = {0}; + + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + for (int i=0; i<2; i++) { + //generalize this for 8 bit dtypes: each lane needs 2*vfetch_depth + 2 _B16x4 K/token dimension elems; each row is multiplied by a factor of 4 + //layout: lane in depth dimension | row across -> + //0 4 8 12 + //1 5 9 13 + //2 6 10 14 + //3 7 11 15 + const int offset = rowid * VTLANELOOP * 2 + 2*vfetch_depth + i; + const int offset1 = offset % 4; //4 corresponds to ROWS_PER_WARP + const int offset2 = offset / 4; + //output format is 16 qheads across 16 lanes, 16 head elems spread across 4 rows + tmp_out = gcn_mfma16x16x16_instr(Vlocal[vtoken_depth][vhe_depth][vfetch_depth].xy[i], + shared_logits[vtoken_depth][offset2][lane16id][offset1], + tmp_out); + } + } + } else { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + _B16x8 Vtmp = Vlocal[vtoken_depth][vhe_depth][vfetch_depth]; + _B8x16 Vtmp8x16 = *reinterpret_cast<_B8x16*>(&Vtmp); + for (int j=0; j<2; j++) { + _B8x8 Vtmp8x8 = Vtmp8x16.xy[j]; + _B16x8 Vlocaltmp = convert_b8x8_custom(Vtmp8x8); + for (int i=0; i<2; i++) { + const int offset = 4*rowid + 2*j + i; + const int offset1 = offset % 4; + const int offset2 = offset / 4; + //output format is 16 qheads across 16 lanes, 16 head elems spread across 4 rows + tmp_out = gcn_mfma16x16x16_instr(Vlocaltmp.xy[i], + shared_logits[vtoken_depth][offset2][lane16id][offset1], + tmp_out); + } + } + } + + } + } + //apply post Softmax V mfma v_scale + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + tmp_out *= *v_scale_ptr; + } + outelems[vhe_depth] = from_floatx4(tmp_out); + } + + __syncthreads(); + + //store Softmax-V mfma output to shared mem + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + //lane16 id head dimension; rowid head element dimension + shared_logits[warpid][vhe_depth][lane16id][rowid] = outelems[vhe_depth]; + } + + __syncthreads(); + + //write to tmp_out with coalesced writes + if (warpid == 0) { + _B16x8 vout[GQA_RATIO4]; + //each lane writes out 16Bytes of tmp_out along head elem dimension + const int head_elem_idx = lane16id * 8; + if (head_elem_idx < HEAD_SIZE) { + for (int h = 0; h < GQA_RATIO4; h++) { + const int local_head_idx = 4 * h + rowid; + const int offset1 = (head_elem_idx / 16)%4; + const int offset2 = head_elem_idx / 16 / NWARPS; + const int offset3 = (head_elem_idx / 4)%4; + for (int i=0; i<2; i++) { + vout[h].xy[i] = shared_logits[offset1][offset2][local_head_idx][offset3+i]; + } + } + + const int hsz_maxp_mult = HEAD_SIZE * max_num_partitions; + scalar_t* out_ptr = out + + seq_idx * total_num_heads * hsz_maxp_mult + partition_idx * HEAD_SIZE; + for (int h = 0; h < GQA_RATIO4; h++) { + const int local_head_idx = 4 * h + rowid; + if (local_head_idx < GQA_RATIO) { + const int out_head_idx = wg_start_head_idx + local_head_idx; + scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult; + scalar_t* out_ptr3 = out_ptr2 + head_elem_idx; + _B16x8* out_ptr_B16x8 = reinterpret_cast<_B16x8*>(out_ptr3); + *out_ptr_B16x8 = vout[h]; + } + } + } + } +} -// grid (num_seqs, num_partitions,num_heads/gqa_ratio) -// block (partition size) +///////////////////////////////////////////////////////////// +// grid (num_seqs, num_partitions, num_kv_heads) +// block (256 : partition size) template -__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( +__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, // head_size/x, block_size, x] @@ -255,8 +803,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( return; } constexpr int QHLOOP = - DIVIDE_ROUND_UP(GQA_RATIO, 4); // each 4 lanes fetch 4 different qheads, - // total qheads =8, so qhloop is 2 + DIVIDE_ROUND_UP(GQA_RATIO, 4); // each 4 lanes fetch 4 different qheads constexpr int GQA_RATIO4 = 4 * QHLOOP; __shared__ float shared_qk_max[NWARPS][GQA_RATIO4 + 1]; __shared__ float shared_exp_sum[NWARPS][GQA_RATIO4 + 1]; @@ -265,16 +812,21 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( constexpr int KHELOOP = HEAD_SIZE / x; _B16x8 Klocal[KHELOOP]; _B8x8 Klocalb8[KHELOOP]; + // v head_size dimension is distributed across warp constexpr int VHELOOP = HEAD_SIZE / - WARP_SIZE; // v head_size dimension is distributed across lanes + WARP_SIZE; constexpr int VTLOOP = 8; // 16 separate 4xtokens across warp -> 16/2 // 8xtokens + constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE; + int vphysical_blocks[VBLOCKS]; _B16x8 Vlocal[VHELOOP][VTLOOP]; _B8x8 Vlocalb8[VHELOOP][VTLOOP]; floatx4 dout[QHLOOP]; float qk_max[QHLOOP]; - #pragma unroll + + __shared__ _B16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1]; + for (int h = 0; h < QHLOOP; h++) { dout[h] = {0}; qk_max[h] = -FLT_MAX; @@ -302,28 +854,22 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( const int local_token_idx = threadIdx.x; const int global_token_idx = partition_start_token_idx + local_token_idx; + // fetch block number for k const int block_idx = (global_token_idx < context_len) ? global_token_idx / BLOCK_SIZE : last_ctx_block; - // fetch block number for q and k // int32 physical_block_number leads to overflow when multiplied with // kv_block_stride const int64_t physical_block_number = static_cast(block_table[block_idx]); // fetch vphysical block numbers up front - constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE; - int vphysical_blocks[VBLOCKS]; - const int warp_start_block_idx = warp_start_token_idx / BLOCK_SIZE; - if constexpr (GQA_RATIO < 12) { - #pragma unroll - for (int b = 0; b < VBLOCKS; b++) { + for (int b = 0; b < VBLOCKS; b++) { const int vblock_idx = warp_start_block_idx + b; const int vblock_idx_ctx = (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block; vphysical_blocks[b] = block_table[vblock_idx_ctx]; - } } // each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems @@ -331,7 +877,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE; const _B16x8* q_ptrh8 = reinterpret_cast(q_ptr); const int qhead_elemh8 = laneid / 4; - #pragma unroll + for (int h = 0; h < QHLOOP - 1; h++) { const int qhead_idx = h * 4 + lane4id; Qlocal[h] = q_ptrh8[qhead_idx * HEAD_SIZE / 8 + qhead_elemh8]; @@ -353,14 +899,13 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( // is already cast as _H8 if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { const _B16x8* k_ptrh8 = reinterpret_cast(k_ptr); - #pragma unroll for (int d = 0; d < KHELOOP; d++) { Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset]; } } else { + //vllm defines X as 16 Bytes of elements of cache_t constexpr int X = 16 / sizeof(cache_t); const cache_t* k_ptr2 = k_ptr + physical_block_offset * X; - #pragma unroll for (int d = 0; d < KHELOOP; d++) { const int head_elem = d * 8; const int offset1 = head_elem / X; @@ -371,8 +916,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } float alibi_slope[QHLOOP]; - if (alibi_slopes != nullptr) { - #pragma unroll + if constexpr(ALIBI_ENABLED) { for (int h = 0; h < QHLOOP; h++) { const int qhead_idx = h * 4 + lane4id; alibi_slope[h] = (qhead_idx < GQA_RATIO) @@ -381,22 +925,11 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } - // fetch vphysical block numbers up front - if constexpr (GQA_RATIO >= 12) { - #pragma unroll - for (int b = 0; b < VBLOCKS; b++) { - const int vblock_idx = warp_start_block_idx + b; - const int vblock_idx_ctx = - (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block; - vphysical_blocks[b] = block_table[vblock_idx_ctx]; - } - } - const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; + //fetch vcache in kv cache auto case if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { const _B16x8* v_ptrh8 = reinterpret_cast(v_ptr); // iterate over each v block - #pragma unroll for (int b = 0; b < VBLOCKS; b++) { // int32 physical_block_number leads to overflow when multiplied with // kv_block_stride @@ -405,21 +938,20 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( const _B16x8* v_ptrh8b = v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; // iterate over each head elem (within head_size) - #pragma unroll for (int h = 0; h < VHELOOP; h++) { const int head_size_elem = h * WARP_SIZE + laneid; const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; // iterate over all velems within block - #pragma unroll for (int d = 0; d < BLOCK_SIZE / 8; d++) { Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; } } } - } else { + } //if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) + //fetch vcache in fp8 case + else { // if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) const _B8x8* v_ptrh8 = reinterpret_cast(v_ptr); // iterate over each v block - #pragma unroll for (int b = 0; b < VBLOCKS; b++) { // int32 physical_block_number leads to overflow when multiplied with // kv_block_stride @@ -428,165 +960,133 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( const _B8x8* v_ptrh8b = v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; // iterate over each head elem (within head_size) - #pragma unroll for (int h = 0; h < VHELOOP; h++) { const int head_size_elem = h * WARP_SIZE + laneid; const _B8x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; // iterate over all velems within block - #pragma unroll for (int d = 0; d < BLOCK_SIZE / 8; d++) { - // Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; - const _B8x8 Vlocalb8 = v_ptrh8be[d]; - Vlocal[h][b * BLOCK_SIZE / 8 + d] = - scaled_convert_b8x8(Vlocalb8, *v_scale_ptr); + Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; } } } } - if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { - #pragma unroll - for (int d = 0; d < KHELOOP; d++) { - Klocal[d] = - scaled_convert_b8x8(Klocalb8[d], *k_scale_ptr); +#define QK_mfma(x) \ + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { \ + Klocal[x] = convert_b8x8_custom(Klocalb8[x]); \ + } \ + for (int h = 0; h < QHLOOP; h++) { \ + dout[h] = gcn_mfma4x4x4_instr(Qlocal[h].xy[0], \ + Klocal[x].xy[0], dout[h]);\ + dout[h] = gcn_mfma4x4x4_instr(Qlocal[h].xy[1], \ + Klocal[x].xy[1], dout[h]);\ + } + //QK mfma + QK_mfma(0); + QK_mfma(1); + QK_mfma(2); + QK_mfma(3); + QK_mfma(4); + QK_mfma(5); + QK_mfma(6); + QK_mfma(7); + //below only needed for head size 128 + if constexpr (KHELOOP > 8) { + QK_mfma(8); + QK_mfma(9); + QK_mfma(10); + QK_mfma(11); + QK_mfma(12); + QK_mfma(13); + QK_mfma(14); + QK_mfma(15); } +#undef QK_mfma + + float scale2 = scale; + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + scale2 *= *k_scale_ptr; } - #pragma unroll for (int h = 0; h < QHLOOP; h++) { - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[0].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[0].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[1].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[1].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[2].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[2].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[3].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[3].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[4].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[4].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[5].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[5].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[6].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[6].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[7].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[7].xy[1], dout[h]); - if constexpr (KHELOOP > 8) { - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[8].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[8].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[9].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[9].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[10].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[10].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[11].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[11].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[12].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[12].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[13].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[13].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[14].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[14].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[15].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[15].xy[1], dout[h]); - } // KHELOOP>8 - dout[h] *= scale; + dout[h] *= scale2; } + // transpose dout so that 4 token ids are in each lane, and 4 heads are across // 4 lanes - #pragma unroll for (int h = 0; h < QHLOOP; h++) { floatx4 tmp = {0}; - #pragma unroll for (int i = 0; i < 4; i++) { const float B = (lane4id == i) ? 1.0f : 0.0f; - // const float A = (global_token_idx < context_len) ? dout[h][i] : 0.0f; tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(dout[h][i], B, tmp, 0, 0, 0); - // tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(A, B, tmp, 0, 0, 0); } dout[h] = tmp; } const int lane4_token_idx = 4 * (global_token_idx >> 2); - const int alibi_offset = lane4_token_idx - context_len + 1; - if (alibi_slopes != nullptr) { - #pragma unroll + + if constexpr(ALIBI_ENABLED) { + const int alibi_offset = lane4_token_idx - context_len + 1; for (int h = 0; h < QHLOOP; h++) { - #pragma unroll for (int i = 0; i < 4; i++) { dout[h][i] += alibi_slope[h] * (alibi_offset + i); } } } - #pragma unroll + const int bpermute_mask = 4*(16*((laneid>>2)%4) + lane4id); + for (int h = 0; h < QHLOOP; h++) { qk_max[h] = -FLT_MAX; - #pragma unroll for (int i = 0; i < 4; i++) { qk_max[h] = (lane4_token_idx + i < context_len) ? fmaxf(qk_max[h], dout[h][i]) : qk_max[h]; } - #pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { - qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h], mask)); - } + + //for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { + // qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h], mask)); + //} + //faster version of above code with dpp + asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:4" : "=v"(qk_max[h]) : "v"(qk_max[h]), "v"(qk_max[h]) ); + asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:8" : "=v"(qk_max[h]) : "v"(qk_max[h]), "v"(qk_max[h]) ); + + auto tmp = __builtin_amdgcn_ds_bpermute(bpermute_mask, *reinterpret_cast(&qk_max[h])); + qk_max[h] = *reinterpret_cast(&tmp); + asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:4" : "=v"(qk_max[h]) : "v"(qk_max[h]), "v"(qk_max[h]) ); + asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:8" : "=v"(qk_max[h]) : "v"(qk_max[h]), "v"(qk_max[h]) ); } + float exp_sum[QHLOOP]; - #pragma unroll for (int h = 0; h < QHLOOP; h++) { exp_sum[h] = 0.0f; - #pragma unroll for (int i = 0; i < 4; i++) { dout[h][i] = (lane4_token_idx + i < context_len) ? __expf(dout[h][i] - qk_max[h]) : 0.0f; exp_sum[h] += dout[h][i]; } - #pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { - exp_sum[h] += __shfl_xor(exp_sum[h], mask); - } + //for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { + // exp_sum[h] += __shfl_xor(exp_sum[h], mask); + //} + //faster version of above code with dpp + asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:4" : "=v"(exp_sum[h]) : "v"(exp_sum[h]), "v"(exp_sum[h]) ); + asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:8" : "=v"(exp_sum[h]) : "v"(exp_sum[h]), "v"(exp_sum[h]) ); + + auto tmp = __builtin_amdgcn_ds_bpermute(bpermute_mask, *reinterpret_cast(&exp_sum[h])); + exp_sum[h] = *reinterpret_cast(&tmp); + asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:4" : "=v"(exp_sum[h]) : "v"(exp_sum[h]), "v"(exp_sum[h]) ); + asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:8" : "=v"(exp_sum[h]) : "v"(exp_sum[h]), "v"(exp_sum[h]) ); } - #pragma unroll + if (laneid<4) { for (int h = 0; h < QHLOOP; h++) { const int head_idx = 4 * h + lane4id; shared_qk_max[warpid][head_idx] = qk_max[h]; shared_exp_sum[warpid][head_idx] = exp_sum[h]; } + } } // warp within context __syncthreads(); @@ -596,18 +1096,15 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( max_logits + seq_idx * num_heads * max_num_partitions + partition_idx; float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + partition_idx; - #pragma unroll for (int h = 0; h < QHLOOP; h++) { float global_qk_max = -FLT_MAX; float warp_qk_max[NWARPS]; const int head_idx = 4 * h + lane4id; - #pragma unroll for (int w = 0; w < NWARPS; w++) { warp_qk_max[w] = shared_qk_max[w][head_idx]; global_qk_max = fmaxf(global_qk_max, warp_qk_max[w]); } float global_exp_sum = 0.0f; - #pragma unroll for (int w = 0; w < NWARPS; w++) { global_exp_sum += shared_exp_sum[w][head_idx] * __expf(warp_qk_max[w] - global_qk_max); @@ -625,83 +1122,66 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( // logits[h] -> every 4 lanes hold 4 heads, each lane holds 4 tokens, there // are 4x16 tokens across warp _B16x4 logits[QHLOOP]; - #pragma unroll for (int h = 0; h < QHLOOP; h++) { - logits[h] = from_floatx4(dout[h]); + logits[h] = from_floatx4_rtz(dout[h]); } - __shared__ _B16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1]; if (warp_start_token_idx >= context_len) { // warp out of context - #pragma unroll for (int qh = 0; qh < QHLOOP; qh++) { - #pragma unroll for (int vh = 0; vh < VHELOOP; vh++) { vout_shared[qh][vh][laneid][warpid] = {0}; } } } else { // warp in context - // iterate across heads - #pragma unroll - for (int qh = 0; qh < QHLOOP; qh++) { - // iterate over each v head elem (within head_size) - #pragma unroll + #define SV_mfma(x) \ + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) {\ + Vlocal[vh][x] = convert_b8x8_custom(Vlocalb8[vh][x]);\ + }\ + for (int qh = 0; qh < QHLOOP; qh++) { \ + acc[qh] = gcn_mfma4x4x4_instr(logits[qh], Vlocal[vh][x].xy[0], \ + acc[qh]); \ + acc[qh] = gcn_mfma4x4x4_instr(logits[qh], Vlocal[vh][x].xy[1], \ + acc[qh]); \ + } + for (int vh = 0; vh < VHELOOP; vh++) { - floatx4 acc = {0}; + floatx4 acc[QHLOOP]; + for (int qh = 0; qh < QHLOOP; qh++) { + acc[qh] = {0}; + } // iterate over tokens - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][0].xy[0], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][0].xy[1], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][1].xy[0], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][1].xy[1], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][2].xy[0], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][2].xy[1], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][3].xy[0], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][3].xy[1], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][4].xy[0], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][4].xy[1], - acc); - acc = gcn_mfma_instr(logits[qh], - Vlocal[vh][5].xy[0], acc); - acc = gcn_mfma_instr(logits[qh], - Vlocal[vh][5].xy[1], acc); - acc = gcn_mfma_instr(logits[qh], - Vlocal[vh][6].xy[0], acc); - acc = gcn_mfma_instr(logits[qh], - Vlocal[vh][6].xy[1], acc); - acc = gcn_mfma_instr(logits[qh], - Vlocal[vh][7].xy[0], acc); - acc = gcn_mfma_instr(logits[qh], - Vlocal[vh][7].xy[1], acc); - vout_shared[qh][vh][laneid][warpid] = from_floatx4(acc); + SV_mfma(0); + SV_mfma(1); + SV_mfma(2); + SV_mfma(3); + SV_mfma(4); + SV_mfma(5); + SV_mfma(6); + SV_mfma(7); + + for (int qh = 0; qh < QHLOOP; qh++) { + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + acc[qh] *= *v_scale_ptr; + } + vout_shared[qh][vh][laneid][warpid] = from_floatx4(acc[qh]); + } } - } + +#undef SV_mfma } // warp in context __syncthreads(); if (warpid == 0) { - // const float out_scale = (fp8_out_scale_ptr != nullptr) ? - // __fdividef(1.0f,(*fp8_out_scale_ptr)) : 1.0f; const float out_scale = (fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f; _B16x4 vout[QHLOOP][VHELOOP]; // iterate across heads - #pragma unroll for (int qh = 0; qh < QHLOOP; qh++) { // iterate over each v head elem (within head_size) - #pragma unroll for (int vh = 0; vh < VHELOOP; vh++) { vout[qh][vh] = {0}; - #pragma unroll for (int w = 0; w < NWARPS; w++) { vout[qh][vh] = addx4(vout[qh][vh], vout_shared[qh][vh][laneid][w]); @@ -709,58 +1189,20 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } - if (context_len > partition_size) { - scalar_t* out_ptr = out + - seq_idx * num_heads * max_num_partitions * HEAD_SIZE + - partition_idx * HEAD_SIZE; - const int out_num_partitions = max_num_partitions; - bit16_t* out_ptr_b16 = reinterpret_cast(out_ptr); - #pragma unroll - for (int qh = 0; qh < QHLOOP; qh++) { - #pragma unroll - for (int vh = 0; vh < VHELOOP; vh++) { - const int head_size_elem = vh * WARP_SIZE + laneid; - #pragma unroll - for (int i = 0; i < 4; i++) { - const int head_idx = 4 * qh + i; - if (head_idx < GQA_RATIO) { - out_ptr_b16[(wg_start_head_idx + head_idx) * out_num_partitions * - HEAD_SIZE + - head_size_elem] = vout[qh][vh][i]; - } - } - } - } - } // context_len > partition_size - else { - bit8_t* final_out_ptr_b8; - bit16_t* final_out_ptr_b16; - if constexpr (std::is_same::value) { - final_out_ptr_b8 = final_out + seq_idx * num_heads * HEAD_SIZE; - } else { - OUTT* out_ptr = final_out + seq_idx * num_heads * HEAD_SIZE; - final_out_ptr_b16 = reinterpret_cast(out_ptr); - } - #pragma unroll - for (int qh = 0; qh < QHLOOP; qh++) { - #pragma unroll - for (int vh = 0; vh < VHELOOP; vh++) { - const int head_size_elem = vh * WARP_SIZE + laneid; - #pragma unroll - for (int i = 0; i < 4; i++) { - const int head_idx = 4 * qh + i; - if (head_idx < GQA_RATIO) { - if constexpr (std::is_same::value) { - const float tmpf = - out_scale * to_float_b16(vout[qh][vh][i]); - const OUTT tmp = hip_fp8(tmpf).data; - final_out_ptr_b8[(wg_start_head_idx + head_idx) * HEAD_SIZE + - head_size_elem] = tmp; - } else { - final_out_ptr_b16[(wg_start_head_idx + head_idx) * HEAD_SIZE + - head_size_elem] = vout[qh][vh][i]; - } - } + scalar_t* out_ptr = out + + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + partition_idx * HEAD_SIZE; + const int out_num_partitions = max_num_partitions; + bit16_t* out_ptr_b16 = reinterpret_cast(out_ptr); + for (int qh = 0; qh < QHLOOP; qh++) { + for (int vh = 0; vh < VHELOOP; vh++) { + const int head_size_elem = vh * WARP_SIZE + laneid; + for (int i = 0; i < 4; i++) { + const int head_idx = 4 * qh + i; + if (head_idx < GQA_RATIO) { + out_ptr_b16[(wg_start_head_idx + head_idx) * out_num_partitions * + HEAD_SIZE + + head_size_elem] = vout[qh][vh][i]; } } } @@ -787,12 +1229,6 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( const int seq_idx = blockIdx.y; const int context_len = context_lens[seq_idx]; const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); - if (num_partitions == 1) { - // if num_partitions==1, main kernel will write to out directly, no work in - // reduction kernel - return; - } - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; const int warpid = threadIdx.x / WARP_SIZE; const int laneid = threadIdx.x % WARP_SIZE; @@ -957,8 +1393,6 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( const float inv_global_exp_sum = __fdividef(1.0f, shared_global_exp_sum + 1e-6f); - // const float out_scale = (fp8_out_scale_ptr != nullptr) ? - // __fdividef(1.0f,(*fp8_out_scale_ptr)) : 1.0f; const float out_scale = (fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f; acc *= inv_global_exp_sum; @@ -975,9 +1409,36 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( template -__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( +__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, const float* k_scale_ptr, const float* v_scale_ptr, + const float* __restrict__ fp8_out_scale_ptr) { + UNREACHABLE_CODE +} + +template +__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, // head_size/x, block_size, x] @@ -1018,9 +1479,19 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support -#define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \ - paged_attention_ll4mi_QKV_kernel \ +#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \ + paged_attention_ll4mi_QKV_mfma16_kernel \ + <<>>( \ + query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ + block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ + exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \ + k_scale_ptr, v_scale_ptr, fp8_out_scale_ptr); + +#define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \ + paged_attention_ll4mi_QKV_mfma4_kernel \ <<>>( \ query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ @@ -1036,7 +1507,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( context_lens_ptr, max_num_partitions, fp8_out_scale_ptr); template + int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD, bool ALIBI_ENABLED> void paged_attention_custom_launcher( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, @@ -1070,7 +1541,6 @@ void paged_attention_custom_launcher( const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); - // NOTE: fp8_out_scale is optional. const float* fp8_out_scale_ptr = fp8_out_scale @@ -1079,81 +1549,81 @@ void paged_attention_custom_launcher( OUTT* out_ptr = reinterpret_cast(out.data_ptr()); const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); + + //partition size is fixed at 256 since both mfma4 and mfma16 kernels support it + //mfma4 kernel also supports partition size 512 + constexpr int PARTITION_SIZE = 256; const int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); const int gqa_ratio = num_heads / num_kv_heads; assert(num_heads % num_kv_heads == 0); assert(head_size == HEAD_SIZE); - constexpr int NTHR = PARTITION_SIZE; + constexpr int NTHR = 256; dim3 grid(num_seqs, max_num_partitions, num_kv_heads); dim3 block(NTHR); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + //mfma4 kernel is faster than mfma16 for gqa_ratio <= 4 switch (gqa_ratio) { case 1: - LAUNCH_CUSTOM_ATTENTION(1); + LAUNCH_CUSTOM_ATTENTION_MFMA4(1); break; case 2: - LAUNCH_CUSTOM_ATTENTION(2); + LAUNCH_CUSTOM_ATTENTION_MFMA4(2); break; case 3: - LAUNCH_CUSTOM_ATTENTION(3); + LAUNCH_CUSTOM_ATTENTION_MFMA4(3); break; case 4: - LAUNCH_CUSTOM_ATTENTION(4); + LAUNCH_CUSTOM_ATTENTION_MFMA4(4); break; case 5: - LAUNCH_CUSTOM_ATTENTION(5); + LAUNCH_CUSTOM_ATTENTION_MFMA16(5); break; case 6: - LAUNCH_CUSTOM_ATTENTION(6); + LAUNCH_CUSTOM_ATTENTION_MFMA16(6); break; case 7: - LAUNCH_CUSTOM_ATTENTION(7); + LAUNCH_CUSTOM_ATTENTION_MFMA16(7); break; case 8: - LAUNCH_CUSTOM_ATTENTION(8); + LAUNCH_CUSTOM_ATTENTION_MFMA16(8); break; case 9: - LAUNCH_CUSTOM_ATTENTION(9); + LAUNCH_CUSTOM_ATTENTION_MFMA16(9); break; case 10: - LAUNCH_CUSTOM_ATTENTION(10); + LAUNCH_CUSTOM_ATTENTION_MFMA16(10); break; case 11: - LAUNCH_CUSTOM_ATTENTION(11); + LAUNCH_CUSTOM_ATTENTION_MFMA16(11); break; case 12: - LAUNCH_CUSTOM_ATTENTION(12); + LAUNCH_CUSTOM_ATTENTION_MFMA16(12); break; case 13: - LAUNCH_CUSTOM_ATTENTION(13); + LAUNCH_CUSTOM_ATTENTION_MFMA16(13); break; case 14: - LAUNCH_CUSTOM_ATTENTION(14); + LAUNCH_CUSTOM_ATTENTION_MFMA16(14); break; case 15: - LAUNCH_CUSTOM_ATTENTION(15); + LAUNCH_CUSTOM_ATTENTION_MFMA16(15); break; case 16: - LAUNCH_CUSTOM_ATTENTION(16); + LAUNCH_CUSTOM_ATTENTION_MFMA16(16); break; default: TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio); break; } - // reduction kernel is only required if max_context_len > partition size, - // otherwise main kernel writes directly to final output - // note there are cases with graphing where max_context_len is the max - // supported by graphing, not the actual max among all the sequences: in that - // case reduction kernel will still run but return immediately - if (max_context_len > PARTITION_SIZE) { dim3 reduce_grid(num_heads, num_seqs); dim3 reduce_block(head_size); const int npar_loops = DIVIDE_ROUND_UP(max_num_partitions, WARP_SIZE); - // support upto 8*64*256=128K context length + //reduction kernel supports upto 8 NPAR_loops * 64 (warp_size) * 256 (partition size) = 128K context length switch (npar_loops) { case 1: LAUNCH_CUSTOM_REDUCTION(1); @@ -1183,25 +1653,29 @@ void paged_attention_custom_launcher( TORCH_CHECK(false, "Unsupported npar_loops: ", npar_loops); break; } - } } #define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \ - PSIZE) \ + PSIZE, ALIBI_ENABLED) \ paged_attention_custom_launcher( \ + PSIZE, ALIBI_ENABLED>( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ num_kv_heads, scale, block_tables, context_lens, max_context_len, \ alibi_slopes, k_scale, v_scale, fp8_out_scale); +#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ + OUTT, PSIZE) \ + if (alibi_slopes) { \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, true); \ + } else { \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, false); \ + } + #define CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ OUTT) \ switch (partition_size) { \ case 256: \ - CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, 256); \ - break; \ - case 512: \ - CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, 512); \ + CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, 256); \ break; \ default: \ TORCH_CHECK(false, "Unsupported partition size: ", partition_size); \ @@ -1249,7 +1723,6 @@ void paged_attention_custom_launcher( TORCH_CHECK(false, "Unsupported head size: ", head_size); \ break; \ } - void paged_attention( torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 13c92dadcd4e..e53243b9387c 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -117,7 +117,7 @@ def ref_single_query_cached_kv_attention( @pytest.mark.parametrize( "version", - ["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"]) + ["v1", "v2"] if not current_platform.is_rocm() else ["rocm"]) @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -182,7 +182,7 @@ def test_paged_attention( key_cache, value_cache = key_caches[0], value_caches[0] # Using default kv_scale - k_scale = v_scale = 1.0 + k_scale = v_scale = torch.tensor(0.3, dtype=torch.float) # Call the paged attention kernel. output = torch.empty_like(query) @@ -213,7 +213,7 @@ def test_paged_attention( elif version in ("v2", "rocm"): if current_platform.is_rocm(): - PARTITION_SIZE = 1024 if version == "v2" else 512 + PARTITION_SIZE = 1024 if version == "v2" else 256 num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) assert PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape @@ -275,13 +275,15 @@ def test_paged_attention( kv_cache_dtype, k_scale, v_scale, + None, + PARTITION_SIZE, ) opcheck(torch.ops._rocm_C.paged_attention, (output, exp_sums, max_logits, tmp_output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, k_scale, v_scale), + kv_cache_dtype, k_scale, v_scale, None, PARTITION_SIZE), cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0])) @@ -298,14 +300,14 @@ def test_paged_attention( dtype=dtype, device=device) ops.convert_fp8(dequantized_key_cache, key_cache) - key_cache = dequantized_key_cache + key_cache = k_scale * dequantized_key_cache value_cache_shape = value_cache.shape dequantized_value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device) ops.convert_fp8(dequantized_value_cache, value_cache) - value_cache = dequantized_value_cache + value_cache = v_scale * dequantized_value_cache ref_output = torch.empty_like(query) ref_single_query_cached_kv_attention( From ba5df8bd0d25843a1a3d921e9bf8d4a27bce8f23 Mon Sep 17 00:00:00 2001 From: sanyalington Date: Mon, 20 Jan 2025 09:46:09 +0000 Subject: [PATCH 2/8] added comments to mfma4 kernel --- csrc/rocm/attention.cu | 54 ++++++++++++++++++++++++++++++++---------- 1 file changed, 41 insertions(+), 13 deletions(-) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 4e4f5d7fb41e..4d83d9c67ba4 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -762,6 +762,7 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 ///////////////////////////////////////////////////////////// // grid (num_seqs, num_partitions, num_kv_heads) // block (256 : partition size) +//each WG handles 1 partition per sequence template = context_len) { return; } + // every 4 lanes fetch 4 different qheads + //qhloop = num loops over qhead dimension constexpr int QHLOOP = - DIVIDE_ROUND_UP(GQA_RATIO, 4); // each 4 lanes fetch 4 different qheads + DIVIDE_ROUND_UP(GQA_RATIO, 4); constexpr int GQA_RATIO4 = 4 * QHLOOP; __shared__ float shared_qk_max[NWARPS][GQA_RATIO4 + 1]; __shared__ float shared_exp_sum[NWARPS][GQA_RATIO4 + 1]; _B16x8 Qlocal[QHLOOP]; constexpr int x = 16 / sizeof(scalar_t); + //kheloop = num loops over head_size for 16Bytes of Q/dequantized K elements constexpr int KHELOOP = HEAD_SIZE / x; _B16x8 Klocal[KHELOOP]; _B8x8 Klocalb8[KHELOOP]; - // v head_size dimension is distributed across warp + // for SoftMax-V Gemm, V head_size dimension is distributed across warp + //vheloop = num loops to cover v head size dimension constexpr int VHELOOP = HEAD_SIZE / WARP_SIZE; - constexpr int VTLOOP = 8; // 16 separate 4xtokens across warp -> 16/2 - // 8xtokens + //softmax out has warp_size tokens across warp + //vtloop = num loops to cover warp_size(64) tokens with 16Bytes of dequantized V elements + constexpr int VTLOOP = WARP_SIZE/8; + //num vblocks to cover warp_size(64) v elements constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE; int vphysical_blocks[VBLOCKS]; _B16x8 Vlocal[VHELOOP][VTLOOP]; @@ -838,32 +845,37 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_k const int warp_start_token_idx = partition_start_token_idx + warpid * WARP_SIZE; - if (warp_start_token_idx >= context_len) { // warp out of context + // entire warp out of context + if (warp_start_token_idx >= context_len) { #pragma unroll for (int h = 0; h < GQA_RATIO4; h++) { shared_qk_max[warpid][h] = -FLT_MAX; shared_exp_sum[warpid][h] = 0.0f; } - } else { // warp within context + // warp within context + } else { const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); const int last_ctx_block = num_context_blocks - 1; const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; - + //token id within partition const int local_token_idx = threadIdx.x; + //token id within sequence const int global_token_idx = partition_start_token_idx + local_token_idx; // fetch block number for k const int block_idx = (global_token_idx < context_len) ? global_token_idx / BLOCK_SIZE : last_ctx_block; + + //fetch k physical block number // int32 physical_block_number leads to overflow when multiplied with // kv_block_stride const int64_t physical_block_number = static_cast(block_table[block_idx]); - // fetch vphysical block numbers up front + //fetch vphysical block numbers up front const int warp_start_block_idx = warp_start_token_idx / BLOCK_SIZE; for (int b = 0; b < VBLOCKS; b++) { const int vblock_idx = warp_start_block_idx + b; @@ -872,7 +884,8 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_k vphysical_blocks[b] = block_table[vblock_idx_ctx]; } - // each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems + //fetch q elements + //every 4 lanes fetch 8 elems, so warp fetches 8*16 = 128 elems const scalar_t* q_ptr = q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE; const _B16x8* q_ptrh8 = reinterpret_cast(q_ptr); @@ -891,12 +904,15 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_k Qlocal[QHLOOP - 1].xy[1] = {0}; } + //fetch k elements const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride + wg_start_kv_head_idx * kv_head_stride; + // physical_block_offset is already cast in terms of _B16x8 const int physical_block_offset = - local_token_idx % BLOCK_SIZE; // since x=half8, physical_block_offset - // is already cast as _H8 + local_token_idx % BLOCK_SIZE; + + //each K fetch is for 8 elements of cache_t which are later dequantized to scalar_t for fp8 if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { const _B16x8* k_ptrh8 = reinterpret_cast(k_ptr); for (int d = 0; d < KHELOOP; d++) { @@ -915,6 +931,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_k } } + //optional alibi fetch float alibi_slope[QHLOOP]; if constexpr(ALIBI_ENABLED) { for (int h = 0; h < QHLOOP; h++) { @@ -981,7 +998,10 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_k dout[h] = gcn_mfma4x4x4_instr(Qlocal[h].xy[1], \ Klocal[x].xy[1], dout[h]);\ } - //QK mfma + //QK mfma with Q mfma block broadcast + //Q values across head_size dimension stored across lanes + //K values across head_size dimension are stored depthwise within lane + //Q broadcast with absz, cbid of mfma instruction QK_mfma(0); QK_mfma(1); QK_mfma(2); @@ -1005,6 +1025,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_k float scale2 = scale; if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + //post mfma scaling for fp8 scale2 *= *k_scale_ptr; } @@ -1096,6 +1117,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_k max_logits + seq_idx * num_heads * max_num_partitions + partition_idx; float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + partition_idx; + //calculate qk_max and exp_sums for partition for (int h = 0; h < QHLOOP; h++) { float global_qk_max = -FLT_MAX; float warp_qk_max[NWARPS]; @@ -1123,6 +1145,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_k // are 4x16 tokens across warp _B16x4 logits[QHLOOP]; for (int h = 0; h < QHLOOP; h++) { + //use rtz for faster performance with no perceivable accuracy loss logits[h] = from_floatx4_rtz(dout[h]); } @@ -1150,7 +1173,10 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_k for (int qh = 0; qh < QHLOOP; qh++) { acc[qh] = {0}; } - // iterate over tokens + //SoftMax-V calculation + //logits -> token dimension is distributed across lanes + //Vlocal -> token dimension is depthwise within lane + //uses mfma instruction block broadcast for logits SV_mfma(0); SV_mfma(1); SV_mfma(2); @@ -1162,6 +1188,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_k for (int qh = 0; qh < QHLOOP; qh++) { if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + //post mfma v scale for fp8 acc[qh] *= *v_scale_ptr; } vout_shared[qh][vh][laneid][warpid] = from_floatx4(acc[qh]); @@ -1173,6 +1200,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_k __syncthreads(); + //final write to tmp_out after vout accumulation if (warpid == 0) { const float out_scale = (fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f; From 4fdcd75d258eeb6526aad5f50f3d0de462f22dda Mon Sep 17 00:00:00 2001 From: sanyalington Date: Mon, 20 Jan 2025 17:48:43 +0000 Subject: [PATCH 3/8] further comments for mfma16 kernel --- csrc/rocm/attention.cu | 42 ++++++++++++++++++++---------------------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 4d83d9c67ba4..59c73737399e 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -392,7 +392,7 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 constexpr int TOKENS_PER_WARP = T_PAR_SIZE / NWARPS; //sub partition of tokens per warp for qk calculation constexpr int TLOOP = TOKENS_PER_WARP / 16; //each mfma16x16x16 instruction processes 16 tokens - _B16x8 Klocal[TLOOP][QKHELOOP]; //this could be B8x16 too + _B16x8 Klocal[TLOOP][QKHELOOP]; //can be interpreted as B8x16 for 8 bit types const int wg_start_head_idx = blockIdx.z * GQA_RATIO; const int wg_start_kv_head_idx = blockIdx.z; @@ -401,7 +401,7 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 //for QK mfma, tokens in multiples of TOKENS_PER_WARP are spread across warps //each mfma takes QH16xT16x16HE across warp //repeat mfmas across QKHELOOP dimension - //output layout from QKmfma : QH16xT4x4 16 qheads across 16 lanes, 16 tokens across 4 rowsx4 tokens per lane + //output layout from QKmfma : QH16xT4x4 16 qheads across 16 lanes, 16 tokens across 4 rows x 4 tokens per lane const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); const int last_ctx_block = num_context_blocks - 1; @@ -490,11 +490,11 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 alibi_slope = (lane16id < GQA_RATIO) ? alibi_slopes[alibi_head_idx] : 0.f; } - constexpr int VTOKENS_PER_LANE = TOKENS_PER_WARP / ROWS_PER_WARP; //16 tokens per lane + constexpr int VTOKENS_PER_LANE = TOKENS_PER_WARP / ROWS_PER_WARP; //64/4 = 16 contiguous vtokens per lane constexpr int VBLOCKS_PER_LANE = 1; //assumes block size >=16, each lane can correspond to 1 block only constexpr int VTLOOP = NWARPS; //corresponds to tokens across warps constexpr int VTLANELOOP = DIVIDE_ROUND_UP(VTOKENS_PER_LANE , CONTIGUOUS_KV_ELEMS_16B_LOAD); //optimized for 16B fetches; assumes minimum block size is 16 - constexpr int VHELOOP = HEAD_SIZE / 16 / NWARPS; + constexpr int VHELOOP = HEAD_SIZE / 16 / NWARPS; //head_size distributed across warps; each mfma instr works on 16 head elements int vphysical_block_number[VTLOOP][VBLOCKS_PER_LANE]; @@ -511,7 +511,7 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 } } - _B16x8 Vlocal[VTLOOP][VHELOOP][VTLANELOOP]; //this could be B8x16 too + _B16x8 Vlocal[VTLOOP][VHELOOP][VTLANELOOP]; //this can be interpreted as B8x16 too const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride + ((rowid * VTOKENS_PER_LANE)%BLOCK_SIZE); @@ -662,6 +662,9 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 __syncthreads(); + constexpr int ELEMS8_ELEMS4_RATIO = 8 / 4; + constexpr int ELEMS16_ELEMS8_RATIO = 16 / 8; + _B16x4 outelems[VHELOOP]; //Softmax V mfma //v layout: 16he across lanes x 16 tokens per lane @@ -672,33 +675,29 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { - for (int i=0; i<2; i++) { - //generalize this for 8 bit dtypes: each lane needs 2*vfetch_depth + 2 _B16x4 K/token dimension elems; each row is multiplied by a factor of 4 - //layout: lane in depth dimension | row across -> - //0 4 8 12 - //1 5 9 13 - //2 6 10 14 - //3 7 11 15 - const int offset = rowid * VTLANELOOP * 2 + 2*vfetch_depth + i; - const int offset1 = offset % 4; //4 corresponds to ROWS_PER_WARP - const int offset2 = offset / 4; + for (int i=0; i(Vlocal[vtoken_depth][vhe_depth][vfetch_depth].xy[i], shared_logits[vtoken_depth][offset2][lane16id][offset1], tmp_out); } } + //KV cache fp8 } else { for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { _B16x8 Vtmp = Vlocal[vtoken_depth][vhe_depth][vfetch_depth]; + //reinterpret V format as 16 elements of 8bits _B8x16 Vtmp8x16 = *reinterpret_cast<_B8x16*>(&Vtmp); - for (int j=0; j<2; j++) { + for (int j=0; j(Vtmp8x8); - for (int i=0; i<2; i++) { - const int offset = 4*rowid + 2*j + i; - const int offset1 = offset % 4; - const int offset2 = offset / 4; + for (int i=0; i(Vlocaltmp.xy[i], shared_logits[vtoken_depth][offset2][lane16id][offset1], @@ -706,7 +705,6 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 } } } - } } //apply post Softmax V mfma v_scale @@ -726,7 +724,7 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 __syncthreads(); - //write to tmp_out with coalesced writes + //write to tmp_out with coalesced writes after reading from shared mem if (warpid == 0) { _B16x8 vout[GQA_RATIO4]; //each lane writes out 16Bytes of tmp_out along head elem dimension From 6f8e708d36009a9e7e461c443149688edd055af9 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Wed, 22 Jan 2025 20:59:54 +0000 Subject: [PATCH 4/8] clang-format --- csrc/rocm/attention.cu | 1274 +++++++++++++++++++++------------------- 1 file changed, 681 insertions(+), 593 deletions(-) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 59c73737399e..f4d7e5c6d3a4 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -66,7 +66,7 @@ typedef struct _B16x8 { } _B16x8; using _B8x8 = uint2; -using _B8x4 = int32_t; //used in builtins +using _B8x4 = int32_t; // used in builtins using bit8_t = uint8_t; typedef struct _B8x16 { @@ -76,25 +76,24 @@ typedef struct _B8x16 { ////// Non temporal loads /////// template __device__ __forceinline__ T loadnt(T* addr) { - return __builtin_nontemporal_load(addr); + return __builtin_nontemporal_load(addr); } __device__ __forceinline__ _B16x8 load_ntmprl_16Byte(const _B16x8* addr) { - auto addr_alias = reinterpret_cast(addr); - auto dat0 = loadnt(addr_alias); - auto dat1 = loadnt(addr_alias + 1); - auto dat2 = loadnt(addr_alias + 2); - auto dat3 = loadnt(addr_alias + 3); - auto res = make_float4(dat0,dat1,dat2,dat3); - return *reinterpret_cast<_B16x8*>(&res); + auto addr_alias = reinterpret_cast(addr); + auto dat0 = loadnt(addr_alias); + auto dat1 = loadnt(addr_alias + 1); + auto dat2 = loadnt(addr_alias + 2); + auto dat3 = loadnt(addr_alias + 3); + auto res = make_float4(dat0, dat1, dat2, dat3); + return *reinterpret_cast<_B16x8*>(&res); } /////////////////////////////////// - template __device__ __forceinline__ floatx4 gcn_mfma4x4x4_instr(const _B16x4& inpA, - const _B16x4& inpB, - const floatx4& inpC) { + const _B16x4& inpB, + const floatx4& inpC) { if constexpr (std::is_same::value) { return __builtin_amdgcn_mfma_f32_4x4x4f16(inpA, inpB, inpC, absz, cbid, blgp); @@ -108,14 +107,14 @@ __device__ __forceinline__ floatx4 gcn_mfma4x4x4_instr(const _B16x4& inpA, template __device__ __forceinline__ floatx4 gcn_mfma16x16x16_instr(const _B16x4& inpA, - const _B16x4& inpB, - const floatx4& inpC) { + const _B16x4& inpB, + const floatx4& inpC) { if constexpr (std::is_same::value) { return __builtin_amdgcn_mfma_f32_16x16x16f16(inpA, inpB, inpC, absz, cbid, - blgp); + blgp); } else if constexpr (std::is_same::value) { - return __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(inpA, inpB, inpC, absz, cbid, - blgp); + return __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(inpA, inpB, inpC, absz, + cbid, blgp); } else { static_assert(false, "unsupported 16b dtype"); } @@ -170,20 +169,20 @@ __device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) { _B16x4 ret; if constexpr (std::is_same::value) { union h2cvt { - __half2 h2[2]; - _B16x4 b16x4; + __half2 h2[2]; + _B16x4 b16x4; } u; - u.h2[0] = __float22half2_rn(make_float2(inp[0],inp[1])); - u.h2[1] = __float22half2_rn(make_float2(inp[2],inp[3])); + u.h2[0] = __float22half2_rn(make_float2(inp[0], inp[1])); + u.h2[1] = __float22half2_rn(make_float2(inp[2], inp[3])); return u.b16x4; } else if constexpr (std::is_same::value) { for (int i = 0; i < 4; i++) { union fcvt { - uint32_t u32; - float f32; + uint32_t u32; + float f32; } u; u.f32 = inp[i]; - u.u32 += 0x7fff + ((u.u32 >> 16) & 1); //BF16 RNE with no nan/inf check + u.u32 += 0x7fff + ((u.u32 >> 16) & 1); // BF16 RNE with no nan/inf check ret[i] = uint16_t(u.u32 >> 16); } return ret; @@ -203,24 +202,24 @@ __device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, _B16x4 ret; if constexpr (std::is_same::value) { union h2cvt { - _B16x4 b16x4; - __half2 h2[2]; - } u1,u2,s; - u1.b16x4 = inp1; - u2.b16x4 = inp2; + _B16x4 b16x4; + __half2 h2[2]; + } u1, u2, s; + u1.b16x4 = inp1; + u2.b16x4 = inp2; s.h2[0] = u1.h2[0] + u2.h2[0]; s.h2[1] = u1.h2[1] + u2.h2[1]; return s.b16x4; } else if constexpr (std::is_same::value) { for (int i = 0; i < 4; i++) { union fcvt { - float f32; - uint32_t i32; - } u1,u2,s; - u1.i32 = uint32_t(inp1[i])<<16; - u2.i32 = uint32_t(inp2[i])<<16; + float f32; + uint32_t i32; + } u1, u2, s; + u1.i32 = uint32_t(inp1[i]) << 16; + u2.i32 = uint32_t(inp2[i]) << 16; s.f32 = u1.f32 + u2.f32; - ret[i] = uint16_t(s.i32>>16); + ret[i] = uint16_t(s.i32 >> 16); } return ret; } else { @@ -249,17 +248,18 @@ __device__ __forceinline__ _B16x8 scaled_convert_b8x8(const _B8x8 input, } template -__device__ __forceinline__ _B16x8 scaled_convert_b8x8_custom(const _B8x8 input, - const float scale) { +__device__ __forceinline__ _B16x8 +scaled_convert_b8x8_custom(const _B8x8 input, const float scale) { union { floatx4 f32x4[2]; vllm::Float8_ f32x8; } tmpf8; - tmpf8.f32x8 = vllm::fp8::vec_conversion(*reinterpret_cast(&input)); - + tmpf8.f32x8 = vllm::fp8::vec_conversion( + *reinterpret_cast(&input)); + tmpf8.f32x4[0] *= scale; tmpf8.f32x4[1] *= scale; - + _B16x8 ret; ret.xy[0] = from_floatx4(tmpf8.f32x4[0]); ret.xy[1] = from_floatx4(tmpf8.f32x4[1]); @@ -267,19 +267,20 @@ __device__ __forceinline__ _B16x8 scaled_convert_b8x8_custom(const _B8x8 input, } __device__ __forceinline__ floatx4 to_float_fp8x4(const _B8x4& inp) { -#if defined(__gfx90a__) - float4 f32x4 = vllm::fp8::vec_conversion(*reinterpret_cast(&inp)); - return *reinterpret_cast(&f32x4); -#else //MI3xx+ optimized builtins - const auto f0 = __builtin_amdgcn_cvt_pk_f32_fp8(inp, false); - const auto f1 = __builtin_amdgcn_cvt_pk_f32_fp8(inp, true); - floatx4 ret; - ret[0] = f0[0]; - ret[1] = f0[1]; - ret[2] = f1[0]; - ret[3] = f1[1]; - return ret; -#endif + #if defined(__gfx90a__) + float4 f32x4 = vllm::fp8::vec_conversion( + *reinterpret_cast(&inp)); + return *reinterpret_cast(&f32x4); + #else // MI3xx+ optimized builtins + const auto f0 = __builtin_amdgcn_cvt_pk_f32_fp8(inp, false); + const auto f1 = __builtin_amdgcn_cvt_pk_f32_fp8(inp, true); + floatx4 ret; + ret[0] = f0[0]; + ret[1] = f0[1]; + ret[2] = f1[0]; + ret[3] = f1[1]; + return ret; + #endif } template @@ -287,17 +288,17 @@ __device__ __forceinline__ _B16x4 from_floatx4_rtz(const floatx4& inp) { _B16x4 ret; if constexpr (std::is_same::value) { union h2cvt { - _Half2 h2[2]; - _B16x4 b16x4; + _Half2 h2[2]; + _B16x4 b16x4; } u; - u.h2[0] = __builtin_amdgcn_cvt_pkrtz(inp[0],inp[1]); - u.h2[1] = __builtin_amdgcn_cvt_pkrtz(inp[2],inp[3]); + u.h2[0] = __builtin_amdgcn_cvt_pkrtz(inp[0], inp[1]); + u.h2[1] = __builtin_amdgcn_cvt_pkrtz(inp[2], inp[3]); return u.b16x4; } else if constexpr (std::is_same::value) { for (int i = 0; i < 4; i++) { union fcvt { - uint32_t i32; - float f32; + uint32_t i32; + float f32; } u; u.f32 = inp[i]; ret[i] = uint16_t(u.i32 >> 16); @@ -311,13 +312,13 @@ __device__ __forceinline__ _B16x4 from_floatx4_rtz(const floatx4& inp) { template __device__ __forceinline__ _B16x8 convert_b8x8_custom(const _B8x8 input) { union { - _B8x8 b8x8; - _B8x4 b8x4[2]; + _B8x8 b8x8; + _B8x4 b8x4[2]; } tmp; tmp.b8x8 = input; _B16x8 ret; - for (int i=0; i<2; i++) { - ret.xy[i] = from_floatx4_rtz( to_float_fp8x4(tmp.b8x4[i]) ); + for (int i = 0; i < 2; i++) { + ret.xy[i] = from_floatx4_rtz(to_float_fp8x4(tmp.b8x4[i])); } return ret; } @@ -327,9 +328,9 @@ __device__ __forceinline__ _B16x8 convert_b8x8_custom(const _B8x8 input) { // block (256) template -__global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma16_kernel( + int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO> +__global__ +__launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, // head_size/x, block_size, x] @@ -358,414 +359,479 @@ __global__ __launch_bounds__(NUM_THREADS,5) void paged_attention_ll4mi_QKV_mfma1 const int seq_idx = blockIdx.x; const int partition_idx = blockIdx.y; - - constexpr int T_PAR_SIZE = 256; //token partition size set to 256 + + constexpr int T_PAR_SIZE = 256; // token partition size set to 256 const int max_num_partitions = gridDim.y; const int context_len = context_lens[seq_idx]; - - const int partition_start_token_idx = partition_idx * T_PAR_SIZE; //partition_size; + + const int partition_start_token_idx = + partition_idx * T_PAR_SIZE; // partition_size; // exit if partition is out of context for seq if (partition_start_token_idx >= context_len) { return; } - constexpr int GQA_RATIO4 = DIVIDE_ROUND_UP(GQA_RATIO,4); + constexpr int GQA_RATIO4 = DIVIDE_ROUND_UP(GQA_RATIO, 4); __shared__ float shared_qk_max[NWARPS][16 + 1]; __shared__ float shared_exp_sum[NWARPS][16 + 1]; - //shared_logits is used for multiple purposes + // shared_logits is used for multiple purposes __shared__ _B16x4 shared_logits[NWARPS][4][16][4]; - - //for QK mfma16x16, layout is QHead/Tokenx16 across every 16 lanes, 16 Bytes HeadElements in each lane, 4x16B HeadElements across 4 rows of warp - constexpr int ROWS_PER_WARP = WARP_SIZE / 16; //rows refers to 16 lanes; refer dpp terminology - constexpr int CONTIGUOUS_KV_ELEMS_16B_LOAD = 16 / sizeof(cache_t); //8 for 16 bit cache type, 16 for 8 bit types - constexpr int QKHE_PER_FETCH = CONTIGUOUS_KV_ELEMS_16B_LOAD * ROWS_PER_WARP; //each fetch across a warp fetches these many elements - constexpr int QK_SIZE_RATIO = sizeof(scalar_t) / sizeof(cache_t); //1 for 16bit types, 2 for 8bit types - constexpr int QKHELOOP = HEAD_SIZE / QKHE_PER_FETCH; //4xQKHE_16B across warp - _B16x8 Qlocal[QKHELOOP][QK_SIZE_RATIO]; //note that 16 contiguous elements of Q should be fetched per lane for 8 bit cache types : QK_SIZE_RATIO changes for this + // for QK mfma16x16, layout is QHead/Tokenx16 across every 16 lanes, 16 Bytes + // HeadElements in each lane, 4x16B HeadElements across 4 rows of warp + constexpr int ROWS_PER_WARP = + WARP_SIZE / 16; // rows refers to 16 lanes; refer dpp terminology + constexpr int CONTIGUOUS_KV_ELEMS_16B_LOAD = + 16 / sizeof(cache_t); // 8 for 16 bit cache type, 16 for 8 bit types + constexpr int QKHE_PER_FETCH = + CONTIGUOUS_KV_ELEMS_16B_LOAD * + ROWS_PER_WARP; // each fetch across a warp fetches these many elements + constexpr int QK_SIZE_RATIO = + sizeof(scalar_t) / + sizeof(cache_t); // 1 for 16bit types, 2 for 8bit types + constexpr int QKHELOOP = HEAD_SIZE / QKHE_PER_FETCH; // 4xQKHE_16B across + // warp + + _B16x8 Qlocal[QKHELOOP] + [QK_SIZE_RATIO]; // note that 16 contiguous elements of Q should + // be fetched per lane for 8 bit cache types : + // QK_SIZE_RATIO changes for this constexpr int CONTIGUOUS_SCALAR_ELEMS_16B = 16 / sizeof(scalar_t); - constexpr int TOKENS_PER_WARP = T_PAR_SIZE / NWARPS; //sub partition of tokens per warp for qk calculation - constexpr int TLOOP = TOKENS_PER_WARP / 16; //each mfma16x16x16 instruction processes 16 tokens + constexpr int TOKENS_PER_WARP = + T_PAR_SIZE / + NWARPS; // sub partition of tokens per warp for qk calculation + constexpr int TLOOP = + TOKENS_PER_WARP / + 16; // each mfma16x16x16 instruction processes 16 tokens - _B16x8 Klocal[TLOOP][QKHELOOP]; //can be interpreted as B8x16 for 8 bit types + _B16x8 Klocal[TLOOP][QKHELOOP]; // can be interpreted as B8x16 for 8 bit + // types const int wg_start_head_idx = blockIdx.z * GQA_RATIO; const int wg_start_kv_head_idx = blockIdx.z; const int total_num_heads = gridDim.z * GQA_RATIO; - //for QK mfma, tokens in multiples of TOKENS_PER_WARP are spread across warps - //each mfma takes QH16xT16x16HE across warp - //repeat mfmas across QKHELOOP dimension - //output layout from QKmfma : QH16xT4x4 16 qheads across 16 lanes, 16 tokens across 4 rows x 4 tokens per lane + // for QK mfma, tokens in multiples of TOKENS_PER_WARP are spread across warps + // each mfma takes QH16xT16x16HE across warp + // repeat mfmas across QKHELOOP dimension + // output layout from QKmfma : QH16xT4x4 16 qheads across 16 lanes, 16 tokens + // across 4 rows x 4 tokens per lane - const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); - const int last_ctx_block = num_context_blocks - 1; + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int last_ctx_block = num_context_blocks - 1; - const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq; - - int kphysical_block_number[TLOOP]; + const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq; - //fetch k physical block numbers - for (int token_depth = 0; token_depth < TLOOP; token_depth++) { - const int klocal_token_idx = TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; - const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; - const int kblock_idx = (kglobal_token_idx < context_len) - ? kglobal_token_idx / BLOCK_SIZE - : last_ctx_block; - kphysical_block_number[token_depth] = block_table_seq[kblock_idx]; - } - - //fetch Q in shared across warps and then write to registers - const int local_qhead_idx = 4 * warpid + rowid; - const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; - const int64_t seq_idx64 = static_cast(seq_idx); - const scalar_t* q_ptr = q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE; - - const int qhead_element = lane16id * CONTIGUOUS_SCALAR_ELEMS_16B; - if ( (local_qhead_idx < GQA_RATIO) && (qhead_element(q_fetch_ptr); - _B16x8 tmp = *q_fetch_ptr_16B; - if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { - const int offset1 = lane16id/4; //16 contiguous chunks of head elems are spread across 4x4lanes - shared_logits[offset1][lane4id][local_qhead_idx][0] = tmp.xy[0]; - shared_logits[offset1][lane4id][local_qhead_idx][1] = tmp.xy[1]; - } else { - for (int i=0; i<2; i++) { - const int head_elem = lane16id * 2 + i; //element id in _B16x4 terms - const int offset3 = head_elem % 4; - const int offset2 = (head_elem / 4) % 4; - const int offset1 = head_elem /4/4; - shared_logits[offset1][offset2][local_qhead_idx][offset3] = tmp.xy[i]; - } - } - } - __syncthreads(); - for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { - for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { - for (int i=0; i<2; i++) { - Qlocal[qkhe_depth][qkratio].xy[i] = shared_logits[qkhe_depth][rowid][lane16id % GQA_RATIO][2*qkratio + i]; - } - } - } - - //set to true to enable non temporal kv loads: has some benefit in very high batch size cases - constexpr bool NT_KV_LOAD = false; + int kphysical_block_number[TLOOP]; - constexpr int KX = 16 / sizeof(cache_t); //vLLM defines x as 16 Bytes of kv cache elements - const cache_t* k_ptr = k_cache + wg_start_kv_head_idx * kv_head_stride; + // fetch k physical block numbers + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int klocal_token_idx = + TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; + const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; + const int kblock_idx = (kglobal_token_idx < context_len) + ? kglobal_token_idx / BLOCK_SIZE + : last_ctx_block; + kphysical_block_number[token_depth] = block_table_seq[kblock_idx]; + } - const int row_head_elem = rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD; - //fetch K values - for (int token_depth = 0; token_depth < TLOOP; token_depth++) { - const int64_t kblock_number = static_cast(kphysical_block_number[token_depth]); - const cache_t* k_ptr2 = k_ptr + kblock_number * kv_block_stride; - const int klocal_token_idx = TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; - const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; - const int kphysical_block_offset = klocal_token_idx % BLOCK_SIZE; - const cache_t* k_ptr3 = k_ptr2 + kphysical_block_offset * KX; - - for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { - const int head_elem = row_head_elem + qkhe_depth * QKHE_PER_FETCH; - const int offset1 = head_elem / KX; - const int offset2 = head_elem % KX; - const cache_t* k_fetch_ptr = k_ptr3 + offset1 * BLOCK_SIZE * KX + offset2; - const _B16x8* k_fetch_ptr_16B = reinterpret_cast(k_fetch_ptr); - if constexpr(NT_KV_LOAD) { - Klocal[token_depth][qkhe_depth] = load_ntmprl_16Byte(k_fetch_ptr_16B); - } else { - Klocal[token_depth][qkhe_depth] = *k_fetch_ptr_16B; - } + // fetch Q in shared across warps and then write to registers + const int local_qhead_idx = 4 * warpid + rowid; + const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; + const int64_t seq_idx64 = static_cast(seq_idx); + const scalar_t* q_ptr = + q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE; + + const int qhead_element = lane16id * CONTIGUOUS_SCALAR_ELEMS_16B; + if ((local_qhead_idx < GQA_RATIO) && (qhead_element < HEAD_SIZE)) { + const scalar_t* q_fetch_ptr = q_ptr + qhead_element; + const _B16x8* q_fetch_ptr_16B = + reinterpret_cast(q_fetch_ptr); + _B16x8 tmp = *q_fetch_ptr_16B; + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + const int offset1 = + lane16id / + 4; // 16 contiguous chunks of head elems are spread across 4x4lanes + shared_logits[offset1][lane4id][local_qhead_idx][0] = tmp.xy[0]; + shared_logits[offset1][lane4id][local_qhead_idx][1] = tmp.xy[1]; + } else { + for (int i = 0; i < 2; i++) { + const int head_elem = lane16id * 2 + i; // element id in _B16x4 terms + const int offset3 = head_elem % 4; + const int offset2 = (head_elem / 4) % 4; + const int offset1 = head_elem / 4 / 4; + shared_logits[offset1][offset2][local_qhead_idx][offset3] = tmp.xy[i]; + } + } + } + __syncthreads(); + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { + for (int i = 0; i < 2; i++) { + Qlocal[qkhe_depth][qkratio].xy[i] = + shared_logits[qkhe_depth][rowid][lane16id % GQA_RATIO] + [2 * qkratio + i]; } } + } - float alibi_slope; - if constexpr(ALIBI_ENABLED) { - const int alibi_head_idx = wg_start_head_idx + lane16id; - alibi_slope = (lane16id < GQA_RATIO) ? alibi_slopes[alibi_head_idx] : 0.f; + // set to true to enable non temporal kv loads: has some benefit in very high + // batch size cases + constexpr bool NT_KV_LOAD = false; + + constexpr int KX = + 16 / sizeof(cache_t); // vLLM defines x as 16 Bytes of kv cache elements + const cache_t* k_ptr = k_cache + wg_start_kv_head_idx * kv_head_stride; + + const int row_head_elem = rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD; + // fetch K values + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int64_t kblock_number = + static_cast(kphysical_block_number[token_depth]); + const cache_t* k_ptr2 = k_ptr + kblock_number * kv_block_stride; + const int klocal_token_idx = + TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; + const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; + const int kphysical_block_offset = klocal_token_idx % BLOCK_SIZE; + const cache_t* k_ptr3 = k_ptr2 + kphysical_block_offset * KX; + + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + const int head_elem = row_head_elem + qkhe_depth * QKHE_PER_FETCH; + const int offset1 = head_elem / KX; + const int offset2 = head_elem % KX; + const cache_t* k_fetch_ptr = k_ptr3 + offset1 * BLOCK_SIZE * KX + offset2; + const _B16x8* k_fetch_ptr_16B = + reinterpret_cast(k_fetch_ptr); + if constexpr (NT_KV_LOAD) { + Klocal[token_depth][qkhe_depth] = load_ntmprl_16Byte(k_fetch_ptr_16B); + } else { + Klocal[token_depth][qkhe_depth] = *k_fetch_ptr_16B; + } } + } - constexpr int VTOKENS_PER_LANE = TOKENS_PER_WARP / ROWS_PER_WARP; //64/4 = 16 contiguous vtokens per lane - constexpr int VBLOCKS_PER_LANE = 1; //assumes block size >=16, each lane can correspond to 1 block only - constexpr int VTLOOP = NWARPS; //corresponds to tokens across warps - constexpr int VTLANELOOP = DIVIDE_ROUND_UP(VTOKENS_PER_LANE , CONTIGUOUS_KV_ELEMS_16B_LOAD); //optimized for 16B fetches; assumes minimum block size is 16 - constexpr int VHELOOP = HEAD_SIZE / 16 / NWARPS; //head_size distributed across warps; each mfma instr works on 16 head elements - - int vphysical_block_number[VTLOOP][VBLOCKS_PER_LANE]; + float alibi_slope; + if constexpr (ALIBI_ENABLED) { + const int alibi_head_idx = wg_start_head_idx + lane16id; + alibi_slope = (lane16id < GQA_RATIO) ? alibi_slopes[alibi_head_idx] : 0.f; + } - //fetch v physical block numbers - for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { - for (int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE; vblock_depth++) { - const int vlocal_token_idx = vtoken_depth * VTOKENS_PER_LANE * ROWS_PER_WARP + rowid * VTOKENS_PER_LANE + vblock_depth * BLOCK_SIZE; - const int vglobal_token_idx = partition_start_token_idx + vlocal_token_idx; + constexpr int VTOKENS_PER_LANE = + TOKENS_PER_WARP / ROWS_PER_WARP; // 64/4 = 16 contiguous vtokens per lane + constexpr int VBLOCKS_PER_LANE = + 1; // assumes block size >=16, each lane can correspond to 1 block only + constexpr int VTLOOP = NWARPS; // corresponds to tokens across warps + constexpr int VTLANELOOP = DIVIDE_ROUND_UP( + VTOKENS_PER_LANE, + CONTIGUOUS_KV_ELEMS_16B_LOAD); // optimized for 16B fetches; assumes + // minimum block size is 16 + constexpr int VHELOOP = + HEAD_SIZE / 16 / NWARPS; // head_size distributed across warps; each mfma + // instr works on 16 head elements + + int vphysical_block_number[VTLOOP][VBLOCKS_PER_LANE]; + + // fetch v physical block numbers + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE; + vblock_depth++) { + const int vlocal_token_idx = + vtoken_depth * VTOKENS_PER_LANE * ROWS_PER_WARP + + rowid * VTOKENS_PER_LANE + vblock_depth * BLOCK_SIZE; + const int vglobal_token_idx = + partition_start_token_idx + vlocal_token_idx; const int vblock_idx = (vglobal_token_idx < context_len) - ? vglobal_token_idx / BLOCK_SIZE - : last_ctx_block; + ? vglobal_token_idx / BLOCK_SIZE + : last_ctx_block; vphysical_block_number[vtoken_depth][vblock_depth] = - block_table_seq[vblock_idx]; - } + block_table_seq[vblock_idx]; } + } - _B16x8 Vlocal[VTLOOP][VHELOOP][VTLANELOOP]; //this can be interpreted as B8x16 too - - const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride + ((rowid * VTOKENS_PER_LANE)%BLOCK_SIZE); - - //v fetches are 16head elems across lanes x 16 tokens per lane - for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { - const int vhead_elem = vhe_depth * NWARPS * 16 + warpid * 16 + lane16id; - const cache_t* v_ptr2 = v_ptr + vhead_elem * BLOCK_SIZE; - - for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { - for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { - const int vblock_depth = 0; - const int64_t vblock_number = static_cast(vphysical_block_number[vtoken_depth][vblock_depth]); - const cache_t* v_ptr3 = v_ptr2 + (vblock_number * kv_block_stride); - - const cache_t* v_fetch_ptr = v_ptr3 + vfetch_depth * CONTIGUOUS_KV_ELEMS_16B_LOAD; - const _B16x8* v_fetch_ptr_16B = reinterpret_cast(v_fetch_ptr); - if constexpr(NT_KV_LOAD) { - Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = load_ntmprl_16Byte(v_fetch_ptr_16B); - } else { - Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = *v_fetch_ptr_16B; - } - } - } - } - - //calculate post qk mfma scale - float scale2 = scale; - if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { - //multiply by k_scale if fp8 kv cache - scale2 *= *k_scale_ptr; - } + _B16x8 Vlocal[VTLOOP][VHELOOP] + [VTLANELOOP]; // this can be interpreted as B8x16 too - floatx4 dout[TLOOP]; - //qk mfma - for (int token_depth = 0; token_depth < TLOOP; token_depth++) { - dout[token_depth] = {0}; - for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { - if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { - for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { - for (int i=0; i<2; i++) { - dout[token_depth] = gcn_mfma16x16x16_instr(Klocal[token_depth][qkhe_depth].xy[i], - Qlocal[qkhe_depth][qkratio].xy[i], dout[token_depth]); - } - } - } else { //kv cache dtype fp8 - auto Ktmp = Klocal[token_depth][qkhe_depth]; - _B8x16 Ktmp8x16 = *reinterpret_cast<_B8x16*>(&Ktmp); - for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { - _B8x8 Ktmp8x8 = Ktmp8x16.xy[qkratio]; - _B16x8 Klocaltmp = convert_b8x8_custom(Ktmp8x8); - for (int i=0; i<2; i++) { - dout[token_depth] = gcn_mfma16x16x16_instr(Klocaltmp.xy[i], - Qlocal[qkhe_depth][qkratio].xy[i], dout[token_depth]); - } - } + const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride + + ((rowid * VTOKENS_PER_LANE) % BLOCK_SIZE); + + // v fetches are 16head elems across lanes x 16 tokens per lane + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + const int vhead_elem = vhe_depth * NWARPS * 16 + warpid * 16 + lane16id; + const cache_t* v_ptr2 = v_ptr + vhead_elem * BLOCK_SIZE; + + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + const int vblock_depth = 0; + const int64_t vblock_number = static_cast( + vphysical_block_number[vtoken_depth][vblock_depth]); + const cache_t* v_ptr3 = v_ptr2 + (vblock_number * kv_block_stride); + + const cache_t* v_fetch_ptr = + v_ptr3 + vfetch_depth * CONTIGUOUS_KV_ELEMS_16B_LOAD; + const _B16x8* v_fetch_ptr_16B = + reinterpret_cast(v_fetch_ptr); + if constexpr (NT_KV_LOAD) { + Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = + load_ntmprl_16Byte(v_fetch_ptr_16B); + } else { + Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = *v_fetch_ptr_16B; } } - dout[token_depth] *= scale2; } + } - const int qkout_token_idx = partition_start_token_idx + TOKENS_PER_WARP * warpid + rowid * 4; - - //apply alibi - if constexpr(ALIBI_ENABLED) { - for (int token_depth = 0; token_depth < TLOOP; token_depth++) { - const int local_token_idx = qkout_token_idx + token_depth * 16; - const int alibi_offset = local_token_idx - context_len + 1; - for (int i=0; i<4; i++) { - dout[token_depth][i] += alibi_slope * (alibi_offset + i); - } + // calculate post qk mfma scale + float scale2 = scale; + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + // multiply by k_scale if fp8 kv cache + scale2 *= *k_scale_ptr; + } + + floatx4 dout[TLOOP]; + // qk mfma + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + dout[token_depth] = {0}; + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { + for (int i = 0; i < 2; i++) { + dout[token_depth] = gcn_mfma16x16x16_instr( + Klocal[token_depth][qkhe_depth].xy[i], + Qlocal[qkhe_depth][qkratio].xy[i], dout[token_depth]); + } + } + } else { // kv cache dtype fp8 + auto Ktmp = Klocal[token_depth][qkhe_depth]; + _B8x16 Ktmp8x16 = *reinterpret_cast<_B8x16*>(&Ktmp); + for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { + _B8x8 Ktmp8x8 = Ktmp8x16.xy[qkratio]; + _B16x8 Klocaltmp = convert_b8x8_custom(Ktmp8x8); + for (int i = 0; i < 2; i++) { + dout[token_depth] = gcn_mfma16x16x16_instr( + Klocaltmp.xy[i], Qlocal[qkhe_depth][qkratio].xy[i], + dout[token_depth]); + } } + } } - - //calculate qk_max and exp_sum per warp and write to shared memory - float qk_max = -FLT_MAX; - float exp_sum = 0.0f; + dout[token_depth] *= scale2; + } + + const int qkout_token_idx = + partition_start_token_idx + TOKENS_PER_WARP * warpid + rowid * 4; + // apply alibi + if constexpr (ALIBI_ENABLED) { for (int token_depth = 0; token_depth < TLOOP; token_depth++) { - const int local_token_idx = qkout_token_idx + token_depth * 16; - for (int i=0; i<4; i++) { - const float tmp = (local_token_idx + i < context_len) ? dout[token_depth][i] : -FLT_MAX; - qk_max = fmaxf(qk_max, tmp); - } + const int local_token_idx = qkout_token_idx + token_depth * 16; + const int alibi_offset = local_token_idx - context_len + 1; + for (int i = 0; i < 4; i++) { + dout[token_depth][i] += alibi_slope * (alibi_offset + i); + } } + } + + // calculate qk_max and exp_sum per warp and write to shared memory + float qk_max = -FLT_MAX; + float exp_sum = 0.0f; - for (int mask = WARP_SIZE/2; mask >= 16; mask/=2) { - qk_max = fmaxf(qk_max, __shfl_xor(qk_max,mask)); + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + for (int i = 0; i < 4; i++) { + const float tmp = + (local_token_idx + i < context_len) ? dout[token_depth][i] : -FLT_MAX; + qk_max = fmaxf(qk_max, tmp); } + } + for (int mask = WARP_SIZE / 2; mask >= 16; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor(qk_max, mask)); + } - for (int token_depth = 0; token_depth < TLOOP; token_depth++) { - const int local_token_idx = qkout_token_idx + token_depth * 16; - for (int i=0; i<4; i++) { - const float tmp = (local_token_idx + i < context_len) ? __expf(dout[token_depth][i] - qk_max) : 0.0f; - dout[token_depth][i] = tmp; - exp_sum += tmp; - } + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + for (int i = 0; i < 4; i++) { + const float tmp = (local_token_idx + i < context_len) + ? __expf(dout[token_depth][i] - qk_max) + : 0.0f; + dout[token_depth][i] = tmp; + exp_sum += tmp; } + } - for (int mask = WARP_SIZE/2; mask >= 16; mask/=2) { - exp_sum += __shfl_xor(exp_sum,mask); - } + for (int mask = WARP_SIZE / 2; mask >= 16; mask /= 2) { + exp_sum += __shfl_xor(exp_sum, mask); + } - __syncthreads(); //sync before writing to shared mem + __syncthreads(); // sync before writing to shared mem - float* shared_mem = reinterpret_cast(shared_logits); - if (laneid < 16) { - const int qk_max_offset = warpid*16 + lane16id; - shared_mem[qk_max_offset] = qk_max; - const int exp_sum_offset = NWARPS*16 + qk_max_offset; - shared_mem[exp_sum_offset] = exp_sum; - } + float* shared_mem = reinterpret_cast(shared_logits); + if (laneid < 16) { + const int qk_max_offset = warpid * 16 + lane16id; + shared_mem[qk_max_offset] = qk_max; + const int exp_sum_offset = NWARPS * 16 + qk_max_offset; + shared_mem[exp_sum_offset] = exp_sum; + } - __syncthreads(); - - //calculate partition qk_max and exp_sum - float partition_qk_max = -FLT_MAX; - float warp_qk_max_exp[NWARPS]; - float partition_exp_sum = 0.0f; + __syncthreads(); - for (int w=0; w(dout[token_depth]); - } - //write out partition max_logits and exp_sum - if (threadIdx.x < GQA_RATIO) { - const int qhead_idx = lane16id; - const int offset = seq_idx * total_num_heads * max_num_partitions + (wg_start_head_idx + qhead_idx) * max_num_partitions + partition_idx; - max_logits[offset] = partition_qk_max; - exp_sums[offset] = partition_exp_sum; - } - - __syncthreads(); - - constexpr int ELEMS8_ELEMS4_RATIO = 8 / 4; - constexpr int ELEMS16_ELEMS8_RATIO = 16 / 8; - - _B16x4 outelems[VHELOOP]; - //Softmax V mfma - //v layout: 16he across lanes x 16 tokens per lane - for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { - floatx4 tmp_out = {0}; - - for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { - - if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { - for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { - for (int i=0; i(Vlocal[vtoken_depth][vhe_depth][vfetch_depth].xy[i], - shared_logits[vtoken_depth][offset2][lane16id][offset1], - tmp_out); - } - } - //KV cache fp8 - } else { - for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { - _B16x8 Vtmp = Vlocal[vtoken_depth][vhe_depth][vfetch_depth]; - //reinterpret V format as 16 elements of 8bits - _B8x16 Vtmp8x16 = *reinterpret_cast<_B8x16*>(&Vtmp); - for (int j=0; j(Vtmp8x8); - for (int i=0; i(Vlocaltmp.xy[i], - shared_logits[vtoken_depth][offset2][lane16id][offset1], - tmp_out); - } - } + const float inv_sum_scale = + __fdividef(1.f, partition_exp_sum + 1e-6f) * warp_qk_max_exp[warpid]; + + __syncthreads(); + + // write logits to shared mem + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + dout[token_depth] *= inv_sum_scale; + // use rtz conversion for performance, with no visible impact on accuracy + shared_logits[warpid][token_depth][lane16id][rowid] = + from_floatx4_rtz(dout[token_depth]); + } + // write out partition max_logits and exp_sum + if (threadIdx.x < GQA_RATIO) { + const int qhead_idx = lane16id; + const int offset = seq_idx * total_num_heads * max_num_partitions + + (wg_start_head_idx + qhead_idx) * max_num_partitions + + partition_idx; + max_logits[offset] = partition_qk_max; + exp_sums[offset] = partition_exp_sum; + } + + __syncthreads(); + + constexpr int ELEMS8_ELEMS4_RATIO = 8 / 4; + constexpr int ELEMS16_ELEMS8_RATIO = 16 / 8; + + _B16x4 outelems[VHELOOP]; + // Softmax V mfma + // v layout: 16he across lanes x 16 tokens per lane + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + floatx4 tmp_out = {0}; + + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + for (int i = 0; i < ELEMS8_ELEMS4_RATIO; i++) { + const int offset = rowid * VTLANELOOP * ELEMS8_ELEMS4_RATIO + + vfetch_depth * ELEMS8_ELEMS4_RATIO + i; + const int offset1 = offset % ROWS_PER_WARP; + const int offset2 = offset / ROWS_PER_WARP; + // output format is 16 qheads across 16 lanes, 16 head elems spread + // across 4 rows + tmp_out = gcn_mfma16x16x16_instr( + Vlocal[vtoken_depth][vhe_depth][vfetch_depth].xy[i], + shared_logits[vtoken_depth][offset2][lane16id][offset1], + tmp_out); } } + // KV cache fp8 + } else { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + _B16x8 Vtmp = Vlocal[vtoken_depth][vhe_depth][vfetch_depth]; + // reinterpret V format as 16 elements of 8bits + _B8x16 Vtmp8x16 = *reinterpret_cast<_B8x16*>(&Vtmp); + for (int j = 0; j < ELEMS16_ELEMS8_RATIO; j++) { + _B8x8 Vtmp8x8 = Vtmp8x16.xy[j]; + _B16x8 Vlocaltmp = convert_b8x8_custom(Vtmp8x8); + for (int i = 0; i < ELEMS8_ELEMS4_RATIO; i++) { + const int offset = + rowid * ELEMS16_ELEMS8_RATIO * ELEMS8_ELEMS4_RATIO + + j * ELEMS8_ELEMS4_RATIO + i; + const int offset1 = offset % ROWS_PER_WARP; + const int offset2 = offset / ROWS_PER_WARP; + // output format is 16 qheads across 16 lanes, 16 head elems + // spread across 4 rows + tmp_out = gcn_mfma16x16x16_instr( + Vlocaltmp.xy[i], + shared_logits[vtoken_depth][offset2][lane16id][offset1], + tmp_out); + } + } } - //apply post Softmax V mfma v_scale - if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { - tmp_out *= *v_scale_ptr; - } - outelems[vhe_depth] = from_floatx4(tmp_out); + } + } + // apply post Softmax V mfma v_scale + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + tmp_out *= *v_scale_ptr; } + outelems[vhe_depth] = from_floatx4(tmp_out); + } - __syncthreads(); + __syncthreads(); - //store Softmax-V mfma output to shared mem - for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { - //lane16 id head dimension; rowid head element dimension - shared_logits[warpid][vhe_depth][lane16id][rowid] = outelems[vhe_depth]; - } + // store Softmax-V mfma output to shared mem + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + // lane16 id head dimension; rowid head element dimension + shared_logits[warpid][vhe_depth][lane16id][rowid] = outelems[vhe_depth]; + } - __syncthreads(); - - //write to tmp_out with coalesced writes after reading from shared mem - if (warpid == 0) { - _B16x8 vout[GQA_RATIO4]; - //each lane writes out 16Bytes of tmp_out along head elem dimension - const int head_elem_idx = lane16id * 8; - if (head_elem_idx < HEAD_SIZE) { - for (int h = 0; h < GQA_RATIO4; h++) { - const int local_head_idx = 4 * h + rowid; - const int offset1 = (head_elem_idx / 16)%4; - const int offset2 = head_elem_idx / 16 / NWARPS; - const int offset3 = (head_elem_idx / 4)%4; - for (int i=0; i<2; i++) { - vout[h].xy[i] = shared_logits[offset1][offset2][local_head_idx][offset3+i]; - } + __syncthreads(); + + // write to tmp_out with coalesced writes after reading from shared mem + if (warpid == 0) { + _B16x8 vout[GQA_RATIO4]; + // each lane writes out 16Bytes of tmp_out along head elem dimension + const int head_elem_idx = lane16id * 8; + if (head_elem_idx < HEAD_SIZE) { + for (int h = 0; h < GQA_RATIO4; h++) { + const int local_head_idx = 4 * h + rowid; + const int offset1 = (head_elem_idx / 16) % 4; + const int offset2 = head_elem_idx / 16 / NWARPS; + const int offset3 = (head_elem_idx / 4) % 4; + for (int i = 0; i < 2; i++) { + vout[h].xy[i] = + shared_logits[offset1][offset2][local_head_idx][offset3 + i]; } + } - const int hsz_maxp_mult = HEAD_SIZE * max_num_partitions; - scalar_t* out_ptr = out + - seq_idx * total_num_heads * hsz_maxp_mult + partition_idx * HEAD_SIZE; - for (int h = 0; h < GQA_RATIO4; h++) { - const int local_head_idx = 4 * h + rowid; - if (local_head_idx < GQA_RATIO) { - const int out_head_idx = wg_start_head_idx + local_head_idx; - scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult; - scalar_t* out_ptr3 = out_ptr2 + head_elem_idx; - _B16x8* out_ptr_B16x8 = reinterpret_cast<_B16x8*>(out_ptr3); - *out_ptr_B16x8 = vout[h]; - } + const int hsz_maxp_mult = HEAD_SIZE * max_num_partitions; + scalar_t* out_ptr = out + seq_idx * total_num_heads * hsz_maxp_mult + + partition_idx * HEAD_SIZE; + for (int h = 0; h < GQA_RATIO4; h++) { + const int local_head_idx = 4 * h + rowid; + if (local_head_idx < GQA_RATIO) { + const int out_head_idx = wg_start_head_idx + local_head_idx; + scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult; + scalar_t* out_ptr3 = out_ptr2 + head_elem_idx; + _B16x8* out_ptr_B16x8 = reinterpret_cast<_B16x8*>(out_ptr3); + *out_ptr_B16x8 = vout[h]; } } } + } } ///////////////////////////////////////////////////////////// // grid (num_seqs, num_partitions, num_kv_heads) // block (256 : partition size) -//each WG handles 1 partition per sequence +// each WG handles 1 partition per sequence template -__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, // head_size/x, block_size, x] @@ -802,27 +868,25 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_k return; } // every 4 lanes fetch 4 different qheads - //qhloop = num loops over qhead dimension - constexpr int QHLOOP = - DIVIDE_ROUND_UP(GQA_RATIO, 4); + // qhloop = num loops over qhead dimension + constexpr int QHLOOP = DIVIDE_ROUND_UP(GQA_RATIO, 4); constexpr int GQA_RATIO4 = 4 * QHLOOP; __shared__ float shared_qk_max[NWARPS][GQA_RATIO4 + 1]; __shared__ float shared_exp_sum[NWARPS][GQA_RATIO4 + 1]; _B16x8 Qlocal[QHLOOP]; constexpr int x = 16 / sizeof(scalar_t); - //kheloop = num loops over head_size for 16Bytes of Q/dequantized K elements + // kheloop = num loops over head_size for 16Bytes of Q/dequantized K elements constexpr int KHELOOP = HEAD_SIZE / x; _B16x8 Klocal[KHELOOP]; _B8x8 Klocalb8[KHELOOP]; // for SoftMax-V Gemm, V head_size dimension is distributed across warp - //vheloop = num loops to cover v head size dimension - constexpr int VHELOOP = - HEAD_SIZE / - WARP_SIZE; - //softmax out has warp_size tokens across warp - //vtloop = num loops to cover warp_size(64) tokens with 16Bytes of dequantized V elements - constexpr int VTLOOP = WARP_SIZE/8; - //num vblocks to cover warp_size(64) v elements + // vheloop = num loops to cover v head size dimension + constexpr int VHELOOP = HEAD_SIZE / WARP_SIZE; + // softmax out has warp_size tokens across warp + // vtloop = num loops to cover warp_size(64) tokens with 16Bytes of + // dequantized V elements + constexpr int VTLOOP = WARP_SIZE / 8; + // num vblocks to cover warp_size(64) v elements constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE; int vphysical_blocks[VBLOCKS]; _B16x8 Vlocal[VHELOOP][VTLOOP]; @@ -844,46 +908,45 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_k partition_start_token_idx + warpid * WARP_SIZE; // entire warp out of context - if (warp_start_token_idx >= context_len) { + if (warp_start_token_idx >= context_len) { #pragma unroll for (int h = 0; h < GQA_RATIO4; h++) { shared_qk_max[warpid][h] = -FLT_MAX; shared_exp_sum[warpid][h] = 0.0f; } - // warp within context - } else { - + // warp within context + } else { const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); const int last_ctx_block = num_context_blocks - 1; const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; - //token id within partition + // token id within partition const int local_token_idx = threadIdx.x; - //token id within sequence + // token id within sequence const int global_token_idx = partition_start_token_idx + local_token_idx; // fetch block number for k const int block_idx = (global_token_idx < context_len) ? global_token_idx / BLOCK_SIZE : last_ctx_block; - - //fetch k physical block number - // int32 physical_block_number leads to overflow when multiplied with - // kv_block_stride + + // fetch k physical block number + // int32 physical_block_number leads to overflow when multiplied with + // kv_block_stride const int64_t physical_block_number = static_cast(block_table[block_idx]); - //fetch vphysical block numbers up front + // fetch vphysical block numbers up front const int warp_start_block_idx = warp_start_token_idx / BLOCK_SIZE; for (int b = 0; b < VBLOCKS; b++) { - const int vblock_idx = warp_start_block_idx + b; - const int vblock_idx_ctx = - (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block; - vphysical_blocks[b] = block_table[vblock_idx_ctx]; + const int vblock_idx = warp_start_block_idx + b; + const int vblock_idx_ctx = + (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block; + vphysical_blocks[b] = block_table[vblock_idx_ctx]; } - //fetch q elements - //every 4 lanes fetch 8 elems, so warp fetches 8*16 = 128 elems + // fetch q elements + // every 4 lanes fetch 8 elems, so warp fetches 8*16 = 128 elems const scalar_t* q_ptr = q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE; const _B16x8* q_ptrh8 = reinterpret_cast(q_ptr); @@ -902,22 +965,22 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_k Qlocal[QHLOOP - 1].xy[1] = {0}; } - //fetch k elements + // fetch k elements const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride + wg_start_kv_head_idx * kv_head_stride; // physical_block_offset is already cast in terms of _B16x8 - const int physical_block_offset = - local_token_idx % BLOCK_SIZE; + const int physical_block_offset = local_token_idx % BLOCK_SIZE; - //each K fetch is for 8 elements of cache_t which are later dequantized to scalar_t for fp8 + // each K fetch is for 8 elements of cache_t which are later dequantized to + // scalar_t for fp8 if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { const _B16x8* k_ptrh8 = reinterpret_cast(k_ptr); for (int d = 0; d < KHELOOP; d++) { Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset]; } } else { - //vllm defines X as 16 Bytes of elements of cache_t + // vllm defines X as 16 Bytes of elements of cache_t constexpr int X = 16 / sizeof(cache_t); const cache_t* k_ptr2 = k_ptr + physical_block_offset * X; for (int d = 0; d < KHELOOP; d++) { @@ -929,9 +992,9 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_k } } - //optional alibi fetch + // optional alibi fetch float alibi_slope[QHLOOP]; - if constexpr(ALIBI_ENABLED) { + if constexpr (ALIBI_ENABLED) { for (int h = 0; h < QHLOOP; h++) { const int qhead_idx = h * 4 + lane4id; alibi_slope[h] = (qhead_idx < GQA_RATIO) @@ -941,7 +1004,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_k } const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; - //fetch vcache in kv cache auto case + // fetch vcache in kv cache auto case if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { const _B16x8* v_ptrh8 = reinterpret_cast(v_ptr); // iterate over each v block @@ -962,9 +1025,9 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_k } } } - } //if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) - //fetch vcache in fp8 case - else { // if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) + } // if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) + // fetch vcache in fp8 case + else { // if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) const _B8x8* v_ptrh8 = reinterpret_cast(v_ptr); // iterate over each v block for (int b = 0; b < VBLOCKS; b++) { @@ -986,53 +1049,53 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_k } } -#define QK_mfma(x) \ + #define QK_mfma(x) \ if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { \ - Klocal[x] = convert_b8x8_custom(Klocalb8[x]); \ - } \ - for (int h = 0; h < QHLOOP; h++) { \ - dout[h] = gcn_mfma4x4x4_instr(Qlocal[h].xy[0], \ - Klocal[x].xy[0], dout[h]);\ - dout[h] = gcn_mfma4x4x4_instr(Qlocal[h].xy[1], \ - Klocal[x].xy[1], dout[h]);\ - } - //QK mfma with Q mfma block broadcast - //Q values across head_size dimension stored across lanes - //K values across head_size dimension are stored depthwise within lane - //Q broadcast with absz, cbid of mfma instruction - QK_mfma(0); - QK_mfma(1); - QK_mfma(2); - QK_mfma(3); - QK_mfma(4); - QK_mfma(5); - QK_mfma(6); - QK_mfma(7); - //below only needed for head size 128 - if constexpr (KHELOOP > 8) { - QK_mfma(8); - QK_mfma(9); - QK_mfma(10); - QK_mfma(11); - QK_mfma(12); - QK_mfma(13); - QK_mfma(14); - QK_mfma(15); - } -#undef QK_mfma + Klocal[x] = convert_b8x8_custom(Klocalb8[x]); \ + } \ + for (int h = 0; h < QHLOOP; h++) { \ + dout[h] = gcn_mfma4x4x4_instr( \ + Qlocal[h].xy[0], Klocal[x].xy[0], dout[h]); \ + dout[h] = gcn_mfma4x4x4_instr( \ + Qlocal[h].xy[1], Klocal[x].xy[1], dout[h]); \ + } + // QK mfma with Q mfma block broadcast + // Q values across head_size dimension stored across lanes + // K values across head_size dimension are stored depthwise within lane + // Q broadcast with absz, cbid of mfma instruction + QK_mfma(0); + QK_mfma(1); + QK_mfma(2); + QK_mfma(3); + QK_mfma(4); + QK_mfma(5); + QK_mfma(6); + QK_mfma(7); + // below only needed for head size 128 + if constexpr (KHELOOP > 8) { + QK_mfma(8); + QK_mfma(9); + QK_mfma(10); + QK_mfma(11); + QK_mfma(12); + QK_mfma(13); + QK_mfma(14); + QK_mfma(15); + } + #undef QK_mfma float scale2 = scale; if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { - //post mfma scaling for fp8 - scale2 *= *k_scale_ptr; + // post mfma scaling for fp8 + scale2 *= *k_scale_ptr; } for (int h = 0; h < QHLOOP; h++) { dout[h] *= scale2; } - // transpose dout so that 4 token ids are in each lane, and 4 heads are across - // 4 lanes + // transpose dout so that 4 token ids are in each lane, and 4 heads are + // across 4 lanes for (int h = 0; h < QHLOOP; h++) { floatx4 tmp = {0}; for (int i = 0; i < 4; i++) { @@ -1044,7 +1107,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_k const int lane4_token_idx = 4 * (global_token_idx >> 2); - if constexpr(ALIBI_ENABLED) { + if constexpr (ALIBI_ENABLED) { const int alibi_offset = lane4_token_idx - context_len + 1; for (int h = 0; h < QHLOOP; h++) { for (int i = 0; i < 4; i++) { @@ -1053,7 +1116,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_k } } - const int bpermute_mask = 4*(16*((laneid>>2)%4) + lane4id); + const int bpermute_mask = 4 * (16 * ((laneid >> 2) % 4) + lane4id); for (int h = 0; h < QHLOOP; h++) { qk_max[h] = -FLT_MAX; @@ -1063,20 +1126,28 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_k : qk_max[h]; } - //for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { - // qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h], mask)); - //} - //faster version of above code with dpp - asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:4" : "=v"(qk_max[h]) : "v"(qk_max[h]), "v"(qk_max[h]) ); - asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:8" : "=v"(qk_max[h]) : "v"(qk_max[h]), "v"(qk_max[h]) ); - - auto tmp = __builtin_amdgcn_ds_bpermute(bpermute_mask, *reinterpret_cast(&qk_max[h])); + // for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { + // qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h], mask)); + // } + // faster version of above code with dpp + asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:4" + : "=v"(qk_max[h]) + : "v"(qk_max[h]), "v"(qk_max[h])); + asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:8" + : "=v"(qk_max[h]) + : "v"(qk_max[h]), "v"(qk_max[h])); + + auto tmp = __builtin_amdgcn_ds_bpermute( + bpermute_mask, *reinterpret_cast(&qk_max[h])); qk_max[h] = *reinterpret_cast(&tmp); - asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:4" : "=v"(qk_max[h]) : "v"(qk_max[h]), "v"(qk_max[h]) ); - asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:8" : "=v"(qk_max[h]) : "v"(qk_max[h]), "v"(qk_max[h]) ); + asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:4" + : "=v"(qk_max[h]) + : "v"(qk_max[h]), "v"(qk_max[h])); + asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:8" + : "=v"(qk_max[h]) + : "v"(qk_max[h]), "v"(qk_max[h])); } - float exp_sum[QHLOOP]; for (int h = 0; h < QHLOOP; h++) { exp_sum[h] = 0.0f; @@ -1086,25 +1157,34 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_k : 0.0f; exp_sum[h] += dout[h][i]; } - //for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { - // exp_sum[h] += __shfl_xor(exp_sum[h], mask); - //} - //faster version of above code with dpp - asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:4" : "=v"(exp_sum[h]) : "v"(exp_sum[h]), "v"(exp_sum[h]) ); - asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:8" : "=v"(exp_sum[h]) : "v"(exp_sum[h]), "v"(exp_sum[h]) ); - - auto tmp = __builtin_amdgcn_ds_bpermute(bpermute_mask, *reinterpret_cast(&exp_sum[h])); + // for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { + // exp_sum[h] += __shfl_xor(exp_sum[h], mask); + // } + // faster version of above code with dpp + asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:4" + : "=v"(exp_sum[h]) + : "v"(exp_sum[h]), "v"(exp_sum[h])); + asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:8" + : "=v"(exp_sum[h]) + : "v"(exp_sum[h]), "v"(exp_sum[h])); + + auto tmp = __builtin_amdgcn_ds_bpermute( + bpermute_mask, *reinterpret_cast(&exp_sum[h])); exp_sum[h] = *reinterpret_cast(&tmp); - asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:4" : "=v"(exp_sum[h]) : "v"(exp_sum[h]), "v"(exp_sum[h]) ); - asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:8" : "=v"(exp_sum[h]) : "v"(exp_sum[h]), "v"(exp_sum[h]) ); + asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:4" + : "=v"(exp_sum[h]) + : "v"(exp_sum[h]), "v"(exp_sum[h])); + asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:8" + : "=v"(exp_sum[h]) + : "v"(exp_sum[h]), "v"(exp_sum[h])); } - if (laneid<4) { - for (int h = 0; h < QHLOOP; h++) { - const int head_idx = 4 * h + lane4id; - shared_qk_max[warpid][head_idx] = qk_max[h]; - shared_exp_sum[warpid][head_idx] = exp_sum[h]; - } + if (laneid < 4) { + for (int h = 0; h < QHLOOP; h++) { + const int head_idx = 4 * h + lane4id; + shared_qk_max[warpid][head_idx] = qk_max[h]; + shared_exp_sum[warpid][head_idx] = exp_sum[h]; + } } } // warp within context @@ -1115,7 +1195,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_k max_logits + seq_idx * num_heads * max_num_partitions + partition_idx; float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + partition_idx; - //calculate qk_max and exp_sums for partition + // calculate qk_max and exp_sums for partition for (int h = 0; h < QHLOOP; h++) { float global_qk_max = -FLT_MAX; float warp_qk_max[NWARPS]; @@ -1143,11 +1223,10 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_k // are 4x16 tokens across warp _B16x4 logits[QHLOOP]; for (int h = 0; h < QHLOOP; h++) { - //use rtz for faster performance with no perceivable accuracy loss + // use rtz for faster performance with no perceivable accuracy loss logits[h] = from_floatx4_rtz(dout[h]); } - if (warp_start_token_idx >= context_len) { // warp out of context for (int qh = 0; qh < QHLOOP; qh++) { for (int vh = 0; vh < VHELOOP; vh++) { @@ -1155,57 +1234,57 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_k } } } else { // warp in context - #define SV_mfma(x) \ - if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) {\ - Vlocal[vh][x] = convert_b8x8_custom(Vlocalb8[vh][x]);\ - }\ - for (int qh = 0; qh < QHLOOP; qh++) { \ - acc[qh] = gcn_mfma4x4x4_instr(logits[qh], Vlocal[vh][x].xy[0], \ - acc[qh]); \ - acc[qh] = gcn_mfma4x4x4_instr(logits[qh], Vlocal[vh][x].xy[1], \ - acc[qh]); \ - } - - for (int vh = 0; vh < VHELOOP; vh++) { - floatx4 acc[QHLOOP]; - for (int qh = 0; qh < QHLOOP; qh++) { - acc[qh] = {0}; - } - //SoftMax-V calculation - //logits -> token dimension is distributed across lanes - //Vlocal -> token dimension is depthwise within lane - //uses mfma instruction block broadcast for logits - SV_mfma(0); - SV_mfma(1); - SV_mfma(2); - SV_mfma(3); - SV_mfma(4); - SV_mfma(5); - SV_mfma(6); - SV_mfma(7); - - for (int qh = 0; qh < QHLOOP; qh++) { - if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { - //post mfma v scale for fp8 - acc[qh] *= *v_scale_ptr; - } - vout_shared[qh][vh][laneid][warpid] = from_floatx4(acc[qh]); + #define SV_mfma(x) \ + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { \ + Vlocal[vh][x] = convert_b8x8_custom(Vlocalb8[vh][x]); \ + } \ + for (int qh = 0; qh < QHLOOP; qh++) { \ + acc[qh] = gcn_mfma4x4x4_instr( \ + logits[qh], Vlocal[vh][x].xy[0], acc[qh]); \ + acc[qh] = gcn_mfma4x4x4_instr( \ + logits[qh], Vlocal[vh][x].xy[1], acc[qh]); \ + } + + for (int vh = 0; vh < VHELOOP; vh++) { + floatx4 acc[QHLOOP]; + for (int qh = 0; qh < QHLOOP; qh++) { + acc[qh] = {0}; + } + // SoftMax-V calculation + // logits -> token dimension is distributed across lanes + // Vlocal -> token dimension is depthwise within lane + // uses mfma instruction block broadcast for logits + SV_mfma(0); + SV_mfma(1); + SV_mfma(2); + SV_mfma(3); + SV_mfma(4); + SV_mfma(5); + SV_mfma(6); + SV_mfma(7); + + for (int qh = 0; qh < QHLOOP; qh++) { + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + // post mfma v scale for fp8 + acc[qh] *= *v_scale_ptr; } + vout_shared[qh][vh][laneid][warpid] = from_floatx4(acc[qh]); } + } -#undef SV_mfma + #undef SV_mfma } // warp in context __syncthreads(); - //final write to tmp_out after vout accumulation + // final write to tmp_out after vout accumulation if (warpid == 0) { const float out_scale = (fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f; _B16x4 vout[QHLOOP][VHELOOP]; // iterate across heads for (int qh = 0; qh < QHLOOP; qh++) { - // iterate over each v head elem (within head_size) + // iterate over each v head elem (within head_size) for (int vh = 0; vh < VHELOOP; vh++) { vout[qh][vh] = {0}; for (int w = 0; w < NWARPS; w++) { @@ -1437,7 +1516,8 @@ template -__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, // head_size/x, block_size, x] @@ -1464,7 +1544,8 @@ template -__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, // head_size/x, block_size, x] @@ -1505,9 +1586,10 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support -#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \ - paged_attention_ll4mi_QKV_mfma16_kernel \ +#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \ + paged_attention_ll4mi_QKV_mfma16_kernel \ <<>>( \ query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ @@ -1515,9 +1597,10 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \ k_scale_ptr, v_scale_ptr, fp8_out_scale_ptr); -#define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \ - paged_attention_ll4mi_QKV_mfma4_kernel \ +#define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \ + paged_attention_ll4mi_QKV_mfma4_kernel \ <<>>( \ query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ @@ -1533,7 +1616,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( context_lens_ptr, max_num_partitions, fp8_out_scale_ptr); template + int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD, + bool ALIBI_ENABLED> void paged_attention_custom_launcher( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, @@ -1576,8 +1660,8 @@ void paged_attention_custom_launcher( const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); - //partition size is fixed at 256 since both mfma4 and mfma16 kernels support it - //mfma4 kernel also supports partition size 512 + // partition size is fixed at 256 since both mfma4 and mfma16 kernels support + // it mfma4 kernel also supports partition size 512 constexpr int PARTITION_SIZE = 256; const int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); @@ -1591,7 +1675,7 @@ void paged_attention_custom_launcher( const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - //mfma4 kernel is faster than mfma16 for gqa_ratio <= 4 + // mfma4 kernel is faster than mfma16 for gqa_ratio <= 4 switch (gqa_ratio) { case 1: LAUNCH_CUSTOM_ATTENTION_MFMA4(1); @@ -1646,62 +1730,66 @@ void paged_attention_custom_launcher( break; } - dim3 reduce_grid(num_heads, num_seqs); - dim3 reduce_block(head_size); - const int npar_loops = DIVIDE_ROUND_UP(max_num_partitions, WARP_SIZE); - //reduction kernel supports upto 8 NPAR_loops * 64 (warp_size) * 256 (partition size) = 128K context length - switch (npar_loops) { - case 1: - LAUNCH_CUSTOM_REDUCTION(1); - break; - case 2: - LAUNCH_CUSTOM_REDUCTION(2); - break; - case 3: - LAUNCH_CUSTOM_REDUCTION(3); - break; - case 4: - LAUNCH_CUSTOM_REDUCTION(4); - break; - case 5: - LAUNCH_CUSTOM_REDUCTION(5); - break; - case 6: - LAUNCH_CUSTOM_REDUCTION(6); - break; - case 7: - LAUNCH_CUSTOM_REDUCTION(7); - break; - case 8: - LAUNCH_CUSTOM_REDUCTION(8); - break; - default: - TORCH_CHECK(false, "Unsupported npar_loops: ", npar_loops); - break; - } + dim3 reduce_grid(num_heads, num_seqs); + dim3 reduce_block(head_size); + const int npar_loops = DIVIDE_ROUND_UP(max_num_partitions, WARP_SIZE); + // reduction kernel supports upto 8 NPAR_loops * 64 (warp_size) * 256 + // (partition size) = 128K context length + switch (npar_loops) { + case 1: + LAUNCH_CUSTOM_REDUCTION(1); + break; + case 2: + LAUNCH_CUSTOM_REDUCTION(2); + break; + case 3: + LAUNCH_CUSTOM_REDUCTION(3); + break; + case 4: + LAUNCH_CUSTOM_REDUCTION(4); + break; + case 5: + LAUNCH_CUSTOM_REDUCTION(5); + break; + case 6: + LAUNCH_CUSTOM_REDUCTION(6); + break; + case 7: + LAUNCH_CUSTOM_REDUCTION(7); + break; + case 8: + LAUNCH_CUSTOM_REDUCTION(8); + break; + default: + TORCH_CHECK(false, "Unsupported npar_loops: ", npar_loops); + break; + } } #define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \ - PSIZE, ALIBI_ENABLED) \ + PSIZE, ALIBI_ENABLED) \ paged_attention_custom_launcher( \ + PSIZE, ALIBI_ENABLED>( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ num_kv_heads, scale, block_tables, context_lens, max_context_len, \ alibi_slopes, k_scale, v_scale, fp8_out_scale); -#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ - OUTT, PSIZE) \ - if (alibi_slopes) { \ - CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, true); \ - } else { \ - CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, false); \ +#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ + OUTT, PSIZE) \ + if (alibi_slopes) { \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \ + true); \ + } else { \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \ + false); \ } #define CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ OUTT) \ switch (partition_size) { \ case 256: \ - CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, 256); \ + CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \ + 256); \ break; \ default: \ TORCH_CHECK(false, "Unsupported partition size: ", partition_size); \ From 7aa547ff9b9d6b1527a6b2558cc307bfd1043907 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Wed, 22 Jan 2025 21:07:18 +0000 Subject: [PATCH 5/8] Lint --- tests/kernels/test_attention.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 8309446e358f..b826559cd8db 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -116,8 +116,7 @@ def ref_single_query_cached_kv_attention( @pytest.mark.parametrize( - "version", - ["v1", "v2"] if not current_platform.is_rocm() else ["rocm"]) + "version", ["v1", "v2"] if not current_platform.is_rocm() else ["rocm"]) @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) From 4d0166cecfbd974a871e6cb4147076a9fc781412 Mon Sep 17 00:00:00 2001 From: sanyalington Date: Thu, 23 Jan 2025 15:05:33 +0000 Subject: [PATCH 6/8] add flag for logits rtz conversion and disable by default --- csrc/rocm/attention.cu | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index f4d7e5c6d3a4..59df85adc521 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -701,12 +701,19 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( __syncthreads(); + constexpr bool LOGITS_RTZ_CONVERSION = false; + // write logits to shared mem for (int token_depth = 0; token_depth < TLOOP; token_depth++) { dout[token_depth] *= inv_sum_scale; - // use rtz conversion for performance, with no visible impact on accuracy - shared_logits[warpid][token_depth][lane16id][rowid] = + if constexpr(LOGITS_RTZ_CONVERSION) { + // use rtz conversion for performance, with no visible impact on accuracy + shared_logits[warpid][token_depth][lane16id][rowid] = from_floatx4_rtz(dout[token_depth]); + } else { + shared_logits[warpid][token_depth][lane16id][rowid] = + from_floatx4(dout[token_depth]); + } } // write out partition max_logits and exp_sum if (threadIdx.x < GQA_RATIO) { @@ -1219,12 +1226,17 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( __expf(qk_max[h] - global_qk_max); dout[h] *= global_inv_sum_scale; } + constexpr bool LOGITS_RTZ_CONVERSION = false; // logits[h] -> every 4 lanes hold 4 heads, each lane holds 4 tokens, there // are 4x16 tokens across warp _B16x4 logits[QHLOOP]; for (int h = 0; h < QHLOOP; h++) { - // use rtz for faster performance with no perceivable accuracy loss - logits[h] = from_floatx4_rtz(dout[h]); + if constexpr(LOGITS_RTZ_CONVERSION) { + // use rtz for faster performance with no perceivable accuracy loss + logits[h] = from_floatx4_rtz(dout[h]); + } else { + logits[h] = from_floatx4(dout[h]); + } } if (warp_start_token_idx >= context_len) { // warp out of context From 52a0b95ba5002b1bdc505cc6ac8dddcc7490936b Mon Sep 17 00:00:00 2001 From: sanyalington Date: Thu, 23 Jan 2025 15:15:02 +0000 Subject: [PATCH 7/8] lint --- csrc/rocm/attention.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 59df85adc521..e2f86d02dbec 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -706,13 +706,13 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( // write logits to shared mem for (int token_depth = 0; token_depth < TLOOP; token_depth++) { dout[token_depth] *= inv_sum_scale; - if constexpr(LOGITS_RTZ_CONVERSION) { + if constexpr (LOGITS_RTZ_CONVERSION) { // use rtz conversion for performance, with no visible impact on accuracy shared_logits[warpid][token_depth][lane16id][rowid] = - from_floatx4_rtz(dout[token_depth]); + from_floatx4_rtz(dout[token_depth]); } else { shared_logits[warpid][token_depth][lane16id][rowid] = - from_floatx4(dout[token_depth]); + from_floatx4(dout[token_depth]); } } // write out partition max_logits and exp_sum @@ -1231,7 +1231,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( // are 4x16 tokens across warp _B16x4 logits[QHLOOP]; for (int h = 0; h < QHLOOP; h++) { - if constexpr(LOGITS_RTZ_CONVERSION) { + if constexpr (LOGITS_RTZ_CONVERSION) { // use rtz for faster performance with no perceivable accuracy loss logits[h] = from_floatx4_rtz(dout[h]); } else { From 49dfc1d27af07033edf3f1443ce74b5d5c44118c Mon Sep 17 00:00:00 2001 From: TJian Date: Wed, 29 Jan 2025 01:22:10 +0800 Subject: [PATCH 8/8] [Bugfix]: Fix paged attention unit tests of https://github.com/ROCm/vllm/pull/372 (#389) * [Bugfix]: fix paged attention tests based on the updated kernels in `csrc/attention/paged_attention_v1.cu`,`csrc/attention/paged_attention_v2.cu` and `csrc/rocm/attention.cu`. * improve code documentation. * lint --------- Co-authored-by: vllmellm --- csrc/rocm/attention.cu | 4 ++- tests/kernels/test_attention.py | 46 +++++++++++++++++++-------------- 2 files changed, 29 insertions(+), 21 deletions(-) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index e2f86d02dbec..d152292635fe 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -701,13 +701,15 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( __syncthreads(); + // disable rtz conversion due to its impact on accuracy. constexpr bool LOGITS_RTZ_CONVERSION = false; // write logits to shared mem for (int token_depth = 0; token_depth < TLOOP; token_depth++) { dout[token_depth] *= inv_sum_scale; if constexpr (LOGITS_RTZ_CONVERSION) { - // use rtz conversion for performance, with no visible impact on accuracy + // use rtz conversion for better performance, with negligible impact on + // accuracy. shared_logits[warpid][token_depth][lane16id][rowid] = from_floatx4_rtz(dout[token_depth]); } else { diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index b826559cd8db..10d984351f67 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -7,7 +7,7 @@ from tests.kernels.utils import opcheck from vllm import _custom_ops as ops from vllm.platforms import current_platform -from vllm.utils import get_max_shared_memory_bytes +from vllm.utils import get_max_shared_memory_bytes, is_navi from .allclose_default import get_default_atol, get_default_rtol @@ -33,7 +33,7 @@ # This should be sync with get_supported_head_sizes() in # vllm.attention.ops.paged_attn.PagedAttention -HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256] +HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256] BLOCK_SIZES = [16, 32] USE_ALIBI = [False, True] @@ -116,7 +116,8 @@ def ref_single_query_cached_kv_attention( @pytest.mark.parametrize( - "version", ["v1", "v2"] if not current_platform.is_rocm() else ["rocm"]) + "version", + ["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"]) @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -181,7 +182,11 @@ def test_paged_attention( key_cache, value_cache = key_caches[0], value_caches[0] # Using default kv_scale - k_scale = v_scale = torch.tensor(0.3, dtype=torch.float) + k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32) + + # additional argument for v1/v2 pa kernel + num_threads = 1024 if current_platform.is_rocm() \ + and not is_navi() else 128 # Call the paged attention kernel. output = torch.empty_like(query) @@ -203,12 +208,12 @@ def test_paged_attention( v_scale, ) - opcheck(torch.ops._C.paged_attention_v1, - (output, query, key_cache, value_cache, num_kv_heads, scale, - block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), - cond=(head_size == HEAD_SIZES[0] - and block_size == BLOCK_SIZES[0])) + opcheck( + torch.ops._C.paged_attention_v1, + (output, query, key_cache, value_cache, num_kv_heads, scale, + block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, + kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0, num_threads), + cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0])) elif version in ("v2", "rocm"): if current_platform.is_rocm(): @@ -247,13 +252,14 @@ def test_paged_attention( v_scale, ) - opcheck(torch.ops._C.paged_attention_v2, - (output, exp_sums, max_logits, tmp_output, query, - key_cache, value_cache, num_kv_heads, scale, block_tables, - seq_lens, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), - cond=(head_size == HEAD_SIZES[0] - and block_size == BLOCK_SIZES[0])) + opcheck( + torch.ops._C.paged_attention_v2, + (output, exp_sums, max_logits, tmp_output, query, key_cache, + value_cache, num_kv_heads, scale, block_tables, seq_lens, + block_size, max_seq_len, alibi_slopes, kv_cache_dtype, + k_scale, v_scale, 0, 0, 0, 64, 0, num_threads), + cond=(head_size == HEAD_SIZES[0] + and block_size == BLOCK_SIZES[0])) else: ops.paged_attention_rocm( @@ -299,14 +305,14 @@ def test_paged_attention( dtype=dtype, device=device) ops.convert_fp8(dequantized_key_cache, key_cache) - key_cache = k_scale * dequantized_key_cache + key_cache = dequantized_key_cache value_cache_shape = value_cache.shape dequantized_value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device) ops.convert_fp8(dequantized_value_cache, value_cache) - value_cache = v_scale * dequantized_value_cache + value_cache = dequantized_value_cache ref_output = torch.empty_like(query) ref_single_query_cached_kv_attention( @@ -434,4 +440,4 @@ def test_multi_query_kv_attention( ) atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3 rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5 - torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) \ No newline at end of file