Skip to content

Commit 28de2f4

Browse files
SunMarcyounesbelkadadacorvoArthurZucker
authored
[Quantization] Quanto quantizer (#29023)
* start integration * fix * add and debug tests * update tests * make pytorch serialization works * compatible with device_map and offload * fix tests * make style * add ref * guard against safetensors * add float8 and style * fix is_serializable * Fix shard_checkpoint compatibility with quanto * more tests * docs * adjust memory * better * style * pass tests * Update src/transformers/modeling_utils.py Co-authored-by: Younes Belkada <[email protected]> * add is_safe_serialization instead * Update src/transformers/quantizers/quantizer_quanto.py Co-authored-by: Younes Belkada <[email protected]> * add QbitsTensor tests * fix tests * simplify activation list * Update docs/source/en/quantization.md Co-authored-by: David Corvoysier <[email protected]> * better comment * Update tests/quantization/quanto_integration/test_quanto.py Co-authored-by: David Corvoysier <[email protected]> * Update tests/quantization/quanto_integration/test_quanto.py Co-authored-by: David Corvoysier <[email protected]> * find and fix edge case * Update docs/source/en/quantization.md Co-authored-by: Arthur <[email protected]> * pass weights_only_kwarg instead * fix shard_checkpoint loading * simplify update_missing_keys * Update tests/quantization/quanto_integration/test_quanto.py Co-authored-by: Arthur <[email protected]> * recursion to get all tensors * block serialization * skip serialization tests * fix * change by cuda:0 for now * fix regression * update device_map * fix doc * add noteboon * update torch_dtype * update doc * typo * typo * remove comm --------- Co-authored-by: Younes Belkada <[email protected]> Co-authored-by: David Corvoysier <[email protected]> Co-authored-by: Arthur <[email protected]> Co-authored-by: Younes Belkada <[email protected]>
1 parent f02aea2 commit 28de2f4

File tree

18 files changed

+885
-7
lines changed

18 files changed

+885
-7
lines changed

docker/transformers-quantization-latest-gpu/Dockerfile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ RUN python3 -m pip install --no-cache-dir aqlm[gpu]==1.0.2
4545
# Add autoawq for quantization testing
4646
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.0/autoawq-0.2.0+cu118-cp38-cp38-linux_x86_64.whl
4747

48+
# Add quanto for quantization testing
49+
RUN python3 -m pip install --no-cache-dir quanto
50+
4851
# When installing in editable mode, `transformers` is not recognized as a package.
4952
# this line must be added in order for python to be aware of transformers.
5053
RUN cd transformers && python3 setup.py develop

docs/source/en/main_classes/quantization.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
2626

2727
</Tip>
2828

29+
## QuantoConfig
30+
31+
[[autodoc]] QuantoConfig
32+
2933
## AqlmConfig
3034

3135
[[autodoc]] AqlmConfig

docs/source/en/quantization.md

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,59 @@ Interested in adding a new quantization method to Transformers? Read the [HfQuan
2626

2727
</Tip>
2828

29+
## Quanto
30+
31+
<Tip>
32+
33+
Try Quanto + transformers with this [notebook](https://colab.research.google.com/drive/16CXfVmtdQvciSh9BopZUDYcmXCDpvgrT?usp=sharing)!
34+
35+
</Tip>
36+
37+
38+
[🤗 Quanto](https://github.com/huggingface/quanto) library is a versatile pytorch quantization toolkit. The quantization method used is the linear quantization. Quanto provides several unique features such as:
39+
40+
- weights quantization (`float8`,`int8`,`int4`,`int2`)
41+
- activation quantization (`float8`,`int8`)
42+
- modality agnostic (e.g CV,LLM)
43+
- device agnostic (e.g CUDA,MPS,CPU)
44+
- compatibility with `torch.compile`
45+
- easy to add custom kernel for specific device
46+
- supports quantization aware training
47+
<!-- Add link to the blogpost -->
48+
49+
Before you begin, make sure the following libraries are installed:
50+
51+
```bash
52+
pip install quanto
53+
pip install git+https://github.com/huggingface/accelerate.git
54+
pip install git+https://github.com/huggingface/transformers.git
55+
```
56+
57+
Now you can quantize a model by passing [`QuantoConfig`] object in the [`~PreTrainedModel.from_pretrained`] method. This works for any model in any modality, as long as it contains `torch.nn.Linear` layers.
58+
59+
The integration with transformers only supports weights quantization. For the more complex use case such as activation quantization, calibration and quantization aware training, you should use [quanto](https://github.com/huggingface/quanto) library instead.
60+
61+
```py
62+
from transformers import AutoModelForCausalLM, AutoTokenizer, QuantoConfig
63+
64+
model_id = "facebook/opt-125m"
65+
tokenizer = AutoTokenizer.from_pretrained(model_id)
66+
quantization_config = QuantoConfig(weights="int8")
67+
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda:0", quantization_config=quantization_config)
68+
```
69+
70+
Note that serialization is not supported yet with transformers but it is coming soon! If you want to save the model, you can use quanto library instead.
71+
72+
Quanto library uses linear quantization algorithm for quantization. Even though this is a basic quantization technique, we get very good results! Have a look at the following becnhmark (llama-2-7b on perplexity metric). You can find more benchamarks [here](https://github.com/huggingface/quanto/tree/main/bench/generation)
73+
74+
<div class="flex gap-4">
75+
<div>
76+
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/quantization/NousResearch-Llama-2-7b-hf_Perplexity.png" alt="llama-2-7b-quanto-perplexity" />
77+
</div>
78+
</div>
79+
80+
The library is versatible enough to be compatible with most PTQ optimization algorithms. The plan in the future is to integrate the most popular algorithms in the most seamless possible way (AWQ, Smoothquant).
81+
2982
## AQLM
3083

3184

src/transformers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,7 +1100,7 @@
11001100
"is_vision_available",
11011101
"logging",
11021102
],
1103-
"utils.quantization_config": ["AqlmConfig", "AwqConfig", "BitsAndBytesConfig", "GPTQConfig"],
1103+
"utils.quantization_config": ["AqlmConfig", "AwqConfig", "BitsAndBytesConfig", "GPTQConfig", "QuantoConfig"],
11041104
}
11051105

11061106
# sentencepiece-backed objects
@@ -5921,7 +5921,7 @@
59215921
)
59225922

59235923
# bitsandbytes config
5924-
from .utils.quantization_config import AqlmConfig, AwqConfig, BitsAndBytesConfig, GPTQConfig
5924+
from .utils.quantization_config import AqlmConfig, AwqConfig, BitsAndBytesConfig, GPTQConfig, QuantoConfig
59255925

59265926
try:
59275927
if not is_sentencepiece_available():

src/transformers/integrations/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
"run_hp_search_wandb",
8383
],
8484
"peft": ["PeftAdapterMixin"],
85+
"quanto": ["replace_with_quanto_layers"],
8586
}
8687

8788
if TYPE_CHECKING:
@@ -150,6 +151,7 @@
150151
run_hp_search_wandb,
151152
)
152153
from .peft import PeftAdapterMixin
154+
from .quanto import replace_with_quanto_layers
153155
else:
154156
import sys
155157

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from ..utils import is_torch_available
16+
17+
18+
if is_torch_available():
19+
import torch
20+
21+
22+
def replace_with_quanto_layers(
23+
model,
24+
quantization_config=None,
25+
modules_to_not_convert=None,
26+
current_key_name=None,
27+
has_been_replaced=False,
28+
):
29+
"""
30+
Public method that recursively replaces the Linear layers of the given model with Quanto quantized layers.
31+
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
32+
33+
Args:
34+
model (`torch.nn.Module`):
35+
The model to convert, can be any `torch.nn.Module` instance.
36+
quantization_config (`AqlmConfig`, defaults to `None`):
37+
The quantization config object that contains the quantization parameters.
38+
modules_to_not_convert (`list`, *optional*, defaults to `None`):
39+
A list of modules to not convert. If a module name is in the list (e.g. `lm_head`), it will not be
40+
converted.
41+
current_key_name (`list`, *optional*, defaults to `None`):
42+
A list that contains the current key name. This is used for recursion and should not be passed by the user.
43+
has_been_replaced (`bool`, *optional*, defaults to `None`):
44+
A boolean that indicates if the conversion has been successful or not. This is used for recursion and
45+
should not be passed by the user.
46+
"""
47+
from accelerate import init_empty_weights
48+
from quanto import QLayerNorm, QLinear, qfloat8, qint2, qint4, qint8
49+
50+
w_mapping = {"float8": qfloat8, "int8": qint8, "int4": qint4, "int2": qint2}
51+
a_mapping = {None: None, "float8": qfloat8, "int8": qint8}
52+
53+
if modules_to_not_convert is None:
54+
modules_to_not_convert = []
55+
56+
for name, module in model.named_children():
57+
if current_key_name is None:
58+
current_key_name = []
59+
current_key_name.append(name)
60+
61+
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
62+
with init_empty_weights():
63+
if isinstance(module, torch.nn.Linear):
64+
model._modules[name] = QLinear(
65+
in_features=module.in_features,
66+
out_features=module.out_features,
67+
bias=module.bias is not None,
68+
dtype=module.weight.dtype,
69+
weights=w_mapping[quantization_config.weights],
70+
activations=a_mapping[quantization_config.activations],
71+
)
72+
model._modules[name].requires_grad_(False)
73+
has_been_replaced = True
74+
elif isinstance(module, torch.nn.LayerNorm):
75+
if quantization_config.activations is not None:
76+
model._modules[name] = QLayerNorm(
77+
module.normalized_shape,
78+
module.eps,
79+
module.elementwise_affine,
80+
module.bias is not None,
81+
activations=a_mapping[quantization_config.activations],
82+
)
83+
has_been_replaced = True
84+
if len(list(module.children())) > 0:
85+
_, has_been_replaced = replace_with_quanto_layers(
86+
module,
87+
quantization_config=quantization_config,
88+
modules_to_not_convert=modules_to_not_convert,
89+
current_key_name=current_key_name,
90+
has_been_replaced=has_been_replaced,
91+
)
92+
# Remove the last key for recursion
93+
current_key_name.pop(-1)
94+
return model, has_been_replaced

src/transformers/modeling_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -802,7 +802,11 @@ def _load_state_dict_into_meta_model(
802802
elif (
803803
not is_quantized
804804
or (not hf_quantizer.requires_parameters_quantization)
805-
or (not hf_quantizer.check_quantized_param(model, param, param_name, state_dict))
805+
or (
806+
not hf_quantizer.check_quantized_param(
807+
model, param, param_name, state_dict, param_device=param_device, device_map=device_map
808+
)
809+
)
806810
):
807811
# For backward compatibility with older versions of `accelerate` and for non-quantized params
808812
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
@@ -3728,6 +3732,9 @@ def _fix_key(key):
37283732
for pat in cls._keys_to_ignore_on_load_unexpected:
37293733
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
37303734

3735+
if hf_quantizer is not None:
3736+
missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix)
3737+
37313738
# retrieve weights on meta device and put them back on CPU.
37323739
# This is not ideal in terms of memory, but if we don't do that not, we can't initialize them in the next step
37333740
if low_cpu_mem_usage:

src/transformers/quantizers/auto.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@
2222
GPTQConfig,
2323
QuantizationConfigMixin,
2424
QuantizationMethod,
25+
QuantoConfig,
2526
)
2627
from .quantizer_aqlm import AqlmHfQuantizer
2728
from .quantizer_awq import AwqQuantizer
2829
from .quantizer_bnb_4bit import Bnb4BitHfQuantizer
2930
from .quantizer_bnb_8bit import Bnb8BitHfQuantizer
3031
from .quantizer_gptq import GptqHfQuantizer
32+
from .quantizer_quanto import QuantoHfQuantizer
3133

3234

3335
AUTO_QUANTIZER_MAPPING = {
@@ -36,6 +38,7 @@
3638
"bitsandbytes_8bit": Bnb8BitHfQuantizer,
3739
"gptq": GptqHfQuantizer,
3840
"aqlm": AqlmHfQuantizer,
41+
"quanto": QuantoHfQuantizer,
3942
}
4043

4144
AUTO_QUANTIZATION_CONFIG_MAPPING = {
@@ -44,6 +47,7 @@
4447
"bitsandbytes_8bit": BitsAndBytesConfig,
4548
"gptq": GPTQConfig,
4649
"aqlm": AqlmConfig,
50+
"quanto": QuantoConfig,
4751
}
4852

4953

src/transformers/quantizers/base.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from abc import ABC, abstractmethod
15-
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
15+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
1616

1717
from ..utils import is_torch_available
1818
from ..utils.quantization_config import QuantizationConfigMixin
@@ -99,6 +99,16 @@ def adjust_target_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
9999
"""
100100
return torch_dtype
101101

102+
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
103+
"""
104+
Override this method if you want to adjust the `missing_keys`.
105+
106+
Args:
107+
missing_keys (`List[str]`, *optional*):
108+
The list of missing keys in the checkpoint compared to the state dict of the model
109+
"""
110+
return missing_keys
111+
102112
def get_special_dtypes_update(self, model, torch_dtype: "torch.dtype") -> Dict[str, "torch.dtype"]:
103113
"""
104114
returns dtypes for modules that are not quantized - used for the computation of the device_map in case
@@ -111,6 +121,7 @@ def get_special_dtypes_update(self, model, torch_dtype: "torch.dtype") -> Dict[s
111121
torch_dtype (`torch.dtype`):
112122
The dtype passed in `from_pretrained` method.
113123
"""
124+
114125
return {
115126
name: torch_dtype
116127
for name, _ in model.named_parameters()
@@ -122,7 +133,12 @@ def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str,
122133
return max_memory
123134

124135
def check_quantized_param(
125-
self, model: "PreTrainedModel", param_value: "torch.Tensor", param_name: str, state_dict: Dict[str, Any]
136+
self,
137+
model: "PreTrainedModel",
138+
param_value: "torch.Tensor",
139+
param_name: str,
140+
state_dict: Dict[str, Any],
141+
**kwargs,
126142
) -> bool:
127143
"""
128144
checks if a loaded state_dict component is part of quantized param + some validation; only defined if

src/transformers/quantizers/quantizer_bnb_4bit.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,12 @@ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
116116
)
117117

118118
def check_quantized_param(
119-
self, model: "PreTrainedModel", param_value: "torch.Tensor", param_name: str, state_dict: Dict[str, Any]
119+
self,
120+
model: "PreTrainedModel",
121+
param_value: "torch.Tensor",
122+
param_name: str,
123+
state_dict: Dict[str, Any],
124+
**kwargs,
120125
) -> bool:
121126
import bitsandbytes as bnb
122127

0 commit comments

Comments
 (0)