Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")

# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
set(CUTLASS_REVISION "v3.6.0" CACHE STRING "CUTLASS revision to use")
# Please keep this in sync with FetchContent_Declare line below.
set(CUTLASS_REVISION "v3.7.0" CACHE STRING "CUTLASS revision to use")

# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
Expand All @@ -245,6 +246,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
FetchContent_Declare(
cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
# Please keep this in sync with CUTLASS_REVISION line above.
GIT_TAG v3.7.0
GIT_PROGRESS TRUE

Expand All @@ -265,7 +267,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/permute_cols.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
"csrc/sparse/cutlass/sparse_compressor_entry.cu"
"csrc/cutlass_extensions/common.cpp")

set_gencode_flags_for_srcs(
Expand Down Expand Up @@ -358,8 +359,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
# require CUDA 12.2 or later (and only work on Hopper, 9.0a for now).
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
set(SRCS "csrc/sparse/cutlass/sparse_compressor_c3x.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
Expand Down Expand Up @@ -458,7 +458,7 @@ define_gpu_extension_target(
SOURCES ${VLLM_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
USE_SABI 3
WITH_SOABI)

Expand Down
69 changes: 68 additions & 1 deletion csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,30 @@ namespace vllm::c3x {

using namespace cute;

template <typename T>
struct identity {
CUTLASS_HOST_DEVICE
T operator()(T lhs) const { return lhs; }
};

template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct TrivialEpilogue {
private:
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using Compute = cutlass::epilogue::fusion::Sm90Compute<
cutlass::epilogue::thread::Identity, ElementD, ElementAcc,
cutlass::FloatRoundStyle::round_to_nearest>;

public:
using EVTCompute = cutlass::epilogue::fusion::Sm90EVT<Compute, Accum>;
using ArgumentType = typename EVTCompute::Arguments;

template <typename... Args>
static ArgumentType prepare_args(Args... args) {
return {};
}
};

/*
* This class provides the common load descriptors for the
* ScaledEpilogue[...] classes
Expand Down Expand Up @@ -174,6 +198,49 @@ struct ScaledEpilogueBias
}
};

/*
* This epilogue performs the same operation as ScaledEpilogueBias, but the
* bias is a column vector instead of a row vector. Useful e.g. if we are
* computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueColumnBias
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template ColLoad<ElementD>;

using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;

using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;

using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;

public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;

using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);

typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args, bias_args};
}
};

/*
* This epilogue directly supports per-tensor azp in int32 form.
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
Expand Down Expand Up @@ -314,4 +381,4 @@ struct ScaledEpilogueBiasAzpToken
}
};

}; // namespace vllm::c3x
}; // namespace vllm::c3x
3 changes: 1 addition & 2 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,7 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);

bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed,
torch::Tensor& e, torch::Tensor const& a);
std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
#endif

void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
Expand Down
13 changes: 9 additions & 4 deletions csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,17 @@ struct cutlass_3x_gemm {

using EVTCompute = typename Epilogue::EVTCompute;

// These are the minimum alignments needed for the kernels to compile
static constexpr int AlignmentAB =
128 / cutlass::sizeof_bits<ElementAB>::value;
static constexpr int AlignmentCD = 4;

using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
EpilogueSchedule, EVTCompute>::CollectiveOp;
ElementAcc, float, ElementC, StrideC, AlignmentCD, ElementD, StrideD,
AlignmentCD, EpilogueSchedule, EVTCompute>::CollectiveOp;

static constexpr size_t CEStorageSize =
sizeof(typename CollectiveEpilogue::SharedStorage);
Expand All @@ -69,8 +74,8 @@ struct cutlass_3x_gemm {
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
ElementAB, cutlass::layout::RowMajor, 16,
ElementAB, cutlass::layout::ColumnMajor, 16,
ElementAB, cutlass::layout::RowMajor, AlignmentAB,
ElementAB, cutlass::layout::ColumnMajor, AlignmentAB,
ElementAcc, TileShape, ClusterShape,
Stages,
KernelSchedule>::CollectiveOp;
Expand Down
11 changes: 8 additions & 3 deletions csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,19 @@ struct cutlass_2x_gemm {

using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>;

// These are the minimum alignments needed for the kernels to compile
static constexpr int AlignmentAB =
128 / cutlass::sizeof_bits<ElementAB>::value;
static constexpr int AlignmentCD = 4;

// clang-format off
using RowMajor = typename cutlass::layout::RowMajor;
using ColumnMajor = typename cutlass::layout::ColumnMajor;
using KernelType =
ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
float, cutlass::layout::RowMajor, 4,
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, AlignmentAB,
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, AlignmentAB,
float, cutlass::layout::RowMajor, AlignmentCD,
ElementAcc, float, cutlass::arch::OpClassTensorOp,
Arch,
TileShape, WarpShape, InstructionShape,
Expand Down
165 changes: 0 additions & 165 deletions csrc/sparse/cutlass/sparse_compressor_c3x.cu

This file was deleted.

Loading