Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
ba4c2b9
start integration
SunMarc Feb 14, 2024
dc88d4f
fix
SunMarc Feb 14, 2024
ee1ee85
add and debug tests
SunMarc Feb 15, 2024
4c50c4d
update tests
SunMarc Feb 20, 2024
97951ab
make pytorch serialization works
SunMarc Feb 20, 2024
c8436ca
compatible with device_map and offload
SunMarc Feb 21, 2024
9d15653
fix tests
SunMarc Feb 22, 2024
194a58a
Merge remote-tracking branch 'upstream/main' into quanto_integration
SunMarc Feb 22, 2024
d1ccb23
make style
SunMarc Feb 22, 2024
b0e8adb
add ref
SunMarc Feb 22, 2024
6193289
Merge remote-tracking branch 'upstream/main' into quanto_integration
SunMarc Feb 28, 2024
b550157
guard against safetensors
SunMarc Feb 28, 2024
29eee50
add float8 and style
SunMarc Feb 28, 2024
26fe440
fix is_serializable
SunMarc Feb 28, 2024
6d4ab4c
Fix shard_checkpoint compatibility with quanto
SunMarc Feb 28, 2024
daaeb91
more tests
SunMarc Feb 28, 2024
565e699
docs
SunMarc Feb 28, 2024
56ba706
adjust memory
SunMarc Feb 29, 2024
9329a07
better
SunMarc Feb 29, 2024
9da4d0b
style
SunMarc Mar 1, 2024
c13a4ef
pass tests
SunMarc Mar 1, 2024
de9c79a
Update src/transformers/modeling_utils.py
SunMarc Mar 4, 2024
1a7721a
add is_safe_serialization instead
SunMarc Mar 4, 2024
849448d
Merge branch 'quanto_integration' of https://github.com/SunMarc/trans…
SunMarc Mar 4, 2024
c980409
Update src/transformers/quantizers/quantizer_quanto.py
SunMarc Mar 4, 2024
80a5c29
add QbitsTensor tests
SunMarc Mar 4, 2024
c6f66f0
fix tests
SunMarc Mar 4, 2024
7deb644
simplify activation list
SunMarc Mar 5, 2024
693e593
Update docs/source/en/quantization.md
SunMarc Mar 5, 2024
528916b
better comment
SunMarc Mar 5, 2024
7a95507
Update tests/quantization/quanto_integration/test_quanto.py
SunMarc Mar 5, 2024
d60c797
Update tests/quantization/quanto_integration/test_quanto.py
SunMarc Mar 5, 2024
b73c5ee
Merge branch 'quanto_integration' of https://github.com/SunMarc/trans…
SunMarc Mar 5, 2024
1489a1b
Merge branch 'main' into quanto_integration
SunMarc Mar 5, 2024
5e98443
find and fix edge case
SunMarc Mar 5, 2024
850f5e4
Update docs/source/en/quantization.md
SunMarc Mar 6, 2024
5fc659c
pass weights_only_kwarg instead
SunMarc Mar 7, 2024
15f7a2a
fix shard_checkpoint loading
SunMarc Mar 7, 2024
bf5f7e6
simplify update_missing_keys
SunMarc Mar 7, 2024
c52b6c1
Merge remote-tracking branch 'upstream/main' into quanto_integration
SunMarc Mar 7, 2024
ad012e0
Update tests/quantization/quanto_integration/test_quanto.py
SunMarc Mar 8, 2024
3419a3c
recursion to get all tensors
SunMarc Mar 8, 2024
bb7c226
Merge branch 'quanto_integration' of https://github.com/SunMarc/trans…
SunMarc Mar 8, 2024
a1b3c18
block serialization
SunMarc Mar 14, 2024
0030d0a
skip serialization tests
SunMarc Mar 14, 2024
6d1bce3
fix
SunMarc Mar 14, 2024
e677a53
change by cuda:0 for now
SunMarc Mar 14, 2024
e005baf
fix regression
SunMarc Mar 14, 2024
dc8547d
Merge remote-tracking branch 'upstream/main' into quanto_integration
SunMarc Mar 14, 2024
229e439
update device_map
SunMarc Mar 14, 2024
8f5c9f7
fix doc
SunMarc Mar 14, 2024
d4cc911
add noteboon
younesbelkada Mar 14, 2024
95f05a4
update torch_dtype
SunMarc Mar 14, 2024
058937c
Merge branch 'quanto_integration' of https://github.com/SunMarc/trans…
SunMarc Mar 14, 2024
5bfa654
update doc
SunMarc Mar 14, 2024
b0b79f0
typo
SunMarc Mar 14, 2024
e389cd9
typo
SunMarc Mar 14, 2024
46aae3f
remove comm
SunMarc Mar 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docker/transformers-quantization-latest-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ RUN python3 -m pip install --no-cache-dir aqlm[gpu]==1.0.2
# Add autoawq for quantization testing
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.0/autoawq-0.2.0+cu118-cp38-cp38-linux_x86_64.whl

# Add quanto for quantization testing
RUN python3 -m pip install --no-cache-dir quanto

# When installing in editable mode, `transformers` is not recognized as a package.
# this line must be added in order for python to be aware of transformers.
RUN cd transformers && python3 setup.py develop
4 changes: 4 additions & 0 deletions docs/source/en/main_classes/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ Learn how to quantize models in the [Quantization](../quantization) guide.

</Tip>

## QuantoConfig

[[autodoc]] QuantoConfig

## AqlmConfig

[[autodoc]] AqlmConfig
Expand Down
53 changes: 53 additions & 0 deletions docs/source/en/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,59 @@ Interested in adding a new quantization method to Transformers? Read the [HfQuan

</Tip>

## Quanto

<Tip>

Try Quanto + transformers with this [notebook](https://colab.research.google.com/drive/16CXfVmtdQvciSh9BopZUDYcmXCDpvgrT?usp=sharing)!

</Tip>


[🤗 Quanto](https://github.com/huggingface/quanto) library is a versatile pytorch quantization toolkit. The quantization method used is the linear quantization. Quanto provides several unique features such as:

- weights quantization (`float8`,`int8`,`int4`,`int2`)
- activation quantization (`float8`,`int8`)
- modality agnostic (e.g CV,LLM)
- device agnostic (e.g CUDA,MPS,CPU)
- compatibility with `torch.compile`
- easy to add custom kernel for specific device
- supports quantization aware training
<!-- Add link to the blogpost -->

Before you begin, make sure the following libraries are installed:

```bash
pip install quanto
pip install git+https://github.com/huggingface/accelerate.git
pip install git+https://github.com/huggingface/transformers.git
```

Now you can quantize a model by passing [`QuantoConfig`] object in the [`~PreTrainedModel.from_pretrained`] method. This works for any model in any modality, as long as it contains `torch.nn.Linear` layers.

The integration with transformers only supports weights quantization. For the more complex use case such as activation quantization, calibration and quantization aware training, you should use [quanto](https://github.com/huggingface/quanto) library instead.

```py
from transformers import AutoModelForCausalLM, AutoTokenizer, QuantoConfig

model_id = "facebook/opt-125m"
tokenizer = AutoTokenizer.from_pretrained(model_id)
quantization_config = QuantoConfig(weights="int8")
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda:0", quantization_config=quantization_config)
```

Note that serialization is not supported yet with transformers but it is coming soon! If you want to save the model, you can use quanto library instead.

Quanto library uses linear quantization algorithm for quantization. Even though this is a basic quantization technique, we get very good results! Have a look at the following becnhmark (llama-2-7b on perplexity metric). You can find more benchamarks [here](https://github.com/huggingface/quanto/tree/main/bench/generation)

<div class="flex gap-4">
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/quantization/NousResearch-Llama-2-7b-hf_Perplexity.png" alt="llama-2-7b-quanto-perplexity" />
</div>
</div>

The library is versatible enough to be compatible with most PTQ optimization algorithms. The plan in the future is to integrate the most popular algorithms in the most seamless possible way (AWQ, Smoothquant).

## AQLM


Expand Down
4 changes: 2 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,7 +1098,7 @@
"is_vision_available",
"logging",
],
"utils.quantization_config": ["AqlmConfig", "AwqConfig", "BitsAndBytesConfig", "GPTQConfig"],
"utils.quantization_config": ["AqlmConfig", "AwqConfig", "BitsAndBytesConfig", "GPTQConfig", "QuantoConfig"],
}

# sentencepiece-backed objects
Expand Down Expand Up @@ -5906,7 +5906,7 @@
)

# bitsandbytes config
from .utils.quantization_config import AqlmConfig, AwqConfig, BitsAndBytesConfig, GPTQConfig
from .utils.quantization_config import AqlmConfig, AwqConfig, BitsAndBytesConfig, GPTQConfig, QuantoConfig

try:
if not is_sentencepiece_available():
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
"run_hp_search_wandb",
],
"peft": ["PeftAdapterMixin"],
"quanto": ["replace_with_quanto_layers"],
}

if TYPE_CHECKING:
Expand Down Expand Up @@ -150,6 +151,7 @@
run_hp_search_wandb,
)
from .peft import PeftAdapterMixin
from .quanto import replace_with_quanto_layers
else:
import sys

Expand Down
94 changes: 94 additions & 0 deletions src/transformers/integrations/quanto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ..utils import is_torch_available


if is_torch_available():
import torch


def replace_with_quanto_layers(
model,
quantization_config=None,
modules_to_not_convert=None,
current_key_name=None,
has_been_replaced=False,
):
"""
Public method that recursively replaces the Linear layers of the given model with Quanto quantized layers.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.

Args:
model (`torch.nn.Module`):
The model to convert, can be any `torch.nn.Module` instance.
quantization_config (`AqlmConfig`, defaults to `None`):
The quantization config object that contains the quantization parameters.
modules_to_not_convert (`list`, *optional*, defaults to `None`):
A list of modules to not convert. If a module name is in the list (e.g. `lm_head`), it will not be
converted.
current_key_name (`list`, *optional*, defaults to `None`):
A list that contains the current key name. This is used for recursion and should not be passed by the user.
has_been_replaced (`bool`, *optional*, defaults to `None`):
A boolean that indicates if the conversion has been successful or not. This is used for recursion and
should not be passed by the user.
"""
from accelerate import init_empty_weights
from quanto import QLayerNorm, QLinear, qfloat8, qint2, qint4, qint8

w_mapping = {"float8": qfloat8, "int8": qint8, "int4": qint4, "int2": qint2}
a_mapping = {None: None, "float8": qfloat8, "int8": qint8}

if modules_to_not_convert is None:
modules_to_not_convert = []

for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)

if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
with init_empty_weights():
if isinstance(module, torch.nn.Linear):
model._modules[name] = QLinear(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
dtype=module.weight.dtype,
weights=w_mapping[quantization_config.weights],
activations=a_mapping[quantization_config.activations],
)
model._modules[name].requires_grad_(False)
has_been_replaced = True
elif isinstance(module, torch.nn.LayerNorm):
if quantization_config.activations is not None:
model._modules[name] = QLayerNorm(
module.normalized_shape,
module.eps,
module.elementwise_affine,
module.bias is not None,
activations=a_mapping[quantization_config.activations],
)
has_been_replaced = True
if len(list(module.children())) > 0:
_, has_been_replaced = replace_with_quanto_layers(
module,
quantization_config=quantization_config,
modules_to_not_convert=modules_to_not_convert,
current_key_name=current_key_name,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced
9 changes: 8 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,11 @@ def _load_state_dict_into_meta_model(
elif (
not is_quantized
or (not hf_quantizer.requires_parameters_quantization)
or (not hf_quantizer.check_quantized_param(model, param, param_name, state_dict))
or (
not hf_quantizer.check_quantized_param(
model, param, param_name, state_dict, param_device=param_device, device_map=device_map
)
)
):
# For backward compatibility with older versions of `accelerate` and for non-quantized params
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
Expand Down Expand Up @@ -3728,6 +3732,9 @@ def _fix_key(key):
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]

if hf_quantizer is not None:
missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix)
Comment on lines +3735 to +3736
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this is something we want / I understand why this is needed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can try to sneak that in inside the hf_quantizer.create_quantized_param just like how bnb deal with unexpected_keys.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The buffer are created at the moment we replace the layers so modifying the missing keys here hf_quantizer.create_quantized_param don't make sense + it does not cover some edges cases . If possible, I prefer leaving it like this. LMK if this is possible or not.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright no worries 🤗


# retrieve weights on meta device and put them back on CPU.
# This is not ideal in terms of memory, but if we don't do that not, we can't initialize them in the next step
if low_cpu_mem_usage:
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
GPTQConfig,
QuantizationConfigMixin,
QuantizationMethod,
QuantoConfig,
)
from .quantizer_aqlm import AqlmHfQuantizer
from .quantizer_awq import AwqQuantizer
from .quantizer_bnb_4bit import Bnb4BitHfQuantizer
from .quantizer_bnb_8bit import Bnb8BitHfQuantizer
from .quantizer_gptq import GptqHfQuantizer
from .quantizer_quanto import QuantoHfQuantizer


AUTO_QUANTIZER_MAPPING = {
Expand All @@ -36,6 +38,7 @@
"bitsandbytes_8bit": Bnb8BitHfQuantizer,
"gptq": GptqHfQuantizer,
"aqlm": AqlmHfQuantizer,
"quanto": QuantoHfQuantizer,
}

AUTO_QUANTIZATION_CONFIG_MAPPING = {
Expand All @@ -44,6 +47,7 @@
"bitsandbytes_8bit": BitsAndBytesConfig,
"gptq": GPTQConfig,
"aqlm": AqlmConfig,
"quanto": QuantoConfig,
}


Expand Down
20 changes: 18 additions & 2 deletions src/transformers/quantizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from ..utils import is_torch_available
from ..utils.quantization_config import QuantizationConfigMixin
Expand Down Expand Up @@ -99,6 +99,16 @@ def adjust_target_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
"""
return torch_dtype

def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
"""
Override this method if you want to adjust the `missing_keys`.

Args:
missing_keys (`List[str]`, *optional*):
The list of missing keys in the checkpoint compared to the state dict of the model
"""
return missing_keys

def get_special_dtypes_update(self, model, torch_dtype: "torch.dtype") -> Dict[str, "torch.dtype"]:
"""
returns dtypes for modules that are not quantized - used for the computation of the device_map in case
Expand All @@ -111,6 +121,7 @@ def get_special_dtypes_update(self, model, torch_dtype: "torch.dtype") -> Dict[s
torch_dtype (`torch.dtype`):
The dtype passed in `from_pretrained` method.
"""

return {
name: torch_dtype
for name, _ in model.named_parameters()
Expand All @@ -122,7 +133,12 @@ def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str,
return max_memory

def check_quantized_param(
self, model: "PreTrainedModel", param_value: "torch.Tensor", param_name: str, state_dict: Dict[str, Any]
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
) -> bool:
"""
checks if a loaded state_dict component is part of quantized param + some validation; only defined if
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/quantizers/quantizer_bnb_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,12 @@ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
)

def check_quantized_param(
self, model: "PreTrainedModel", param_value: "torch.Tensor", param_name: str, state_dict: Dict[str, Any]
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
) -> bool:
import bitsandbytes as bnb

Expand Down
7 changes: 6 additions & 1 deletion src/transformers/quantizers/quantizer_bnb_8bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,12 @@ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
return torch.int8

def check_quantized_param(
self, model: "PreTrainedModel", param_value: "torch.Tensor", param_name: str, state_dict: Dict[str, Any]
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
):
import bitsandbytes as bnb

Expand Down
Loading