-
Notifications
You must be signed in to change notification settings - Fork 31.3k
[Quantization] Quanto quantizer #29023
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ba4c2b9
dc88d4f
ee1ee85
4c50c4d
97951ab
c8436ca
9d15653
194a58a
d1ccb23
b0e8adb
6193289
b550157
29eee50
26fe440
6d4ab4c
daaeb91
565e699
56ba706
9329a07
9da4d0b
c13a4ef
de9c79a
1a7721a
849448d
c980409
80a5c29
c6f66f0
7deb644
693e593
528916b
7a95507
d60c797
b73c5ee
1489a1b
5e98443
850f5e4
5fc659c
15f7a2a
bf5f7e6
c52b6c1
ad012e0
3419a3c
bb7c226
a1b3c18
0030d0a
6d1bce3
e677a53
e005baf
dc8547d
229e439
8f5c9f7
d4cc911
95f05a4
058937c
5bfa654
b0b79f0
e389cd9
46aae3f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,94 @@ | ||
| # Copyright 2024 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from ..utils import is_torch_available | ||
|
|
||
|
|
||
| if is_torch_available(): | ||
| import torch | ||
|
|
||
|
|
||
| def replace_with_quanto_layers( | ||
| model, | ||
| quantization_config=None, | ||
| modules_to_not_convert=None, | ||
| current_key_name=None, | ||
| has_been_replaced=False, | ||
| ): | ||
| """ | ||
| Public method that recursively replaces the Linear layers of the given model with Quanto quantized layers. | ||
| Returns the converted model and a boolean that indicates if the conversion has been successfull or not. | ||
|
|
||
| Args: | ||
| model (`torch.nn.Module`): | ||
| The model to convert, can be any `torch.nn.Module` instance. | ||
| quantization_config (`AqlmConfig`, defaults to `None`): | ||
| The quantization config object that contains the quantization parameters. | ||
| modules_to_not_convert (`list`, *optional*, defaults to `None`): | ||
| A list of modules to not convert. If a module name is in the list (e.g. `lm_head`), it will not be | ||
| converted. | ||
| current_key_name (`list`, *optional*, defaults to `None`): | ||
| A list that contains the current key name. This is used for recursion and should not be passed by the user. | ||
| has_been_replaced (`bool`, *optional*, defaults to `None`): | ||
| A boolean that indicates if the conversion has been successful or not. This is used for recursion and | ||
| should not be passed by the user. | ||
| """ | ||
| from accelerate import init_empty_weights | ||
| from quanto import QLayerNorm, QLinear, qfloat8, qint2, qint4, qint8 | ||
|
|
||
| w_mapping = {"float8": qfloat8, "int8": qint8, "int4": qint4, "int2": qint2} | ||
| a_mapping = {None: None, "float8": qfloat8, "int8": qint8} | ||
|
|
||
| if modules_to_not_convert is None: | ||
| modules_to_not_convert = [] | ||
|
|
||
| for name, module in model.named_children(): | ||
| if current_key_name is None: | ||
| current_key_name = [] | ||
| current_key_name.append(name) | ||
|
|
||
| if not any(key in ".".join(current_key_name) for key in modules_to_not_convert): | ||
| with init_empty_weights(): | ||
| if isinstance(module, torch.nn.Linear): | ||
| model._modules[name] = QLinear( | ||
| in_features=module.in_features, | ||
| out_features=module.out_features, | ||
| bias=module.bias is not None, | ||
| dtype=module.weight.dtype, | ||
| weights=w_mapping[quantization_config.weights], | ||
| activations=a_mapping[quantization_config.activations], | ||
| ) | ||
| model._modules[name].requires_grad_(False) | ||
| has_been_replaced = True | ||
| elif isinstance(module, torch.nn.LayerNorm): | ||
| if quantization_config.activations is not None: | ||
| model._modules[name] = QLayerNorm( | ||
| module.normalized_shape, | ||
| module.eps, | ||
| module.elementwise_affine, | ||
| module.bias is not None, | ||
| activations=a_mapping[quantization_config.activations], | ||
| ) | ||
| has_been_replaced = True | ||
| if len(list(module.children())) > 0: | ||
| _, has_been_replaced = replace_with_quanto_layers( | ||
| module, | ||
| quantization_config=quantization_config, | ||
| modules_to_not_convert=modules_to_not_convert, | ||
| current_key_name=current_key_name, | ||
| has_been_replaced=has_been_replaced, | ||
| ) | ||
| # Remove the last key for recursion | ||
| current_key_name.pop(-1) | ||
| return model, has_been_replaced |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -802,7 +802,11 @@ def _load_state_dict_into_meta_model( | |
| elif ( | ||
| not is_quantized | ||
| or (not hf_quantizer.requires_parameters_quantization) | ||
| or (not hf_quantizer.check_quantized_param(model, param, param_name, state_dict)) | ||
| or ( | ||
| not hf_quantizer.check_quantized_param( | ||
| model, param, param_name, state_dict, param_device=param_device, device_map=device_map | ||
| ) | ||
| ) | ||
| ): | ||
| # For backward compatibility with older versions of `accelerate` and for non-quantized params | ||
| set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs) | ||
|
|
@@ -3728,6 +3732,9 @@ def _fix_key(key): | |
| for pat in cls._keys_to_ignore_on_load_unexpected: | ||
| unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] | ||
|
|
||
| if hf_quantizer is not None: | ||
| missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix) | ||
|
Comment on lines
+3735
to
+3736
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure this is something we want / I understand why this is needed
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can try to sneak that in inside the
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The buffer are created at the moment we replace the layers so modifying the missing keys here
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alright no worries 🤗 |
||
|
|
||
| # retrieve weights on meta device and put them back on CPU. | ||
| # This is not ideal in terms of memory, but if we don't do that not, we can't initialize them in the next step | ||
| if low_cpu_mem_usage: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.