Skip to content

Commit 9e7722b

Browse files
robertgshaw2-redhatRobert Shaw
authored andcommitted
[ Misc ] Support Fp8 via llm-compressor (vllm-project#6110)
Co-authored-by: Robert Shaw <rshaw@neuralmagic> Signed-off-by: LeiWang1999 <[email protected]>
1 parent 670b72f commit 9e7722b

17 files changed

+603
-372
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test -b 32 -l 250 -f 5 -t 1
2+
model_name: "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
3+
tasks:
4+
- name: "gsm8k"
5+
metrics:
6+
- name: "exact_match,strict-match"
7+
value: 0.752
8+
- name: "exact_match,flexible-extract"
9+
value: 0.752
10+
limit: 250
11+
num_fewshot: 5

.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m neuralmagic/Meta-Llama-3-8B-Instruct-FP8 -b 32 -l 250 -f 5 -t 1
1+
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Meta-Llama-3-8B-Instruct-FP8 -b 32 -l 250 -f 5 -t 1
22
model_name: "neuralmagic/Meta-Llama-3-8B-Instruct-FP8"
33
tasks:
44
- name: "gsm8k"
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test -b "auto" -l 250 -f 5 -t 1
2+
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test"
3+
tasks:
4+
- name: "gsm8k"
5+
metrics:
6+
- name: "exact_match,strict-match"
7+
value: 0.728
8+
- name: "exact_match,flexible-extract"
9+
value: 0.728
10+
limit: 250
11+
num_fewshot: 5
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
Meta-Llama-3-8B-Instruct.yaml
22
Meta-Llama-3-8B-Instruct-FP8.yaml
3+
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
4+
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml

.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do
4646
done
4747

4848
lm_eval --model vllm \
49-
--model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE \
49+
--model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true \
5050
--tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \
5151
--batch_size $BATCH_SIZE

.buildkite/lm-eval-harness/test_lm_eval_correctness.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424

2525
def launch_lm_eval(eval_config):
2626
model_args = f"pretrained={eval_config['model_name']}," \
27-
f"tensor_parallel_size={TP_SIZE}"
27+
f"tensor_parallel_size={TP_SIZE}," \
28+
f"add_bos_token=true"
2829

2930
results = lm_eval.simple_evaluate(
3031
model="vllm",

tests/quantization/test_compressed_tensors.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from vllm import SamplingParams
1010
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
1111
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
12-
CompressedTensorsW8A8, CompressedTensorsWNA16)
12+
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
13+
CompressedTensorsWNA16)
1314
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
1415
QuantizationType)
1516

@@ -37,12 +38,11 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
3738
CompressedTensorsLinearMethod)
3839
assert isinstance(down_proj.quant_method,
3940
CompressedTensorsLinearMethod)
40-
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8)
41+
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
4142

4243
assert qkv_proj.scheme.strategy == strategy
4344
assert qkv_proj.scheme.is_static_input_scheme
44-
expected_type = (torch.int8 if quant_type == QuantizationType.INT else
45-
torch.float8_e4m3fn)
45+
expected_type = torch.int8
4646

4747
assert qkv_proj.weight.dtype is expected_type
4848
assert o_proj.weight.dtype is expected_type
@@ -79,7 +79,7 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args):
7979
qkv_proj = layer.self_attn.qkv_proj
8080

8181
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
82-
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8)
82+
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
8383
assert not qkv_proj.scheme.is_static_input_scheme
8484
assert qkv_proj.scheme.strategy == strategy
8585
assert qkv_proj.weight.dtype is torch.int8
@@ -123,3 +123,25 @@ def test_compressed_tensors_w4a16_marlin24(vllm_runner):
123123
sampling_params = SamplingParams()
124124
output = llm.generate("Hello world!", sampling_params=sampling_params)
125125
assert output
126+
127+
128+
def test_compressed_tensors_fp8(vllm_runner):
129+
model_path = "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
130+
with vllm_runner(model_path) as llm:
131+
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
132+
layer = model.model.layers[0]
133+
134+
qkv_proj = layer.self_attn.qkv_proj
135+
136+
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
137+
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8)
138+
assert qkv_proj.weight.dtype is torch.float8_e4m3fn
139+
assert qkv_proj.input_scale.dtype is torch.float32
140+
assert qkv_proj.weight_scale.dtype is torch.float32
141+
# should be scalars after processing
142+
assert len(qkv_proj.input_scale.shape) == 0
143+
assert len(qkv_proj.weight_scale.shape) == 0
144+
145+
sampling_params = SamplingParams()
146+
output = llm.generate("Hello world!", sampling_params=sampling_params)
147+
assert output

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
1010
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
1111
CompressedTensorsScheme, CompressedTensorsW4A16Sparse24,
12-
CompressedTensorsW8A8, CompressedTensorsWNA16)
12+
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
13+
CompressedTensorsWNA16)
1314
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
1415
CompressionFormat, QuantizationArgs, QuantizationStrategy,
15-
find_first_name_or_class_match)
16+
QuantizationType, find_first_name_or_class_match)
1617
from vllm.platforms import current_platform
1718

1819

@@ -117,6 +118,40 @@ def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
117118

118119
return is_8_bits and is_token and is_symmetric and is_dynamic
119120

121+
def _is_fp8_w8a8(self, weight_quant: BaseModel,
122+
input_quant: BaseModel) -> bool:
123+
# Confirm weights and activations quantized.
124+
if weight_quant is None or input_quant is None:
125+
return False
126+
127+
# Confirm we have floating points.
128+
if not (weight_quant.type == QuantizationType.FLOAT
129+
and input_quant.type == QuantizationType.FLOAT):
130+
return False
131+
132+
# Confirm weight scheme is supported.
133+
is_symmetric_weight = weight_quant.symmetric
134+
is_static_weight = not weight_quant.dynamic
135+
is_per_tensor_weight = (
136+
weight_quant.strategy == QuantizationStrategy.TENSOR)
137+
if not (is_symmetric_weight and is_static_weight
138+
and is_per_tensor_weight):
139+
return False
140+
141+
# Dynamic quantization is always supported if weights supported.
142+
if input_quant.dynamic:
143+
return True
144+
145+
# Confirm activation scheme is supported.
146+
is_symmetric_activation = input_quant.symmetric
147+
is_per_tensor_activation = (
148+
input_quant.strategy == QuantizationStrategy.TENSOR)
149+
if not (is_symmetric_activation and is_per_tensor_activation):
150+
return False
151+
152+
# All conditions satisfied.
153+
return True
154+
120155
def _is_wNa16_group_channel(self, weight_quant: BaseModel,
121156
input_quant: BaseModel) -> bool:
122157
input_quant_none = input_quant is None
@@ -147,14 +182,21 @@ def _get_schema(self, weight_quant: BaseModel,
147182
strategy=weight_quant.strategy,
148183
group_size=weight_quant.group_size)
149184

150-
if self.quant_format == CompressionFormat.int_quantized.value:
185+
if (self.quant_format == CompressionFormat.int_quantized.value or
186+
self.quant_format == CompressionFormat.float_quantized.value):
187+
if self._is_fp8_w8a8(weight_quant, input_quant):
188+
return CompressedTensorsW8A8Fp8(
189+
input_dynamic=input_quant.dynamic)
190+
151191
if self._is_static_tensor_w8a8(weight_quant, input_quant):
152-
return CompressedTensorsW8A8(strategy=weight_quant.strategy,
153-
is_static_input_scheme=True)
192+
return CompressedTensorsW8A8Int8(
193+
strategy=weight_quant.strategy,
194+
is_static_input_scheme=True)
154195

155196
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
156-
return CompressedTensorsW8A8(strategy=weight_quant.strategy,
157-
is_static_input_scheme=False)
197+
return CompressedTensorsW8A8Int8(
198+
strategy=weight_quant.strategy,
199+
is_static_input_scheme=False)
158200

159201
raise NotImplementedError(
160202
"No compressed-tensors compatible scheme was found.")
@@ -187,7 +229,7 @@ def __init__(self, quantization_config: CompressedTensorsConfig):
187229
self.quantization_config = quantization_config
188230

189231
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
190-
return layer.scheme.process_weights_after_loading(layer)
232+
layer.scheme.process_weights_after_loading(layer)
191233

192234
def create_weights(self, layer: torch.nn.Module,
193235
input_size_per_partition: int,
Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,19 @@
1-
from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401
2-
from .compressed_tensors_unquantized import ( # noqa: F401
3-
CompressedTensorsUnquantized)
4-
from .compressed_tensors_w4a16_24 import ( # noqa: F401
5-
W4A16SPARSE24_SUPPORTED_BITS, CompressedTensorsW4A16Sparse24)
6-
from .compressed_tensors_w8a8 import CompressedTensorsW8A8 # noqa: F401
7-
from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS # noqa: F401
8-
from .compressed_tensors_wNa16 import CompressedTensorsWNA16 # noqa: F401
1+
from .compressed_tensors_scheme import CompressedTensorsScheme
2+
from .compressed_tensors_unquantized import CompressedTensorsUnquantized
3+
from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
4+
CompressedTensorsW4A16Sparse24)
5+
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
6+
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
7+
from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS,
8+
CompressedTensorsWNA16)
9+
10+
__all__ = [
11+
"CompressedTensorsScheme",
12+
"CompressedTensorsUnquantized",
13+
"CompressedTensorsWNA16",
14+
"CompressedTensorsW4A16Sparse24",
15+
"CompressedTensorsW8A8Int8",
16+
"CompressedTensorsW8A8Fp8",
17+
"WNA16_SUPPORTED_BITS",
18+
"W4A16SPARSE24_SUPPORTED_BITS",
19+
]

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8.py

Lines changed: 0 additions & 109 deletions
This file was deleted.

0 commit comments

Comments
 (0)