diff --git a/csrc/ops.h b/csrc/ops.h index 9e2e977fa3c2..4cc11a7327d2 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -90,6 +90,8 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits); +bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability); + void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales); diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 687f8efd8dc0..f4e582d780ad 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -25,6 +25,22 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b_scales); #endif +bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) { + // CUTLASS FP8 kernels need at least + // CUDA 12.0 on SM90 systems (Hopper) + // CUDA 12.4 on SM89 systems (Lovelace) + +#if defined CUDA_VERSION + if (cuda_device_capability >= 90) { + return CUDA_VERSION >= 12000; + } else if (cuda_device_capability >= 89) { + return CUDA_VERSION >= 12040; + } +#endif + + return false; +} + void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales) { diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 867bf438937c..dd0199629968 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -140,6 +140,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor b, Tensor a_scales," " Tensor b_scales) -> ()"); ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm); + + // Check if cutlass scaled_mm is supported for CUDA devices of the given + // capability + ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8); + ops.impl("cutlass_scaled_mm_supports_fp8", torch::kCUDA, + &cutlass_scaled_mm_supports_fp8); #endif // Quantized GEMM for GPTQ. diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ab2a67950bfe..87668ebb0ddf 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -212,6 +212,10 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, # cutlass +def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: + return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability) + + def cutlass_scaled_mm(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, out_dtype: Type[torch.dtype]) -> torch.Tensor: diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index e89fd65813c0..bbf3cde54782 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -20,19 +20,8 @@ def cutlass_fp8_supported() -> bool: capability = torch.cuda.get_device_capability() capability = capability[0] * 10 + capability[1] - major, minor = torch.version.cuda.split(".") - version = int(major) * 10 + int(minor) - - # CUTLASS FP8 kernels need at least - # CUDA 12.0 on SM90 systems (Hopper) - # CUDA 12.4 on SM89 systems (Lovelace) - gpu_is_supported = False - if capability >= 90: - gpu_is_supported = version > 120 - elif capability >= 89: - gpu_is_supported = version > 124 - - return gpu_is_supported + + return ops.cutlass_scaled_mm_supports_fp8(capability) class Fp8Config(QuantizationConfig):