Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ set(ignoreMe "${VLLM_PYTHON_PATH}")
set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12")

# Supported NVIDIA architectures.
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0")
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0")

# Supported AMD GPU architectures.
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101")
Expand Down Expand Up @@ -297,7 +297,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# Only build Marlin kernels if we are building for at least some compatible archs.
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
# are not supported by Machete yet.
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}")
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
if (MARLIN_ARCHS)
set(MARLIN_SRCS
"csrc/quantization/fp8/fp8_marlin.cu"
Expand Down Expand Up @@ -335,7 +335,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")

# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a" "${CUDA_ARCHS}")
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a;10.0a;10.1a;12.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
set(SRCS
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
Expand Down Expand Up @@ -369,7 +369,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
# kernels for the remaining archs that are not already built for 3x.
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
"7.5;8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}")
"7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
# subtract out the archs that are already built for 3x
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
if (SCALED_MM_2X_ARCHS)
Expand All @@ -394,7 +394,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# 2:4 Sparse Kernels

# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
# require CUDA 12.2 or later (and only work on Hopper, 9.0a for now).
# require CUDA 12.2 or later (and only work on Hopper and Blackwell).
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
set_gencode_flags_for_srcs(
Expand All @@ -419,8 +419,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS)
set(SRCS
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
)
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
Expand All @@ -433,6 +432,22 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(FP4_ARCHS)
endif()

# FP8 Blackwell Archs
cuda_archs_loose_intersection(BLACKWELL_ARCHS "10.0;10.1;12.0" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND BLACKWELL_ARCHS)
set(SRCS
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${BLACKWELL_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
message(STATUS "Building FP8 for archs: ${BLACKWELL_ARCHS}")
else()
# clear BLACKWELL_ARCHS
set(BLACKWELL_ARCHS)
endif()

#
# Machete kernels

Expand Down Expand Up @@ -514,6 +529,7 @@ define_gpu_extension_target(
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
USE_SABI 3
WITH_SOABI)

Expand All @@ -537,7 +553,7 @@ set_gencode_flags_for_srcs(
CUDA_ARCHS "${CUDA_ARCHS}")

if(VLLM_GPU_LANG STREQUAL "CUDA")
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}")
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
if (MARLIN_MOE_ARCHS)
set(MARLIN_MOE_SRC
"csrc/moe/marlin_kernels/marlin_moe_kernel.h"
Expand Down
48 changes: 23 additions & 25 deletions csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct identity {
T operator()(T lhs) const { return lhs; }
};

template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
template <typename ElementAcc, typename ElementD, typename TileShape>
struct TrivialEpilogue {
private:
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
Expand All @@ -44,32 +44,30 @@ struct TrivialEpilogue {
* This class provides the common load descriptors for the
* ScaledEpilogue[...] classes
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogueBase {
protected:
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;

template <typename T>
using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<1>, Int<0>, Int<0>>>;
0 /*Stages*/, TileShape, T, Stride<Int<1>, Int<0>, Int<0>>>;

template <typename T>
using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
Stride<Int<0>, Int<1>, Int<0>>>;
0 /*Stages*/, TileShape, T, Stride<Int<0>, Int<1>, Int<0>>>;

// Don't want to support nullptr by default
template <typename T, bool EnableNullPtr = false>
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T,
Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
0 /*Stages*/, TileShape, T, T, Stride<Int<1>, Int<0>, Int<0>>,
128 / sizeof_bits_v<T>, EnableNullPtr>;

// Don't want to support nullptr by default
template <typename T, bool EnableNullPtr = false>
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T,
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
0 /*Stages*/, TileShape, T, T, Stride<Int<0>, Int<1>, Int<0>>,
128 / sizeof_bits_v<T>, EnableNullPtr>;

// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
Expand Down Expand Up @@ -116,11 +114,11 @@ struct ScaledEpilogueBase {
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogue
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
Expand Down Expand Up @@ -160,11 +158,11 @@ struct ScaledEpilogue
* The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogueBias
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
Expand Down Expand Up @@ -203,11 +201,11 @@ struct ScaledEpilogueBias
* bias is a column vector instead of a row vector. Useful e.g. if we are
* computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogueColumnBias
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
Expand Down Expand Up @@ -249,11 +247,11 @@ struct ScaledEpilogueColumnBias
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogueBiasAzp
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
Expand Down Expand Up @@ -314,11 +312,11 @@ struct ScaledEpilogueBiasAzp
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogueBiasAzpToken
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
Expand Down
32 changes: 20 additions & 12 deletions csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/util/packed_stride.hpp"

#include "core/math.hpp"
#include "cutlass_extensions/common.hpp"
Expand Down Expand Up @@ -64,33 +65,40 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_params) {
using ElementAB = typename Gemm::ElementAB;
using ElementC = typename Gemm::ElementC;
using ElementD = typename Gemm::ElementD;
using GemmKernel = typename Gemm::GemmKernel;

int64_t lda = a.stride(0);
int64_t ldb = b.stride(1);
int64_t ldc = out.stride(0);

using StrideA = cute::Stride<int64_t, cute::Int<1>, int64_t>;
using StrideB = cute::Stride<int64_t, cute::Int<1>, int64_t>;
using StrideC = typename Gemm::StrideC;

StrideA a_stride{lda, cute::Int<1>{}, 0};
StrideB b_stride{ldb, cute::Int<1>{}, 0};
StrideC c_stride{ldc, cute::Int<1>{}, cute::Int<0>{}};
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = StrideC;
using StrideAux = StrideC;

typename GemmKernel::ProblemShape prob_shape = get_problem_shape(a, b);
auto [M, N, K, L] = prob_shape;

StrideA a_stride =
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
StrideB b_stride =
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
StrideC c_stride =
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
StrideD d_stride =
cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
StrideAux aux_stride = d_stride;

auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
b_stride};

auto c_ptr = static_cast<ElementD*>(out.data_ptr());
// auto d_ptr = static_cast<ElementC*>(out.data_ptr());
typename GemmKernel::EpilogueArguments epilogue_args{
Gemm::Epilogue::prepare_args(
std::forward<EpilogueArgs>(epilogue_params)...),
c_ptr, c_stride, c_ptr, c_stride};
c_ptr, c_stride, c_ptr, d_stride};

cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
epilogue_args);
Expand Down
68 changes: 62 additions & 6 deletions csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,7 @@ struct cutlass_3x_gemm {
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;

using EpilogueDescriptor =
cutlass::epilogue::collective::detail::EpilogueDescriptor<
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
ElementD, EpilogueSchedule>;

using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;

using StrideD = Stride<int64_t, Int<1>, Int<0>>;
using ElementC = void;
Expand Down Expand Up @@ -88,4 +83,65 @@ struct cutlass_3x_gemm {
struct GemmKernel : public KernelType {};
};

template <typename ElementAB_, typename ElementD_,
template <typename, typename, typename> typename Epilogue_,
typename TileShape, typename ClusterShape, typename KernelSchedule,
typename EpilogueSchedule>
struct cutlass_3x_gemm_sm100 {
using ElementAB = ElementAB_;
using LayoutA = cutlass::layout::RowMajor;
static constexpr int AlignmentA =
128 / cutlass::sizeof_bits<ElementAB>::value;

using LayoutB = cutlass::layout::ColumnMajor;
static constexpr int AlignmentB =
128 / cutlass::sizeof_bits<ElementAB>::value;

using ElementC = void;
using LayoutC = cutlass::layout::RowMajor;
static constexpr int AlignmentC =
128 / cutlass::sizeof_bits<ElementD_>::value;

using ElementD = ElementD_;
using LayoutD = cutlass::layout::RowMajor;
static constexpr int AlignmentD = AlignmentC;

using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;

// MMA type
using ElementAccumulator = float;

// Epilogue types
using ElementBias = cutlass::half_t;
using ElementCompute = float;
using ElementAux = ElementD;
using LayoutAux = LayoutD;
using ElementAmax = float;

using EVTCompute = typename Epilogue::EVTCompute;

using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape,
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC,
ElementD, LayoutD, AlignmentD, EpilogueSchedule,
EVTCompute>::CollectiveOp;

using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementAB,
LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB,
ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
};

} // namespace vllm
6 changes: 6 additions & 0 deletions csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,10 @@ void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);

void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);

} // namespace vllm
24 changes: 24 additions & 0 deletions csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm100_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"

namespace vllm {

void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm100_fp8_epilogue<c3x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm100_fp8_epilogue<c3x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}

} // namespace vllm
Loading