Skip to content

Commit 4007cd8

Browse files
committed
[Misc][Kernel]: Add GPTQAllSpark Quantization
Signed-off-by: wyj371990 <[email protected]>
1 parent eb24dc4 commit 4007cd8

File tree

16 files changed

+2053
-9
lines changed

16 files changed

+2053
-9
lines changed

CMakeLists.txt

100755100644
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,22 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
298298
" in CUDA target architectures")
299299
endif()
300300

301+
# Only build AllSpark kernels if we are building for at least some compatible archs.
302+
cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}")
303+
if (ALLSPARK_ARCHS)
304+
set(ALLSPARK_SRCS
305+
"csrc/quantization/gptq_allspark/allspark_repack.cu"
306+
"csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu")
307+
set_gencode_flags_for_srcs(
308+
SRCS "${ALLSPARK_SRCS}"
309+
CUDA_ARCHS "${ALLSPARK_ARCHS}")
310+
list(APPEND VLLM_EXT_SRC "${ALLSPARK_SRCS}")
311+
message(STATUS "Building AllSpark kernels for archs: ${ALLSPARK_ARCHS}")
312+
else()
313+
message(STATUS "Not building AllSpark kernels as no compatible archs found"
314+
" in CUDA target architectures")
315+
endif()
316+
301317
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
302318
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
303319
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a" "${CUDA_ARCHS}")

benchmarks/kernels/benchmark_marlin.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
1111
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
1212
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
13+
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
14+
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, ALLSPARK_SUPPORTED_QUANT_TYPES)
1315
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
1416
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
1517
MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types)
@@ -18,12 +20,12 @@
1820
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
1921
marlin_24_quantize)
2022
from vllm.model_executor.layers.quantization.utils.quant_utils import (
21-
gptq_pack, gptq_quantize_weights, sort_weights)
23+
gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
2224
from vllm.scalar_type import ScalarType
2325
from vllm.utils import FlexibleArgumentParser
2426

2527
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
26-
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
28+
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
2729

2830
ACT_ORDER_OPTS = [False, True]
2931
K_FULL_OPTS = [False, True]
@@ -81,6 +83,27 @@ def bench_run(results: List[benchmark.Measurement], model: str,
8183
GPTQ_MARLIN_24_MAX_PARALLEL)
8284
marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int)
8385

86+
# AllSpark W8A16 quant
87+
as_supported_case = (quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES
88+
and group_size == -1 and not act_order and is_k_full)
89+
if as_supported_case:
90+
properties = torch.cuda.get_device_properties(b.device.index)
91+
sm_count = properties.multi_processor_count
92+
sm_version = properties.major * 10 + properties.minor
93+
94+
supported_arch = (sm_version >= 80 and sm_version < 90)
95+
as_supported_case = as_supported_case and supported_arch
96+
if supported_arch:
97+
has_zp = False
98+
w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size,
99+
has_zp)
100+
qw = qw.to(torch.uint8)
101+
102+
qw_reorder, s_reorder, zp_reorder = \
103+
ops.allspark_repack_weight(
104+
qw, s, zp, has_zp)
105+
CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
106+
84107
globals = {
85108
# Gen params
86109
"quant_type": quant_type,
@@ -109,10 +132,21 @@ def bench_run(results: List[benchmark.Measurement], model: str,
109132
# GPTQ params
110133
"q_w_gptq": q_w_gptq,
111134
"repack_sort_indices": repack_sort_indices,
135+
# AllSpark W8A16 params
136+
"qw_reorder": qw_reorder if as_supported_case else None,
137+
"s_reorder": s_reorder if as_supported_case else None,
138+
"zp_reorder": zp_reorder if as_supported_case else None,
139+
"sm_count": sm_count if as_supported_case else None,
140+
"sm_version": sm_version if as_supported_case else None,
141+
"CUBLAS_M_THRESHOLD":
142+
CUBLAS_M_THRESHOLD if as_supported_case else None,
143+
"weight_name_pattern":
144+
f'model.layers.k{size_k}.m{size_m}.n{size_n}.qweight',
112145
# Kernels
113146
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
114147
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
115148
"gptq_marlin_repack": ops.gptq_marlin_repack,
149+
"allspark_w8a16_gemm": ops.allspark_w8a16_gemm,
116150
}
117151

118152
min_run_time = 1
@@ -172,6 +206,17 @@ def bench_run(results: List[benchmark.Measurement], model: str,
172206
description="gptq_marlin_repack",
173207
).blocked_autorange(min_run_time=min_run_time))
174208

209+
if as_supported_case:
210+
results.append(
211+
benchmark.Timer(
212+
stmt=
213+
"output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True, weight_name_pattern)", # noqa: E501
214+
globals=globals,
215+
label=label,
216+
sub_label=sub_label,
217+
description="allspark_w8a16_gemm_fp32",
218+
).blocked_autorange(min_run_time=min_run_time))
219+
175220

176221
def main(args):
177222
print("Benchmarking models:")

0 commit comments

Comments
 (0)