Skip to content

Commit 00c11f4

Browse files
kushanamlulmer
authored andcommitted
add cutlass support for blackwell fp8 gemm (vllm-project#13798)
Signed-off-by: Louis Ulmer <[email protected]>
1 parent c744ee7 commit 00c11f4

File tree

11 files changed

+272
-65
lines changed

11 files changed

+272
-65
lines changed

CMakeLists.txt

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ set(ignoreMe "${VLLM_PYTHON_PATH}")
3131
set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12")
3232

3333
# Supported NVIDIA architectures.
34-
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0")
34+
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0")
3535

3636
# Supported AMD GPU architectures.
3737
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101")
@@ -297,7 +297,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
297297
# Only build Marlin kernels if we are building for at least some compatible archs.
298298
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
299299
# are not supported by Machete yet.
300-
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}")
300+
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
301301
if (MARLIN_ARCHS)
302302
set(MARLIN_SRCS
303303
"csrc/quantization/fp8/fp8_marlin.cu"
@@ -335,7 +335,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
335335

336336
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
337337
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
338-
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a" "${CUDA_ARCHS}")
338+
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a;10.0a;10.1a;12.0a" "${CUDA_ARCHS}")
339339
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
340340
set(SRCS
341341
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
@@ -369,7 +369,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
369369
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
370370
# kernels for the remaining archs that are not already built for 3x.
371371
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
372-
"7.5;8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}")
372+
"7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
373373
# subtract out the archs that are already built for 3x
374374
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
375375
if (SCALED_MM_2X_ARCHS)
@@ -394,7 +394,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
394394
# 2:4 Sparse Kernels
395395

396396
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
397-
# require CUDA 12.2 or later (and only work on Hopper, 9.0a for now).
397+
# require CUDA 12.2 or later (and only work on Hopper and Blackwell).
398398
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
399399
set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
400400
set_gencode_flags_for_srcs(
@@ -419,8 +419,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
419419
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS)
420420
set(SRCS
421421
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
422-
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
423-
)
422+
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu")
424423
set_gencode_flags_for_srcs(
425424
SRCS "${SRCS}"
426425
CUDA_ARCHS "${FP4_ARCHS}")
@@ -433,6 +432,22 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
433432
set(FP4_ARCHS)
434433
endif()
435434

435+
# FP8 Blackwell Archs
436+
cuda_archs_loose_intersection(BLACKWELL_ARCHS "10.0;10.1;12.0" "${CUDA_ARCHS}")
437+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND BLACKWELL_ARCHS)
438+
set(SRCS
439+
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
440+
)
441+
set_gencode_flags_for_srcs(
442+
SRCS "${SRCS}"
443+
CUDA_ARCHS "${BLACKWELL_ARCHS}")
444+
list(APPEND VLLM_EXT_SRC "${SRCS}")
445+
message(STATUS "Building FP8 for archs: ${BLACKWELL_ARCHS}")
446+
else()
447+
# clear BLACKWELL_ARCHS
448+
set(BLACKWELL_ARCHS)
449+
endif()
450+
436451
#
437452
# Machete kernels
438453

@@ -514,6 +529,7 @@ define_gpu_extension_target(
514529
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
515530
ARCHITECTURES ${VLLM_GPU_ARCHES}
516531
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
532+
INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
517533
USE_SABI 3
518534
WITH_SOABI)
519535

@@ -537,7 +553,7 @@ set_gencode_flags_for_srcs(
537553
CUDA_ARCHS "${CUDA_ARCHS}")
538554

539555
if(VLLM_GPU_LANG STREQUAL "CUDA")
540-
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}")
556+
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
541557
if (MARLIN_MOE_ARCHS)
542558
set(MARLIN_MOE_SRC
543559
"csrc/moe/marlin_kernels/marlin_moe_kernel.h"

csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ struct identity {
2222
T operator()(T lhs) const { return lhs; }
2323
};
2424

25-
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
25+
template <typename ElementAcc, typename ElementD, typename TileShape>
2626
struct TrivialEpilogue {
2727
private:
2828
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
@@ -44,32 +44,30 @@ struct TrivialEpilogue {
4444
* This class provides the common load descriptors for the
4545
* ScaledEpilogue[...] classes
4646
*/
47-
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
47+
template <typename ElementAcc, typename ElementD, typename TileShape>
4848
struct ScaledEpilogueBase {
4949
protected:
5050
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
5151

5252
template <typename T>
5353
using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
54-
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
55-
Stride<Int<1>, Int<0>, Int<0>>>;
54+
0 /*Stages*/, TileShape, T, Stride<Int<1>, Int<0>, Int<0>>>;
5655

5756
template <typename T>
5857
using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
59-
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
60-
Stride<Int<0>, Int<1>, Int<0>>>;
58+
0 /*Stages*/, TileShape, T, Stride<Int<0>, Int<1>, Int<0>>>;
6159

6260
// Don't want to support nullptr by default
6361
template <typename T, bool EnableNullPtr = false>
6462
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
65-
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T,
66-
Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
63+
0 /*Stages*/, TileShape, T, T, Stride<Int<1>, Int<0>, Int<0>>,
64+
128 / sizeof_bits_v<T>, EnableNullPtr>;
6765

6866
// Don't want to support nullptr by default
6967
template <typename T, bool EnableNullPtr = false>
7068
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
71-
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T,
72-
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
69+
0 /*Stages*/, TileShape, T, T, Stride<Int<0>, Int<1>, Int<0>>,
70+
128 / sizeof_bits_v<T>, EnableNullPtr>;
7371

7472
// This utility function constructs the arguments for the load descriptors
7573
// from a tensor. It can handle both row and column, as well as row/column or
@@ -116,11 +114,11 @@ struct ScaledEpilogueBase {
116114
the A and B operands respectively. These scales may be either per-tensor or
117115
per row or column.
118116
*/
119-
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
117+
template <typename ElementAcc, typename ElementD, typename TileShape>
120118
struct ScaledEpilogue
121-
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
119+
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
122120
private:
123-
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
121+
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
124122
using Accum = typename SUPER::Accum;
125123
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
126124
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
@@ -160,11 +158,11 @@ struct ScaledEpilogue
160158
* The bias tensor must be per-output channel.
161159
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
162160
*/
163-
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
161+
template <typename ElementAcc, typename ElementD, typename TileShape>
164162
struct ScaledEpilogueBias
165-
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
163+
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
166164
private:
167-
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
165+
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
168166
using Accum = typename SUPER::Accum;
169167
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
170168
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
@@ -203,11 +201,11 @@ struct ScaledEpilogueBias
203201
* bias is a column vector instead of a row vector. Useful e.g. if we are
204202
* computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels.
205203
*/
206-
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
204+
template <typename ElementAcc, typename ElementD, typename TileShape>
207205
struct ScaledEpilogueColumnBias
208-
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
206+
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
209207
private:
210-
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
208+
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
211209
using Accum = typename SUPER::Accum;
212210
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
213211
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
@@ -249,11 +247,11 @@ struct ScaledEpilogueColumnBias
249247
*
250248
* This epilogue also supports bias, which remains per-channel.
251249
*/
252-
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
250+
template <typename ElementAcc, typename ElementD, typename TileShape>
253251
struct ScaledEpilogueBiasAzp
254-
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
252+
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
255253
private:
256-
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
254+
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
257255
using Accum = typename SUPER::Accum;
258256
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
259257
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
@@ -314,11 +312,11 @@ struct ScaledEpilogueBiasAzp
314312
*
315313
* This epilogue also supports bias, which remains per-channel.
316314
*/
317-
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
315+
template <typename ElementAcc, typename ElementD, typename TileShape>
318316
struct ScaledEpilogueBiasAzpToken
319-
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
317+
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
320318
private:
321-
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
319+
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
322320
using Accum = typename SUPER::Accum;
323321
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
324322
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;

csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "cutlass/gemm/kernel/gemm_universal.hpp"
1717
#include "cutlass/epilogue/collective/collective_builder.hpp"
1818
#include "cutlass/gemm/collective/collective_builder.hpp"
19+
#include "cutlass/util/packed_stride.hpp"
1920

2021
#include "core/math.hpp"
2122
#include "cutlass_extensions/common.hpp"
@@ -64,33 +65,40 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
6465
torch::Tensor const& b,
6566
EpilogueArgs&&... epilogue_params) {
6667
using ElementAB = typename Gemm::ElementAB;
68+
using ElementC = typename Gemm::ElementC;
6769
using ElementD = typename Gemm::ElementD;
6870
using GemmKernel = typename Gemm::GemmKernel;
6971

70-
int64_t lda = a.stride(0);
71-
int64_t ldb = b.stride(1);
72-
int64_t ldc = out.stride(0);
73-
74-
using StrideA = cute::Stride<int64_t, cute::Int<1>, int64_t>;
75-
using StrideB = cute::Stride<int64_t, cute::Int<1>, int64_t>;
76-
using StrideC = typename Gemm::StrideC;
77-
78-
StrideA a_stride{lda, cute::Int<1>{}, 0};
79-
StrideB b_stride{ldb, cute::Int<1>{}, 0};
80-
StrideC c_stride{ldc, cute::Int<1>{}, cute::Int<0>{}};
72+
using StrideA = typename Gemm::GemmKernel::StrideA;
73+
using StrideB = typename Gemm::GemmKernel::StrideB;
74+
using StrideC = typename Gemm::GemmKernel::StrideC;
75+
using StrideD = StrideC;
76+
using StrideAux = StrideC;
8177

8278
typename GemmKernel::ProblemShape prob_shape = get_problem_shape(a, b);
79+
auto [M, N, K, L] = prob_shape;
80+
81+
StrideA a_stride =
82+
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
83+
StrideB b_stride =
84+
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
85+
StrideC c_stride =
86+
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
87+
StrideD d_stride =
88+
cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
89+
StrideAux aux_stride = d_stride;
8390

8491
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
8592
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
8693
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
8794
b_stride};
8895

8996
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
97+
// auto d_ptr = static_cast<ElementC*>(out.data_ptr());
9098
typename GemmKernel::EpilogueArguments epilogue_args{
9199
Gemm::Epilogue::prepare_args(
92100
std::forward<EpilogueArgs>(epilogue_params)...),
93-
c_ptr, c_stride, c_ptr, c_stride};
101+
c_ptr, c_stride, c_ptr, d_stride};
94102

95103
cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
96104
epilogue_args);

csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,7 @@ struct cutlass_3x_gemm {
4040
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
4141
float>::type;
4242

43-
using EpilogueDescriptor =
44-
cutlass::epilogue::collective::detail::EpilogueDescriptor<
45-
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
46-
ElementD, EpilogueSchedule>;
47-
48-
using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
43+
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
4944

5045
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
5146
using ElementC = void;
@@ -88,4 +83,65 @@ struct cutlass_3x_gemm {
8883
struct GemmKernel : public KernelType {};
8984
};
9085

86+
template <typename ElementAB_, typename ElementD_,
87+
template <typename, typename, typename> typename Epilogue_,
88+
typename TileShape, typename ClusterShape, typename KernelSchedule,
89+
typename EpilogueSchedule>
90+
struct cutlass_3x_gemm_sm100 {
91+
using ElementAB = ElementAB_;
92+
using LayoutA = cutlass::layout::RowMajor;
93+
static constexpr int AlignmentA =
94+
128 / cutlass::sizeof_bits<ElementAB>::value;
95+
96+
using LayoutB = cutlass::layout::ColumnMajor;
97+
static constexpr int AlignmentB =
98+
128 / cutlass::sizeof_bits<ElementAB>::value;
99+
100+
using ElementC = void;
101+
using LayoutC = cutlass::layout::RowMajor;
102+
static constexpr int AlignmentC =
103+
128 / cutlass::sizeof_bits<ElementD_>::value;
104+
105+
using ElementD = ElementD_;
106+
using LayoutD = cutlass::layout::RowMajor;
107+
static constexpr int AlignmentD = AlignmentC;
108+
109+
using ElementAcc =
110+
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
111+
float>::type;
112+
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
113+
114+
// MMA type
115+
using ElementAccumulator = float;
116+
117+
// Epilogue types
118+
using ElementBias = cutlass::half_t;
119+
using ElementCompute = float;
120+
using ElementAux = ElementD;
121+
using LayoutAux = LayoutD;
122+
using ElementAmax = float;
123+
124+
using EVTCompute = typename Epilogue::EVTCompute;
125+
126+
using CollectiveEpilogue =
127+
typename cutlass::epilogue::collective::CollectiveBuilder<
128+
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape,
129+
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
130+
ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC,
131+
ElementD, LayoutD, AlignmentD, EpilogueSchedule,
132+
EVTCompute>::CollectiveOp;
133+
134+
using CollectiveMainloop =
135+
typename cutlass::gemm::collective::CollectiveBuilder<
136+
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementAB,
137+
LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB,
138+
ElementAccumulator, TileShape, ClusterShape,
139+
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
140+
sizeof(typename CollectiveEpilogue::SharedStorage))>,
141+
KernelSchedule>::CollectiveOp;
142+
143+
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
144+
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
145+
};
146+
91147
} // namespace vllm

csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,10 @@ void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
3030
torch::Tensor const& a_scales,
3131
torch::Tensor const& b_scales);
3232

33+
void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
34+
torch::Tensor const& b,
35+
torch::Tensor const& a_scales,
36+
torch::Tensor const& b_scales,
37+
std::optional<torch::Tensor> const& bias);
38+
3339
} // namespace vllm
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#include "scaled_mm_kernels.hpp"
2+
#include "scaled_mm_sm100_fp8_dispatch.cuh"
3+
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
4+
5+
namespace vllm {
6+
7+
void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
8+
torch::Tensor const& b,
9+
torch::Tensor const& a_scales,
10+
torch::Tensor const& b_scales,
11+
std::optional<torch::Tensor> const& bias) {
12+
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
13+
if (bias) {
14+
TORCH_CHECK(bias->dtype() == out.dtype(),
15+
"currently bias dtype must match output dtype ", out.dtype());
16+
return cutlass_scaled_mm_sm100_fp8_epilogue<c3x::ScaledEpilogueBias>(
17+
out, a, b, a_scales, b_scales, *bias);
18+
} else {
19+
return cutlass_scaled_mm_sm100_fp8_epilogue<c3x::ScaledEpilogue>(
20+
out, a, b, a_scales, b_scales);
21+
}
22+
}
23+
24+
} // namespace vllm

0 commit comments

Comments
 (0)