Skip to content

Commit 56fe4c2

Browse files
[TPU][Quantization] TPU W8A8 (#11785)
Co-authored-by: Woosuk Kwon <[email protected]>
1 parent 47de882 commit 56fe4c2

File tree

18 files changed

+565
-190
lines changed

18 files changed

+565
-190
lines changed

.buildkite/run-tpu-test.sh

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,13 @@ remove_docker_container
1414
# For HF_TOKEN.
1515
source /etc/environment
1616
# Run a simple end-to-end example.
17-
docker run --privileged --net host --shm-size=16G -it -e "HF_TOKEN=$HF_TOKEN" --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && python3 -m pip install lm_eval[api]==0.4.4 && pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference/offline_inference_tpu.py"
17+
docker run --privileged --net host --shm-size=16G -it \
18+
-e "HF_TOKEN=$HF_TOKEN" --name tpu-test \
19+
vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \
20+
&& python3 -m pip install pytest \
21+
&& python3 -m pip install lm_eval[api]==0.4.4 \
22+
&& pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py \
23+
&& pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \
24+
&& python3 /workspace/vllm/tests/tpu/test_compilation.py \
25+
&& python3 /workspace/vllm/tests/tpu/test_quantization_accuracy.py \
26+
&& python3 /workspace/vllm/examples/offline_inference/offline_inference_tpu.py"
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from dataclasses import dataclass
2+
3+
import lm_eval
4+
import pytest
5+
6+
TASK = "gsm8k"
7+
FILTER = "exact_match,strict-match"
8+
RTOL = 0.03
9+
10+
11+
@dataclass
12+
class GSM8KAccuracyTestConfig:
13+
model_name: str
14+
excepted_value: float
15+
16+
def get_model_args(self) -> str:
17+
return (f"pretrained={self.model_name},"
18+
"max_model_len=4096,max_num_seqs=32")
19+
20+
21+
# NOTE: Accuracy scores measured on GPUs.
22+
ACCURACY_CONFIGS = [
23+
GSM8KAccuracyTestConfig(
24+
model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8",
25+
excepted_value=0.76), # no bias
26+
# NOTE(rob): We cannot re-initialize VLLM in the same process for TPU,
27+
# so only one of these tests can run in a single call to pytest. As
28+
# a follow up, move this into the LM-EVAL section of the CI.
29+
# GSM8KAccuracyTestConfig(
30+
# model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8",
31+
# excepted_value=0.66), # bias in QKV layers
32+
]
33+
34+
35+
@pytest.mark.parametrize("config", ACCURACY_CONFIGS)
36+
def test_gsm8k_correctness(config: GSM8KAccuracyTestConfig):
37+
38+
results = lm_eval.simple_evaluate(
39+
model="vllm",
40+
model_args=config.get_model_args(),
41+
tasks="gsm8k",
42+
batch_size="auto",
43+
)
44+
45+
EXPECTED_VALUE = config.excepted_value
46+
measured_value = results["results"][TASK][FILTER]
47+
assert (measured_value - RTOL < EXPECTED_VALUE
48+
and measured_value + RTOL > EXPECTED_VALUE
49+
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
Lines changed: 31 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
1-
from typing import Callable, List, Optional
1+
from typing import Callable, List, Optional, Set
22

33
import torch
44
from compressed_tensors.quantization import QuantizationStrategy
5-
from torch.nn import Parameter
65

76
from vllm.logger import init_logger
87
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
98
CompressedTensorsScheme)
10-
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
11-
apply_int8_linear, convert_to_channelwise)
9+
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
10+
ScaledMMLinearLayerConfig, choose_scaled_mm_linear_kernel)
1211
from vllm.model_executor.parameter import (BasevLLMParameter,
1312
ChannelQuantScaleParameter,
1413
ModelWeightParameter,
@@ -18,6 +17,7 @@
1817

1918

2019
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
20+
_kernel_backends_being_used: Set[str] = set()
2121

2222
def __init__(self, strategy: str, is_static_input_scheme: bool,
2323
input_symmetric: bool):
@@ -30,74 +30,25 @@ def get_min_capability(cls) -> int:
3030
# turing and up
3131
return 75
3232

33-
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
34-
# WEIGHT
35-
# Cutlass kernels need transposed weight.
36-
weight = layer.weight
37-
layer.weight = Parameter(weight.t(), requires_grad=False)
38-
39-
# WEIGHT SCALE
40-
# Cutlass kernels support only per-tensor and per-channel.
41-
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
42-
# scales being passed to the kernel), convert to the per-channel case.
43-
is_fused_module = len(self.logical_widths) > 1
44-
if is_fused_module and self.strategy == QuantizationStrategy.TENSOR:
45-
ws_channelwise = convert_to_channelwise(layer.weight_scale,
46-
self.logical_widths)
47-
layer.weight_scale = Parameter(ws_channelwise, requires_grad=False)
48-
else:
49-
layer.weight_scale = Parameter(layer.weight_scale.data,
50-
requires_grad=False)
51-
# INPUT SCALE
52-
if self.is_static_input_scheme:
53-
if self.input_symmetric:
54-
layer.input_scale = Parameter(layer.input_scale.max(),
55-
requires_grad=False)
56-
layer.input_zero_point = None
57-
else:
58-
# reconstruct the ranges
59-
int8_traits = torch.iinfo(torch.int8)
60-
azps = layer.input_zero_point.to(dtype=torch.int32)
61-
range_max = (layer.input_scale *
62-
(int8_traits.max - azps)).max()
63-
range_min = (layer.input_scale *
64-
(int8_traits.min - azps)).min()
65-
66-
scale = (range_max - range_min) / (int8_traits.max -
67-
int8_traits.min)
68-
layer.input_scale = Parameter(scale, requires_grad=False)
69-
70-
# AZP loaded as int8 but used as int32
71-
azp = (int8_traits.min -
72-
range_min / scale).to(dtype=torch.int32)
73-
layer.input_zero_point = Parameter(azp, requires_grad=False)
74-
75-
else:
76-
layer.input_scale = None
77-
layer.input_zero_point = None
78-
79-
# azp_adj is the AZP adjustment term, used to account for weights.
80-
# It does not depend on scales or azp, so it is the same for
81-
# static and dynamic quantization.
82-
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
83-
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
84-
if not self.input_symmetric:
85-
azp_adj = layer.weight.sum(dim=0, keepdim=True, dtype=torch.int32)
86-
if self.is_static_input_scheme:
87-
# cutlass_w8a8 requires azp to be folded into azp_adj
88-
# in the per-tensor case
89-
azp_adj = layer.input_zero_point * azp_adj
90-
91-
layer.azp_adj = azp_adj
92-
else:
93-
layer.azp_adj = None
94-
9533
def create_weights(self, layer: torch.nn.Module,
9634
output_partition_sizes: List[int],
9735
input_size_per_partition: int,
9836
params_dtype: torch.dtype, weight_loader: Callable,
9937
**kwargs):
100-
self.logical_widths = output_partition_sizes
38+
layer.logical_widths = output_partition_sizes
39+
40+
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
41+
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
42+
is_static_input_scheme=self.is_static_input_scheme,
43+
input_symmetric=self.input_symmetric)
44+
45+
kernel_type = choose_scaled_mm_linear_kernel(
46+
scaled_mm_linear_kernel_config)
47+
48+
if kernel_type.__name__ not in self._kernel_backends_being_used:
49+
logger.info("Using %s for CompressedTensorsW8A8Int8",
50+
kernel_type.__name__)
51+
self._kernel_backends_being_used.add(kernel_type.__name__)
10152

10253
# WEIGHT
10354
weight = ModelWeightParameter(data=torch.empty(
@@ -140,12 +91,18 @@ def create_weights(self, layer: torch.nn.Module,
14091
weight_loader=weight_loader)
14192
layer.register_parameter("input_zero_point", input_zero_point)
14293

94+
self.kernel = kernel_type(c=scaled_mm_linear_kernel_config,
95+
w_q_param_name="weight",
96+
w_s_param_name="weight_scale",
97+
i_s_param_name="input_scale",
98+
i_zp_param_name="input_zero_point",
99+
azp_adj_param_name="azp_adj")
100+
101+
# Checkpoints are serialized in compressed-tensors format, which is
102+
# different from the format the kernel may want. Handle repacking here.
103+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
104+
self.kernel.process_weights_after_loading(layer)
105+
143106
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
144107
bias: Optional[torch.Tensor]) -> torch.Tensor:
145-
return apply_int8_linear(input=x,
146-
weight=layer.weight,
147-
weight_scale=layer.weight_scale,
148-
input_scale=layer.input_scale,
149-
input_zero_point=layer.input_zero_point,
150-
azp_adj=layer.azp_adj,
151-
bias=bias)
108+
return self.kernel.apply_weights(layer, x, bias)

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from vllm.logger import init_logger
77
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
88
CompressedTensorsScheme)
9-
from vllm.model_executor.layers.quantization.kernels import (
9+
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
1010
MPLinearLayerConfig, choose_mp_linear_kernel)
1111
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
1212
marlin_repeat_scales_on_all_ranks)

vllm/model_executor/layers/quantization/gptq_marlin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
set_weight_attrs)
1212
from vllm.model_executor.layers.quantization.base_config import (
1313
QuantizationConfig)
14-
from vllm.model_executor.layers.quantization.kernels import (
14+
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
1515
MPLinearLayerConfig, choose_mp_linear_kernel)
1616
from vllm.model_executor.layers.quantization.utils import replace_parameter
1717
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
Lines changed: 0 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,74 +0,0 @@
1-
from typing import List, Optional, Type
2-
3-
import vllm.envs as envs
4-
from vllm.model_executor.layers.quantization.kernels.exllama import (
5-
ExllamaLinearKernel)
6-
from vllm.model_executor.layers.quantization.kernels.machete import (
7-
MacheteLinearKernel)
8-
from vllm.model_executor.layers.quantization.kernels.marlin import (
9-
MarlinLinearKernel)
10-
from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import (
11-
MPLinearKernel, MPLinearLayerConfig)
12-
from vllm.platforms import current_platform
13-
14-
# in priority/performance order (when available)
15-
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
16-
MacheteLinearKernel,
17-
MarlinLinearKernel,
18-
ExllamaLinearKernel,
19-
]
20-
21-
22-
def choose_mp_linear_kernel(
23-
config: MPLinearLayerConfig,
24-
compute_capability: Optional[int] = None) -> Type[MPLinearKernel]:
25-
"""
26-
Choose an MPLinearKernel that can implement the given config for the given
27-
compute capability. Attempts to choose the best kernel in terms of
28-
performance.
29-
30-
Args:
31-
config (MPLinearLayerConfig): Description of the linear layer to be
32-
implemented.
33-
compute_capability (Optional[int], optional): The compute capability of
34-
the target device, if None uses `current_platform` to get the compute
35-
capability. Defaults to None.
36-
37-
Raises:
38-
ValueError: If no kernel can implement the given config.
39-
40-
Returns:
41-
Type[MPLinearKernel]: Chosen kernel.
42-
"""
43-
if compute_capability is None:
44-
if current_platform is None:
45-
raise ValueError("Cannot determine compute capability")
46-
_cc = current_platform.get_device_capability()
47-
compute_capability = _cc[0] * 10 + _cc[1]
48-
49-
failure_reasons = []
50-
for kernel in _POSSIBLE_KERNELS:
51-
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
52-
failure_reasons.append(
53-
f' {kernel.__name__} disabled by environment variable')
54-
continue
55-
56-
if kernel.get_min_capability() > compute_capability:
57-
failure_reasons.append(
58-
f"{kernel.__name__} requires capability "
59-
f"{kernel.get_min_capability()}, current compute capability "
60-
f"is {compute_capability}")
61-
continue
62-
63-
can_implement, failure_reason = kernel.can_implement(config)
64-
if can_implement:
65-
return kernel
66-
else:
67-
failure_reasons.append(
68-
f' {kernel.__name__} cannot implement due to: {failure_reason}'
69-
)
70-
71-
raise ValueError(
72-
"Failed to find a kernel that can implement the "\
73-
"WNA16 linear layer. Reasons: \n"
74-
+ '\n'.join(failure_reasons))
File renamed without changes.
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from typing import List, Optional, Type
2+
3+
import vllm.envs as envs
4+
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501
5+
ExllamaLinearKernel)
6+
from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501
7+
MacheteLinearKernel)
8+
from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import ( # noqa: E501
9+
MarlinLinearKernel)
10+
from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKernel import ( # noqa: E501
11+
MPLinearKernel, MPLinearLayerConfig)
12+
from vllm.platforms import current_platform
13+
14+
# in priority/performance order (when available)
15+
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
16+
MacheteLinearKernel,
17+
MarlinLinearKernel,
18+
ExllamaLinearKernel,
19+
]
20+
21+
22+
def choose_mp_linear_kernel(
23+
config: MPLinearLayerConfig,
24+
compute_capability: Optional[int] = None) -> Type[MPLinearKernel]:
25+
"""
26+
Choose an MPLinearKernel that can implement the given config for the given
27+
compute capability. Attempts to choose the best kernel in terms of
28+
performance.
29+
30+
Args:
31+
config (MPLinearLayerConfig): Description of the linear layer to be
32+
implemented.
33+
compute_capability (Optional[int], optional): The compute capability of
34+
the target device, if None uses `current_platform` to get the compute
35+
capability. Defaults to None.
36+
37+
Raises:
38+
ValueError: If no kernel can implement the given config.
39+
40+
Returns:
41+
Type[MPLinearKernel]: Chosen kernel.
42+
"""
43+
if compute_capability is None:
44+
if current_platform is None:
45+
raise ValueError("Cannot determine compute capability")
46+
_cc = current_platform.get_device_capability()
47+
compute_capability = _cc[0] * 10 + _cc[1]
48+
49+
failure_reasons = []
50+
for kernel in _POSSIBLE_KERNELS:
51+
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
52+
failure_reasons.append(
53+
f' {kernel.__name__} disabled by environment variable')
54+
continue
55+
56+
if kernel.get_min_capability() > compute_capability:
57+
failure_reasons.append(
58+
f"{kernel.__name__} requires capability "
59+
f"{kernel.get_min_capability()}, current compute capability "
60+
f"is {compute_capability}")
61+
continue
62+
63+
can_implement, failure_reason = kernel.can_implement(config)
64+
if can_implement:
65+
return kernel
66+
else:
67+
failure_reasons.append(
68+
f' {kernel.__name__} cannot implement due to: {failure_reason}'
69+
)
70+
71+
raise ValueError(
72+
"Failed to find a kernel that can implement the "\
73+
"WNA16 linear layer. Reasons: \n"
74+
+ '\n'.join(failure_reasons))
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)