diff --git a/CMakeLists.txt b/CMakeLists.txt index d530646cd78..45bacd044c0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -301,8 +301,52 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # are not supported by Machete yet. cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}") if (MARLIN_ARCHS) + + # + # For the Marlin kernels we automatically generate sources for various + # preselected input type pairs and schedules. + # Generate sources: + set(MARLIN_GEN_SCRIPT + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/gptq_marlin/generate_kernels.py) + file(MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH) + + message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH}") + message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH}") + + if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH} + OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH}) + execute_process( + COMMAND ${CMAKE_COMMAND} -E env + PYTHONPATH=$PYTHONPATH + ${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT} + RESULT_VARIABLE marlin_generation_result + OUTPUT_VARIABLE marlin_generation_result + OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log + ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log + ) + + if (NOT marlin_generation_result EQUAL 0) + message(FATAL_ERROR "Marlin generation failed." + " Result: \"${marlin_generation_result}\"" + "\nCheck the log for details: " + "${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log") + else() + set(MARLIN_GEN_SCRIPT_HASH ${MARLIN_GEN_SCRIPT_HASH} + CACHE STRING "Last run Marlin generate script hash" FORCE) + message(STATUS "Marlin generation completed successfully.") + endif() + else() + message(STATUS "Marlin generation script has not changed, skipping generation.") + endif() + + file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/kernel_*.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}" + CUDA_ARCHS "${MARLIN_ARCHS}") + + list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC}) + set(MARLIN_SRCS - "csrc/quantization/fp8/fp8_marlin.cu" "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" "csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu" @@ -644,7 +688,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH}) execute_process( COMMAND ${CMAKE_COMMAND} -E env - PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH + PYTHONPATH=$PYTHONPATH ${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT} RESULT_VARIABLE moe_marlin_generation_result OUTPUT_VARIABLE moe_marlin_generation_output diff --git a/csrc/moe/marlin_moe_wna16/.gitignore b/csrc/moe/marlin_moe_wna16/.gitignore new file mode 100644 index 00000000000..77088552b85 --- /dev/null +++ b/csrc/moe/marlin_moe_wna16/.gitignore @@ -0,0 +1 @@ +kernel_*.cu \ No newline at end of file diff --git a/csrc/moe/marlin_moe_wna16/generate_kernels.py b/csrc/moe/marlin_moe_wna16/generate_kernels.py index d1c0d92f681..902bcd9dfd2 100644 --- a/csrc/moe/marlin_moe_wna16/generate_kernels.py +++ b/csrc/moe/marlin_moe_wna16/generate_kernels.py @@ -25,15 +25,13 @@ "{{thread_k_blocks}}, " "{{'true' if m_block_size_8 else 'false'}}, " "{{stages}}, " - "{{'true' if has_act_order else 'false'}}, " - "{{'true' if has_zp else 'false'}}, " "{{group_blocks}}, " "{{'true' if is_zp_float else 'false'}}>" "( MARLIN_KERNEL_PARAMS );") # int8 with zero point case (vllm::kU8) is also supported, # we don't add it to reduce wheel size. -SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128"] +SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn"] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)] THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] @@ -52,21 +50,29 @@ def remove_old_kernels(): def generate_new_kernels(): for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): - has_zp = "B" not in scalar_type all_template_str_list = [] for group_blocks, m_blocks, thread_configs in itertools.product( GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS): - has_act_order = group_blocks == 0 - if has_zp and has_act_order: + # act order case only support gptq-int4 and gptq-int8 + if group_blocks == 0 and scalar_type not in [ + "vllm::kU4B8", "vllm::kU8B128" + ]: continue if thread_configs[2] == 256: + # for small batch (m_blocks == 1), we only need (128, 128, 256) + # for large batch (m_blocks > 1), we only need (64, 256, 256) if m_blocks <= 1 and thread_configs[0] != 128: continue if m_blocks > 1 and thread_configs[0] != 64: continue + # we only support channelwise quantization and group_size == 128 + # for fp8 + if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: + continue + k_blocks = thread_configs[0] // 16 n_blocks = thread_configs[1] // 16 threads = thread_configs[2] @@ -82,8 +88,6 @@ def generate_new_kernels(): thread_k_blocks=k_blocks, m_block_size_8=m_blocks == 0.5, stages="pipe_stages", - has_act_order=has_act_order, - has_zp=has_zp, group_blocks=group_blocks, is_zp_float=False, ) diff --git a/csrc/moe/marlin_moe_wna16/kernel.h b/csrc/moe/marlin_moe_wna16/kernel.h index 3d92660e802..c40c33d01f3 100644 --- a/csrc/moe/marlin_moe_wna16/kernel.h +++ b/csrc/moe/marlin_moe_wna16/kernel.h @@ -18,7 +18,7 @@ const float *__restrict__ topk_weights_ptr, int top_k, \ bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ int prob_n, int prob_k, int *locks, bool use_atomic_add, \ - bool use_fp32_reduce + bool use_fp32_reduce, int max_shared_mem namespace MARLIN_NAMESPACE_NAME { template shared // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks, // number of consecutive 16x16 blocks - // with a separate quantization scale - const bool is_zp_float // is zero point of float16 type? + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? > __global__ void Marlin(MARLIN_KERNEL_PARAMS); diff --git a/csrc/moe/marlin_moe_wna16/marlin_template.h b/csrc/moe/marlin_moe_wna16/marlin_template.h index 3705216cada..c9e199bcea1 100644 --- a/csrc/moe/marlin_moe_wna16/marlin_template.h +++ b/csrc/moe/marlin_moe_wna16/marlin_template.h @@ -25,6 +25,7 @@ #include "quantization/gptq_marlin/marlin.cuh" #include "quantization/gptq_marlin/marlin_dtypes.cuh" +#include "quantization/gptq_marlin/dequant.h" #include "core/scalar_type.hpp" #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ @@ -48,11 +49,9 @@ template shared // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks, // number of consecutive 16x16 blocks - // with a separate quantization scale - const bool is_zp_float // is zero point of float16 type? + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? > __global__ void Marlin( const int4* __restrict__ A, // fp16 input matrix of shape mxk @@ -77,8 +76,8 @@ __global__ void Marlin( int prob_k, // reduction dimension k int* locks, // extra global storage for barrier synchronization bool use_atomic_add, // whether to use atomic add to reduce - bool use_fp32_reduce // whether to use fp32 global reduce -) {} + bool use_fp32_reduce, // whether to use fp32 global reduce + int max_shared_mem) {} } // namespace MARLIN_NAMESPACE_NAME @@ -166,144 +165,6 @@ __device__ inline void ldsm(typename ScalarType::FragA& frag_a, } } -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -// Constructs destination register by taking bytes from 2 sources (based on -// mask) -template -__device__ inline uint32_t prmt(uint32_t a) { - uint32_t res; - asm volatile("prmt.b32 %0, %1, %2, %3;\n" - : "=r"(res) - : "r"(a), "n"(start_byte), "n"(mask)); - return res; -} - -template -__device__ inline typename ScalarType::FragB dequant( - int q, typename ScalarType::FragB& frag_b); - -// -// Efficiently dequantize 4bit values packed in an int32 value into a full -// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, -// with some small changes: -// - FP16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 -// - BF16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 -// -template <> -__device__ inline typename ScalarType::FragB dequant( - int q, typename ScalarType::FragB& frag_b) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - const int SUB = 0x64086408; - const int MUL = 0x2c002c00; - const int ADD = 0xd480d480; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -template <> -__device__ inline typename ScalarType::FragB -dequant(int q, - typename ScalarType::FragB& frag_b) { - static constexpr uint32_t MASK = 0x000f000f; - static constexpr uint32_t EX = 0x43004300; - - // Guarantee that the `(a & b) | c` operations are LOP3s. - - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); - q >>= 4; - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); - - static constexpr uint32_t MUL = 0x3F803F80; - static constexpr uint32_t ADD = 0xC308C308; - - frag_b[0] = __hfma2(*reinterpret_cast(&lo), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -// -// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or -// bf16 Reference: -// - FP16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 -// - BF16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 -// -template <> -__device__ inline typename ScalarType::FragB dequant( - int q, typename ScalarType::FragB& frag_b) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); - - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - return frag_b; -} - -template <> -__device__ inline typename ScalarType::FragB -dequant(int q, - typename ScalarType::FragB& frag_b) { - float fp32_intermediates[4]; - uint32_t* fp32_intermediates_casted = - reinterpret_cast(fp32_intermediates); - - static constexpr uint32_t fp32_base = 0x4B000000; - fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); - fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); - fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); - fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); - - fp32_intermediates[0] -= 8388736.f; - fp32_intermediates[1] -= 8388736.f; - fp32_intermediates[2] -= 8388736.f; - fp32_intermediates[3] -= 8388736.f; - - uint32_t* bf16_result_ptr = reinterpret_cast(&frag_b); - bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], - fp32_intermediates_casted[1], 0x7632); - bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], - fp32_intermediates_casted[3], 0x7632); - - return frag_b; -} - // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. template @@ -429,11 +290,9 @@ template shared // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks, // number of consecutive 16x16 blocks - // with a separate quantization scale - const bool is_zp_float // is zero point of float16 type? + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? > __global__ void Marlin( const int4* __restrict__ A, // fp16 input matrix of shape mxk @@ -458,8 +317,8 @@ __global__ void Marlin( int prob_k, // reduction dimension k int* locks, // extra global storage for barrier synchronization bool use_atomic_add, // whether to use atomic add to reduce - bool use_fp32_reduce // whether to use fp32 global reduce -) { + bool use_fp32_reduce, // whether to use fp32 global reduce + int max_shared_mem) { // Each threadblock processes one "stripe" of the B matrix with (roughly) the // same size, which might involve multiple column "slices" (of width 16 * // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM @@ -481,6 +340,8 @@ __global__ void Marlin( extern __shared__ int4 sh[]; static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); + constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8; + constexpr bool has_act_order = group_blocks == 0; constexpr int pack_factor = 32 / w_type.size_bits(); static_assert(thread_m_blocks == 1 || !m_block_size_8); @@ -534,13 +395,20 @@ __global__ void Marlin( int64_t B_expert_off = 0; int4* sh_block_sorted_ids_int4 = sh; + int4* sh_rd_block_sorted_ids_int4 = + sh_block_sorted_ids_int4 + moe_block_size / 4; + int4* sh_block_topk_weights_int4 = + sh_rd_block_sorted_ids_int4 + moe_block_size / 4; + // sh_block_topk_weights_int4 only need (moe_block_size / 4); + // but we pad to align to 256 bytes + int4* sh_new = + sh_block_topk_weights_int4 + moe_block_size / 2 + moe_block_size; int32_t* sh_block_sorted_ids = reinterpret_cast(sh_block_sorted_ids_int4); - int4* sh_block_topk_weights_int4 = - sh_block_sorted_ids_int4 + moe_block_size / 4; + int32_t* sh_rd_block_sorted_ids = + reinterpret_cast(sh_rd_block_sorted_ids_int4); scalar_t2* sh_block_topk_weights = reinterpret_cast(sh_block_topk_weights_int4); - int4* sh_new = sh_block_topk_weights_int4 + moe_block_size / 4; int32_t block_num_valid_tokens = 0; int32_t locks_off = 0; @@ -584,6 +452,11 @@ __global__ void Marlin( sh_block_sorted_ids_int4[tid4] = reinterpret_cast( sorted_token_ids_ptr)[block_id * moe_block_size / 4 + tid4]; + #pragma unroll + for (int i = 0; i < 4; i++) + sh_rd_block_sorted_ids[tid4 * 4 + i] = + sh_block_sorted_ids[tid4 * 4 + i] / top_k; + if (mul_topk_weights) { #pragma unroll for (int i = 0; i < 4; i++) { @@ -743,6 +616,7 @@ __global__ void Marlin( constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; // constexpr int act_s_row_stride = 1; // int act_s_col_stride = act_s_row_stride * num_groups; + constexpr int act_s_max_num_groups = 32; int act_s_col_stride = 1; int act_s_col_warp_stride = act_s_col_stride * 8; int tb_n_warps = thread_n_blocks / 4; @@ -758,9 +632,9 @@ __global__ void Marlin( int zp_gl_rd_delta = zp_gl_stride; // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; + int a_gl_rd_row = threadIdx.x / a_gl_rd_delta_o; + int a_gl_rd_col = a_gl_rd_delta_o * slice_row + threadIdx.x % a_gl_rd_delta_o; + // Shared write index of current thread. int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); @@ -774,8 +648,8 @@ __global__ void Marlin( (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; b_gl_rd += b_sh_stride * slice_col; b_gl_rd += b_gl_rd_delta_o * slice_row; - int b_sh_wr = threadIdx.x * b_thread_vecs; - int b_sh_rd = threadIdx.x * b_thread_vecs; + auto b_sh_wr = threadIdx.x * b_thread_vecs; + auto b_sh_rd = threadIdx.x * b_thread_vecs; // For act_order constexpr int k_iter_size = tb_k / b_sh_wr_iters; @@ -794,7 +668,7 @@ __global__ void Marlin( s_sh_stride * slice_col + threadIdx.x; } } - int s_sh_wr = threadIdx.x; + auto s_sh_wr = threadIdx.x; bool s_sh_wr_pred = threadIdx.x < s_sh_stride; // Zero-points @@ -807,7 +681,7 @@ __global__ void Marlin( zp_sh_stride * slice_col + threadIdx.x; } } - int zp_sh_wr = threadIdx.x; + auto zp_sh_wr = threadIdx.x; bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; // We use a different scale layout for grouped and column-wise quantization as @@ -851,7 +725,7 @@ __global__ void Marlin( // each warp must also write a consecutive memory segment? auto transform_a = [&](int i) { int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ (row % 8); }; // Since the computation of this remapping is non-trivial and, due to our main // loop unrolls, all shared memory accesses are static, we simply precompute @@ -879,12 +753,28 @@ __global__ void Marlin( B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; // Shared memory storage for global fetch pipelines. - int4* sh_a = sh_new; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_g_idx = sh_b + (stages * b_sh_stage); + constexpr int sh_red_size = (2 * thread_n_blocks + 1) * 16 * thread_m_blocks; + constexpr int sh_b_size = stages * b_sh_stage; + int4* sh_b = sh_new; + int4* sh_red = sh_new; + int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); int4* sh_zp = sh_g_idx + (stages * g_idx_stage); + constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride) + : (stages * s_sh_stage); int4* sh_s = sh_zp + (stages * zp_sh_stage); - int4* sh_red = sh_b; + // shared memory reused by reduction should be smaller than + // shared memory used by weight. + static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <= + stages * b_sh_stage); + int4* sh_a = sh_s + sh_s_size; + constexpr int shm_size_used = + moe_block_size + stages * (g_idx_stage + zp_sh_stage) + sh_s_size + + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + + // all remaining shared memory is used to cache A (input) + // sh_a_max_row is at least ` stages * 16 * thread_m_blocks ` + int sh_a_max_row = + ((max_shared_mem - 1024) / 16 - shm_size_used) / (thread_k_blocks * 2); // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; @@ -905,15 +795,14 @@ __global__ void Marlin( int sh_first_group_id = -1; int sh_num_groups = -1; - constexpr int sh_max_num_groups = 32; auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id, int last_group_id) { sh_first_group_id = first_group_id; sh_num_groups = last_group_id - first_group_id + 1; - if (sh_num_groups < sh_max_num_groups) { - sh_num_groups = sh_max_num_groups; + if (sh_num_groups < act_s_max_num_groups) { + sh_num_groups = act_s_max_num_groups; } if (sh_first_group_id + sh_num_groups > num_groups) { @@ -940,27 +829,31 @@ __global__ void Marlin( } } }; + // Asynchronously fetch the next A, B and s tile from global to the next // shared memory pipeline location. - int a_remaining_load_count_in_slice = stages; - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + bool should_load_a = true; + int max_num_stage_groups = + ((sh_a_max_row - moe_block_size) / moe_block_size + 1) / stages; + max_num_stage_groups = max(max_num_stage_groups, 1); + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true, + int pipe_a = 0) { if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - if (prob_k > thread_k_blocks * 16 * stages || slice_col == 0 || - a_remaining_load_count_in_slice > 0) { - a_remaining_load_count_in_slice--; + if (should_load_a) { + int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a; #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) { - int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; - int row = a_idx / a_gl_stride; + int row = a_gl_rd_delta_i / a_gl_stride * i + a_gl_rd_row; int64_t sorted_row = 0; if (!m_block_size_8 || row < 8) - sorted_row = sh_block_sorted_ids[row] / top_k; - int64_t true_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; + sorted_row = sh_rd_block_sorted_ids[row]; + int64_t true_idx = + sorted_row * a_gl_stride + a_gl_rd_col + a_gl_rd_delta_o * a_off; cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[true_idx], row < block_num_valid_tokens); } } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { @@ -1063,8 +956,8 @@ __global__ void Marlin( // Load the next sub-tile from the current location in the shared memory pipe // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; + auto fetch_to_registers = [&](int k, int pipe, int pipe_a = 0) { + int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a; #pragma unroll for (int i = 0; i < thread_m_blocks; i++) ldsm( @@ -1109,12 +1002,17 @@ __global__ void Marlin( } } else if constexpr (group_blocks != -1) { if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + if (k % b_sh_wr_iters == 0) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + reinterpret_cast(&frag_s[1])[0] = + reinterpret_cast(&frag_s[0])[0]; + } } else { - int warp_id = threadIdx.x / 32; + auto warp_id = threadIdx.x / 32; int n_warps = thread_n_blocks / 4; int warp_row = warp_id / n_warps; @@ -1152,7 +1050,7 @@ __global__ void Marlin( // Determine "position" inside the thread-block (based on warp and // thread-id) - int warp_id = threadIdx.x / 32; + auto warp_id = threadIdx.x / 32; int n_warps = thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N @@ -1161,7 +1059,7 @@ __global__ void Marlin( cur_k += warp_row * 16; - int th_id = threadIdx.x % 32; + auto th_id = threadIdx.x % 32; cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix int s_col_shift = @@ -1222,15 +1120,18 @@ __global__ void Marlin( } } else if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_zp_stage = - sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = - (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + if (k % b_sh_wr_iters == 0) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + #pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } } } else { - int warp_id = threadIdx.x / 32; + auto warp_id = threadIdx.x / 32; int n_warps = thread_n_blocks / 4; int warp_row = warp_id / n_warps; @@ -1251,6 +1152,7 @@ __global__ void Marlin( sh_zp_stage += cur_group_id * zp_sh_stride; + #pragma unroll for (int i = 0; i < num_ints_per_thread; i++) { frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; @@ -1263,12 +1165,16 @@ __global__ void Marlin( if constexpr (group_blocks != -1) { if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_zp_stage = - sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd]; + if (k % b_sh_wr_iters == 0) { + int4* sh_zp_stage = + sh_zp + + zp_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_zpf[k % 2])[0] = + sh_zp_stage[zp_sh_rd]; + } } else { - int warp_id = threadIdx.x / 32; + auto warp_id = threadIdx.x / 32; int n_warps = thread_n_blocks / 4; int warp_row = warp_id / n_warps; @@ -1292,6 +1198,25 @@ __global__ void Marlin( } }; + auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) { + if constexpr (has_zp && is_zp_float || !has_zp) { + dequant(q, frag_b_ptr); + } else { + static_assert(has_zp && !is_zp_float); + static_assert(w_type_id == vllm::kU4.id() || w_type_id == vllm::kU8.id()); + // If (has_zp && !is_zp_float), + // we use not-zp version `dequant` function + // to improve numerical accuracy. + // Since both weight and zero point are dequanted using this logic, + // the final dequanted weight would be correct. + if constexpr (w_type_id == vllm::kU4.id()) { + dequant(q, frag_b_ptr); + } else if constexpr (w_type_id == vllm::kU8.id()) { + dequant(q, frag_b_ptr); + } + } + }; + // Execute the actual tensor core matmul of a sub-tile. bool is_first_matmul_in_slice = true; auto matmul = [&](int k) { @@ -1315,15 +1240,17 @@ __global__ void Marlin( zp_quant_1 = frag_qzp[k2][1]; } - dequant(zp_quant_0, frag_zp_0); - dequant(zp_quant_1, frag_zp_1); - - frag_zp[0] = frag_zp_0[0]; - frag_zp[1] = frag_zp_0[1]; - frag_zp[2] = frag_zp_1[0]; - frag_zp[3] = frag_zp_1[1]; + dequant_data(zp_quant_0, reinterpret_cast(&frag_zp)); + dequant_data(zp_quant_1, reinterpret_cast(&frag_zp) + 2); } } + if constexpr (has_zp && is_zp_float) { + if (is_new_zp) { + reinterpret_cast(&frag_zp)[0] = + reinterpret_cast(&frag_zpf[k2])[0]; + } + } + // We have the m dimension as the inner loop in order to encourage overlapping // dequantization and matmul operations. #pragma unroll @@ -1342,8 +1269,8 @@ __global__ void Marlin( b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; } - dequant(b_quant_0, frag_b0); - dequant(b_quant_1, frag_b1); + dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); + dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); // Apply scale to frag_b0 if constexpr (has_act_order) { @@ -1351,8 +1278,7 @@ __global__ void Marlin( scale4(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); scale4(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], - act_frag_s[k][2][j], act_frag_s[k2][3][j], 1); - + act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); } else if constexpr (has_zp && !is_zp_float && group_blocks == -1) { int idx = (threadIdx.x / 4) % 2; scalar_t2 s2 = Dtype::nums2num2( @@ -1361,18 +1287,12 @@ __global__ void Marlin( if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2); scale_and_sub(frag_b0, s2.x, frag_zp[j].x); scale_and_sub(frag_b1, s2.y, frag_zp[j].y); - } else if constexpr (has_zp && !is_zp_float && group_blocks != -1) { + } else if constexpr (has_zp && group_blocks != -1) { if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], *reinterpret_cast(&frag_s[k2][j])); - scale_and_sub(frag_b0, frag_s[k % 2][j][0].x, frag_zp[j].x); - scale_and_sub(frag_b1, frag_s[k % 2][j][0].y, frag_zp[j].y); - } else if constexpr (has_zp && is_zp_float && group_blocks != -1) { - if (is_new_zp) - frag_zpf[k2][j] = __hmul2( - frag_zpf[k2][j], *reinterpret_cast(&frag_s[k2][j])); - scale_and_sub(frag_b0, frag_s[k2][j].x, frag_zpf[k2][j].x); - scale_and_sub(frag_b1, frag_s[k2][j].y, frag_zpf[k2][j].y); + scale_and_sub(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x); + scale_and_sub(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y); } else if constexpr (group_blocks != -1) { scale(frag_b0, frag_s[k2][j], 0); scale(frag_b1, frag_s[k2][j], 1); @@ -1397,7 +1317,7 @@ __global__ void Marlin( auto thread_block_reduce = [&]() { constexpr int red_off = threads / b_sh_stride_threads / 2; if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride_threads; + auto red_idx = threadIdx.x / b_sh_stride_threads; constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; constexpr int red_sh_delta = b_sh_stride_threads; int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + @@ -1731,7 +1651,7 @@ __global__ void Marlin( fetch_col_scale_to_shared(); } } - fetch_to_shared(i, i, i < slice_iters); + fetch_to_shared(i, i, i < slice_iters, i); } zero_accums(); @@ -1740,8 +1660,10 @@ __global__ void Marlin( fetch_to_registers(0, 0); fetch_scales_to_registers(0, 0); fetch_zp_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - slice_k_start_shared_fetch += tb_k * (stages - 1); + a_gl_rd_col += a_gl_rd_delta_o * (stages - 1); + if constexpr (has_act_order) { + slice_k_start_shared_fetch += tb_k * (stages - 1); + } }; if (slice_iters) { start_pipes(); @@ -1754,43 +1676,56 @@ __global__ void Marlin( // have even length meaning that the next iteration will always start at // index 0. + for (int stage_group_id = 0; stage_group_id < max_num_stage_groups; + stage_group_id++) { #pragma unroll - for (int pipe = 0; pipe < stages;) { + for (int pipe = 0; pipe < stages;) { #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - fetch_scales_to_registers(k + 1, pipe); - fetch_zp_to_registers(k + 1, pipe); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - init_same_group(pipe % stages); + for (int k = 0; k < b_sh_wr_iters; k++) { + int idx = + (pipe >= stages && stage_group_id == max_num_stage_groups - 1) + ? (pipe - stages) + : (pipe + stage_group_id * stages); + fetch_to_registers(k + 1, pipe % stages, idx); + fetch_scales_to_registers(k + 1, pipe); + fetch_zp_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + int idx = (pipe >= 1 && stage_group_id == max_num_stage_groups - 1) + ? (pipe - 1) + : (pipe + (stage_group_id + 1) * stages - 1); + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages, idx); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) { - break; } - } - a_remaining_load_count_in_slice = 0; - a_gl_rd += a_gl_rd_delta_o * stages; - slice_k_start += tb_k * stages; - slice_k_start_shared_fetch += tb_k * stages; + a_gl_rd_col += a_gl_rd_delta_o * stages; - if constexpr (has_act_order) { - int first_group_id = g_idx[slice_k_start]; - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; + if constexpr (has_act_order) { + slice_k_start += tb_k * stages; + slice_k_start_shared_fetch += tb_k * stages; + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_act_order_scales_to_shared(false, first_group_id, + last_group_id); + __syncthreads(); + } } - int last_group_id = g_idx[last_g_idx]; - if (last_group_id >= sh_first_group_id + sh_num_groups) { - fetch_act_order_scales_to_shared(false, first_group_id, last_group_id); - __syncthreads(); + if (slice_iters == 0) { + break; } } @@ -1877,15 +1812,30 @@ __global__ void Marlin( if (last || use_atomic_add) // only the last block in a slice actually writes the result write_result(); - if (slice_row) a_remaining_load_count_in_slice = stages; + int old_slice_row = slice_row; slice_row = 0; slice_col_par++; slice_col++; is_first_matmul_in_slice = true; init_slice(); + + // Should we load A matrix in next slice? + // `slice_col == 0`: when move to a new moe block + // `old_slice_row > 0`: + // when the last slice is not starting from k_index == 0 + // (only happen when it is the first slice of a threadblock) + // `prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups`: + // when the required shared memory size is larger than + // the remaining shared memory + if (slice_col == 0 || old_slice_row || + prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups) { + should_load_a = true; + } else { + should_load_a = false; + } + if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd_col = (threadIdx.x % a_gl_rd_delta_o); #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; @@ -1900,12 +1850,10 @@ __global__ void Marlin( slice_k_finish = slice_k_start + tb_k * slice_iters; slice_k_start_shared_fetch = slice_k_start; slice_n_offset = act_s_col_tb_stride * slice_col; - } else { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; } - start_pipes(); } } diff --git a/csrc/moe/marlin_moe_wna16/ops.cu b/csrc/moe/marlin_moe_wna16/ops.cu index a16e955a325..00b4e934cc3 100644 --- a/csrc/moe/marlin_moe_wna16/ops.cu +++ b/csrc/moe/marlin_moe_wna16/ops.cu @@ -116,7 +116,7 @@ __global__ void permute_cols_kernel( int base_k = 0; for (int i = 0; i < iters; i++) { - int cur_k = base_k + threadIdx.x; + auto cur_k = base_k + threadIdx.x; int src_pos = perm_int_ptr[cur_k]; out_half[cur_k] = a_row_half[src_pos]; @@ -126,7 +126,7 @@ __global__ void permute_cols_kernel( if (rest) { if (threadIdx.x < rest) { - int cur_k = base_k + threadIdx.x; + auto cur_k = base_k + threadIdx.x; int src_pos = perm_int_ptr[cur_k]; out_half[cur_k] = a_row_half[src_pos]; @@ -195,7 +195,6 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m, tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K load_groups = max(load_groups, 32); // We load at least 32 scale groups return load_groups * tb_n * 2; - } else { int tb_scales = tb_groups * tb_n * 2; @@ -203,22 +202,24 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m, } } -int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks, - int prob_m, int prob_n, int prob_k, int num_bits, - int group_size, bool has_act_order, bool is_k_full, - int has_zp, int is_zp_float) { +int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8, + int thread_m_blocks, int prob_m, int prob_n, + int prob_k, int num_bits, int group_size, + bool has_act_order, bool is_k_full, int has_zp, + int is_zp_float) { int pack_factor = 32 / num_bits; // Get B size int tb_k = th_config.thread_k; int tb_n = th_config.thread_n; - int tb_m = thread_m_blocks * 16; + int tb_m = thread_m_blocks * (m_block_size_8 ? 8 : 16); - // shm size for block_sorted_ids/block_topk_weights + // shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights // both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32) - int sh_block_meta_size = tb_m * 4 * 2; + int sh_block_meta_size = tb_m * 4; int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; + int sh_red_size = tb_m * (tb_n + 8) * 2; int sh_s_size = get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full); @@ -233,16 +234,17 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks, sh_zp_size = sh_s_size / 2; } - int total_size = sh_a_size + sh_b_size + sh_s_size + sh_zp_size + - sh_g_idx_size + sh_block_meta_size; + int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size + + sh_zp_size + sh_g_idx_size + sh_block_meta_size; return total_size; } -bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, - int prob_m, int prob_n, int prob_k, int num_bits, - int group_size, bool has_act_order, bool is_k_full, - int has_zp, int is_zp_float, int max_shared_mem) { +bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8, + int thread_m_blocks, int prob_m, int prob_n, int prob_k, + int num_bits, int group_size, bool has_act_order, + bool is_k_full, int has_zp, int is_zp_float, + int max_shared_mem) { // Sanity if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { @@ -266,143 +268,113 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, // Check that pipeline fits into cache int cache_size = get_kernel_cache_size( - th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, - has_act_order, is_k_full, has_zp, is_zp_float); + th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float); return cache_size <= max_shared_mem; } - #define __GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - M_BLOCK_SIZE_8, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ - NUM_THREADS, IS_ZP_FLOAT) \ - else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - m_block_size_8 == M_BLOCK_SIZE_8 && \ - has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ - is_zp_float == IS_ZP_FLOAT) { \ - kernel = Marlin; \ + #define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ + else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + m_block_size_8 == M_BLOCK_SIZE_8 && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ + is_zp_float == IS_ZP_FLOAT) { \ + kernel = Marlin; \ } - #define GPTQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, true, false, 0, NUM_THREADS, \ - false) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, false, 0, \ - NUM_THREADS, false) \ - \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, -1, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 2, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 4, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 8, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, -1, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 2, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 4, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 8, \ - NUM_THREADS, false) - - #define GPTQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, false, 0, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, false, 0, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, false, 0, \ - NUM_THREADS, false) \ - \ - __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, -1, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 2, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 4, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 8, \ - NUM_THREADS, false) \ - \ - __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, -1, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 2, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 4, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 8, \ - NUM_THREADS, false) \ - \ - __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, -1, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 2, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 4, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 8, \ - NUM_THREADS, false) - - #define AWQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, -1, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 2, NUM_THREADS, \ - false) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \ - false) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 8, NUM_THREADS, \ - false) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, -1, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 2, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 8, \ - NUM_THREADS, false) - - #define AWQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, -1, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 2, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 8, \ - NUM_THREADS, false) \ - \ - __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, -1, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 2, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 8, \ - NUM_THREADS, false) \ - \ - __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, -1, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 2, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 8, \ - NUM_THREADS, false) + // COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false) + // this is the most common cases + // BIGGROUP: cases for big group size (group_blocks in [-1, 8]) + // FZP: cases for float-zero-point (is_zp_float = true) + // ACT: cases for act order case (group_blocks == 0) + #define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + + #define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + + #define COMMON_GET_IF(W_TYPE) \ + COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \ + COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \ + COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \ + COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) + + #define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + + #define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + + #define BIGGROUP_GET_IF(W_TYPE) \ + BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ + BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ + BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \ + BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) // We currently have 4-bit models only with group_blocks == 4 - #define HQQ_GET_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \ - true) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ - NUM_THREADS, true) \ - __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ - NUM_THREADS, true) \ - __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ - NUM_THREADS, true) \ - __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ - NUM_THREADS, true) + #define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) + + #define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) + + #define FZP_GET_IF(W_TYPE) \ + FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FZP_GET_IF_M234(W_TYPE, 8, 4, 128) + + // We currently have 4-bit models only with group_blocks == 4 + #define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) + + #define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) + + #define ACT_GET_IF(W_TYPE) \ + ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \ + ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \ + ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \ + ACT_GET_IF_M234(W_TYPE, 8, 4, 128) template MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, @@ -415,23 +387,15 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, auto kernel = MarlinDefault; if (false) { } - GPTQ_GET_IF_M1(vllm::kU4B8, 8, 8, 256) - GPTQ_GET_IF_M1(vllm::kU4B8, 8, 4, 128) - GPTQ_GET_IF_M234(vllm::kU4B8, 16, 4, 256) - GPTQ_GET_IF_M234(vllm::kU4B8, 8, 4, 128) + COMMON_GET_IF(vllm::kU4) + COMMON_GET_IF(vllm::kU4B8) + COMMON_GET_IF(vllm::kU8B128) - GPTQ_GET_IF_M1(vllm::kU8B128, 8, 8, 256) - GPTQ_GET_IF_M1(vllm::kU8B128, 8, 4, 128) + BIGGROUP_GET_IF(vllm::kFE4M3fn) - GPTQ_GET_IF_M234(vllm::kU8B128, 16, 4, 256) - GPTQ_GET_IF_M234(vllm::kU8B128, 8, 4, 128) - - AWQ_GET_IF_M1(vllm::kU4, 8, 8, 256) - AWQ_GET_IF_M1(vllm::kU4, 8, 4, 128) - - AWQ_GET_IF_M234(vllm::kU4, 16, 4, 256) - AWQ_GET_IF_M234(vllm::kU4, 8, 4, 128) + ACT_GET_IF(vllm::kU4B8) + ACT_GET_IF(vllm::kU8B128) return kernel; } @@ -457,19 +421,19 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, for (int i = 0; i < thread_configs_size; i++) { thread_config_t th_config = thread_configs[i]; - if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, has_act_order, is_k_full, has_zp, - is_zp_float, max_shared_mem)) { + if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m, + prob_n, prob_k, num_bits, group_size, has_act_order, + is_k_full, has_zp, is_zp_float, max_shared_mem)) { continue; } int cache_size = get_kernel_cache_size( - th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, - group_size, has_act_order, is_k_full, has_zp, is_zp_float); + th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float); int group_blocks = 0; if (!has_act_order) { - group_blocks = group_size == -1 ? -1 : group_size / 16; + group_blocks = group_size == -1 ? -1 : (group_size / 16); } auto kernel = get_marlin_kernel( @@ -515,14 +479,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, bool m_block_size_8 = moe_block_size == 8; if (has_zp) { - TORCH_CHECK( - q_type == vllm::kU4 || q_type == vllm::kU8, - "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); + TORCH_CHECK(q_type == vllm::kU4, + "q_type must be u4 when has_zp = True. Got = ", q_type.str()); } else { - TORCH_CHECK( - q_type == vllm::kU4B8 || q_type == vllm::kU8B128, - "q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ", - q_type.str()); + TORCH_CHECK(q_type == vllm::kU4B8 || q_type == vllm::kU8B128 || + q_type == vllm::kFE4M3fn, + "q_type must be uint4b8, uint8b128 or fp8e4m3 when has_zp = " + "False. Got = ", + q_type.str()); } TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, @@ -631,18 +595,18 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, int thread_k_blocks = thread_k / 16; int thread_n_blocks = thread_n / 16; - TORCH_CHECK(is_valid_config(thread_tfg, thread_m_blocks, prob_m, prob_n, - prob_k, num_bits, group_size, has_act_order, - is_k_full, has_zp, is_zp_float, max_shared_mem), - "Invalid thread config: thread_m_blocks = ", thread_m_blocks, - ", thread_k = ", thread_tfg.thread_k, - ", thread_n = ", thread_tfg.thread_n, - ", num_threads = ", thread_tfg.num_threads, " for MKN = [", - prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, - ", group_size = ", group_size, - ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, - ", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float, - ", max_shared_mem = ", max_shared_mem); + TORCH_CHECK( + is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks, prob_m, + prob_n, prob_k, num_bits, group_size, has_act_order, + is_k_full, has_zp, is_zp_float, max_shared_mem), + "Invalid thread config: thread_m_blocks = ", thread_m_blocks, + ", thread_k = ", thread_tfg.thread_k, + ", thread_n = ", thread_tfg.thread_n, + ", num_threads = ", thread_tfg.num_threads, " for MKN = [", prob_m, ", ", + prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + ", group_size = ", group_size, ", has_act_order = ", has_act_order, + ", is_k_full = ", is_k_full, ", has_zp = ", has_zp, + ", is_zp_float = ", is_zp_float, ", max_shared_mem = ", max_shared_mem); auto kernel = get_marlin_kernel( q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8, @@ -666,7 +630,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr, topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m, - prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce); + prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce, max_shared_mem); // clang-format on } @@ -841,10 +805,11 @@ torch::Tensor moe_wna16_marlin_gemm( b_q_type == vllm::kU4, "b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str()); } else { - TORCH_CHECK( - b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128, - "b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ", - b_q_type.str()); + TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 || + b_q_type == vllm::kFE4M3fn, + "b_q_type must be uint4b8, uint8b128 or fp8e4m3 when has_zp = " + "False. Got = ", + b_q_type.str()); } if (has_zp && is_zp_float) { diff --git a/csrc/quantization/gptq_marlin/.gitignore b/csrc/quantization/gptq_marlin/.gitignore new file mode 100644 index 00000000000..77088552b85 --- /dev/null +++ b/csrc/quantization/gptq_marlin/.gitignore @@ -0,0 +1 @@ +kernel_*.cu \ No newline at end of file diff --git a/csrc/quantization/gptq_marlin/dequant.h b/csrc/quantization/gptq_marlin/dequant.h new file mode 100644 index 00000000000..3c0d77ac345 --- /dev/null +++ b/csrc/quantization/gptq_marlin/dequant.h @@ -0,0 +1,291 @@ + +#include "marlin_dtypes.cuh" + +namespace MARLIN_NAMESPACE_NAME { + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +template +__device__ inline void dequant(int q, scalar_t2* frag_b); + +// +// Efficiently dequantize 4bit values packed in an int32 value into a full +// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, +// with some small changes: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 +// +template <> +__device__ inline void dequant(int q, half2* frag_b) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // clang-format on + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // clang-format on + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64006400; + const int MUL = 0x2c002c00; + const int ADD = 0xd400d400; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + // clang-format on + + static constexpr uint32_t MUL = 0x3F803F80; + static constexpr uint32_t ADD = 0xC308C308; + + frag_b[0] = __hfma2(*reinterpret_cast(&lo), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + // clang-format on + + static constexpr uint32_t MUL = 0x3F803F80; + static constexpr uint32_t ADD = 0xC300C300; + + frag_b[0] = __hfma2(*reinterpret_cast(&lo), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); +} + +// +// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or +// bf16 Reference: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 +// +template <> +__device__ inline void dequant(int q, + half2* frag_b) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; + + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388736.f; + fp32_intermediates[1] -= 8388736.f; + fp32_intermediates[2] -= 8388736.f; + fp32_intermediates[3] -= 8388736.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], + fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], + fp32_intermediates_casted[3], 0x7632); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388608.f; + fp32_intermediates[1] -= 8388608.f; + fp32_intermediates[2] -= 8388608.f; + fp32_intermediates[3] -= 8388608.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], + fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], + fp32_intermediates_casted[3], 0x7632); +} + +template <> +__device__ inline void dequant(int q, + half2* frag_b) { + // Constants for FP8 (E4M3) and FP16 formats + constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, FP16_EXPONENT = 5; + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; + + // Calculate MASK for extracting mantissa and exponent + constexpr int MASK1 = 0x80000000; + constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); + constexpr int MASK3 = MASK2 & 0x7fffffff; + constexpr int MASK = MASK3 | (MASK3 >> 16); + // Final MASK value: 0x7F007F00 + + // Extract and shift FP8 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = + (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + + // Convert to half2 and apply bias + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = __hmul2(*reinterpret_cast(&Out1), bias_reg); + frag_b[0] = __hmul2(*reinterpret_cast(&Out2), bias_reg); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + // Constants for FP8 (E4M3) and BF16 formats + constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; + + // Calculate MASK for extracting mantissa and exponent + constexpr int MASK1 = 0x80000000; + constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); + constexpr int MASK3 = MASK2 & 0x7fffffff; + constexpr int MASK = MASK3 | (MASK3 >> 16); + // Final MASK value: 0x7F007F00 + + // Extract and shift FP8 values to BF16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = + (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent + // position + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const nv_bfloat162 bias_reg = + __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + + // Convert to bfloat162 and apply bias + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = __hmul2(*reinterpret_cast(&Out1), bias_reg); + frag_b[0] = __hmul2(*reinterpret_cast(&Out2), bias_reg); +} + +#endif + +} // namespace MARLIN_NAMESPACE_NAME diff --git a/csrc/quantization/gptq_marlin/generate_kernels.py b/csrc/quantization/gptq_marlin/generate_kernels.py new file mode 100644 index 00000000000..8b4b951f3d8 --- /dev/null +++ b/csrc/quantization/gptq_marlin/generate_kernels.py @@ -0,0 +1,116 @@ +# SPDX-License-Identifier: Apache-2.0 +import glob +import itertools +import os +import subprocess + +import jinja2 + +FILE_HEAD = """ +// auto generated by generate.py +// clang-format off + +#include "kernel.h" +#include "marlin_template.h" + +namespace MARLIN_NAMESPACE_NAME { +""".strip() + +TEMPLATE = ("template __global__ void Marlin<" + "{{scalar_t}}, " + "{{w_type_id}}, " + "{{threads}}, " + "{{thread_m_blocks}}, " + "{{thread_n_blocks}}, " + "{{thread_k_blocks}}, " + "{{'true' if m_block_size_8 else 'false'}}, " + "{{stages}}, " + "{{group_blocks}}, " + "{{'true' if is_zp_float else 'false'}}>" + "( MARLIN_KERNEL_PARAMS );") + +# int8 with zero point case (vllm::kU8) is also supported, +# we don't add it to reduce wheel size. +SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn"] +THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), + (128, 64, 128)] + +THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] +# group_blocks: +# = 0 : act order case +# = -1 : channelwise quantization +# > 0 : group_size=16*group_blocks +GROUP_BLOCKS = [0, -1, 2, 4, 8] +DTYPES = ["fp16", "bf16"] + + +def remove_old_kernels(): + for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"): + subprocess.call(["rm", "-f", filename]) + + +def generate_new_kernels(): + for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): + all_template_str_list = [] + + for group_blocks, m_blocks, thread_configs in itertools.product( + GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS): + + # act order case only support gptq-int4 and gptq-int8 + if group_blocks == 0 and scalar_type not in [ + "vllm::kU4B8", "vllm::kU8B128" + ]: + continue + if thread_configs[2] == 256: + # for small batch (m_blocks == 1), we only need (128, 128, 256) + # for large batch (m_blocks > 1), we only need (64, 256, 256) + if m_blocks <= 1 and thread_configs[0] != 128: + continue + if m_blocks > 1 and thread_configs[0] != 64: + continue + + # we only support channelwise quantization and group_size == 128 + # for fp8 + if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: + continue + + k_blocks = thread_configs[0] // 16 + n_blocks = thread_configs[1] // 16 + threads = thread_configs[2] + + c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" + + is_zp_float_list = [False] + if dtype == "fp16" and scalar_type == "vllm::kU4" and \ + group_blocks == 4: + # HQQ (is_zp_float = true) only supports + # 4bit quantization and fp16 + is_zp_float_list.append(True) + + for is_zp_float in is_zp_float_list: + template_str = jinja2.Template(TEMPLATE).render( + scalar_t=c_dtype, + w_type_id=scalar_type + ".id()", + threads=threads, + thread_m_blocks=max(m_blocks, 1), + thread_n_blocks=n_blocks, + thread_k_blocks=k_blocks, + m_block_size_8=m_blocks == 0.5, + stages="pipe_stages", + group_blocks=group_blocks, + is_zp_float=is_zp_float, + ) + + all_template_str_list.append(template_str) + + file_content = FILE_HEAD + "\n\n" + file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" + filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu" + + with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: + f.write(file_content) + + +if __name__ == "__main__": + remove_old_kernels() + generate_new_kernels() diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index a974c881eb8..02527a48166 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -19,10 +19,11 @@ * Adapted from https://github.com/IST-DASLab/marlin */ -#include "marlin.cuh" -#include "marlin_dtypes.cuh" -#include "core/scalar_type.hpp" +#ifndef MARLIN_NAMESPACE_NAME + #define MARLIN_NAMESPACE_NAME marlin +#endif +#include "kernel.h" #include "core/registration.h" #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ @@ -30,13 +31,12 @@ std::is_same::value, \ "only float16 and bfloat16 is supported"); -template -inline std::string str(T x) { - return std::to_string(x); -} - namespace marlin { +__global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){}; + +using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS); + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, @@ -44,46 +44,17 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, int4* __restrict__ out_int4_ptr, int size_m, int size_k, int lda, int block_rows) {} -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1, // number of consecutive 16x16 blocks - // with a separate quantization scale - const bool is_zp_float // is zero point of float16 type? - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int* __restrict__ g_idx, // int32 group indices of shape k - int num_groups, // number of scale groups per output channel - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks, // extra global storage for barrier synchronization - bool use_fp32_reduce // whether to use fp32 global reduce -) {} - } // namespace marlin -torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& b_zeros, - torch::Tensor& g_idx, torch::Tensor& perm, - torch::Tensor& workspace, - vllm::ScalarTypeId const b_q_type_id, - int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, bool has_zp, bool is_zp_float) { +torch::Tensor gptq_marlin_gemm( + torch::Tensor& a, std::optional c_or_none, + torch::Tensor& b_q_weight, torch::Tensor& b_scales, + std::optional const& b_zeros_or_none, + std::optional const& g_idx_or_none, + std::optional const& perm_or_none, torch::Tensor& workspace, + vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, + bool is_zp_float) { TORCH_CHECK_NOT_IMPLEMENTED(false, "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); return torch::empty({1, 1}); @@ -91,369 +62,6 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, #else -// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 -// output/accumulation. -template -__device__ inline void mma(const typename ScalarType::FragA& a_frag, - const typename ScalarType::FragB& frag_b, - typename ScalarType::FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); - } -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -template -__device__ inline void ldsm4(typename ScalarType::FragA& frag_a, - const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -// Constructs destination register by taking bytes from 2 sources (based on -// mask) -template -__device__ inline uint32_t prmt(uint32_t a) { - uint32_t res; - asm volatile("prmt.b32 %0, %1, %2, %3;\n" - : "=r"(res) - : "r"(a), "n"(start_byte), "n"(mask)); - return res; -} - -template -__device__ inline typename ScalarType::FragB dequant(int q); - -// -// Efficiently dequantize 4bit values packed in an int32 value into a full -// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, -// with some small changes: -// - FP16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 -// - BF16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 -// -template <> -__device__ inline typename ScalarType::FragB -dequant(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - const int SUB = 0x64086408; - const int MUL = 0x2c002c00; - const int ADD = 0xd480d480; - typename ScalarType::FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -template <> -__device__ inline typename ScalarType::FragB -dequant(int q) { - static constexpr uint32_t MASK = 0x000f000f; - static constexpr uint32_t EX = 0x43004300; - - // Guarantee that the `(a & b) | c` operations are LOP3s. - - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); - q >>= 4; - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); - - typename ScalarType::FragB frag_b; - static constexpr uint32_t MUL = 0x3F803F80; - static constexpr uint32_t ADD = 0xC308C308; - - frag_b[0] = __hfma2(*reinterpret_cast(&lo), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -template <> -__device__ inline typename ScalarType::FragB -dequant(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - - const int SUB = 0x64006400; - const int MUL = 0x2c002c00; - const int ADD = 0xd400d400; - typename ScalarType::FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -template <> -__device__ inline typename ScalarType::FragB -dequant(int q) { - static constexpr uint32_t MASK = 0x000f000f; - static constexpr uint32_t EX = 0x43004300; - - // Guarantee that the `(a & b) | c` operations are LOP3s. - - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); - q >>= 4; - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); - - typename ScalarType::FragB frag_b; - static constexpr uint32_t MUL = 0x3F803F80; - static constexpr uint32_t ADD = 0xC300C300; - - frag_b[0] = __hfma2(*reinterpret_cast(&lo), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -// -// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or -// bf16 Reference: -// - FP16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 -// - BF16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 -// -template <> -__device__ inline typename ScalarType::FragB -dequant(int q) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); - - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - - typename ScalarType::FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - return frag_b; -} - -template <> -__device__ inline typename ScalarType::FragB -dequant(int q) { - typename ScalarType::FragB frag_b; - - float fp32_intermediates[4]; - uint32_t* fp32_intermediates_casted = - reinterpret_cast(fp32_intermediates); - - static constexpr uint32_t fp32_base = 0x4B000000; - fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); - fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); - fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); - fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); - - fp32_intermediates[0] -= 8388736.f; - fp32_intermediates[1] -= 8388736.f; - fp32_intermediates[2] -= 8388736.f; - fp32_intermediates[3] -= 8388736.f; - - uint32_t* bf16_result_ptr = reinterpret_cast(&frag_b); - bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], - fp32_intermediates_casted[1], 0x7632); - bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], - fp32_intermediates_casted[3], 0x7632); - - return frag_b; -} - -template <> -__device__ inline typename ScalarType::FragB -dequant(int q) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); - - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; - - typename ScalarType::FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - return frag_b; -} - -template <> -__device__ inline typename ScalarType::FragB -dequant(int q) { - typename ScalarType::FragB frag_b; - - float fp32_intermediates[4]; - uint32_t* fp32_intermediates_casted = - reinterpret_cast(fp32_intermediates); - - static constexpr uint32_t fp32_base = 0x4B000000; - fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); - fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); - fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); - fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); - - fp32_intermediates[0] -= 8388608.f; - fp32_intermediates[1] -= 8388608.f; - fp32_intermediates[2] -= 8388608.f; - fp32_intermediates[3] -= 8388608.f; - - uint32_t* bf16_result_ptr = reinterpret_cast(&frag_b); - bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], - fp32_intermediates_casted[1], 0x7632); - bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], - fp32_intermediates_casted[3], 0x7632); - - return frag_b; -} - -// Multiply dequantized values by the corresponding quantization scale; used -// only for grouped quantization. -template -__device__ inline void scale(typename ScalarType::FragB& frag_b, - typename ScalarType::FragS& frag_s, - int i) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 s = - ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); - frag_b[0] = __hmul2(frag_b[0], s); - frag_b[1] = __hmul2(frag_b[1], s); -} - -template -__device__ inline void sub_zp(typename ScalarType::FragB& frag_b, - typename ScalarType::scalar_t2& frag_zp, - int i) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 zp = - ScalarType::num2num2(reinterpret_cast(&frag_zp)[i]); - frag_b[0] = __hsub2(frag_b[0], zp); - frag_b[1] = __hsub2(frag_b[1], zp); -} - -// Same as above, but for act_order (each K is multiplied individually) -template -__device__ inline void scale4(typename ScalarType::FragB& frag_b, - typename ScalarType::FragS& frag_s_1, - typename ScalarType::FragS& frag_s_2, - typename ScalarType::FragS& frag_s_3, - typename ScalarType::FragS& frag_s_4, - int i) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 s_val_1_2; - s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; - s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; - - scalar_t2 s_val_3_4; - s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; - s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; - - frag_b[0] = __hmul2(frag_b[0], s_val_1_2); - frag_b[1] = __hmul2(frag_b[1], s_val_3_4); -} - -// Given 2 floats multiply by 2 scales (halves) -template -__device__ inline void scale_float(float* c, - typename ScalarType::FragS& s) { - scalar_t* s_ptr = reinterpret_cast(&s); - c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); - c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); -} - -// Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { - if (threadIdx.x == 0) { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" - : "=r"(state) - : "l"(lock)); - while (state != count); - } - __syncthreads(); -} - -// Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { - __syncthreads(); - if (threadIdx.x == 0) { - if (reset) { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" - : - : "l"(lock), "r"(val)); - } -} - // For a given "a" of size [M,K] performs a permutation of the K columns based // on the given "perm" indices. __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, @@ -510,1304 +118,19 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, } } -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks = -1, // number of consecutive 16x16 blocks - // with a separate quantization scale - const bool is_zp_float // is zero point of float16 type? - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int* __restrict__ g_idx, // int32 group indices of shape k - int num_groups, // number of scale groups per output channel - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int lda, // A.stride(0), equal to prob_k is A is contiguous - int* locks, // extra global storage for barrier synchronization - bool use_atomic_add, // whether to use atomic add to reduce - bool use_fp32_reduce // whether to use fp32 global reduce -) { - // Each threadblock processes one "stripe" of the B matrix with (roughly) the - // same size, which might involve multiple column "slices" (of width 16 * - // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM - // example: - // 0 1 3 - // 0 2 3 - // 1 2 4 - // While this kind of partitioning makes things somewhat more complicated, it - // ensures good utilization of all SMs for many kinds of shape and GPU - // configurations, while requiring as few slow global cross-threadblock - // reductions as possible. - using Dtype = ScalarType; - using scalar_t2 = typename ScalarType::scalar_t2; - using FragA = typename ScalarType::FragA; - using FragB = typename ScalarType::FragB; - using FragC = typename ScalarType::FragC; - using FragS = typename ScalarType::FragS; - using FragZP = typename ScalarType::FragZP; - - static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); - - constexpr int pack_factor = 32 / w_type.size_bits(); - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = prob_k / 16 / thread_k_blocks; - int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); - - if constexpr (!has_act_order && group_blocks != -1) { - if (group_blocks >= thread_k_blocks) { - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts - // in the middle of group. - iters = (group_blocks / thread_k_blocks) * - div_ceil(iters, (group_blocks / thread_k_blocks)); - } - } - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - int par_id = 0; - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * lda / 8; - C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - par_id = slice_col_par / n_tiles; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = div_ceil(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - A += 16 * thread_m_blocks * lda / 8; - C += 16 * thread_m_blocks * prob_n / 8; - locks += n_tiles; - slice_col = 0; - par_id++; - } - }; - init_slice(); - - // A sizes/strides - - // stride of the A matrix in global memory - int a_gl_stride = lda / 8; - // stride of an A matrix tile in shared memory - constexpr int a_sh_stride = 16 * thread_k_blocks / 8; - // delta between subsequent A tiles in global memory - constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; - // between subsequent accesses within a tile - int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); - // between shared memory writes - constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); - // between shared memory tile reads - constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); - // within a shared memory tile - constexpr int a_sh_rd_delta_i = a_sh_stride * 16; - // overall size of a tile - constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); - // number of shared write iterations for a tile - constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); - - // B sizes/strides - int b_gl_stride = 16 * prob_n / (pack_factor * 4); - constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; - constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; - constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; - - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); - constexpr int b_sh_wr_delta = threads * b_thread_vecs; - constexpr int b_sh_rd_delta = threads * b_thread_vecs; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_tb_groups = - !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks - : 1; - constexpr int s_sh_stage = s_tb_groups * s_sh_stride; - int s_gl_rd_delta = s_gl_stride; - - // Scale size/strides with act_order - constexpr int tb_k = 16 * thread_k_blocks; - constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; - // constexpr int act_s_row_stride = 1; - // int act_s_col_stride = act_s_row_stride * num_groups; - int act_s_col_stride = 1; - int act_s_col_warp_stride = act_s_col_stride * 8; - int tb_n_warps = thread_n_blocks / 4; - int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; - - // Zero-points sizes/strides - int zp_gl_stride = is_zp_float ? prob_n / 8 : (prob_n / pack_factor) / 4; - constexpr int zp_sh_stride = is_zp_float - ? 16 * thread_n_blocks / 8 - : ((16 * thread_n_blocks) / pack_factor) / 4; - constexpr int zp_tb_groups = s_tb_groups; - constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; - int zp_gl_rd_delta = zp_gl_stride; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - int a_sh_rd = - a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - auto b_sh_wr = threadIdx.x * b_thread_vecs; - auto b_sh_rd = threadIdx.x * b_thread_vecs; - - // For act_order - constexpr int k_iter_size = tb_k / b_sh_wr_iters; - int slice_k_start = tb_k * slice_row; - int slice_k_finish = slice_k_start + tb_k * slice_iters; - int slice_k_start_shared_fetch = slice_k_start; - int slice_n_offset = act_s_col_tb_stride * slice_col; - - // No act_order - int s_gl_rd; - if constexpr (!has_act_order) { - if constexpr (group_blocks == -1) { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_sh_stride * slice_col + threadIdx.x; - } - } - auto s_sh_wr = threadIdx.x; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - - // Zero-points - int zp_gl_rd; - if constexpr (has_zp) { - if constexpr (group_blocks == -1) { - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; - } else { - zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - zp_sh_stride * slice_col + threadIdx.x; - } - } - auto zp_sh_wr = threadIdx.x; - bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; - - // We use a different scale layout for grouped and column-wise quantization as - // we scale a `half2` tile in column-major layout in the former and in - // row-major in the latter case. - int s_sh_rd; - if constexpr (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) % 4; - - // Zero-points have the same read layout as the scales - // (without column-wise case) - constexpr int num_col_threads = 8; - constexpr int num_row_threads = 4; - constexpr int num_ints_per_thread = 8 / pack_factor; - int zp_sh_rd; - if constexpr (has_zp) { - if constexpr (is_zp_float) { - if constexpr (group_blocks != -1) { - zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - } - } else { - zp_sh_rd = num_ints_per_thread * num_col_threads * - ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); - } - } - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_g_idx = sh_b + (stages * b_sh_stage); - int4* sh_zp = sh_g_idx + (stages * g_idx_stage); - int4* sh_s = sh_zp + (stages * zp_sh_stage); - int4* sh_red = sh_s + (stages * s_sh_stage); - - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks]; - I4 frag_b_quant[2][b_thread_vecs]; - FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; // No act-order - FragS act_frag_s[2][4][4]; // For act-order - int frag_qzp[2][num_ints_per_thread]; // Zero-points - FragZP frag_zp; // Zero-points in fp16 - FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - int sh_first_group_id = -1; - int sh_num_groups = -1; - constexpr int sh_max_num_groups = 32; - - auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, - int last_group_id) { - sh_first_group_id = first_group_id; - sh_num_groups = last_group_id - first_group_id + 1; - - if (sh_num_groups < sh_max_num_groups) { - sh_num_groups = sh_max_num_groups; - } - - if (sh_first_group_id + sh_num_groups > num_groups) { - sh_num_groups = num_groups - sh_first_group_id; - } - - int row_offset = first_group_id * s_gl_stride; - - if (is_async) { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], - &scales_ptr[row_offset + (i * s_gl_stride) + - slice_n_offset + threadIdx.x]); - } - } - } else { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - sh_s[(i * s_sh_stride) + threadIdx.x] = - scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + - threadIdx.x]; - } - } - } - }; - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - cp_async4_pred( - &sh_a_stage[a_sh_wr_trans[i]], - &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], - a_sh_wr_pred[i]); - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < b_thread_vecs; j++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); - } - - B_ptr[i] += b_gl_rd_delta_o; - } - - if constexpr (has_act_order) { - // Fetch g_idx thread-block portion - int full_pipe = a_off; - int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; - if (cur_k < prob_k && cur_k < slice_k_finish) { - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - - int4 const* cur_g_idx_stage_ptr = - reinterpret_cast(&g_idx[cur_k]); - - if (threadIdx.x < g_idx_stage) { - cp_async4_pred(&sh_g_idx_stage[threadIdx.x], - &cur_g_idx_stage_ptr[threadIdx.x]); - } - } - } else { - if constexpr (group_blocks != -1) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); - } - // Only fetch scales if this tile starts a new group - if ((pipe + 1) % (group_blocks / thread_k_blocks) == 0) { - s_gl_rd += s_gl_rd_delta; - } - } else { - for (int i = 0; i < s_tb_groups; i++) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], - &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } - } - - if constexpr (has_zp && group_blocks != -1) { - int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch zero-points if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; - } - } else { - for (int i = 0; i < zp_tb_groups; i++) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], - &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; - } - } - } - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - auto fetch_zp_to_shared = [&]() { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], - &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - - #pragma unroll - for (int i = 0; i < b_thread_vecs; i++) { - frag_b_quant[k % 2][i] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); - } - }; - - bool is_same_group[stages]; - int same_group_id[stages]; - - auto init_same_group = [&](int pipe) { - if constexpr (!has_act_order) { - is_same_group[pipe] = false; - same_group_id[pipe] = 0; - return; - } - - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - int group_id_1 = sh_g_idx_int_ptr[0]; - int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; - - is_same_group[pipe] = group_id_1 == group_id_2; - same_group_id[pipe] = group_id_1; - }; - - auto fetch_scales_to_registers = [&](int k, int full_pipe) { - int pipe = full_pipe % stages; - - if constexpr (!has_act_order) { - // No act-order case - if constexpr (group_blocks != -1) { - if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } else { - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / group_blocks; - - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - reinterpret_cast(&frag_s[k % 2])[0] = - sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } - } - - return; - } - - // Act-order case - - // Determine K of the "current" thread-block - int cur_k = slice_k_start + tb_k * full_pipe; - if (cur_k >= prob_k || cur_k >= slice_k_finish) { - return; - } - - // Reset (to current thread-block) since we read g_idx portion from the - // shared memory - cur_k = 0; - - // Progress to current iteration - cur_k += k_iter_size * (k % b_sh_wr_iters); - - // Determine "position" inside the thread-block (based on warp and - // thread-id) - auto warp_id = threadIdx.x / 32; - int n_warps = - thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N - - int warp_row = warp_id / n_warps; - int warp_col = warp_id % n_warps; - - cur_k += warp_row * 16; - - auto th_id = threadIdx.x % 32; - cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix - - int s_col_shift = - /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + - (th_id / 4) * act_s_col_stride; - - if (is_same_group[pipe]) { - if (k % 2 == 0) { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + - s_col_shift]; - } else { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); - } - - for (int i = 1; i < 4; i++) { - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); - } - return; - } - - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - constexpr int k_frag_offsets[4] = {0, 1, 8, - 9}; // Tensor core offsets per thread - - #pragma unroll - for (int i = 0; i < 4; i++) { - int actual_k = cur_k + k_frag_offsets[i]; - - int group_id = sh_g_idx_int_ptr[actual_k]; - int rel_group_id = group_id - sh_first_group_id; - - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - sh_s[rel_group_id * s_sh_stride + s_col_shift]; - } - }; - - auto fetch_zp_to_registers = [&](int k, int full_pipe) { - // This code does not handle group_blocks == 0, - // which signifies act_order. - // has_zp implies AWQ, which doesn't have act_order, - static_assert(!has_zp || group_blocks != 0); - - if constexpr (has_zp && !is_zp_float) { - int pipe = full_pipe % stages; - - if constexpr (group_blocks == -1) { - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; - } - - } else if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_zp_stage = - sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = - (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; - } - } else { - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = 0; - - // Suppress bogus and persistent divide-by-zero warning - #pragma nv_diagnostic push - #pragma nv_diag_suppress divide_by_zero - cur_group_id = k_blocks / group_blocks; - #pragma nv_diagnostic pop - - int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - - sh_zp_stage += cur_group_id * zp_sh_stride; - - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = - (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; - } - } - } - - else if constexpr (has_zp && is_zp_float) { - int pipe = full_pipe % stages; - - if constexpr (group_blocks != -1) { - if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_zp_stage = - sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd]; - } else { - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - // Suppress bogus and persistent divide-by-zero warning - #pragma nv_diagnostic push - #pragma nv_diag_suppress divide_by_zero - int cur_group_id = k_blocks / group_blocks; - #pragma nv_diagnostic pop - - int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - - reinterpret_cast(&frag_zpf[k % 2])[0] = - sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride]; - } - } - } - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - if constexpr (has_zp && !is_zp_float) { - FragB frag_zp_0; - FragB frag_zp_1; - int zp_quant_0, zp_quant_1; - - if constexpr (w_type.size_bits() == 4) { - zp_quant_0 = frag_qzp[k % 2][0]; - zp_quant_1 = zp_quant_0 >> 8; - } else { - static_assert(w_type.size_bits() == 8); - zp_quant_0 = frag_qzp[k % 2][0]; - zp_quant_1 = frag_qzp[k % 2][1]; - } - - frag_zp_0 = dequant(zp_quant_0); - frag_zp_1 = dequant(zp_quant_1); - - frag_zp[0] = frag_zp_0[0]; - frag_zp[1] = frag_zp_0[1]; - frag_zp[2] = frag_zp_1[0]; - frag_zp[3] = frag_zp_1[1]; - } - - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - FragB frag_b0; - FragB frag_b1; - int b_quant_0, b_quant_1; - - if constexpr (w_type.size_bits() == 4) { - b_quant_0 = frag_b_quant[k % 2][0][j]; - b_quant_1 = b_quant_0 >> 8; - } else { - static_assert(w_type.size_bits() == 8); - int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); - b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; - b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - } - - frag_b0 = dequant(b_quant_0); - frag_b1 = dequant(b_quant_1); - - // Apply zero-point to frag_b0 - if constexpr (has_zp && !is_zp_float) { - sub_zp(frag_b0, frag_zp[j], 0); - } - - else if constexpr (has_zp && is_zp_float && group_blocks != -1) { - sub_zp(frag_b0, frag_zpf[k % 2][j], 0); - } - - // Apply scale to frag_b0 - if constexpr (has_act_order) { - scale4(frag_b0, act_frag_s[k % 2][0][j], - act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], - act_frag_s[k % 2][3][j], 0); - } else { - if constexpr (group_blocks != -1) { - scale(frag_b0, frag_s[k % 2][j], 0); - } - } - - // Apply zero-point to frag_b1 - if constexpr (has_zp && !is_zp_float) { - sub_zp(frag_b1, frag_zp[j], 1); - } - - else if constexpr (has_zp && is_zp_float && group_blocks != -1) { - sub_zp(frag_b1, frag_zpf[k % 2][j], 1); - } - - // Apply scale to frag_b1 - if constexpr (has_act_order) { - scale4(frag_b1, act_frag_s[k % 2][0][j], - act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], - act_frag_s[k % 2][3][j], 1); - - } else { - if constexpr (group_blocks != -1) { - scale(frag_b1, frag_s[k % 2][j], 1); - } - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride_threads / 2; - if (red_off >= 1) { - auto red_idx = threadIdx.x / b_sh_stride_threads; - constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; - constexpr int red_sh_delta = b_sh_stride_threads; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - float* c_rd = reinterpret_cast( - &sh_red[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh_red[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh_red[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh_red[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce_fp16 = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 4 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + - 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads; - auto c_sh_wr = threadIdx.x; - - int row = (threadIdx.x % 32) / 4; - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred( - &sh_red[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { - if (!first) { - int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - Dtype::num2float(reinterpret_cast(&c_red)[j]); - } - } - if (!last) { - int4 c; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast(&c)[j] = - Dtype::float2num(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); - } - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = - c; - } - } - } - } - }; - - // Globally reduce over threadblocks that compute the same column block. - // We use a tmp C buffer to reduce in full fp32 precision. - auto global_reduce_fp32 = [&](bool first = false, bool last = false) { - constexpr int tb_m = thread_m_blocks * 16; - constexpr int tb_n = thread_n_blocks * 16; - - constexpr int c_size = tb_m * tb_n * sizeof(float) / 16; - - constexpr int active_threads = 32 * thread_n_blocks / 4; - bool is_th_active = threadIdx.x < active_threads; - - int par_offset = c_size * n_tiles * par_id; - int slice_offset = c_size * slice_col; - - constexpr int num_floats = thread_m_blocks * 4 * 2 * 4; - constexpr int th_size = num_floats * sizeof(float) / 16; - - int c_cur_offset = par_offset + slice_offset; - - if (!is_th_active) { - return; - } - - if (!first) { - float* frag_c_ptr = reinterpret_cast(&frag_c); - #pragma unroll - for (int k = 0; k < th_size; k++) { - sh_red[threadIdx.x] = - C_tmp[c_cur_offset + active_threads * k + threadIdx.x]; - - float* sh_c_ptr = reinterpret_cast(&sh_red[threadIdx.x]); - #pragma unroll - for (int f = 0; f < 4; f++) { - frag_c_ptr[k * 4 + f] += sh_c_ptr[f]; - } - } - } - - if (!last) { - int4* frag_c_ptr = reinterpret_cast(&frag_c); - #pragma unroll - for (int k = 0; k < th_size; k++) { - C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k]; - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks + 1; - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int c_sh_rd_delta = - c_sh_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - int c_sh_wr = - (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); - int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { - scalar_t2 res = - Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); - - // For per-column quantization we finally apply the scale here (only for - // 4-bit) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 4) { - res = __hmul2(res, s[0]); - } - - ((scalar_t2*)sh_red)[idx] = res; - }; - - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - int wr = c_sh_wr + 8 * j; - write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); - } - c_sh_wr += 16 * (4 * c_sh_stride); - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (c_gl_wr < c_gl_wr_end) { - if (use_atomic_add && slice_count > 1) { - scalar_t2* C_half2 = reinterpret_cast(&C[c_gl_wr]); - scalar_t2* sh_red_half2 = - reinterpret_cast(&sh_red[c_sh_rd]); - #pragma unroll - for (int a = 0; a < 4; a++) { - atomicAdd(&C_half2[a], sh_red_half2[a]); - } - } else { - C[c_gl_wr] = sh_red[c_sh_rd]; - } - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - - #pragma unroll - for (int i = 0; i < stages - 1; i++) { - if (has_act_order && i == 0) { - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); - } - - if constexpr (has_zp && !is_zp_float && group_blocks == -1) { - if (i == 0) { - fetch_zp_to_shared(); - } - } - fetch_to_shared(i, i, i < slice_iters); - } - - zero_accums(); - wait_for_stage(); - init_same_group(0); - fetch_to_registers(0, 0); - fetch_scales_to_registers(0, 0); - fetch_zp_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - slice_k_start_shared_fetch += tb_k * (stages - 1); - }; - if (slice_iters) { - start_pipes(); - } - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines - // have even length meaning that the next iteration will always start at - // index 0. - - #pragma unroll - for (int pipe = 0; pipe < stages;) { - #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - fetch_scales_to_registers(k + 1, pipe); - fetch_zp_to_registers(k + 1, pipe); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - init_same_group(pipe % stages); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) { - break; - } - } - - a_gl_rd += a_gl_rd_delta_o * stages; - slice_k_start += tb_k * stages; - slice_k_start_shared_fetch += tb_k * stages; - - if constexpr (has_act_order) { - int first_group_id = g_idx[slice_k_start]; - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - int last_group_id = g_idx[last_g_idx]; - if (last_group_id >= sh_first_group_id + sh_num_groups) { - fetch_scales_to_shared(false, first_group_id, last_group_id); - __syncthreads(); - } - } - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out - if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (w_type.size_bits() == 8) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } else { - if (last || use_atomic_add) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } - } - } - - thread_block_reduce(); - if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (w_type.size_bits() == 8) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - - } else { - if (last || use_atomic_add) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - } - } - } - - // For 8-bit channelwise, we apply the scale before the global reduction - // that converts the fp32 results to fp16 (so that we avoid possible - // overflow in fp16) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 8) { - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - scale_float( - reinterpret_cast(&frag_c[i][j][0][0]), - frag_s[j / 2][2 * (j % 2) + 0]); - scale_float( - reinterpret_cast(&frag_c[i][j][0][2]), - frag_s[j / 2][2 * (j % 2) + 0]); - - scale_float( - reinterpret_cast(&frag_c[i][j][1][0]), - frag_s[j / 2][2 * (j % 2) + 1]); - scale_float( - reinterpret_cast(&frag_c[i][j][1][2]), - frag_s[j / 2][2 * (j % 2) + 1]); - } - } - } - } - - if (slice_count > 1 && !use_atomic_add) { - // only globally reduce if there is more than one block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - if (use_fp32_reduce) { - global_reduce_fp32(slice_idx == 0, last); - } else { - global_reduce_fp16(slice_idx == 0, last); - } - barrier_release(&locks[slice_col], last); - } - if (last || use_atomic_add) - // only the last block in a slice actuallywrites the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } - - // Update slice k/n for scales loading - if constexpr (has_act_order) { - slice_k_start = tb_k * slice_row; - slice_k_finish = slice_k_start + tb_k * slice_iters; - slice_k_start_shared_fetch = slice_k_start; - slice_n_offset = act_s_col_tb_stride * slice_col; - - } else { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; - } - - start_pipes(); - } - } - } -} - - #define __CALL_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS, \ - IS_ZP_FLOAT) \ - else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ - is_zp_float == IS_ZP_FLOAT) { \ - if constexpr (!IS_ZP_FLOAT || std::is_same::value) { \ - cudaFuncSetAttribute( \ - Marlin, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - Marlin \ - <<>>( \ - A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \ - num_groups, prob_m, prob_n, prob_k, lda, locks, \ - part_use_atomic_add, use_fp32_reduce); \ - } \ - } - typedef struct { int thread_k; int thread_n; int num_threads; } thread_config_t; -typedef struct { - int max_m_blocks; - thread_config_t tb_cfg; -} exec_config_t; - thread_config_t small_batch_thread_configs[] = { // Ordered by priority // thread_k, thread_n, num_threads {128, 128, 256}, {64, 128, 128}, - {128, 64, 128}, -}; + {128, 64, 128}}; thread_config_t large_batch_thread_configs[] = { // Ordered by priority @@ -1815,9 +138,12 @@ thread_config_t large_batch_thread_configs[] = { // thread_k, thread_n, num_threads {64, 256, 256}, {64, 128, 128}, - {128, 64, 128}, + {128, 64, 128}}; -}; +typedef struct { + int blocks_per_sm; + thread_config_t tb_cfg; +} exec_config_t; int get_scales_cache_size(thread_config_t const& th_config, int prob_m, int prob_n, int prob_k, int num_bits, int group_size, @@ -1842,7 +168,6 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m, tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K load_groups = max(load_groups, 32); // We load at least 32 scale groups return load_groups * tb_n * 2; - } else { int tb_scales = tb_groups * tb_n * 2; @@ -1850,49 +175,43 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m, } } -bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, - int prob_m, int prob_n, int prob_k, int num_bits, - int scales_cache_size, int max_shared_mem) { +int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int group_size, bool has_act_order, bool is_k_full, + int has_zp, int is_zp_float) { int pack_factor = 32 / num_bits; // Get B size int tb_k = th_config.thread_k; int tb_n = th_config.thread_n; - - int b_size = (tb_k * tb_n / pack_factor) * 4; - - // Get A size - int m_blocks = div_ceil(prob_m, 16); - int tb_max_m = 16; - - while (true) { - if (m_blocks >= max_m_blocks) { - tb_max_m *= max_m_blocks; - break; - } - - max_m_blocks--; - if (max_m_blocks == 0) { - TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); - } + int tb_m = thread_m_blocks * 16; + int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; + int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; + int sh_red_size = tb_m * (tb_n + 8); + int sh_s_size = + get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, + group_size, has_act_order, is_k_full); + int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0; + int sh_zp_size = 0; + if (has_zp) { + if (is_zp_float) + sh_zp_size = sh_s_size; + else if (num_bits == 4) + sh_zp_size = sh_s_size / 4; + else if (num_bits == 8) + sh_zp_size = sh_s_size / 2; } - int a_size = (tb_max_m * tb_k) * 2; - - float pipe_size = (a_size + b_size) * pipe_stages; - - float reduce_size = max(th_config.num_threads * 32 * 4, - (tb_n / 64) * 32 * (tb_max_m / 16) * 4 * 2 * 4 * 2); + int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size + + sh_zp_size + sh_g_idx_size; - TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity - - return pipe_size + reduce_size < 0.95f * (max_shared_mem - scales_cache_size); + return total_size; } -bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, +bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, int prob_m, int prob_n, int prob_k, int num_bits, int group_size, bool has_act_order, bool is_k_full, - int max_shared_mem) { + int has_zp, int is_zp_float, int max_shared_mem) { // Sanity if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { @@ -1914,163 +233,204 @@ bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, return false; } - // Determine cache for scales - int scales_cache_size = - get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, - group_size, has_act_order, is_k_full); - // Check that pipeline fits into cache - if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, scales_cache_size, max_shared_mem)) { - return false; - } - - return true; + int cache_size = get_kernel_cache_size( + th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, has_zp, is_zp_float); + return cache_size <= max_shared_mem; } -int determine_reduce_max_m(int prob_m, int max_par) { - constexpr int tile_m_size = 16; + #define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ + else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + m_block_size_8 == M_BLOCK_SIZE_8 && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ + is_zp_float == IS_ZP_FLOAT) { \ + kernel = Marlin; \ + } + + // COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false) + // this is the most common cases + // BIGGROUP: cases for big group size (group_blocks in [-1, 8]) + // FZP: cases for float-zero-point (is_zp_float = true) + // ACT: cases for act order case (group_blocks == 0) + #define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + + #define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + + #define COMMON_GET_IF(W_TYPE) \ + COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \ + COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \ + COMMON_GET_IF_M1(W_TYPE, 4, 8, 128) \ + COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \ + COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) \ + COMMON_GET_IF_M234(W_TYPE, 4, 8, 128) + + #define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + + #define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + + #define BIGGROUP_GET_IF(W_TYPE) \ + BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ + BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ + BIGGROUP_GET_IF_M1(W_TYPE, 4, 8, 128) \ + BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \ + BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \ + BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128) + + // We currently have 4-bit models only with group_blocks == 4 + #define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) + + #define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) + + #define FZP_GET_IF(W_TYPE) \ + FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FZP_GET_IF_M1(W_TYPE, 4, 8, 128) \ + FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FZP_GET_IF_M234(W_TYPE, 8, 4, 128) \ + FZP_GET_IF_M234(W_TYPE, 4, 8, 128) + + // We currently have 4-bit models only with group_blocks == 4 + #define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) + + #define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) + + #define ACT_GET_IF(W_TYPE) \ + ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \ + ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \ + ACT_GET_IF_M1(W_TYPE, 4, 8, 128) \ + ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \ + ACT_GET_IF_M234(W_TYPE, 8, 4, 128) \ + ACT_GET_IF_M234(W_TYPE, 4, 8, 128) - if (prob_m <= tile_m_size) { - return tile_m_size; +template +MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, + int thread_m_blocks, int thread_n_blocks, + int thread_k_blocks, bool m_block_size_8, + bool has_act_order, bool has_zp, + int group_blocks, int num_threads, + bool is_zp_float) { + int num_bits = q_type.size_bits(); + auto kernel = MarlinDefault; + if (false) { + } - } else if (prob_m <= tile_m_size * 2) { - return tile_m_size * 2; + COMMON_GET_IF(vllm::kU4) + COMMON_GET_IF(vllm::kU4B8) + COMMON_GET_IF(vllm::kU8B128) - } else if (prob_m <= tile_m_size * 3) { - return tile_m_size * 3; + BIGGROUP_GET_IF(vllm::kFE4M3fn) - } else if (prob_m <= tile_m_size * 4) { - return tile_m_size * 4; + ACT_GET_IF(vllm::kU4B8) + ACT_GET_IF(vllm::kU8B128) - } else { - int cur_par = min(div_ceil(prob_m, tile_m_size * 4), max_par); - return tile_m_size * 4 * cur_par; + if (std::is_same::value) { + if (false) { + } + FZP_GET_IF(vllm::kU4) } + + return kernel; } -exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, - int num_bits, int group_size, - bool has_act_order, bool is_k_full, - int max_shared_mem) { - int max_m_blocks = 4; - while (max_m_blocks > 0) { - if (prob_m <= 16) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, has_act_order, is_k_full, - max_shared_mem)) { - return exec_config_t{max_m_blocks, th_config}; - } - } - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, has_act_order, is_k_full, - max_shared_mem)) { - return exec_config_t{max_m_blocks, th_config}; - } - } +template +exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, + int prob_n, int prob_k, int thread_m_blocks, + bool m_block_size_8, int num_bits, + int group_size, bool has_act_order, + bool is_k_full, bool has_zp, + bool is_zp_float, int max_shared_mem, + int sms) { + exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; + thread_config_t* thread_configs = thread_m_blocks > 1 + ? large_batch_thread_configs + : small_batch_thread_configs; + int thread_configs_size = + thread_m_blocks > 1 + ? sizeof(large_batch_thread_configs) / sizeof(thread_config_t) + : sizeof(small_batch_thread_configs) / sizeof(thread_config_t); + + for (int i = 0; i < thread_configs_size; i++) { + thread_config_t th_config = thread_configs[i]; + + if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, has_zp, + is_zp_float, max_shared_mem)) { + continue; } - max_m_blocks--; // Process less M blocks per invocation to reduce cache - // usage - } + int cache_size = get_kernel_cache_size( + th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, + group_size, has_act_order, is_k_full, has_zp, is_zp_float); - return exec_config_t{0, {-1, -1, -1}}; -} + int group_blocks = 0; + if (!has_act_order) { + group_blocks = group_size == -1 ? -1 : group_size / 16; + } - #define GPTQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \ - false) \ - \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \ - false) \ - \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \ - false) \ - \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \ - false) \ - \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \ - false) - - #define AWQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \ - false) \ - \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \ - false) \ - \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \ - false) \ - \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, false) + auto kernel = get_marlin_kernel( + q_type, thread_m_blocks, th_config.thread_n / 16, + th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp, + group_blocks, th_config.num_threads, is_zp_float); - // We currently have 4-bit models only with group_blocks == 4 - #define HQQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ - true) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ - true) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ - true) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, true) + if (kernel == MarlinDefault) continue; + + // int m_tiles = div_ceil(prob_m, thread_m_blocks * 16); + // int n_tiles = prob_n / th_config.thread_n; + // int k_tiles = prob_k / th_config.thread_k; + + return {1, th_config}; + } + + return exec_cfg; +} template void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, @@ -2078,78 +438,24 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, int prob_n, int prob_k, int lda, void* workspace, vllm::ScalarType const& q_type, bool has_act_order, bool is_k_full, bool has_zp, int num_groups, int group_size, - int dev, cudaStream_t stream, int thread_k, int thread_n, - int sms, int max_par, bool use_atomic_add, bool use_fp32_reduce, - bool is_zp_float) { + int dev, cudaStream_t stream, int thread_k_init, + int thread_n_init, int sms, bool use_atomic_add, + bool use_fp32_reduce, bool is_zp_float) { if (has_zp) { TORCH_CHECK( q_type == vllm::kU4 || q_type == vllm::kU8, "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); } else { - TORCH_CHECK( - q_type == vllm::kU4B8 || q_type == vllm::kU8B128, - "q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ", - q_type.str()); + TORCH_CHECK(q_type == vllm::kU4B8 || q_type == vllm::kU8B128 || + q_type == vllm::kFE4M3fn, + "q_type must be uint4b8, uint8b128 or float8_e4m3fn when " + "has_zp = False. Got = ", + q_type.str()); } TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); - // TODO: remove alias when we start supporting other 8bit types - int num_bits = q_type.size_bits(); - int tot_m = prob_m; - int tot_m_blocks = div_ceil(tot_m, 16); - int pad = 16 * tot_m_blocks - tot_m; - - if (sms == -1) { - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); - } - - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - TORCH_CHECK(max_shared_mem > 0); - - // Set thread config - exec_config_t exec_cfg; - if (thread_k != -1 && thread_n != -1) { - // User-defined config - exec_cfg = - exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}}; - } else { - // Auto config - exec_cfg = - determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size, - has_act_order, is_k_full, max_shared_mem); - } - - TORCH_CHECK(exec_cfg.max_m_blocks > 0 && - is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, - prob_m, prob_n, prob_k, num_bits, group_size, - has_act_order, is_k_full, max_shared_mem), - "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, - ", thread_k = ", exec_cfg.tb_cfg.thread_k, - ", thread_n = ", exec_cfg.tb_cfg.thread_n, - ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", - prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, - ", group_size = ", group_size, - ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, - ", max_shared_mem = ", max_shared_mem); - - int num_threads = exec_cfg.tb_cfg.num_threads; - thread_k = exec_cfg.tb_cfg.thread_k; - thread_n = exec_cfg.tb_cfg.thread_n; - - int thread_k_blocks = thread_k / 16; - int thread_n_blocks = thread_n / 16; - - int blocks = sms; - - TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, - " is not divisible by thread_n = ", thread_n); - TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, - " is not divisible by thread_k = ", thread_k); - int group_blocks = 0; if (has_act_order) { if (is_k_full) { @@ -2161,7 +467,6 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, TORCH_CHECK(group_size == 0); group_blocks = 0; } - } else { if (group_size == -1) { group_blocks = -1; @@ -2172,6 +477,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, } } + int num_bits = q_type.size_bits(); const int4* A_ptr = (const int4*)A; const int4* B_ptr = (const int4*)B; int4* C_ptr = (int4*)C; @@ -2186,106 +492,138 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, if (has_act_order) { // Permute A columns - int block_rows = div_ceil(prob_m, blocks); - permute_cols_kernel<<>>( + int block_rows = div_ceil(prob_m, sms); + // avoid ">>>" being formatted to "> > >" + // clang-format off + permute_cols_kernel<<>>( A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, lda, block_rows); + // clang-format on A_ptr = a_tmp_ptr; lda = prob_k; - } - // If we have a full K, then we can run the non-act-order version of Marlin - // (since the weight rows are reordered by increasing group ids, and by having - // a full K, we have full original groups) - if (is_k_full) { - has_act_order = false; + // If we have a full K, then we can run the non-act-order version of Marlin + // (since the weight rows are reordered by increasing group ids, and by + // having a full K, we have full original groups) + if (is_k_full) has_act_order = false; } - // Main loop - for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) { - int thread_m_blocks = tot_m_blocks - i; - prob_m = tot_m - 16 * i; - int par = 1; - if (thread_m_blocks > exec_cfg.max_m_blocks) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks); - if (par > max_par) par = max_par; - prob_m = (16 * exec_cfg.max_m_blocks) * par; - i += exec_cfg.max_m_blocks * (par - 1); - thread_m_blocks = exec_cfg.max_m_blocks; + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + int max_par = 16; + if (prob_n <= 4096) max_par = 16 * 8; + int max_shared_mem_new = max_shared_mem; + int rest_m = prob_m; + int max_thread_m_blocks = 4; + while (rest_m) { + int par_count = rest_m / (max_thread_m_blocks * 16); + if (par_count > max_par) par_count = max_par; + int prob_m_split = + par_count > 0 ? (par_count * (max_thread_m_blocks * 16)) : rest_m; + + int thread_k = thread_k_init; + int thread_n = thread_n_init; + + int thread_m_blocks = min(div_ceil(prob_m_split, 16), max_thread_m_blocks); + int m_block_size_8 = prob_m_split <= 8; + + // Set thread config + exec_config_t exec_cfg; + thread_config_t thread_tfg; + if (thread_k != -1 && thread_n != -1) { + thread_tfg = thread_config_t{thread_k, thread_n, default_threads}; + exec_cfg = exec_config_t{1, thread_tfg}; + TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, + " is not divisible by thread_n = ", thread_n); + TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, + " is not divisible by thread_k = ", thread_k); + } else { + // Auto config + exec_cfg = determine_exec_config( + q_type, prob_m_split, prob_n, prob_k, thread_m_blocks, m_block_size_8, + num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float, + max_shared_mem, sms); + thread_tfg = exec_cfg.tb_cfg; + if (thread_tfg.thread_k == -1 && max_thread_m_blocks > 1) { + max_thread_m_blocks--; + continue; + } } - // atomic add reduce have better performance only when m * n is small - bool part_use_atomic_add = - use_atomic_add && div_ceil(prob_m, 64) * prob_n <= 2048; + int num_threads = thread_tfg.num_threads; + thread_k = thread_tfg.thread_k; + thread_n = thread_tfg.thread_n; + int blocks = sms * exec_cfg.blocks_per_sm; + if (exec_cfg.blocks_per_sm > 1) + max_shared_mem_new = max_shared_mem / exec_cfg.blocks_per_sm - 1024; - if (false) { - } - GPTQ_CALL_IF(vllm::kU4B8, 16, 4, 256) - GPTQ_CALL_IF(vllm::kU4B8, 8, 8, 256) - GPTQ_CALL_IF(vllm::kU4B8, 8, 4, 128) - GPTQ_CALL_IF(vllm::kU4B8, 4, 8, 128) - GPTQ_CALL_IF(vllm::kU8B128, 16, 4, 256) - GPTQ_CALL_IF(vllm::kU8B128, 8, 8, 256) - GPTQ_CALL_IF(vllm::kU8B128, 8, 4, 128) - GPTQ_CALL_IF(vllm::kU8B128, 4, 8, 128) - - AWQ_CALL_IF(vllm::kU4, 16, 4, 256) - AWQ_CALL_IF(vllm::kU4, 8, 8, 256) - AWQ_CALL_IF(vllm::kU4, 8, 4, 128) - AWQ_CALL_IF(vllm::kU4, 4, 8, 128) - AWQ_CALL_IF(vllm::kU8, 16, 4, 256) - AWQ_CALL_IF(vllm::kU8, 8, 8, 256) - AWQ_CALL_IF(vllm::kU8, 8, 4, 128) - AWQ_CALL_IF(vllm::kU8, 4, 8, 128) - - HQQ_CALL_IF(vllm::kU4, 16, 4, 256) - HQQ_CALL_IF(vllm::kU4, 8, 8, 256) - HQQ_CALL_IF(vllm::kU4, 8, 4, 128) - HQQ_CALL_IF(vllm::kU4, 4, 8, 128) - else { + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + TORCH_CHECK( + is_valid_config(thread_tfg, thread_m_blocks, prob_m_split, prob_n, + prob_k, num_bits, group_size, has_act_order, is_k_full, + has_zp, is_zp_float, max_shared_mem_new), + "Invalid thread config: thread_m_blocks = ", thread_m_blocks, + ", thread_k = ", thread_tfg.thread_k, + ", thread_n = ", thread_tfg.thread_n, + ", num_threads = ", thread_tfg.num_threads, " for MKN = [", prob_m, + ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + ", prob_m_split = ", prob_m_split, ", group_size = ", group_size, + ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, + ", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float, + ", max_shared_mem_new = ", max_shared_mem_new); + + auto kernel = get_marlin_kernel( + q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, + m_block_size_8, has_act_order, has_zp, group_blocks, num_threads, + is_zp_float); + + if (kernel == MarlinDefault) { TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]", ", has_act_order = ", has_act_order, ", num_groups = ", num_groups, ", group_size = ", group_size, + ", prob_m_split = ", prob_m_split, ", thread_m_blocks = ", thread_m_blocks, ", thread_n_blocks = ", thread_n_blocks, ", thread_k_blocks = ", thread_k_blocks, - ", num_bits = ", num_bits); + ", num_threads = ", num_threads, ", num_bits = ", num_bits); } - A_ptr += 16 * thread_m_blocks * (lda / 8) * par; - C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + max_shared_mem_new); + + bool part_use_atomic_add = + use_atomic_add && div_ceil(prob_m_split, 64) * prob_n <= 2048; + + // avoid ">>>" being formatted to "> > >" + // clang-format off + kernel<<>>( + A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, num_groups, + prob_m_split, prob_n, prob_k, lda, locks, part_use_atomic_add, + use_fp32_reduce, max_shared_mem_new); + // clang-format on + + A_ptr += prob_m_split * (lda / 8); + C_ptr += prob_m_split * (prob_n / 8); + rest_m -= prob_m_split; } } } // namespace marlin -torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& b_zeros, - torch::Tensor& g_idx, torch::Tensor& perm, - torch::Tensor& workspace, - vllm::ScalarTypeId const& b_q_type_id, - int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, bool has_zp, bool use_atomic_add, - bool use_fp32_reduce, bool is_zp_float) { +torch::Tensor gptq_marlin_gemm( + torch::Tensor& a, std::optional c_or_none, + torch::Tensor& b_q_weight, torch::Tensor& b_scales, + std::optional const& b_zeros_or_none, + std::optional const& g_idx_or_none, + std::optional const& perm_or_none, torch::Tensor& workspace, + vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, + bool is_zp_float) { vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); - if (has_zp) { - TORCH_CHECK( - b_q_type == vllm::kU4 || b_q_type == vllm::kU8, - "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); - } else { - TORCH_CHECK( - b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128, - "b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ", - b_q_type.str()); - } - - if (has_zp && is_zp_float) { - TORCH_CHECK(a.scalar_type() == at::ScalarType::Half, - "Computation type must be float16 (half) when using float zero " - "points."); - } - int pack_factor = 32 / b_q_type.size_bits(); // Verify A @@ -2295,15 +633,19 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, ", size_k = ", size_k); // Verify B - TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k, - " is not divisible by tile_size = ", marlin::tile_size); - TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0), + TORCH_CHECK( + size_k % MARLIN_NAMESPACE_NAME::tile_size == 0, "size_k = ", size_k, + " is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size); + TORCH_CHECK((size_k / MARLIN_NAMESPACE_NAME::tile_size) == b_q_weight.size(0), "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), - ", size_k = ", size_k, ", tile_size = ", marlin::tile_size); - TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0, - "b_q_weight.size(1) = ", b_q_weight.size(1), - " is not divisible by tile_size = ", marlin::tile_size); - int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor; + ", size_k = ", size_k, + ", tile_size = ", MARLIN_NAMESPACE_NAME::tile_size); + TORCH_CHECK( + b_q_weight.size(1) % MARLIN_NAMESPACE_NAME::tile_size == 0, + "b_q_weight.size(1) = ", b_q_weight.size(1), + " is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size); + int actual_size_n = + (b_q_weight.size(1) / MARLIN_NAMESPACE_NAME::tile_size) * pack_factor; TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, ", actual_size_n = ", actual_size_n); @@ -2320,63 +662,47 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); - TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU"); - TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous"); - - TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU"); - TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous"); - - TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU"); - TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous"); + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_k = -1; + // thread_n: `n` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_n = -1; + // sms: number of SMs to use for the kernel + int sms = -1; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device()); // Alloc buffers const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); torch::Tensor c; - if (use_atomic_add) { - c = torch::zeros({size_m, size_n}, options); + if (c_or_none.has_value()) { + c = c_or_none.value(); + TORCH_CHECK(c.device().is_cuda(), "c is not on GPU"); + TORCH_CHECK(c.is_contiguous(), "c is not contiguous"); + TORCH_CHECK(c.size(0) == size_m, "Shape mismatch: c.size(0) = ", c.size(0), + ", size_m = ", size_m); + TORCH_CHECK(c.size(1) == size_n, "Shape mismatch: c.size(1) = ", c.size(1), + ", size_n = ", size_n); } else { c = torch::empty({size_m, size_n}, options); } - - torch::Tensor a_tmp; - bool has_act_order = g_idx.size(0) != 0; - if (has_act_order) { - a_tmp = torch::empty({size_m, size_k}, options); - } else { - a_tmp = torch::empty({0}, options); - } + if (size_m == 0) return c; // Alloc C tmp buffer that is going to be used for the global reduce torch::Tensor c_tmp; - int reduce_max_m = marlin::determine_reduce_max_m(size_m, marlin::max_par); - int reduce_n = size_n; auto options_fp32 = torch::TensorOptions().dtype(at::kFloat).device(a.device()); if (use_fp32_reduce) { - c_tmp = torch::empty({reduce_max_m, reduce_n}, options_fp32); + int max_m_block_size = (size_m + 16 - 1) / 16 * 16; + max_m_block_size = min(max_m_block_size, 64); + int max_c_tmp_size = + sms * max_m_block_size * MARLIN_NAMESPACE_NAME::max_thread_n; + c_tmp = torch::empty({max_c_tmp_size}, options_fp32); } else { - reduce_max_m = 0; - reduce_n = 0; c_tmp = torch::empty({0}, options_fp32); } - // thread_k: `k` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_k = -1; - // thread_n: `n` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_n = -1; - // sms: number of SMs to use for the kernel (can usually be left as auto -1) - int sms = -1; - - // Verify g_idx and perm - TORCH_CHECK((g_idx.size(0) == 0 && perm.size(0) == 0) || - (g_idx.size(0) == size_k && perm.size(0) == size_k), - "Unexpected g_idx.size(0) = ", g_idx.size(0), - " and perm.size(0) = ", perm.size(0), - ", where size_k = ", size_k); - // Detect groupsize and act_order int num_groups = -1; int group_size = -1; @@ -2387,7 +713,31 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, " is not size_n = ", size_n); num_groups = b_scales.size(0); + torch::Tensor g_idx, perm, a_tmp; + if (g_idx_or_none.has_value() && perm_or_none.has_value()) { + g_idx = g_idx_or_none.value(); + perm = perm_or_none.value(); + + TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU"); + TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous"); + TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU"); + TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous"); + + // Verify g_idx and perm + TORCH_CHECK((g_idx.size(-1) == 0 && perm.size(-1) == 0) || + (g_idx.size(-1) == size_k && perm.size(-1) == size_k), + "Unexpected g_idx.size(-1) = ", g_idx.size(-1), + " and perm.size(-1) = ", perm.size(-1), + ", where size_k = ", size_k); + } else { + g_idx = torch::empty({0}, options); + perm = torch::empty({0}, options); + a_tmp = torch::empty({0}, options); + } + bool has_act_order = g_idx.size(-1) > 0 && perm.size(-1) > 0; + if (has_act_order) { + a_tmp = torch::empty({size_m, size_k}, options); if (is_k_full) { TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, @@ -2398,6 +748,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, } } else { + a_tmp = torch::empty({0}, options); if (num_groups > 1) { TORCH_CHECK( size_k % num_groups == 0, "size_k = ", size_k, @@ -2408,6 +759,33 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, } } + torch::Tensor b_zeros; + if (b_zeros_or_none.has_value()) { + b_zeros = b_zeros_or_none.value(); + TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU"); + TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous"); + } else { + b_zeros = torch::empty({0}, options); + } + bool has_zp = b_zeros.size(-1) > 0; + if (has_zp) { + TORCH_CHECK( + b_q_type == vllm::kU4 || b_q_type == vllm::kU8, + "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); + } else { + TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 || + b_q_type == vllm::kFE4M3fn, + "b_q_type must be uint4b8, uint8b128 or float8_e4m3fn when " + "has_zp = False. Got = ", + b_q_type.str()); + } + + if (has_zp && is_zp_float) { + TORCH_CHECK(a.scalar_type() == at::ScalarType::Half, + "Computation type must be float16 (half) when using float zero " + "points."); + } + // Verify b_zeros if (has_zp) { int rank = b_zeros.sizes().size(); @@ -2431,9 +809,11 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, } // Verify workspace size - TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n, - ", is not divisible by min_thread_n = ", marlin::min_thread_n); - int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par; + TORCH_CHECK(size_n % MARLIN_NAMESPACE_NAME::min_thread_n == 0, + "size_n = ", size_n, ", is not divisible by min_thread_n = ", + MARLIN_NAMESPACE_NAME::min_thread_n); + + int min_workspace_size = sms; TORCH_CHECK(workspace.numel() >= min_workspace_size, "workspace.numel = ", workspace.numel(), " is below min_workspace_size = ", min_workspace_size); @@ -2447,8 +827,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, a_tmp.data_ptr(), size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, marlin::max_par, use_atomic_add, - use_fp32_reduce, is_zp_float); + thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); } else if (a.scalar_type() == at::ScalarType::BFloat16) { marlin::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), @@ -2458,7 +837,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, a.stride(0), workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, - marlin::max_par, use_atomic_add, use_fp32_reduce, is_zp_float); + use_atomic_add, use_fp32_reduce, is_zp_float); } else { TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); } diff --git a/csrc/quantization/gptq_marlin/kernel.h b/csrc/quantization/gptq_marlin/kernel.h new file mode 100644 index 00000000000..eb2700c95e8 --- /dev/null +++ b/csrc/quantization/gptq_marlin/kernel.h @@ -0,0 +1,37 @@ + +#ifndef MARLIN_NAMESPACE_NAME + #define MARLIN_NAMESPACE_NAME marlin +#endif + +#include "marlin.cuh" +#include "marlin_dtypes.cuh" +#include "core/scalar_type.hpp" + +#define MARLIN_KERNEL_PARAMS \ + const int4 *__restrict__ A, const int4 *__restrict__ B, \ + int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ + const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \ + const int *__restrict__ g_idx, int num_groups, int prob_m, int prob_n, \ + int prob_k, int lda, int *locks, bool use_atomic_add, \ + bool use_fp32_reduce, int max_shared_mem + +namespace MARLIN_NAMESPACE_NAME { +template shared + // fetch pipeline + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin(MARLIN_KERNEL_PARAMS); + +} diff --git a/csrc/quantization/gptq_marlin/marlin_template.h b/csrc/quantization/gptq_marlin/marlin_template.h new file mode 100644 index 00000000000..ca05b8a25f8 --- /dev/null +++ b/csrc/quantization/gptq_marlin/marlin_template.h @@ -0,0 +1,1678 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#ifndef MARLIN_NAMESPACE_NAME + #define MARLIN_NAMESPACE_NAME marlin +#endif + +#include "marlin.cuh" +#include "marlin_dtypes.cuh" +#include "dequant.h" +#include "core/scalar_type.hpp" + +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert(std::is_same::value || \ + std::is_same::value, \ + "only float16 and bfloat16 is supported"); + +namespace MARLIN_NAMESPACE_NAME { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks, // extra global storage for barrier synchronization + bool use_fp32_reduce // whether to use fp32 global reduce +) {} + +} // namespace marlin + +#else + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +template +__device__ inline void mma(const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } +} + +template +__device__ inline void mma_trans( + const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + const typename ScalarType::FragB& frag_b2, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + const uint32_t* b2 = reinterpret_cast(&frag_b2); + float* c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +template +__device__ inline void ldsm(typename ScalarType::FragA& frag_a, + const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + if constexpr (count == 4) { + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); + } else if constexpr (count == 2) { + asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" + : "=r"(a[0]), "=r"(a[1]) + : "r"(smem)); + } else if constexpr (count == 1) { + asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" + : "=r"(a[0]) + : "r"(smem)); + } else { + static_assert(count == 1 || count == 2 || count == 4, "invalid count"); + } +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +template +__device__ inline void scale(typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s = + ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +template +__device__ inline void scale_and_sub( + typename ScalarType::FragB& frag_b, scalar_t s, scalar_t zp) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s2 = ScalarType::num2num2(s); + scalar_t2 zp2 = ScalarType::num2num2(zp); + frag_b[0] = __hfma2(frag_b[0], s2, __hneg2(zp2)); + frag_b[1] = __hfma2(frag_b[1], s2, __hneg2(zp2)); +} + +template +__device__ inline void sub_zp(typename ScalarType::FragB& frag_b, + typename ScalarType::scalar_t2& frag_zp, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 zp = + ScalarType::num2num2(reinterpret_cast(&frag_zp)[i]); + frag_b[0] = __hsub2(frag_b[0], zp); + frag_b[1] = __hsub2(frag_b[1], zp); +} + +// Same as above, but for act_order (each K is multiplied individually) +template +__device__ inline void scale4(typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s_1, + typename ScalarType::FragS& frag_s_2, + typename ScalarType::FragS& frag_s_3, + typename ScalarType::FragS& frag_s_4, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s_val_1_2; + s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; + + scalar_t2 s_val_3_4; + s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; + + frag_b[0] = __hmul2(frag_b[0], s_val_1_2); + frag_b[1] = __hmul2(frag_b[1], s_val_3_4); +} + +// Given 2 floats multiply by 2 scales (halves) +template +__device__ inline void scale_float(float* c, + typename ScalarType::FragS& s) { + scalar_t* s_ptr = reinterpret_cast(&s); + c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} + +// Wait until value of lock to be negative, and then add 1 +__device__ inline void wait_negative_and_add(int* lock) { + if (threadIdx.x == 0) { + int state = 0; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state >= 0); + atomicAdd(lock, 1); + } + __syncthreads(); +} + +template shared + // fetch pipeline + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int lda, // A.stride(0), equal to prob_k is A is contiguous + int* locks, // extra global storage for barrier synchronization + bool use_atomic_add, // whether to use atomic add to reduce + bool use_fp32_reduce, // whether to use fp32 global reduce + int max_shared_mem) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + using Dtype = ScalarType; + using scalar_t2 = typename ScalarType::scalar_t2; + using FragA = typename ScalarType::FragA; + using FragB = typename ScalarType::FragB; + using FragC = typename ScalarType::FragC; + using FragS = typename ScalarType::FragS; + using FragZP = typename ScalarType::FragZP; + + static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); + constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8; + constexpr bool has_act_order = group_blocks == 0; + constexpr int m_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); + + constexpr int pack_factor = 32 / w_type.size_bits(); + static_assert(thread_m_blocks == 1 || !m_block_size_8); + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > m_block_size) { + parallel = prob_m / m_block_size; + prob_m = m_block_size; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * + div_ceil(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + int par_id = 0; + int locks_off = 0; + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * lda / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + slice_col = slice_col_par % n_tiles; + par_id = slice_col_par / n_tiles; + } + if (parallel * n_tiles >= gridDim.x) { + // when parallel * n_tiles >= sms + // then there are at most $sms$ conflict tile blocks + locks_off = blockIdx.x; + } else { + locks_off = (iters * blockIdx.x) / k_tiles - 1; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&](bool first_init = false) { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = div_ceil(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (parallel * n_tiles >= gridDim.x) { + if (slice_count > 1 && slice_idx == slice_count - 1) { + locks_off++; + } + } else { + locks_off++; + } + + if (first_init && use_atomic_add && slice_count > 1 && slice_idx == 0) { + constexpr int threads_per_m = 16 * thread_n_blocks / 8; + int m_per_thread = + div_ceil(thread_m_blocks * 16, threads / threads_per_m); + if (m_block_size_8) m_per_thread = div_ceil(8, threads / threads_per_m); + for (int i = 0; i < m_per_thread; i++) { + int row = threads / threads_per_m * i + threadIdx.x / threads_per_m; + if (row < prob_m) { + int col = slice_col * 16 * thread_n_blocks / 8 + + threadIdx.x % threads_per_m; + C[row * prob_n / 8 + col] = {0, 0, 0, 0}; + } + } + // After write zero to output, write a negative value to lock. + // Every SM that processes the same slice would wait for + // the negative value, and then atomicAdd 1 to it. + // After all SMs are processed, the lock value would back to 0 again. + __syncthreads(); + if (threadIdx.x == 0) locks[locks_off] = 1 - slice_count; + } + + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * lda / 8; + C += 16 * thread_m_blocks * prob_n / 8; + slice_col = 0; + par_id++; + } + }; + init_slice(true); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = lda / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * m_block_size; + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_tb_groups = + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks + : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + constexpr int act_s_max_num_groups = 32; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + // Zero-points sizes/strides + int zp_gl_stride = is_zp_float ? prob_n / 8 : (prob_n / pack_factor) / 4; + constexpr int zp_sh_stride = is_zp_float + ? 16 * thread_n_blocks / 8 + : ((16 * thread_n_blocks) / pack_factor) / 4; + constexpr int zp_tb_groups = s_tb_groups; + constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; + int zp_gl_rd_delta = zp_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) + + (threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1)); + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + auto b_sh_wr = threadIdx.x * b_thread_vecs; + auto b_sh_rd = threadIdx.x * b_thread_vecs; + + // For act_order + constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (!has_act_order) { + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + } + } + auto s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // Zero-points + int zp_gl_rd; + if constexpr (has_zp) { + if constexpr (group_blocks == -1) { + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + zp_sh_stride * slice_col + threadIdx.x; + } + } + auto zp_sh_wr = threadIdx.x; + bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else if constexpr (group_blocks == -1 && (m_block_size_8 || has_zp)) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 8; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + + // Zero-points have the same read layout as the scales + // (without column-wise case) + constexpr int num_col_threads = 8; + constexpr int num_row_threads = 4; + constexpr int num_ints_per_thread = 8 / pack_factor; + int zp_sh_rd; + if constexpr (has_zp) { + if constexpr (is_zp_float) { + if constexpr (group_blocks != -1) { + zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + } + } else { + zp_sh_rd = num_ints_per_thread * num_col_threads * + ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); + } + } + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ (row % 8); + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + constexpr int sh_red_size = (2 * thread_n_blocks + 1) * 16 * thread_m_blocks; + constexpr int sh_b_size = stages * b_sh_stage; + int4* sh_b = sh; + int4* sh_red = sh; + int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + int4* sh_zp = sh_g_idx + (stages * g_idx_stage); + constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride) + : (stages * s_sh_stage); + int4* sh_s = sh_zp + (stages * zp_sh_stage); + // shared memory reused by reduction should be smaller than + // shared memory used by weight. + static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <= + stages * b_sh_stage); + int4* sh_a = sh_s + sh_s_size; + // constexpr int shm_size_used = + // stages * (g_idx_stage + zp_sh_stage) + sh_s_size + + // (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2][b_thread_vecs]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + int frag_qzp[2][num_ints_per_thread]; // Zero-points + FragZP frag_zp; // Zero-points in fp16 + FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ + + // Zero accumulators. + auto zero_accums = [&]() { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + int sh_first_group_id = -1; + int sh_num_groups = -1; + + auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id, + int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups < act_s_max_num_groups) { + sh_num_groups = act_s_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); + } + + B_ptr[i] += b_gl_rd_delta_o; + } + + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const* cur_g_idx_stage_ptr = + reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], + &cur_g_idx_stage_ptr[threadIdx.x]); + } + } + } else { + if constexpr (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], + &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + } + + if constexpr (has_zp && group_blocks != -1) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch zero-points if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } else { + for (int i = 0; i < zp_tb_groups; i++) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], + &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + auto fetch_col_zp_to_shared = [&]() { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + }; + + auto fetch_col_scale_to_shared = [&]() { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm( + frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + + #pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; + + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks == -1) { + // load only when starting a new slice + if (k == 0 && full_pipe == 0) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } else if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + if (k % b_sh_wr_iters == 0) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + reinterpret_cast(&frag_s[1])[0] = + reinterpret_cast(&frag_s[0])[0]; + } + } else { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / group_blocks; + + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + auto warp_id = threadIdx.x / 32; + int n_warps = + thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + auto th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + + s_col_shift]; + } else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, + 9}; // Tensor core offsets per thread + + #pragma unroll + for (int i = 0; i < 4; i++) { + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + auto fetch_zp_to_registers = [&](int k, int full_pipe) { + // This code does not handle group_blocks == 0, + // which signifies act_order. + // has_zp implies AWQ, which doesn't have act_order, + static_assert(!has_zp || group_blocks != 0); + + if constexpr (has_zp && !is_zp_float) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks == -1) { + // load only when starting a new slice + if (k == 0 && full_pipe == 0) { + #pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; + } + } + + } else if constexpr (group_blocks >= thread_k_blocks) { + if (k % b_sh_wr_iters == 0) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + #pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } else { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = 0; + + // Suppress bogus and persistent divide-by-zero warning + #pragma nv_diagnostic push + #pragma nv_diag_suppress divide_by_zero + cur_group_id = k_blocks / group_blocks; + #pragma nv_diagnostic pop + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + sh_zp_stage += cur_group_id * zp_sh_stride; + + #pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } + + else if constexpr (has_zp && is_zp_float) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + if (k % b_sh_wr_iters == 0) { + int4* sh_zp_stage = + sh_zp + + zp_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_zpf[k % 2])[0] = + sh_zp_stage[zp_sh_rd]; + } + } else { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + // Suppress bogus and persistent divide-by-zero warning + #pragma nv_diagnostic push + #pragma nv_diag_suppress divide_by_zero + int cur_group_id = k_blocks / group_blocks; + #pragma nv_diagnostic pop + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + reinterpret_cast(&frag_zpf[k % 2])[0] = + sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride]; + } + } + } + }; + + auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) { + if constexpr (has_zp && is_zp_float || !has_zp) { + dequant(q, frag_b_ptr); + } else { + static_assert(has_zp && !is_zp_float); + static_assert(w_type_id == vllm::kU4.id() || w_type_id == vllm::kU8.id()); + // If (has_zp && !is_zp_float), + // we use not-zp version `dequant` function + // to improve numerical accuracy. + // Since both weight and zero point are dequanted using this logic, + // the final dequanted weight would be correct. + if constexpr (w_type_id == vllm::kU4.id()) { + dequant(q, frag_b_ptr); + } else if constexpr (w_type_id == vllm::kU8.id()) { + dequant(q, frag_b_ptr); + } + } + }; + + // Execute the actual tensor core matmul of a sub-tile. + bool is_first_matmul_in_slice = true; + auto matmul = [&](int k) { + int k2 = k % 2; + const bool is_new_zp = + ((group_blocks != -1) && (group_blocks < thread_k_blocks || k == 0)) || + (group_blocks == -1 && is_first_matmul_in_slice); + if constexpr (has_zp && !is_zp_float) { + if (is_new_zp) { + if constexpr (group_blocks == -1) is_first_matmul_in_slice = false; + FragB frag_zp_0; + FragB frag_zp_1; + int zp_quant_0, zp_quant_1; + + if constexpr (w_type.size_bits() == 4) { + zp_quant_0 = frag_qzp[k2][0]; + zp_quant_1 = zp_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + zp_quant_0 = frag_qzp[k2][0]; + zp_quant_1 = frag_qzp[k2][1]; + } + + dequant_data(zp_quant_0, reinterpret_cast(&frag_zp)); + dequant_data(zp_quant_1, reinterpret_cast(&frag_zp) + 2); + } + } + if constexpr (has_zp && is_zp_float) { + if (is_new_zp) { + reinterpret_cast(&frag_zp)[0] = + reinterpret_cast(&frag_zpf[k2])[0]; + } + } + + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + FragB frag_b0; + FragB frag_b1; + int b_quant_0, b_quant_1; + + if constexpr (w_type.size_bits() == 4) { + b_quant_0 = frag_b_quant[k2][0][j]; + b_quant_1 = b_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k2]); + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + } + + dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); + dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); + + // Apply scale to frag_b0 + if constexpr (has_act_order) { + static_assert(group_blocks != -1); + scale4(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], + act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); + scale4(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], + act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); + } else if constexpr (has_zp && !is_zp_float && group_blocks == -1) { + int idx = (threadIdx.x / 4) % 2; + scalar_t2 s2 = Dtype::nums2num2( + reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], + reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 1])[idx]); + if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2); + scale_and_sub(frag_b0, s2.x, frag_zp[j].x); + scale_and_sub(frag_b1, s2.y, frag_zp[j].y); + } else if constexpr (has_zp && group_blocks != -1) { + if (is_new_zp) + frag_zp[j] = __hmul2(frag_zp[j], + *reinterpret_cast(&frag_s[k2][j])); + scale_and_sub(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x); + scale_and_sub(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y); + } else if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k2][j], 0); + scale(frag_b1, frag_s[k2][j], 1); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + if constexpr (m_block_size_8) { + mma_trans(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]); + } else { + mma(frag_a[k2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k2][i], frag_b1, frag_c[i][j][1]); + } + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + auto red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j += (m_block_size_8 ? 2 : 1)) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = reinterpret_cast( + &sh_red[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh_red[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh_red[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i += (m_block_size_8 ? 2 : 1)) { + float* c_rd = + reinterpret_cast(&sh_red[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce_fp16 = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr; + if constexpr (m_block_size_8) { + c_gl_wr = c_gl_stride * ((threadIdx.x % 4) * 2) + + 4 * (threadIdx.x / 32) + (threadIdx.x % 32) / 8; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + } else { + c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + } + constexpr int c_sh_wr_delta = active_threads; + auto c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) { + if constexpr (m_block_size_8) { + cp_async4_pred(&sh_red[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + i * c_gl_stride + + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i], + (threadIdx.x % 4) * 2 + i < prob_m); + } else { + cp_async4_pred( + &sh_red[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); + } + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) { + bool mask = (!m_block_size_8) && (i < (thread_m_blocks - 1) * 4 || + 8 * (i / 2) + row < prob_m) || + (m_block_size_8) && ((threadIdx.x % 4) * 2 + i < prob_m); + if (mask) { + if (!first) { + int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + int delta = 0; + if constexpr (m_block_size_8) { + delta = j % 2 == 1 ? -2 : 0; + } + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta] += + Dtype::num2float(reinterpret_cast(&c_red)[j]); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + int delta = 0; + if constexpr (m_block_size_8) { + delta = j % 2 == 1 ? -2 : 0; + } + reinterpret_cast(&c)[j] = + Dtype::float2num(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta]); + } + if constexpr (m_block_size_8) + C[c_gl_wr + i * c_gl_stride + + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i] = c; + else + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)] = c; + } + } + } + } + }; + + // Globally reduce over threadblocks that compute the same column block. + // We use a tmp C buffer to reduce in full fp32 precision. + auto global_reduce_fp32 = [&](bool first = false, bool last = false) { + constexpr int tb_m = thread_m_blocks * 16; + constexpr int tb_n = thread_n_blocks * 16; + + constexpr int c_size = tb_m * tb_n * sizeof(float) / 16; + + constexpr int active_threads = 32 * thread_n_blocks / 4; + bool is_th_active = threadIdx.x < active_threads; + + constexpr int num_floats = thread_m_blocks * 4 * 2 * 4; + constexpr int th_size = num_floats * sizeof(float) / 16; + + int c_cur_offset = locks_off * c_size; + + if (!is_th_active) { + return; + } + + if (!first) { + float* frag_c_ptr = reinterpret_cast(&frag_c); + #pragma unroll + for (int k = 0; k < th_size; k += (m_block_size_8 ? 2 : 1)) { + sh_red[threadIdx.x] = + C_tmp[c_cur_offset + active_threads * k + threadIdx.x]; + + float* sh_c_ptr = reinterpret_cast(&sh_red[threadIdx.x]); + #pragma unroll + for (int f = 0; f < 4; f++) { + frag_c_ptr[k * 4 + f] += sh_c_ptr[f]; + } + } + } + + if (!last) { + int4* frag_c_ptr = reinterpret_cast(&frag_c); + #pragma unroll + for (int k = 0; k < th_size; k += (m_block_size_8 ? 2 : 1)) { + C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k]; + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = + c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr; + if constexpr (m_block_size_8) { + c_sh_wr = (8 * c_sh_stride) * ((threadIdx.x % 32) % 4 * 2) + + (threadIdx.x % 32) / 4; + c_sh_wr += 64 * (threadIdx.x / 32); + } else { + c_sh_wr = + (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + } + + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s) { + scalar_t2 res = + Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 4 && !has_zp) { + res = __hmul2(res, s[0]); + } + + if constexpr (m_block_size_8) { + ((scalar_t*)sh_red)[idx] = res.x; + ((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; + } else { + ((scalar_t2*)sh_red)[idx] = res; + } + }; + + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + if constexpr (m_block_size_8) { + int wr = c_sh_wr + 16 * j; + write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], + frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3], + frag_s[j / 2][2 * (j % 2) + 1]); + } else { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; + i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + if (use_atomic_add && slice_count > 1) { + scalar_t2* C_half2 = reinterpret_cast(&C[c_gl_wr]); + scalar_t2* sh_red_half2 = + reinterpret_cast(&sh_red[c_sh_rd]); + #pragma unroll + for (int a = 0; a < 4; a++) { + atomicAdd(&C_half2[a], sh_red_half2[a]); + } + } else { + C[c_gl_wr] = sh_red[c_sh_rd]; + } + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + __syncthreads(); + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + + #pragma unroll + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_act_order_scales_to_shared(true, g_idx[slice_k_start], + g_idx[last_g_idx]); + } + + if constexpr (has_zp && !is_zp_float && group_blocks == -1) { + if (i == 0) { + fetch_col_zp_to_shared(); + fetch_col_scale_to_shared(); + } + } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + fetch_zp_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + if constexpr (has_act_order) { + slice_k_start_shared_fetch += tb_k * (stages - 1); + } + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + fetch_zp_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + + a_gl_rd += a_gl_rd_delta_o * stages; + + if constexpr (has_act_order) { + slice_k_start += tb_k * stages; + slice_k_start_shared_fetch += tb_k * stages; + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_act_order_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } + } + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if constexpr (!has_act_order && group_blocks == -1 && !has_zp) { + if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1 && !has_zp) { + if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + if constexpr (m_block_size_8) { + int idx = (threadIdx.x / 4) % 2; + scalar_t2* frag_s_half2 = reinterpret_cast(frag_s); + #pragma unroll + for (int i = 0; i < 8; i++) { + frag_s_half2[i] = Dtype::num2num2( + reinterpret_cast(&frag_s_half2[i])[idx]); + } + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 8 && !has_zp) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + scale_float( + reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float( + reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]); + + if constexpr (!m_block_size_8) { + scale_float( + reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float( + reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + } + } + + if (slice_count > 1 && !use_atomic_add) { + // only globally reduce if there is more than one block in a slice + barrier_acquire(&locks[locks_off], slice_idx); + if (use_fp32_reduce) { + global_reduce_fp32(slice_idx == 0, last); + } else { + global_reduce_fp16(slice_idx == 0, last); + } + barrier_release(&locks[locks_off], last); + } + if (use_atomic_add && slice_count > 1 && slice_idx != 0) + wait_negative_and_add(&locks[locks_off]); + if (last || use_atomic_add) + // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + is_first_matmul_in_slice = true; + init_slice(); + + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + + } else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } + + start_pipes(); + } + } + } +} + +} // namespace MARLIN_NAMESPACE_NAME + +#endif diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 5ed33097672..f59b42d88c6 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -291,12 +291,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // gptq_marlin Optimized Quantized GEMM for GPTQ. ops.def( - "gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " - "Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, " - "int b_q_type, " + "gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, " + "Tensor b_scales, Tensor? b_zeros_or_none, Tensor? g_idx_or_none, " + "Tensor? perm_or_none, Tensor workspace, int b_q_type, " "SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, " - "bool has_zp, bool use_atomic_add, bool use_fp32_reduce, " - "bool is_zp_float) -> Tensor", + "bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor", {stride_tag}); // conditionally compiled so impl registration is in source file @@ -341,14 +340,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size); #ifndef USE_ROCM - // fp8_marlin Optimized Quantized GEMM for FP8 weight-only. - ops.def( - "fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " - "Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, " - "SymInt size_k) -> Tensor", - {stride_tag}); - // conditionally compiled so impl registration is in source file - // marlin_qqq_gemm for QQQ. ops.def( "marlin_qqq_gemm(Tensor a, Tensor b_q_weight, " diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index f2cca65ae42..abf3e3667a7 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -11,19 +11,20 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock import vllm.model_executor.layers.fused_moe # noqa -from tests.kernels.utils import (opcheck, stack_and_dev, torch_moe, - torch_moe_single) +from tests.kernels.utils import opcheck, stack_and_dev, torch_moe from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( fused_moe as iterative_moe) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + marlin_quant_fp8_torch) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( awq_marlin_quantize, marlin_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( quantize_weights) from vllm.model_executor.models.mixtral import MixtralMoE from vllm.platforms import current_platform -from vllm.scalar_type import scalar_types +from vllm.scalar_type import ScalarType, scalar_types NUM_EXPERTS = [8, 64] EP_SIZE = [1, 4] @@ -285,7 +286,7 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, atol=mixtral_moe_tol[dtype]) -@pytest.mark.parametrize("m", [1, 33, 123]) +@pytest.mark.parametrize("m", [1, 123, 666]) @pytest.mark.parametrize("n", [128, 1024]) @pytest.mark.parametrize("k", [256, 2048]) @pytest.mark.parametrize("e", [4, 12]) @@ -294,8 +295,10 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("group_size", [-1, 32, 128]) @pytest.mark.parametrize("act_order", [True, False]) -@pytest.mark.parametrize("num_bits", [4, 8]) -@pytest.mark.parametrize("has_zp", [True, False]) +@pytest.mark.parametrize("quant_type", [ + scalar_types.uint4, scalar_types.uint8b128, scalar_types.uint4b8, + scalar_types.float8_e4m3fn +]) @pytest.mark.parametrize("is_k_full", [True, False]) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") def test_fused_marlin_moe( @@ -308,14 +311,22 @@ def test_fused_marlin_moe( dtype: torch.dtype, group_size: int, act_order: bool, - num_bits: int, - has_zp: bool, + quant_type: ScalarType, is_k_full: bool, ): - current_platform.seed_everything(7) + torch.cuda.manual_seed(0) + has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] + + if quant_type == scalar_types.float8_e4m3fn: + if group_size not in [-1, 128]: + return + if act_order: + return # Filter act_order if act_order: + if quant_type == scalar_types.float8_e4m3fn: + return if group_size == -1: return if group_size in (k, n): @@ -326,17 +337,9 @@ def test_fused_marlin_moe( if not is_k_full: return - if has_zp: - # we don't build kernel for int8 with zero - if num_bits == 8: - return - quant_type = scalar_types.uint4 if num_bits == 4 else scalar_types.uint8 - else: - quant_type = scalar_types.uint4b8 \ - if num_bits == 4 else scalar_types.uint8b128 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20 if ep_size > 1: local_e = e // ep_size @@ -364,17 +367,23 @@ def test_fused_marlin_moe( qweight1_l.append(qweight1) scales1_l.append(scales1) zeros1_l.append(zeros1) - else: + elif quant_type != scalar_types.float8_e4m3fn: test_perm = torch.randperm(k) - quant_res = marlin_quantize(w1[i].transpose(1, 0), quant_type, - group_size, act_order, test_perm) - w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = quant_res + w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \ + marlin_quantize(w1[i].transpose(1, 0), quant_type, + group_size, act_order, test_perm) w_ref1_l.append(w_ref1.T) qweight1_l.append(qweight1) scales1_l.append(scales1) g_idx1_l.append(g_idx1) sort_indices1_l.append(sort_indices1) + else: + w_ref1, qweight1, scales1 = marlin_quant_fp8_torch( + w1[i], group_size) + w_ref1_l.append(w_ref1.T) + qweight1_l.append(qweight1) + scales1_l.append(scales1) w_ref1 = stack_and_dev(w_ref1_l) qweight1 = stack_and_dev(qweight1_l).contiguous() @@ -399,17 +408,23 @@ def test_fused_marlin_moe( qweight2_l.append(qweight2) scales2_l.append(scales2) zeros2_l.append(zeros2) - else: + elif quant_type != scalar_types.float8_e4m3fn: test_perm = torch.randperm(n) - quant_res = marlin_quantize(w2[i].transpose(1, 0), quant_type, - group_size, act_order, test_perm) - w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = quant_res + w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \ + marlin_quantize(w2[i].transpose(1, 0), quant_type, + group_size, act_order, test_perm) w_ref2_l.append(w_ref2.T) qweight2_l.append(qweight2) scales2_l.append(scales2) g_idx2_l.append(g_idx2) sort_indices2_l.append(sort_indices2) + else: + w_ref2, qweight2, scales2 = marlin_quant_fp8_torch( + w2[i], group_size) + w_ref2_l.append(w_ref2.T) + qweight2_l.append(qweight2) + scales2_l.append(scales2) w_ref2 = stack_and_dev(w_ref2_l) qweight2 = stack_and_dev(qweight2_l).contiguous() @@ -442,102 +457,10 @@ def test_fused_marlin_moe( sort_indices2=sort_indices2, w1_zeros=zeros1, w2_zeros=zeros2, - num_bits=num_bits, + quant_type_id=quant_type.id, is_k_full=is_k_full) - torch.testing.assert_close(marlin_output, torch_output, atol=2e-2, rtol=0) - - -@pytest.mark.skip("This test is here for the sake of debugging, " - "don't run it in automated tests.") -@pytest.mark.parametrize("m", [1, 33, 123]) -@pytest.mark.parametrize("n", [128, 1024]) -@pytest.mark.parametrize("k", [256, 2048]) -@pytest.mark.parametrize("e", [4, 12]) -@pytest.mark.parametrize("topk", [2, 3]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("group_size", [-1, 32, 128]) -@pytest.mark.parametrize("act_order", [True, False]) -@pytest.mark.parametrize("num_bits", [4, 8]) -@pytest.mark.parametrize("has_zp", [True, False]) -@pytest.mark.parametrize("is_k_full", [True, False]) -def test_single_marlin_moe_multiply(m: int, n: int, k: int, e: int, topk: int, - dtype: torch.dtype, group_size: int, - act_order: bool, num_bits: int, - has_zp: bool, is_k_full: bool): - # Filter act_order - if act_order: - if group_size == -1: - return - if group_size in (k, n): - return - if has_zp: - return - else: - if not is_k_full: - return - - if has_zp: - quant_type = scalar_types.uint4 if num_bits == 4 else scalar_types.uint8 - else: - quant_type = scalar_types.uint4b8 \ - if num_bits == 4 else scalar_types.uint8b128 - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 - - w_ref_l = [] - qweight_l = [] - scales_l = [] - zeros_l = [] - g_idx_l = [] - sort_indices_l = [] - - for i in range(w.shape[0]): - if has_zp: - w_ref, qweight, scales, zeros = awq_marlin_quantize( - w[i].transpose(1, 0), quant_type, group_size) - - w_ref_l.append(w_ref.T) - qweight_l.append(qweight) - scales_l.append(scales) - zeros_l.append(zeros) - else: - test_perm = torch.randperm(k) - w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize( - w[i].transpose(1, 0), quant_type, group_size, act_order, - test_perm) - - w_ref_l.append(w_ref.T) - qweight_l.append(qweight) - scales_l.append(scales) - g_idx_l.append(g_idx) - sort_indices_l.append(sort_indices) - - w_ref = stack_and_dev(w_ref_l) - qweight = stack_and_dev(qweight_l).contiguous() - scales = stack_and_dev(scales_l) - g_idx = stack_and_dev(g_idx_l) if g_idx_l else None - zeros = stack_and_dev(zeros_l) if zeros_l else None - sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None - - score = torch.randn((m, e), device="cuda", dtype=dtype) - marlin_output = torch.ops.vllm.single_marlin_moe( - a, - qweight, - scales, - score, - topk, - renormalize=False, - g_idx=g_idx, - sort_indices=sort_indices, - w_zeros=zeros, - num_bits=num_bits, - is_k_full=is_k_full, - ) - - torch_output = torch_moe_single(a, w_ref, score, topk) - - torch.testing.assert_close(marlin_output, torch_output, atol=2e-2, rtol=0) + torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0) def test_moe_align_block_size_opcheck(): diff --git a/tests/kernels/quantization/test_awq_marlin.py b/tests/kernels/quantization/test_awq_marlin.py deleted file mode 100644 index c30fe60becd..00000000000 --- a/tests/kernels/quantization/test_awq_marlin.py +++ /dev/null @@ -1,164 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Test AWQ with fused MoE Marlin kernels. - -Run `pytest tests/kernels/test_awq_marlin.py`. -""" -import pytest -import torch - -import vllm.model_executor.layers.fused_moe # noqa -from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe, - torch_moe_single) -from vllm import _custom_ops as ops -from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk -from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( - awq_marlin_quantize) -from vllm.scalar_type import scalar_types - -NUM_EXPERTS = [8, 64] -TOP_KS = [2, 6] -GROUP_SIZES = [-1, 32, 128] - - -@pytest.mark.parametrize("m", [1, 33, 64, 222]) -@pytest.mark.parametrize("n", [128, 2048]) -@pytest.mark.parametrize("k", [128, 1024]) -@pytest.mark.parametrize("e", NUM_EXPERTS) -@pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("group_size", GROUP_SIZES) -@pytest.mark.skipif(not (ops.supports_moe_ops - and hasattr(torch.ops._moe_C, "marlin_gemm_moe")), - reason="Marlin is not supported on this GPU type.") -def test_fused_marlin_moe_awq( - m: int, - n: int, - k: int, - e: int, - topk: int, - group_size: int, -): - torch.manual_seed(7) - - num_bits = 4 - quant_type = scalar_types.uint4 - dtype = torch.float16 - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - - w_ref1_l = [] - qweights1_l = [] - scales1_l = [] - zp1_l = [] - - for i in range(w1.shape[0]): - w_ref1, qweight1, scales1, zp1 = awq_marlin_quantize( - w1[i].transpose(1, 0), quant_type, group_size) - w_ref1_l.append(w_ref1) - qweights1_l.append(qweight1) - scales1_l.append(scales1) - zp1_l.append(zp1) - - w_ref1 = stack_and_dev(w_ref1_l) - qweight1 = stack_and_dev(qweights1_l).contiguous() - scales1 = stack_and_dev(scales1_l) - zp1 = stack_and_dev(zp1_l) - - w_ref2_l = [] - qweights2_l = [] - scales2_l = [] - zp2_l = [] - - for i in range(w2.shape[0]): - w_ref2, qweight2, scales2, zp2 = awq_marlin_quantize( - w2[i].transpose(1, 0), quant_type, group_size) - w_ref2_l.append(w_ref2) - qweights2_l.append(qweight2) - scales2_l.append(scales2) - zp2_l.append(zp2) - - w_ref2 = stack_and_dev(w_ref2_l) - qweight2 = stack_and_dev(qweights2_l).contiguous() - scales2 = stack_and_dev(scales2_l) - zp2 = stack_and_dev(zp2_l) - - score = torch.randn((m, e), device="cuda", dtype=dtype) - - topk_weights, topk_ids, token_expert_indices = fused_topk( - a, score, topk, False) - marlin_output = torch.ops.vllm.fused_marlin_moe( - a, - qweight1, - qweight2, - scales1, - scales2, - score, - topk_weights, - topk_ids, - w1_zeros=zp1, - w2_zeros=zp2, - num_bits=num_bits, - ) - - torch_output = torch_moe(a, w_ref1.transpose(1, 2), w_ref2.transpose(1, 2), - score, topk, None) - - assert compute_max_diff(marlin_output, torch_output) < 4e-2 - - -@pytest.mark.skip("This test is here for the sake of debugging, " - "don't run it in automated tests.") -@pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) -@pytest.mark.parametrize("n", [128, 2048, 256, 1024]) -@pytest.mark.parametrize("k", [128, 1024, 512]) -@pytest.mark.parametrize("e", [8, 64]) -@pytest.mark.parametrize("topk", [2, 6]) -@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) -def test_single_marlin_moe_multiply_awq( - m: int, - n: int, - k: int, - e: int, - topk: int, - group_size: int, -): - torch.manual_seed(7) - - num_bits = 4 - quant_type = scalar_types.uint4 - dtype = torch.float16 - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 - - w_ref_l = [] - qweights_l = [] - scales_l = [] - zp_l = [] - - for i in range(w.shape[0]): - w_ref, qweight, scales, zp = awq_marlin_quantize( - w[i].transpose(1, 0), quant_type, group_size) - w_ref_l.append(w_ref) - qweights_l.append(qweight) - scales_l.append(scales) - zp_l.append(zp) - - w_ref = stack_and_dev(w_ref_l) - qweight = stack_and_dev(qweights_l).contiguous() - scales = stack_and_dev(scales_l).contiguous() - zp = stack_and_dev(zp_l).contiguous() - - score = torch.randn((m, e), device="cuda", dtype=dtype) - - marlin_output = torch.ops.vllm.single_marlin_moe(a, - qweight, - scales, - score, - topk, - renormalize=False, - w_zeros=zp, - num_bits=num_bits) - - torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) - - assert compute_max_diff(marlin_output, torch_output) < 1e-2 diff --git a/tests/kernels/quantization/test_marlin_gemm.py b/tests/kernels/quantization/test_marlin_gemm.py index 3165201aa35..c125e0b5ec7 100644 --- a/tests/kernels/quantization/test_marlin_gemm.py +++ b/tests/kernels/quantization/test_marlin_gemm.py @@ -18,9 +18,10 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx, - marlin_permute_scales, query_marlin_supported_quant_types) + marlin_make_workspace_new, marlin_permute_scales, + query_marlin_supported_quant_types) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( - pack_fp8_to_int32) + marlin_quant_fp8_torch) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( MarlinWorkspace, awq_marlin_quantize, get_weight_perm, marlin_quantize, marlin_weights) @@ -73,7 +74,7 @@ def rand_data(shape, dtype=torch.float16): @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @pytest.mark.parametrize("quant_type", - query_marlin_supported_quant_types(False)) + query_marlin_supported_quant_types(False, False)) @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @@ -138,7 +139,7 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @pytest.mark.parametrize("quant_type", - query_marlin_supported_quant_types(False)) + query_marlin_supported_quant_types(True)) @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, @@ -220,38 +221,50 @@ def test_gptq_marlin_gemm( if group_size == size_k: return + if size_k % group_size != 0: + return + a_input = rand_data((size_m, size_k)) b_weight = rand_data((size_k, size_n)) - w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( - b_weight, quant_type, group_size, act_order) + if quant_type == scalar_types.float8_e4m3fn: + if group_size not in [-1, 128]: + return + if act_order: + return + w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch( + b_weight.T, group_size) + g_idx = None + sort_indices = None + else: + w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( + b_weight, quant_type, group_size, act_order) marlin_zp = marlin_make_empty_g_idx(marlin_s.device) - workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_MAX_PARALLEL) + workspace = marlin_make_workspace_new(w_ref.device) - opcheck(torch.ops._C.gptq_marlin_gemm, - (a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices, - workspace.scratch, quant_type.id, a_input.shape[0], - b_weight.shape[1], a_input.shape[1], is_k_full, False, - use_atomic_add, use_fp32_reduce, False), - test_utils=DEFAULT_OPCHECK_TEST_UTILS) + opcheck( + torch.ops._C.gptq_marlin_gemm, + (a_input, None, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices, + workspace, quant_type.id, a_input.shape[0], b_weight.shape[1], + a_input.shape[1], is_k_full, use_atomic_add, use_fp32_reduce, False), + test_utils=DEFAULT_OPCHECK_TEST_UTILS) output = ops.gptq_marlin_gemm( a_input, + None, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices, - workspace.scratch, + workspace, quant_type, a_input.shape[0], b_weight.shape[1], a_input.shape[1], is_k_full=is_k_full, - has_zp=False, use_atomic_add=use_atomic_add, use_fp32_reduce=use_fp32_reduce, is_zp_float=False, @@ -326,80 +339,6 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, assert max_diff < 0.04 -@pytest.mark.skipif(not is_quant_method_supported("fp8"), - reason="Marlin is not supported on this GPU type.") -@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) -@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("num_bits", [8]) -@pytest.mark.parametrize("group_size", [-1]) -@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -@pytest.mark.parametrize("dtype", DTYPES) -def test_fp8_marlin_gemm( - k_chunk, - n_chunk, - num_bits, - group_size, - mnk_factors, - dtype, -): - m_factor, n_factor, k_factor = mnk_factors - - size_m = m_factor - size_k = k_chunk * k_factor - size_n = n_chunk * n_factor - - a_input = rand_data((size_m, size_k), dtype=dtype) - b_weight = rand_data((size_k, size_n), dtype=dtype) - - # WEIGHTS - fp8_weight, weight_scale = ops.scaled_fp8_quant(b_weight, scale=None) - # Repack weights to gptq format (packed int32 elements) - packed_gptq_qweight = pack_fp8_to_int32(fp8_weight) - # Repack weights to marlin format - marlin_qweight = ops.gptq_marlin_repack( - b_q_weight=packed_gptq_qweight, - perm=torch.empty(0, dtype=torch.int, device="cuda"), - size_k=size_k, - size_n=size_n, - num_bits=8, - ) - - # WEIGHT SCALES - # Currently Marlin doesn't support per-tensor scales, so we - # expand it to channelwise - scales = weight_scale.repeat(1, size_n).to(a_input.dtype).to("cuda") - # Permute scales - marlin_scales = marlin_permute_scales(s=scales, - size_k=size_k, - size_n=size_n, - group_size=-1) - - workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_MAX_PARALLEL) - - opcheck(torch.ops._C.fp8_marlin_gemm, - (a_input, marlin_qweight, marlin_scales, workspace.scratch, - num_bits, a_input.shape[0], b_weight.shape[1], a_input.shape[1])) - - output = ops.fp8_marlin_gemm( - a=a_input, - b_q_weight=marlin_qweight, - b_scales=marlin_scales, - workspace=workspace.scratch, - num_bits=num_bits, - size_m=a_input.shape[0], - size_n=b_weight.shape[1], - size_k=a_input.shape[1], - ) - output_ref = torch.matmul(a_input, b_weight) - - torch.cuda.synchronize() - - max_diff = compute_max_diff(output, output_ref) - - assert max_diff < 0.04 - - @pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @@ -432,25 +371,23 @@ def test_awq_marlin_gemm( g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device) sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device) is_k_full = True - has_zp = True - workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_MAX_PARALLEL) + workspace = marlin_make_workspace_new(a_input.device) output = ops.gptq_marlin_gemm( a_input, + None, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices, - workspace.scratch, + workspace, quant_type, a_input.shape[0], b_weight.shape[1], a_input.shape[1], is_k_full=is_k_full, - has_zp=has_zp, use_fp32_reduce=use_fp32_reduce, is_zp_float=False, ) @@ -508,23 +445,22 @@ def test_hqq_marlin_gemm( g_idx = marlin_make_empty_g_idx(dev) g_idx_sort_indices = marlin_make_empty_g_idx(dev) - workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_MAX_PARALLEL) + workspace = marlin_make_workspace_new(b_weight.device) output = ops.gptq_marlin_gemm( a_input, + None, marlin_w_q, marlin_s, marlin_zp, g_idx, g_idx_sort_indices, - workspace.scratch, + workspace, quant_type, a_input.shape[0], b_weight.shape[0], a_input.shape[1], is_k_full=True, - has_zp=True, use_fp32_reduce=use_fp32_reduce, is_zp_float=True, ) @@ -621,23 +557,22 @@ def test_marlin_gemm_subset_input(): b_weight, quant_type, group_size, False) marlin_zp = marlin_make_empty_g_idx(marlin_s.device) - workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_MAX_PARALLEL) + workspace = marlin_make_workspace_new(a_input.device) output = ops.gptq_marlin_gemm( a_input, + None, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices, - workspace.scratch, + workspace, quant_type, a_input.shape[0], b_weight.shape[1], a_input.shape[1], is_k_full=True, - has_zp=False, use_atomic_add=False, use_fp32_reduce=True, is_zp_float=False, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 64f4310151c..44377ccb295 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -325,18 +325,18 @@ def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, @register_fake("_C::gptq_marlin_gemm") def _gptq_marlin_gemm_fake(a: torch.Tensor, + c: Optional[torch.Tensor], b_q_weight: torch.Tensor, b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, + b_zeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], workspace: torch.Tensor, - b_q_type: ScalarType, + b_q_type_id: int, size_m: torch.SymInt, size_n: torch.SymInt, size_k: torch.SymInt, - is_k_full: bool, - has_zp: bool = False, + is_k_full: bool = True, use_atomic_add: bool = False, use_fp32_reduce: bool = False, is_zp_float: bool = False) -> torch.Tensor: @@ -407,14 +407,6 @@ def _aqlm_dequant_fake( dtype=codebooks.dtype, device=codebooks.device) - @register_fake("_C::fp8_marlin_gemm") - def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) - @register_fake("_C::machete_mm") def machete_mm_fake( a: torch.Tensor, @@ -815,35 +807,26 @@ def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, def gptq_marlin_gemm(a: torch.Tensor, + c: Optional[torch.Tensor], b_q_weight: torch.Tensor, b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, + b_zeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], workspace: torch.Tensor, b_q_type: ScalarType, size_m: int, size_n: int, size_k: int, - is_k_full: bool, - has_zp: bool = False, + is_k_full: bool = True, use_atomic_add: bool = False, use_fp32_reduce: bool = False, is_zp_float: bool = False) -> torch.Tensor: - return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros, + return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_scales, b_zeros, g_idx, perm, workspace, b_q_type.id, size_m, size_n, size_k, is_k_full, - has_zp, use_atomic_add, - use_fp32_reduce, is_zp_float) - - -# fp8 marlin -def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: int, size_n: int, - size_k: int) -> torch.Tensor: - return torch.ops._C.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace, - num_bits, size_m, size_n, size_k) + use_atomic_add, use_fp32_reduce, + is_zp_float) # machete diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 238808b226f..b96d34ec2db 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -7,163 +7,13 @@ import vllm._custom_ops as ops from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_topk, moe_align_block_size, try_get_optimal_moe_config) -from vllm.scalar_type import scalar_types + moe_align_block_size, try_get_optimal_moe_config) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + marlin_make_workspace_new, maybe_warn_marlin_atomic_add) +from vllm.scalar_type import ScalarType, scalar_types from vllm.utils import direct_register_custom_op -def get_scalar_type(num_bits: int, has_zp: bool): - if has_zp: - return scalar_types.uint4 if num_bits == 4 else scalar_types.uint8 - else: - return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128 - - -def single_marlin_moe( - hidden_states: torch.Tensor, - w: torch.Tensor, - scales: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - g_idx: Optional[torch.Tensor] = None, - sort_indices: Optional[torch.Tensor] = None, - w_zeros: Optional[torch.Tensor] = None, - workspace: Optional[torch.Tensor] = None, - num_bits: int = 8, - is_k_full: bool = True, -) -> torch.Tensor: - """ - This function computes the multiplication of hidden_states with expert - weights used in Marlin MoE, using weights w and top-k gating mechanism. - Its purpose is testing and debugging the fused MoE kernel. - - Parameters: - - hidden_states (torch.Tensor): The input tensor to the Marlin Mul. - - w (torch.Tensor): The set of expert weights. - - scales (torch.Tensor): The quantization scales. - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - g_idx (Optional[torch.Tensor]): Optional act_order indices. - - sort_indices (Optional[torch.Tensor]): Optional act_order input - permutation. - - topk (int): The number of top-k experts to select. - - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w. - - num_bits (bool): The number of bits in expert weights quantization. - - Returns: - - torch.Tensor: The output tensor after applying the MoE layer. - """ - # Check constraints. - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch") - assert hidden_states.shape[1] == w.shape[1] * 16, "Hidden size mismatch" - assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w.is_contiguous(), "Expert weights must be contiguous" - assert hidden_states.dtype in [torch.float16, torch.bfloat16] - assert num_bits in [4, 8] - - M, K = hidden_states.shape - E = w.shape[0] - N = w.shape[2] // (num_bits // 2) - - topk_weights, topk_ids, token_expert_indices = fused_topk( - hidden_states, gating_output, topk, renormalize) - - # This might not be an optimal config for a single MMM - get_config_func = functools.partial(try_get_optimal_moe_config, - w.shape, - w.shape, - topk_ids.shape[1], - None, - is_marlin=True) - config = get_config_func(M) - - block_size_m = config['BLOCK_SIZE_M'] - - if global_num_experts == -1: - global_num_experts = E - sorted_token_ids, expert_ids, num_tokens_post_padded = \ - moe_align_block_size(topk_ids, block_size_m, E, expert_map) - - if workspace is None: - max_workspace_size = (max(2 * N, K) // 64) * \ - (sorted_token_ids.size(0) // block_size_m) - device = hidden_states.device - sms = torch.cuda.get_device_properties(device).multi_processor_count - max_workspace_size = min(max_workspace_size, sms) - workspace = torch.zeros(max_workspace_size, - dtype=torch.int, - device=device, - requires_grad=False) - - scalar_type = get_scalar_type(num_bits, w_zeros is not None) - intermediate_cache = torch.empty( - (M * topk_ids.shape[1], N), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - - ops.moe_wna16_marlin_gemm(hidden_states, - intermediate_cache, - w, - scales, - w_zeros, - g_idx, - sort_indices, - workspace, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - topk_weights, - moe_block_size=block_size_m, - top_k=topk, - mul_topk_weights=False, - is_ep=expert_map is not None, - b_q_type=scalar_type, - size_m=M, - size_n=N, - size_k=K, - is_k_full=is_k_full, - use_atomic_add=False, - use_fp32_reduce=True, - is_zp_float=False) - intermediate_cache = intermediate_cache.view(-1, topk, N) - - return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) - - -def single_marlin_moe_fake( - hidden_states: torch.Tensor, - w: torch.Tensor, - scales: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - g_idx: Optional[torch.Tensor] = None, - sort_indices: Optional[torch.Tensor] = None, - w_zeros: Optional[torch.Tensor] = None, - workspace: Optional[torch.Tensor] = None, - num_bits: int = 8, - is_k_full: bool = True, -) -> torch.Tensor: - return torch.empty_like(hidden_states) - - -direct_register_custom_op( - op_name="single_marlin_moe", - op_func=single_marlin_moe, - mutates_args=[], - fake_impl=single_marlin_moe_fake, -) - - def fused_marlin_moe(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -172,6 +22,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, gating_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + quant_type_id: int, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, g_idx1: Optional[torch.Tensor] = None, @@ -181,7 +32,6 @@ def fused_marlin_moe(hidden_states: torch.Tensor, w1_zeros: Optional[torch.Tensor] = None, w2_zeros: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, - num_bits: int = 8, is_k_full: bool = True, inplace: bool = False) -> torch.Tensor: """ @@ -211,6 +61,15 @@ def fused_marlin_moe(hidden_states: torch.Tensor, Returns: - torch.Tensor: The output tensor after applying the MoE layer. """ + quant_type = ScalarType.from_id(quant_type_id) + assert quant_type in [ + scalar_types.uint4, scalar_types.uint8b128, scalar_types.uint4b8, + scalar_types.float8_e4m3fn + ] + + int4_scalar_types = [scalar_types.uint4, scalar_types.uint4b8] + num_bits = 4 if quant_type in int4_scalar_types else 8 + # Check constraints. assert hidden_states.shape[0] == gating_output.shape[ 0], "Number of tokens mismatch" @@ -248,18 +107,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, expert_map) if workspace is None: - max_workspace_size = (max(2 * N, K) // 64) * \ - (sorted_token_ids.size(0) // block_size_m) - device = hidden_states.device - sms = torch.cuda.get_device_properties(device).multi_processor_count - max_workspace_size = min(max_workspace_size, sms * 4) - workspace = torch.zeros(max_workspace_size, - dtype=torch.int, - device=device, - requires_grad=False) - - scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None) - scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None) + workspace = marlin_make_workspace_new(hidden_states.device, 4) intermediate_cache2 = torch.empty( (M * topk_ids.shape[1], N), @@ -276,6 +124,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, intermediate_cache3 = intermediate_cache13[:M * topk_ids.shape[1] * K] intermediate_cache3 = intermediate_cache3.view(-1, K) + maybe_warn_marlin_atomic_add(hidden_states.device, hidden_states.dtype) use_atomic_add = hidden_states.dtype == torch.half or \ torch.cuda.get_device_capability(hidden_states.device)[0] >= 9 @@ -296,7 +145,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, top_k=topk, mul_topk_weights=False, is_ep=expert_map is not None, - b_q_type=scalar_type1, + b_q_type=quant_type, size_m=M, size_n=2 * N, size_k=K, @@ -328,7 +177,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, top_k=1, mul_topk_weights=True, is_ep=expert_map is not None, - b_q_type=scalar_type2, + b_q_type=quant_type, size_m=M * topk, size_n=K, size_k=N, @@ -351,6 +200,7 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor, gating_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + quant_type_id: int, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, g_idx1: Optional[torch.Tensor] = None, @@ -360,7 +210,6 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor, w1_zeros: Optional[torch.Tensor] = None, w2_zeros: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, - num_bits: int = 8, is_k_full: bool = True, inplace: bool = False) -> torch.Tensor: return torch.empty_like(hidden_states) diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index f7c885c2baa..556166f19f2 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -22,9 +22,10 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, check_marlin_supports_layer, check_moe_marlin_supports_layer, - marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, - marlin_permute_scales, moe_awq_to_marlin_zero_points, - verify_marlin_supported, verify_marlin_supports_shape) + marlin_make_empty_g_idx, marlin_make_workspace_new, + marlin_moe_permute_scales, marlin_permute_scales, + moe_awq_to_marlin_zero_points, verify_marlin_supported, + verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import (GroupQuantScaleParameter, PackedvLLMParameter) @@ -267,8 +268,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: requires_grad=False) # Allocate marlin workspace - layer.workspace = marlin_make_workspace( - layer.output_size_per_partition, device) + layer.workspace = marlin_make_workspace_new(device) # Repack weights from AWQ format to marlin format. marlin_qweight = ops.awq_marlin_repack( @@ -322,6 +322,9 @@ class AWQMoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: AWQMarlinConfig): self.quant_config = quant_config + if self.quant_config.weight_bits != 4: + raise ValueError("AWQMoEMethod only supports 4bit now.") + self.quant_type = scalar_types.uint4 def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -396,11 +399,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, set_weight_attrs(w2_qzeros, extra_weight_attrs) device = layer.w13_qweight.device - sms = torch.cuda.get_device_properties(device).multi_processor_count - layer.workspace = torch.zeros((sms * 4, ), - dtype=torch.int, - device=device, - requires_grad=False) + layer.workspace = marlin_make_workspace_new(device, 4) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: num_experts = layer.w13_qweight.shape[0] @@ -511,10 +510,9 @@ def apply( router_logits, topk_weights, topk_ids, + quant_type_id=self.quant_type.id, global_num_experts=global_num_experts, expert_map=expert_map, w1_zeros=layer.w13_qzeros, w2_zeros=layer.w2_qzeros, - workspace=layer.workspace, - num_bits=self.quant_config.weight_bits, - ) + workspace=layer.workspace) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py index 5c826190873..1b54e154ecb 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py @@ -55,7 +55,7 @@ def process_weights_after_loading(self, layer) -> None: # required by torch.compile to be torch.nn.Parameter layer.input_scale = torch.nn.Parameter(layer.input_scale.data, requires_grad=False) - prepare_fp8_layer_for_marlin(layer, strategy="channel") + prepare_fp8_layer_for_marlin(layer) def create_weights(self, layer: torch.nn.Module, input_size: int, output_partition_sizes: List[int], @@ -68,6 +68,7 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition layer.orig_dtype = params_dtype + layer.weight_block_size = None # WEIGHT weight = ModelWeightParameter(data=torch.empty( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 5515ba27ea1..f7056016fe8 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -21,19 +21,21 @@ QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( - apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) + apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, + prepare_moe_fp8_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, all_close_1d, convert_to_channelwise, - cutlass_block_fp8_supported, cutlass_fp8_supported, - maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz, - per_tensor_dequantize, requantize_with_max_scale) + Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported, + cutlass_fp8_supported, maybe_create_device_identity, + normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, + requantize_with_max_scale) from vllm.model_executor.parameter import (BlockQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -181,10 +183,6 @@ def __init__(self, quant_config: Fp8Config): self.use_marlin = False self.block_quant = self.quant_config.weight_block_size is not None - if self.block_quant: - # Marlin doesn't support block-wise fp8 - self.use_marlin = False - self.fp8_linear = Fp8LinearOp( # Default to using per_token quantization if cutlass is supported use_per_token_if_dynamic=cutlass_fp8_supported()) @@ -203,10 +201,16 @@ def create_weights( output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + layer.weight_block_size = None if self.block_quant: tp_size = get_tensor_model_parallel_world_size() assert self.quant_config.weight_block_size is not None + layer.weight_block_size = self.quant_config.weight_block_size block_n, block_k = ( self.quant_config.weight_block_size[0], self.quant_config.weight_block_size[1], @@ -229,12 +233,6 @@ def create_weights( f"{output_partition_size} is not divisible by " f"weight quantization block_n = {block_n}.") - layer.logical_widths = output_partition_sizes - - layer.input_size_per_partition = input_size_per_partition - layer.output_size_per_partition = output_size_per_partition - layer.orig_dtype = params_dtype - # WEIGHT weight_dtype = (torch.float8_e4m3fn if self.quant_config.is_checkpoint_fp8_serialized else @@ -303,9 +301,11 @@ def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: return weight def process_weights_after_loading(self, layer: Module) -> None: + size_k_first = True # TODO(rob): refactor block quant into separate class. if self.block_quant: assert self.quant_config.activation_scheme == "dynamic" + size_k_first = False if current_platform.is_fp8_fnuz(): weight, weight_scale_inv, _ = \ normalize_e4m3fn_to_e4m3fnuz( @@ -321,21 +321,12 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.weight = Parameter(weight, requires_grad=False) layer.weight_scale_inv = Parameter(weight_scale_inv, requires_grad=False) - return # If checkpoint not serialized fp8, quantize the weights. - if not self.quant_config.is_checkpoint_fp8_serialized: + elif not self.quant_config.is_checkpoint_fp8_serialized: qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None) - # If using marlin (w8a16), kernel uses channelwise weights, - # so extend the weight scales to be channelwise. - if self.use_marlin: - assert weight_scale.numel() == 1 - weight_scale = convert_to_channelwise( - weight_scale.expand(len(layer.logical_widths)), - layer.logical_widths) - # Update the layer with the new values. layer.weight = Parameter(qweight.t(), requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False) @@ -349,20 +340,14 @@ def process_weights_after_loading(self, layer: Module) -> None: if self.quant_config.activation_scheme == "static": layer.input_scale = torch.nn.Parameter(layer.input_scale.data, requires_grad=False) - # If using marlin (w8a16), kernel uses channelwise weights, - # so extend the weight scales to be channelwise. - if self.use_marlin: - weight = layer.weight - weight_scale = convert_to_channelwise(layer.weight_scale, - layer.logical_widths) + + weight = layer.weight + weight_scale = layer.weight_scale # If using w8a8, torch._scaled_mm needs per tensor, so # requantize the logical shards as a single weight. - else: + if not self.use_marlin: # Dequant -> Quant with max scale so we can run per tensor. - weight = layer.weight - weight_scale = layer.weight_scale - if current_platform.is_fp8_fnuz(): weight, weight_scale, input_scale = \ normalize_e4m3fn_to_e4m3fnuz( @@ -388,7 +373,7 @@ def process_weights_after_loading(self, layer: Module) -> None: requires_grad=False) if self.use_marlin: - prepare_fp8_layer_for_marlin(layer) + prepare_fp8_layer_for_marlin(layer, size_k_first) # Activations not quantized for marlin. del layer.input_scale @@ -444,6 +429,14 @@ def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config self.block_quant = self.quant_config.weight_block_size is not None + # For GPUs that lack FP8 hardware support, we can leverage the Marlin + # kernel for fast weight-only FP8 quantization + self.use_marlin = (not current_platform.has_device_capability(89) + or envs.VLLM_TEST_FORCE_FP8_MARLIN) + # Disable marlin for rocm + if current_platform.is_rocm(): + self.use_marlin = False + # Check for DeepGemm support. self.allow_deep_gemm = False if envs.VLLM_USE_DEEP_GEMM: @@ -461,10 +454,17 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): + layer.intermediate_size_per_partition = intermediate_size_per_partition + layer.hidden_size = hidden_size + layer.num_experts = num_experts + layer.orig_dtype = params_dtype + layer.weight_block_size = None + if self.quant_config.is_checkpoint_fp8_serialized: params_dtype = torch.float8_e4m3fn if self.block_quant: assert self.quant_config.weight_block_size is not None + layer.weight_block_size = self.quant_config.weight_block_size tp_size = get_tensor_model_parallel_world_size() block_n, block_k = ( self.quant_config.weight_block_size[0], @@ -630,10 +630,8 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w2_weight_scale_inv = \ dg.get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous() - return - # If checkpoint is fp16, quantize in place. - if not self.quant_config.is_checkpoint_fp8_serialized: + elif not self.quant_config.is_checkpoint_fp8_serialized: fp8_dtype = current_platform.fp8_dtype() w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) @@ -677,8 +675,6 @@ def process_weights_after_loading(self, layer: Module) -> None: requires_grad=False) layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) - return - # If checkpoint is fp8, we need to handle that the # MoE kernels require single activation scale and single weight # scale for w13 per expert. @@ -766,7 +762,12 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) - return + + if self.use_marlin: + prepare_moe_fp8_layer_for_marlin(layer, False) + # Activations not quantized for marlin. + del layer.w13_input_scale + del layer.w2_input_scale def apply( self, @@ -801,6 +802,20 @@ def apply( e_score_correction_bias=e_score_correction_bias, ) + if self.use_marlin: + return torch.ops.vllm.fused_marlin_moe( + x, + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale, + layer.w2_weight_scale, + router_logits, + topk_weights, + topk_ids, + quant_type_id=scalar_types.float8_e4m3fn.id, + global_num_experts=global_num_experts, + expert_map=expert_map) + return fused_experts( x, layer.w13_weight, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 703d54b3bee..56aafca87e9 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -21,8 +21,8 @@ get_linear_quant_method) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( check_marlin_supported, check_moe_marlin_supports_layer, - marlin_moe_permute_scales, marlin_repeat_scales_on_all_ranks, - verify_marlin_supported) + marlin_make_workspace_new, marlin_moe_permute_scales, + marlin_repeat_scales_on_all_ranks, verify_marlin_supported) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, GroupQuantScaleParameter, PackedColumnParameter, @@ -350,6 +350,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: GPTQMarlinConfig) -> None: self.quant_config = quant_config + if self.quant_config.quant_type.size_bits == 4: + self.quant_type = scalar_types.uint4b8 + elif self.quant_config.quant_type.size_bits == 8: + self.quant_type = scalar_types.uint8b128 + else: + raise ValueError( + "GPTQMarlinMoEMethod only supports int4 and int8 now.") def create_weights( self, @@ -498,11 +505,7 @@ def create_weights( set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) device = layer.w13_qweight.device - sms = torch.cuda.get_device_properties(device).multi_processor_count - layer.workspace = torch.zeros((sms * 4, ), - dtype=torch.int, - device=device, - requires_grad=False) + layer.workspace = marlin_make_workspace_new(device, 4) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: @@ -633,12 +636,12 @@ def apply( router_logits, topk_weights, topk_ids, + quant_type_id=self.quant_type.id, global_num_experts=global_num_experts, expert_map=expert_map, g_idx1=layer.w13_g_idx, g_idx2=layer.w2_g_idx, sort_indices1=layer.w13_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices, - num_bits=self.quant_config.quant_type.size_bits, workspace=layer.workspace, is_k_full=self.is_k_full) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py index 7bd824ff9e5..97fcde1618c 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py @@ -8,7 +8,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear, check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx, - marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx, + marlin_make_workspace_new, marlin_permute_scales, marlin_sort_g_idx, marlin_zero_points, query_marlin_supported_quant_types, unpack_cols) from vllm.model_executor.parameter import (BasevLLMParameter, permute_param_layout_) @@ -53,8 +53,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel) # Allocate marlin workspace. - self.workspace = marlin_make_workspace(c.partition_weight_shape[1], - device) + self.workspace = marlin_make_workspace_new(device) # Default names since marlin requires empty parameters for these, # TODO: remove this requirement from marlin (allow optional tensors) @@ -127,6 +126,5 @@ def apply_weights(self, wtype=c.weight_type, input_size_per_partition=c.partition_weight_shape[0], output_size_per_partition=c.partition_weight_shape[1], - has_zp=self.config.zero_points, is_k_full=self.is_k_full, bias=bias) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 4a190480d35..a2b1b7cb0e1 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -7,12 +7,15 @@ import vllm.envs as envs from vllm import _custom_ops as ops +from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types from .quant_utils import pack_cols, unpack_cols +logger = init_logger(__name__) + GPTQ_MARLIN_TILE = 16 GPTQ_MARLIN_MIN_THREAD_N = 64 GPTQ_MARLIN_MIN_THREAD_K = 128 @@ -29,9 +32,11 @@ # For binary size and compile time, we don't support the same types for with and # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. # TODO: we may want to move this into the C++ so its closer to the actual impl -def query_marlin_supported_quant_types(has_zp: bool, - device_capability: Optional[int] = None - ): +def query_marlin_supported_quant_types( + has_zp: bool, + include_fp_type: bool = True, + device_capability: Optional[int] = None, +): if device_capability is None: capability_tuple = current_platform.get_device_capability() device_capability = (-1 if capability_tuple is None else @@ -42,12 +47,13 @@ def query_marlin_supported_quant_types(has_zp: bool, if has_zp: # AWQ style, unsigned + runtime zero-point - return [scalar_types.uint4, scalar_types.uint8] + return [scalar_types.uint4] else: # GPTQ style, unsigned + symmetric bias - # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able - # to add `scalar_types.float8_e4m3fn` here - return [scalar_types.uint4b8, scalar_types.uint8b128] + res = [scalar_types.uint4b8, scalar_types.uint8b128] + if include_fp_type: + res += [scalar_types.float8_e4m3fn] + return res def _check_marlin_supported( @@ -62,7 +68,7 @@ def _check_marlin_supported( capability_tuple.to_int()) supported_types = query_marlin_supported_quant_types( - has_zp, device_capability) + has_zp, True, device_capability) if quant_type not in supported_types: return (False, f"Marlin does not support weight_bits = {quant_type}. " @@ -175,6 +181,17 @@ def marlin_make_workspace(output_size_per_partition: int, requires_grad=False) +def marlin_make_workspace_new(device: torch.device, + max_blocks_per_sm: int = 1) -> torch.Tensor: + # In the new marlin kernel, we use the num of threadblocks as workspace + # size. The num of threadblocks is is sms_count * max_blocks_per_sm. + sms = torch.cuda.get_device_properties(device).multi_processor_count + return torch.zeros(sms * max_blocks_per_sm, + dtype=torch.int, + device=device, + requires_grad=False) + + def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: return (not act_order) or (act_order and not is_row_parallel) @@ -304,21 +321,50 @@ def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, return output +def maybe_warn_marlin_atomic_add(device, dtype): + if torch.compiler.is_dynamo_compiling(): + return + device_capability = torch.cuda.get_device_capability(device) + if device_capability[0] < 9 and dtype == torch.bfloat16: + logger.info_once( + "You are running Marlin kernel with bf16 on GPUs before SM90. " + "You can consider change to fp16 to achieve better performance " + "if possible.") + + +def maybe_warn_marlin_atomic_add_env(): + if torch.compiler.is_dynamo_compiling(): + return + if envs.VLLM_MARLIN_USE_ATOMIC_ADD: + return + logger.info_once( + "Marlin kernel can achieve better performance for small size_n " + "with experimental use_atomic_add feature. " + "You can consider set environment variable " + "VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible.") + + def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device, dtype: torch.dtype) -> bool: + + # the performance of atomicAdd is better than global reduce + # only when m*n is small and k is large + if n >= 2048 or k < 2048 or device.type != "cuda": + return False + # disable atomicAdd reduce by default, # one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1 - if not envs.VLLM_MARLIN_USE_ATOMIC_ADD or device.type != "cuda": + if not envs.VLLM_MARLIN_USE_ATOMIC_ADD: + maybe_warn_marlin_atomic_add_env() return False # sm8x doesn't support atomicAdd + bfloat16 natively device_capability = torch.cuda.get_device_capability(device) if device_capability[0] < 9 and dtype == torch.bfloat16: + maybe_warn_marlin_atomic_add(device, dtype) return False - # the performance of atomicAdd is better than global reduce - # only when m*n is small and k is large - return n < 2048 and k >= 2048 + return True def apply_gptq_marlin_linear( @@ -332,7 +378,6 @@ def apply_gptq_marlin_linear( wtype: ScalarType, output_size_per_partition: int, input_size_per_partition: int, - has_zp: bool, is_k_full: bool, bias: Optional[torch.Tensor] = None, use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: @@ -346,6 +391,7 @@ def apply_gptq_marlin_linear( dtype=input.dtype) output = ops.gptq_marlin_gemm(reshaped_x, + None, weight, weight_scale, weight_zp, @@ -358,7 +404,6 @@ def apply_gptq_marlin_linear( size_k=input_size_per_partition, is_k_full=is_k_full, use_atomic_add=use_atomic_add, - has_zp=has_zp, use_fp32_reduce=use_fp32_reduce, is_zp_float=False) @@ -391,6 +436,7 @@ def apply_awq_marlin_linear( dtype=input.dtype) output = ops.gptq_marlin_gemm(reshaped_x, + None, weight, weight_scale, weight_zp, @@ -401,8 +447,6 @@ def apply_awq_marlin_linear( size_m=reshaped_x.shape[0], size_n=output_size_per_partition, size_k=input_size_per_partition, - is_k_full=True, - has_zp=True, use_atomic_add=use_atomic_add, use_fp32_reduce=use_fp32_reduce, is_zp_float=False) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index 6120a8e66ae..1e0078e246b 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -6,9 +6,11 @@ import vllm._custom_ops as ops from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales, + should_use_atomic_add_reduce) from vllm.platforms import current_platform - -from .marlin_utils import marlin_make_workspace, marlin_permute_scales +from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -18,30 +20,40 @@ def is_fp8_marlin_supported(): def apply_fp8_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - workspace: torch.Tensor, - size_n: int, - size_k: int, - bias: Optional[torch.Tensor], -) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor], + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: # For GPUs that lack FP8 hardware support, we can leverage the # Marlin kernel for fast weight-only FP8 quantization reshaped_x = input.reshape(-1, input.shape[-1]) out_shape = input.shape[:-1] + (size_n, ) - output = ops.fp8_marlin_gemm( - a=reshaped_x, - b_q_weight=weight, - b_scales=weight_scale, - workspace=workspace, - num_bits=8, - size_m=reshaped_x.shape[0], - size_n=size_n, - size_k=size_k, - ) + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=size_n, + k=size_k, + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(a=reshaped_x, + c=None, + b_q_weight=weight, + b_scales=weight_scale, + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float8_e4m3fn, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce) if bias is not None: output.add_(bias) # In-place add @@ -50,7 +62,7 @@ def apply_fp8_marlin_linear( def prepare_fp8_layer_for_marlin(layer: torch.nn.Module, - strategy: str = "tensor") -> None: + size_k_first: bool = True) -> None: logger.warning_once( "Your GPU does not have native support for FP8 computation but " "FP8 quantization is being used. Weight-only FP8 compression will " @@ -60,51 +72,234 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module, part_size_n = layer.output_size_per_partition part_size_k = layer.input_size_per_partition + if size_k_first: + assert layer.weight.shape == (part_size_k, part_size_n) + else: + assert layer.weight.shape == (part_size_n, part_size_k) + device = layer.weight.device # WORKSPACE - layer.workspace = marlin_make_workspace(part_size_n, device) + layer.workspace = marlin_make_workspace_new(device) # WEIGHT # Repack weights to marlin format - marlin_qweight = ops.gptq_marlin_repack(b_q_weight=pack_fp8_to_int32( - layer.weight), - perm=torch.empty(0, - dtype=torch.int, - device=device), + perm = torch.empty(0, dtype=torch.int, device=device) + qweight = pack_fp8_to_int32(layer.weight, size_k_first) + if not size_k_first: + qweight = qweight.T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, + perm=perm, size_k=part_size_k, size_n=part_size_n, num_bits=8) layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) # WEIGHT SCALES - scales = layer.weight_scale.to(layer.orig_dtype) # Permute scales + if "weight_scale" in dir(layer): + scales = layer.weight_scale.to(layer.orig_dtype) + elif "weight_scale_inv" in dir(layer): + scales = layer.weight_scale_inv.to(layer.orig_dtype) + del layer.weight_scale_inv + + if layer.weight_block_size is None: + group_size = -1 + else: + group_size = layer.weight_block_size[1] + + # marlin kernel only support channel-wise and group-wise quantization + # we need to convert the scales + if layer.weight_block_size is None: + if scales.nelement() == 1: + # tensor-wise quantization -> channel-wise quantization + # (1, 1) =>(repeat)=> (1, size_n) + scales = scales.view(1, 1).repeat_interleave(part_size_n, 1) + elif scales.nelement() > 1 and scales.nelement() != part_size_n: + assert part_size_n % scales.nelement() == 0 + s_size = scales.nelement() + # tensor-wise quantization (for gate-up proj) + # -> channel-wise quantization + # (1, s_size) =>(repeat)=> (1, size_n) + scales = scales.view(1, s_size) + scales = scales.repeat_interleave(part_size_n // s_size, 1) + else: + # channel-wise quantization + # (1, size_n) + scales = scales.view(1, part_size_n) + else: + # block-wise quantization -> group-wise quantization + # (size_k // block_size[1], ceil(size_n / block_size[0])) + # =>(repeat)=> (size_k // block_size[1], size_n) + block_n = layer.weight_block_size[0] + scales = scales.T.repeat_interleave(block_n, 1) + # size_n may not divisible by block_size[0] + scales = scales[:, :part_size_n] + marlin_scales = marlin_permute_scales(s=scales, size_k=part_size_k, size_n=part_size_n, - group_size=-1) + group_size=group_size) layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) -def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: +def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module, + size_k_first: bool = True) -> None: + logger.warning_once( + "Your GPU does not have native support for FP8 computation but " + "FP8 quantization is being used. Weight-only FP8 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + e = layer.num_experts + k = layer.hidden_size + n = layer.intermediate_size_per_partition + + # WORKSPACE + device = layer.w13_weight.device + layer.workspace = marlin_make_workspace_new(device, 4) + perm = torch.empty(0, dtype=torch.int, device=device) + + # WEIGHT + # Repack weights to marlin format + for name in ["w13_weight", "w2_weight"]: + weight = getattr(layer, name) + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + if size_k_first: + assert weight.shape == (e, size_k, size_n) + else: + assert weight.shape == (e, size_n, size_k) + + for i in range(e): + qweight = pack_fp8_to_int32(weight[i], size_k_first) + if not size_k_first: + qweight = qweight.T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, + perm=perm, + size_k=size_k, + size_n=size_n, + num_bits=8) + tensor_list.append(marlin_qweight) + + weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + weight = torch.nn.Parameter(weight, requires_grad=False) + + setattr(layer, name, weight) + + # WEIGHT SCALES + # Permute scales + if layer.weight_block_size is None: + group_size = -1 + else: + group_size = layer.weight_block_size[1] + + for name in ["w13", "w2"]: + if name + "_weight_scale" in dir(layer): + new_name = name + "_weight_scale" + scales = getattr(layer, new_name).to(layer.orig_dtype) + delattr(layer, new_name) + elif name + "_weight_scale_inv" in dir(layer): + new_name = name + "_weight_scale_inv" + scales = getattr(layer, new_name).to(layer.orig_dtype) + delattr(layer, new_name) + + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + # marlin kernel only support channel-wise and group-wise quantization + # we need to convert the scales + if layer.weight_block_size is None: + if scales.nelement() == e: + # tensor-wise quantization -> channel-wise quantization + # (e, 1, 1) =>(repeat)=> (e, 1, size_n) + scales = scales.view(e, 1, 1).repeat_interleave(size_n, 2) + elif scales.nelement() > e and scales.nelement() != e * size_n: + assert (e * size_n) % scales.nelement() == 0 + s_size = scales.nelement() // e + # tensor-wise quantization (for gate-up proj) + # -> channel-wise quantization + # (e, 1, s_size) =>(repeat)=> (e, 1, size_n) + scales = scales.view(e, 1, s_size) + scales = scales.repeat_interleave(size_n // s_size, 2) + else: + # channel-wise quantization + # (e, 1, size_n) + scales = scales.view(e, 1, size_n) + else: + # block-wise quantization -> group-wise quantization + # (e, size_k // block_size[1], ceil(size_n / block_size[0])) + # =>(repeat)=> (e, size_k // block_size[1], size_n) + block_n = layer.weight_block_size[0] + scales = scales.permute(0, 2, 1).repeat_interleave(block_n, 2) + # size_n may not divisible by block_size[0] + scales = scales[..., :size_n].contiguous() + + for i in range(e): + marlin_scales = marlin_permute_scales(s=scales[i], + size_k=size_k, + size_n=size_n, + group_size=group_size) + tensor_list.append(marlin_scales) + + scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + scales = torch.nn.Parameter(scales, requires_grad=False) + + setattr(layer, name + "_weight_scale", scales) + + +def pack_fp8_to_int32(fp8_tensor: torch.Tensor, + size_k_first: bool = True) -> torch.Tensor: """ Repack FP8 weights to gptq format (packed int32 elements) """ assert fp8_tensor.dtype == torch.float8_e4m3fn - assert fp8_tensor.shape[0] % 4 == 0 + assert fp8_tensor.ndim == 2 + + fp8_tensor = fp8_tensor.T if size_k_first else fp8_tensor + fp8_tensor = fp8_tensor.contiguous() + # fp8_tensor is contiguous and have shape (N, K) now + # with `.view(torch.int32)`, it become (N, K // 4) + int32_tensor = fp8_tensor.view(torch.int32) + return int32_tensor.T.contiguous() if size_k_first else int32_tensor + - # Reshape to prepare for packing - reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) +def marlin_quant_fp8_torch(weight, group_size): + size_n, size_k = weight.shape + device = weight.device - # Convert fp8 to uint8 (byte) representation - byte_tensor = reshaped.view(torch.uint8) + if group_size != -1: + scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 448 + repeated_scales = scales.repeat_interleave(group_size, 1) + fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn) + weight_ref = fp8_weight.to(weight.dtype) * repeated_scales + else: + scales = weight.view(size_n, 1, group_size).abs().max(-1)[0] / 448 + repeated_scales = scales.repeat_interleave(size_k, 1) + fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn) + weight_ref = fp8_weight.to(weight.dtype) * repeated_scales + + packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous() + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=packed_weight, + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=8, + ) - # Pack 4 uint8 values into one int32 - packed = (byte_tensor[:, 0].to(torch.int32) | - (byte_tensor[:, 1].to(torch.int32) << 8) | - (byte_tensor[:, 2].to(torch.int32) << 16) | - (byte_tensor[:, 3].to(torch.int32) << 24)) + marlin_scales = marlin_permute_scales(s=scales.T, + size_k=size_k, + size_n=size_n, + group_size=group_size) - return packed.view(fp8_tensor.shape[0] // 4, - *fp8_tensor.shape[1:]).contiguous() + return weight_ref.T, marlin_qweight, marlin_scales diff --git a/vllm/scalar_type.py b/vllm/scalar_type.py index 1d7675dda43..5d893a3a586 100644 --- a/vllm/scalar_type.py +++ b/vllm/scalar_type.py @@ -6,6 +6,8 @@ from enum import Enum from typing import Optional, Union +_SCALAR_TYPES_ID_MAP = {} + # Mirrors enum in `core/scalar_type.hpp` class NanRepr(Enum): @@ -158,6 +160,8 @@ def or_and_advance(member, bit_width): assert offset <= 64, \ f"ScalarType fields too big {offset} to fit into an int64" + _SCALAR_TYPES_ID_MAP[val] = self + return val @property @@ -295,6 +299,13 @@ def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, ret.id # noqa B018: make sure the id is cached return ret + @classmethod + def from_id(cls, scalar_type_id: int): + if scalar_type_id not in _SCALAR_TYPES_ID_MAP: + raise ValueError( + f"scalar_type_id {scalar_type_id} doesn't exists.") + return _SCALAR_TYPES_ID_MAP[scalar_type_id] + # naming generally follows: https://github.com/jax-ml/ml_dtypes # for floating point types (leading f) the scheme is: