|
10 | 10 | from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
11 | 11 | GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
|
12 | 12 | GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
|
| 13 | +from vllm.model_executor.layers.quantization.utils.allspark_utils import ( |
| 14 | + ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, ALLSPARK_SUPPORTED_QUANT_TYPES) |
13 | 15 | from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
14 | 16 | GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
15 | 17 | MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types)
|
|
18 | 20 | from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
19 | 21 | marlin_24_quantize)
|
20 | 22 | from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
21 |
| - gptq_pack, gptq_quantize_weights, sort_weights) |
| 23 | + gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights) |
22 | 24 | from vllm.scalar_type import ScalarType
|
23 | 25 | from vllm.utils import FlexibleArgumentParser
|
24 | 26 |
|
25 | 27 | DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
|
26 |
| -DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] |
| 28 | +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] |
27 | 29 |
|
28 | 30 | ACT_ORDER_OPTS = [False, True]
|
29 | 31 | K_FULL_OPTS = [False, True]
|
@@ -81,6 +83,27 @@ def bench_run(results: List[benchmark.Measurement], model: str,
|
81 | 83 | GPTQ_MARLIN_24_MAX_PARALLEL)
|
82 | 84 | marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int)
|
83 | 85 |
|
| 86 | + # AllSpark W8A16 quant |
| 87 | + as_supported_case = (quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES |
| 88 | + and group_size == -1 and not act_order and is_k_full) |
| 89 | + if as_supported_case: |
| 90 | + properties = torch.cuda.get_device_properties(b.device.index) |
| 91 | + sm_count = properties.multi_processor_count |
| 92 | + sm_version = properties.major * 10 + properties.minor |
| 93 | + |
| 94 | + supported_arch = (sm_version >= 80 and sm_version < 90) |
| 95 | + as_supported_case = as_supported_case and supported_arch |
| 96 | + if supported_arch: |
| 97 | + has_zp = False |
| 98 | + w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, |
| 99 | + has_zp) |
| 100 | + qw = qw.to(torch.uint8) |
| 101 | + |
| 102 | + qw_reorder, s_reorder, zp_reorder = \ |
| 103 | + ops.allspark_repack_weight( |
| 104 | + qw, s, zp, has_zp) |
| 105 | + CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD |
| 106 | + |
84 | 107 | globals = {
|
85 | 108 | # Gen params
|
86 | 109 | "quant_type": quant_type,
|
@@ -109,10 +132,21 @@ def bench_run(results: List[benchmark.Measurement], model: str,
|
109 | 132 | # GPTQ params
|
110 | 133 | "q_w_gptq": q_w_gptq,
|
111 | 134 | "repack_sort_indices": repack_sort_indices,
|
| 135 | + # AllSpark W8A16 params |
| 136 | + "qw_reorder": qw_reorder if as_supported_case else None, |
| 137 | + "s_reorder": s_reorder if as_supported_case else None, |
| 138 | + "zp_reorder": zp_reorder if as_supported_case else None, |
| 139 | + "sm_count": sm_count if as_supported_case else None, |
| 140 | + "sm_version": sm_version if as_supported_case else None, |
| 141 | + "CUBLAS_M_THRESHOLD": |
| 142 | + CUBLAS_M_THRESHOLD if as_supported_case else None, |
| 143 | + "weight_name_pattern": |
| 144 | + f'model.layers.k{size_k}.m{size_m}.n{size_n}.qweight', |
112 | 145 | # Kernels
|
113 | 146 | "gptq_marlin_gemm": ops.gptq_marlin_gemm,
|
114 | 147 | "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
|
115 | 148 | "gptq_marlin_repack": ops.gptq_marlin_repack,
|
| 149 | + "allspark_w8a16_gemm": ops.allspark_w8a16_gemm, |
116 | 150 | }
|
117 | 151 |
|
118 | 152 | min_run_time = 1
|
@@ -172,6 +206,17 @@ def bench_run(results: List[benchmark.Measurement], model: str,
|
172 | 206 | description="gptq_marlin_repack",
|
173 | 207 | ).blocked_autorange(min_run_time=min_run_time))
|
174 | 208 |
|
| 209 | + if as_supported_case: |
| 210 | + results.append( |
| 211 | + benchmark.Timer( |
| 212 | + stmt= |
| 213 | + "output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True, weight_name_pattern)", # noqa: E501 |
| 214 | + globals=globals, |
| 215 | + label=label, |
| 216 | + sub_label=sub_label, |
| 217 | + description="allspark_w8a16_gemm_fp32", |
| 218 | + ).blocked_autorange(min_run_time=min_run_time)) |
| 219 | + |
175 | 220 |
|
176 | 221 | def main(args):
|
177 | 222 | print("Benchmarking models:")
|
|
0 commit comments