diff --git a/CMakeLists.txt b/CMakeLists.txt index a0fd346c6c15..aa6edb5effd7 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -228,7 +228,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") # Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case. - set(CUTLASS_REVISION "v3.6.0" CACHE STRING "CUTLASS revision to use") + # Please keep this in sync with FetchContent_Declare line below. + set(CUTLASS_REVISION "v3.7.0" CACHE STRING "CUTLASS revision to use") # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) @@ -245,6 +246,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") FetchContent_Declare( cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git + # Please keep this in sync with CUTLASS_REVISION line above. GIT_TAG v3.7.0 GIT_PROGRESS TRUE @@ -265,7 +267,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/permute_cols.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" - "csrc/sparse/cutlass/sparse_compressor_entry.cu" "csrc/cutlass_extensions/common.cpp") set_gencode_flags_for_srcs( @@ -358,8 +359,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor # require CUDA 12.2 or later (and only work on Hopper, 9.0a for now). if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS) - set(SRCS "csrc/sparse/cutlass/sparse_compressor_c3x.cu" - "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu") + set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_3X_ARCHS}") @@ -458,7 +458,7 @@ define_gpu_extension_target( SOURCES ${VLLM_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} - INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp index c590c66a6665..583fa3c45511 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -16,6 +16,30 @@ namespace vllm::c3x { using namespace cute; +template +struct identity { + CUTLASS_HOST_DEVICE + T operator()(T lhs) const { return lhs; } +}; + +template +struct TrivialEpilogue { + private: + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + using Compute = cutlass::epilogue::fusion::Sm90Compute< + cutlass::epilogue::thread::Identity, ElementD, ElementAcc, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + template + static ArgumentType prepare_args(Args... args) { + return {}; + } +}; + /* * This class provides the common load descriptors for the * ScaledEpilogue[...] classes @@ -174,6 +198,49 @@ struct ScaledEpilogueBias } }; +/* + * This epilogue performs the same operation as ScaledEpilogueBias, but the + * bias is a column vector instead of a row vector. Useful e.g. if we are + * computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels. + */ +template +struct ScaledEpilogueColumnBias + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template ColLoad; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + + using ArgumentType = typename EVTCompute::Arguments; + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + + typename EVTCompute0::Arguments evt0_args{b_args}; + return ArgumentType{a_args, evt0_args, bias_args}; + } +}; + /* * This epilogue directly supports per-tensor azp in int32 form. * As opposed to the per-token epilogue below, this epilogue only has an azp_adj @@ -314,4 +381,4 @@ struct ScaledEpilogueBiasAzpToken } }; -}; // namespace vllm::c3x \ No newline at end of file +}; // namespace vllm::c3x diff --git a/csrc/ops.h b/csrc/ops.h index e39d4ef3188a..45e0a1afe481 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -176,8 +176,7 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b_scales, std::optional const& bias); -bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed, - torch::Tensor& e, torch::Tensor const& a); +std::vector cutlass_sparse_compress(torch::Tensor const& a); #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh index 9227ebb73524..d2f43e2b7a89 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh @@ -53,12 +53,17 @@ struct cutlass_3x_gemm { using EVTCompute = typename Epilogue::EVTCompute; + // These are the minimum alignments needed for the kernels to compile + static constexpr int AlignmentAB = + 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentCD = 4; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, - ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4, - EpilogueSchedule, EVTCompute>::CollectiveOp; + ElementAcc, float, ElementC, StrideC, AlignmentCD, ElementD, StrideD, + AlignmentCD, EpilogueSchedule, EVTCompute>::CollectiveOp; static constexpr size_t CEStorageSize = sizeof(typename CollectiveEpilogue::SharedStorage); @@ -69,8 +74,8 @@ struct cutlass_3x_gemm { using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - ElementAB, cutlass::layout::RowMajor, 16, - ElementAB, cutlass::layout::ColumnMajor, 16, + ElementAB, cutlass::layout::RowMajor, AlignmentAB, + ElementAB, cutlass::layout::ColumnMajor, AlignmentAB, ElementAcc, TileShape, ClusterShape, Stages, KernelSchedule>::CollectiveOp; diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh index f2fae4b66d65..ce7cf2f35282 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh @@ -103,14 +103,19 @@ struct cutlass_2x_gemm { using EVTD = cutlass::epilogue::threadblock::Sm80EVT; + // These are the minimum alignments needed for the kernels to compile + static constexpr int AlignmentAB = + 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentCD = 4; + // clang-format off using RowMajor = typename cutlass::layout::RowMajor; using ColumnMajor = typename cutlass::layout::ColumnMajor; using KernelType = ArchGuard - -#if defined CUDA_VERSION && CUDA_VERSION >= 12020 -#include "sparse_scaled_mm_c3x.cuh" - -#include "cutlass/numeric_conversion.h" -#include "cutlass/transform/device/transform_universal_adapter.hpp" -#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" -#include "cutlass/epilogue/collective/default_epilogue.hpp" - -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/packed_stride.hpp" -// clang-format on - -using namespace cute; -using namespace vllm; - -/// Make A structured sparse by replacing elements with 0 and compress it -template -bool cutlass_sparse_compress(torch::Tensor& a_nzs, torch::Tensor& a_meta, - torch::Tensor const& a) { - // Checks for conformality - TORCH_CHECK(a.dtype() == torch::kInt8 || a.dtype() == torch::kFloat8_e4m3fn || - a.dtype() == torch::kFloat16 || a.dtype() == torch::kBFloat16); - TORCH_CHECK(a.dim() == 2) - // Check for strides and alignment - TORCH_CHECK(a.stride(0) % 4 == 0) // Required for semi-structured sparsity - TORCH_CHECK(a.stride(1) == 1) - - int m = a.size(0); - int k = a.size(1); - - // Sparse kernel setup; this kernel is not used for matmul, - // but just for setting up the compressor utility - // A matrix configuration - using ElementA = ElementA_; - using LayoutTagA = cutlass::layout::RowMajor; - constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; - // B matrix configuration - using ElementB = ElementA; - using LayoutTagB = cutlass::layout::ColumnMajor; - constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; - // C/D matrix configuration - using ElementC = float; - using LayoutTagC = cutlass::layout::ColumnMajor; - constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; - // Core kernel configurations - using ElementAccumulator = ElementAcc_; - using TileShape = Shape<_128, _128, _128>; - using TileShapeRef = Shape<_128, _128, _64>; - using ClusterShape = Shape<_1, _2, _1>; - using KernelSchedule = typename std::conditional< - std::is_same_v, - cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum, - cutlass::gemm::KernelTmaWarpSpecialized>::type; - - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; - using ProblemShape = Shape; - - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, - ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, ElementAccumulator, ElementC, LayoutTagC, - AlignmentC, ElementC, LayoutTagC, AlignmentC, - EpilogueSchedule>::CollectiveOp; - - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, ElementA, - LayoutTagA, AlignmentA, ElementB, LayoutTagB, AlignmentB, - ElementAccumulator, TileShape, ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout( - sizeof(typename CollectiveEpilogue::SharedStorage))>, - KernelSchedule>::CollectiveOp; - - using GemmKernel = - cutlass::gemm::kernel::GemmUniversal; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - using StrideA = cutlass::gemm::TagToStrideA_t; - using StrideE = StrideA; - - using StrideA = Stride, int64_t>; - - // The n (=1) dimension does not matter for the compressor - typename GemmKernel::ProblemShape prob_shape{m, 1, k, 1}; - - using LayoutA = typename GemmKernel::CollectiveMainloop::LayoutA; - using LayoutE = typename GemmKernel::CollectiveMainloop::LayoutE; - - using ElementE = typename GemmKernel::CollectiveMainloop::ElementE; - using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig; - - // Offline compressor kernel - using CompressorUtility = - cutlass::transform::kernel::StructuredSparseCompressorUtility< - ProblemShape, ElementA, LayoutTagA, SparseConfig>; - - using CompressorKernel = - cutlass::transform::kernel::StructuredSparseCompressor< - ProblemShape, ElementA, LayoutTagA, SparseConfig, - cutlass::arch::Sm90>; - - using Compressor = - cutlass::transform::device::TransformUniversalAdapter; - - auto [M, N, K, L] = prob_shape; - - StrideA stride_A; - stride_A = - cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); - - CompressorUtility compressor_utility(prob_shape, stride_A); - - int ME = compressor_utility.get_metadata_m_physical(); - int KE = compressor_utility.get_metadata_k_physical(); - int KC = compressor_utility.get_tensorA_k_physical(); - - auto a_ptr = static_cast(a.data_ptr()); - - auto a_nzs_ptr = static_cast(a_nzs.data_ptr()); - auto a_meta_ptr = static_cast( - a_meta.data_ptr()); - - cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = 0; - hw_info.sm_count = - cutlass::KernelHardwareInfo::query_device_multiprocessor_count( - hw_info.device_id); - typename Compressor::Arguments arguments{ - prob_shape, {a_ptr, stride_A, a_nzs_ptr, a_meta_ptr}, {hw_info}}; - - Compressor compressor_op; - size_t workspace_size = Compressor::get_workspace_size(arguments); - cutlass::device_memory::allocation workspace(workspace_size); - - CUTLASS_CHECK(compressor_op.can_implement(arguments)); - CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.get())); - CUTLASS_CHECK(compressor_op.run()); - CUDA_CHECK(cudaDeviceSynchronize()); - - return true; -} - -bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta, - torch::Tensor const& a) { - if (a.dtype() == torch::kBFloat16) { - return cutlass_sparse_compress(a_nzs, a_meta, - a); - } else if (a.dtype() == torch::kFloat16) { - return cutlass_sparse_compress(a_nzs, a_meta, a); - } else if (a.dtype() == torch::kFloat8_e4m3fn) { - return cutlass_sparse_compress(a_nzs, a_meta, - a); - } else if (a.dtype() == torch::kInt8) { - return cutlass_sparse_compress(a_nzs, a_meta, a); - } - return false; -} -#endif diff --git a/csrc/sparse/cutlass/sparse_compressor_c3x.cuh b/csrc/sparse/cutlass/sparse_compressor_c3x.cuh new file mode 100644 index 000000000000..2cc235f3a68a --- /dev/null +++ b/csrc/sparse/cutlass/sparse_compressor_c3x.cuh @@ -0,0 +1,90 @@ +#pragma once + +// clang-format will break include orders +// clang-format off +#include + +#if defined CUDA_VERSION && CUDA_VERSION >= 12020 +#include "sparse_scaled_mm_c3x.cuh" + +#include "cutlass/numeric_conversion.h" +#include "cutlass/transform/device/transform_universal_adapter.hpp" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" + +// clang-format on + +using namespace cute; +using namespace vllm; + +using CompressorResult = std::tuple; +/// Make A structured sparse by replacing elements with 0 and compress it +template +CompressorResult cutlass_sparse_compress(torch::Tensor const& a) { + // Checks for conformality + TORCH_CHECK(a.dtype() == torch::kInt8 || a.dtype() == torch::kFloat8_e4m3fn || + a.dtype() == torch::kFloat16 || a.dtype() == torch::kBFloat16); + TORCH_CHECK(a.dim() == 2) + // Check for strides and alignment + TORCH_CHECK(a.stride(0) % 4 == 0) // Required for semi-structured sparsity + TORCH_CHECK(a.stride(1) == 1) + + using GemmKernel = typename Gemm::KernelType; + using ElementA = typename Gemm::ElementAB; + using ElementE = typename GemmKernel::CollectiveMainloop::ElementE; + + int m = a.size(0); + int k = a.size(1); + using ProblemShape = typename GemmKernel::ProblemShape; + ProblemShape prob_shape{m, 1, k, 1}; + + int64_t lda = a.stride(0); + using StrideA = Stride, int64_t>; + StrideA a_stride{lda, Int<1>{}, 0}; + + using CompressorUtility = typename Gemm::CompressorUtility; + CompressorUtility compressor_utility(prob_shape, a_stride); + + // Allocate buffers for the metadata E and the compressed matrix A + int ME = compressor_utility.get_metadata_m_physical(); + int KE = compressor_utility.get_metadata_k_physical(); + int MC = compressor_utility.get_tensorA_m_physical(); + int KC = compressor_utility.get_tensorA_k_physical(); + + auto const a_meta_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto const a_nzs_options = + torch::TensorOptions().dtype(a.dtype()).device(a.device()); + + auto a_meta = torch::zeros({ME, KE}, a_meta_options); + auto a_nzs = torch::zeros({MC, KC}, a_nzs_options); + + auto a_ptr = static_cast(a.data_ptr()); + auto a_nzs_ptr = static_cast(a_nzs.data_ptr()); + auto a_meta_ptr = static_cast(a_meta.data_ptr()); + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = a.device().index(); + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); + + using Compressor = typename Gemm::Compressor; + typename Compressor::Arguments arguments{ + prob_shape, {a_ptr, a_stride, a_nzs_ptr, a_meta_ptr}, {hw_info}}; + + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + CUTLASS_CHECK(compressor_op.can_implement(arguments)); + CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.data_ptr())); + CUTLASS_CHECK(compressor_op.run()); + CUDA_CHECK(cudaDeviceSynchronize()); + + return {a_meta, a_nzs}; +} + +#endif diff --git a/csrc/sparse/cutlass/sparse_compressor_entry.cu b/csrc/sparse/cutlass/sparse_compressor_entry.cu deleted file mode 100644 index 3401761c1b70..000000000000 --- a/csrc/sparse/cutlass/sparse_compressor_entry.cu +++ /dev/null @@ -1,42 +0,0 @@ -#include - -#include -#include - -#include "cutlass_extensions/common.hpp" - -#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X -bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta, - torch::Tensor const& a); -#endif - -bool cutlass_sparse_compress_entry(torch::Tensor& a_nzs, torch::Tensor& a_meta, - torch::Tensor const& a) { - // Checks for conformality - TORCH_CHECK(a.dim() == 2 && a_meta.dim() == 2 && a_nzs.dim() == 2); - TORCH_CHECK(a.size(0) == a_nzs.size(0) && a.size(0) == a_meta.size(0) && - a_nzs.size(1) * 2 == a.size(1) && - a_meta.size(1) * 2 * 4 == a.size(1)); - // Considering elemsPerMetaElem = 8b / 2b_per_nz = 4 - - // Check for strides and alignment - TORCH_CHECK(a.stride(1) == 1 && a_nzs.stride(1) == 1 && - a_meta.stride(1) == 1); // Row-major - TORCH_CHECK(a.stride(0) % 8 == 0); // 8 Byte Alignment for Compression - - at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); - int32_t version_num = get_sm_version_num(); - - // Guard against compilation issues for sm90 kernels -#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X - if (version_num >= 90) { - return cutlass_sparse_compress_sm90(a_nzs, a_meta, a); - } -#endif - - TORCH_CHECK_NOT_IMPLEMENTED( - false, - "No compiled cutlass_scaled_sparse_mm for a compute capability less than " - "CUDA device capability: ", - version_num); -} diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu index 5a1879787c32..3dcaa6373f11 100644 --- a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu +++ b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu @@ -9,17 +9,30 @@ using namespace cute; using namespace vllm; +struct GemmCallerTraits { + using return_type = void; + + template + static return_type invoke(Args&&... args) { + return cutlass_sparse_gemm_caller(std::forward(args)...); + } +}; + +struct GemmCompressorTraits { + using return_type = CompressorResult; + + template + static return_type invoke(Args&&... args) { + return cutlass_sparse_compress(std::forward(args)...); + } +}; + template typename Epilogue, - typename... EpilogueArgs> -void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& bt_nzs, - torch::Tensor const& bt_meta, - EpilogueArgs&&... args) { - static_assert(std::is_same()); - TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); - TORCH_CHECK(bt_nzs.dtype() == torch::kFloat8_e4m3fn); + typename DispatchFunc, typename... Args> +typename DispatchFunc::return_type cutlass_gemm_sm90_fp8_dispatch( + uint32_t m, uint32_t n, Args&&... args) { + static_assert(std::is_same_v); using Cutlass3xGemmDefault = typename sm90_config_default::Cutlass3xGemm; @@ -49,122 +62,87 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, using Cutlass3xGemm8 = typename sm90_fp8_config_8::Cutlass3xGemm; - uint32_t const n = bt_nzs.size(0); - uint32_t const m = a.size(0); // Batch size uint32_t const mp2 = std::max(static_cast(64), next_pow_2(m)); // next power of 2 if (mp2 <= 64) { if (n == 28672) { - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } else if (n == 4096 || n == 6144) { - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } } else if (mp2 <= 128) { if (n == 4096) { - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } else if (n == 28672) { - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } else if (n == 6144) { - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } } else if (mp2 <= 256) { if (n == 4096) { - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } else if (n == 28672) { - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } else if (n == 6144) { - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } } else { if (n == 6144 || n == 28672) { - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } else if (n == 4096) { - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } } // Otherwise the default heuristic if (mp2 <= 64) { // n in [1, 64] - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } else if (mp2 <= 128) { // n in (64, 128] - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } else if (mp2 <= 256) { // n in (128, 256] - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } else { // n in (256, inf) - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } } template typename Epilogue, - typename... EpilogueArgs> -void cutlass_gemm_sm90_fp16_dispatch(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& bt_nzs, - torch::Tensor const& bt_meta, - EpilogueArgs&&... args) { - static_assert(std::is_same()); - TORCH_CHECK(a.dtype() == torch::kFloat16); - TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); - TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16); - - using Cutlass3xGemmDefault = - typename sm90_config_default::Cutlass3xGemm; - - // m in (128, inf) - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); -} - -template typename Epilogue, - typename... EpilogueArgs> -void cutlass_gemm_sm90_bf16_dispatch(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& bt_nzs, - torch::Tensor const& bt_meta, - EpilogueArgs&&... args) { - static_assert(std::is_same()); - TORCH_CHECK(a.dtype() == torch::kBFloat16); - TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); - TORCH_CHECK(bt_nzs.dtype() == torch::kBFloat16); - + typename DispatchFunc, typename... Args> +typename DispatchFunc::return_type cutlass_gemm_sm90_16bit_dispatch( + uint32_t m, uint32_t n, Args&&... args) { using Cutlass3xGemmDefault = typename sm90_config_default::Cutlass3xGemm; - // m in (128, inf) - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } template typename Epilogue, - typename... EpilogueArgs> -void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& bt_nzs, - torch::Tensor const& bt_meta, - EpilogueArgs&&... args) { - static_assert(std::is_same()); - TORCH_CHECK(a.dtype() == torch::kInt8); - TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); - TORCH_CHECK(bt_nzs.dtype() == torch::kInt8); + typename DispatchFunc, typename... Args> +typename DispatchFunc::return_type cutlass_gemm_sm90_int8_dispatch( + uint32_t m, uint32_t n, Args&&... args) { + static_assert(std::is_same_v); using Cutlass3xGemmDefault = typename sm90_config_default::Cutlass3xGemm; @@ -179,37 +157,35 @@ void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a, typename sm90_int8_config_M32_NSmall::Cutlass3xGemm; - uint32_t const n = out.size(1); bool const is_small_n = n < 8192; - - uint32_t const m = a.size(0); uint32_t const mp2 = std::max(static_cast(32), next_pow_2(m)); // next power of 2 if (mp2 <= 32) { // m in [1, 32] if (is_small_n) { - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } else { - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } } else if (mp2 <= 64) { // m in (32, 64] - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } else if (mp2 <= 128) { // m in (64, 128] - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } else { // m in (128, inf) - return cutlass_sparse_gemm_caller( - out, a, bt_nzs, bt_meta, std::forward(args)...); + return DispatchFunc::template invoke( + std::forward(args)...); } } +// Dispatch to GEMM implementations based on element types template