Skip to content

Commit 34f0a9f

Browse files
yiliu30chensuyue
andauthored
Add save/load support for HQQ (#1913)
Signed-off-by: yiliu30 <[email protected]> Co-authored-by: chen, suyue <[email protected]>
1 parent d320460 commit 34f0a9f

File tree

6 files changed

+188
-3
lines changed

6 files changed

+188
-3
lines changed

neural_compressor/torch/algorithms/weight_only/hqq/core.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
# NOTICE: the original `Quantizer` has been modified to `HQQTensorHandle`
2020
# and `QTensor` to decouple the data structure and the quantization logic.
2121

22-
from typing import Any, Dict, Tuple
22+
from typing import Any, Dict, Mapping, Tuple
2323

2424
import torch
2525

@@ -278,3 +278,61 @@ def from_float(
278278
# !!! Delete the float explicitly to save memory
279279
del float_module
280280
return new_mod
281+
282+
def state_dict(self, *args, **kwargs): # nn.Module override compatible
283+
state_dict = self.q_weight.to_state_dict()
284+
if self.bias is not None:
285+
state_dict["bias"] = self.bias
286+
if "destination" in kwargs and "prefix" in kwargs:
287+
for key, value in state_dict.items():
288+
kwargs["destination"][kwargs["prefix"] + key] = value
289+
return state_dict
290+
291+
def _load_from_state_dict(
292+
self,
293+
state_dict,
294+
prefix,
295+
local_metadata,
296+
strict,
297+
missing_keys,
298+
unexpected_keys,
299+
error_msgs,
300+
):
301+
all_expected_keys = ["val", "scale_quantized", "zero_quantized", "meta_info"]
302+
if self.bias is not None:
303+
all_expected_keys.append("bias")
304+
305+
for key in all_expected_keys:
306+
if prefix + key not in state_dict:
307+
missing_keys.append(key)
308+
if missing_keys:
309+
return # Can't load weights if either weight or meta is missing
310+
311+
cur_state_dict = {}
312+
for key in all_expected_keys:
313+
cur_state_dict[key] = state_dict.pop(prefix + key)
314+
315+
unexpected_keys += state_dict.keys()
316+
self._assign_state_dict(cur_state_dict, strict)
317+
318+
def _assign_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False):
319+
_scale_quantized = state_dict["scale_quantized"]
320+
_zero_quantized = state_dict["zero_quantized"]
321+
scale_state = state_dict["meta_info"]["scale"]
322+
zero_state = state_dict["meta_info"]["zero"]
323+
if _scale_quantized:
324+
scale = HQQTensorHandle._create_q_tensor(scale_state["val"], scale_state["meta_info"])
325+
else:
326+
scale = state_dict["meta_info"]["scale"]
327+
if _zero_quantized:
328+
zero = HQQTensorHandle._create_q_tensor(zero_state["val"], zero_state["meta_info"])
329+
else:
330+
zero = state_dict["meta_info"]["zero"]
331+
meta = state_dict["meta_info"]
332+
meta["scale"] = scale
333+
meta["zero"] = zero
334+
self.q_weight = HQQTensorHandle._create_q_tensor(state_dict["val"], meta)
335+
if self.bias is not None:
336+
self.bias = state_dict["bias"]
337+
self.quantized = True
338+
return self

neural_compressor/torch/algorithms/weight_only/hqq/qtensor.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,19 @@ def half(self):
115115
if self.zero is not None:
116116
self.zero = self.zero.half()
117117
return self
118+
119+
def to_state_dict(self):
120+
state = {}
121+
state["val"] = self.val
122+
state["meta_info"] = self.meta_info.to_dict()
123+
state["scale_quantized"] = self.is_scale_quantized()
124+
state["zero_quantized"] = self.is_zero_quantized()
125+
if self.is_scale_quantized():
126+
state["meta_info"]["scale"] = self.scale.to_state_dict()
127+
else:
128+
state["meta_info"]["scale"] = self.scale
129+
if self.is_zero_quantized():
130+
state["meta_info"]["zero"] = self.zero.to_state_dict()
131+
else:
132+
state["meta_info"]["zero"] = self.zero
133+
return state

neural_compressor/torch/algorithms/weight_only/save_load.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ def load_inc_format_woq_model(self, qmodel_weight_file_path, qconfig_file_path):
124124

125125
with open(qconfig_file_path, "r") as file:
126126
self.quantization_config = json.load(file)
127-
128127
model = self._build_woq_model()
129128
model.load_state_dict(qweights, assign=True)
130129
model.eval()
@@ -157,8 +156,19 @@ def load_hf_format_woq_model(self):
157156

158157
return model
159158

159+
def _is_hqq_model(self):
160+
for name, module in self.original_model.named_modules():
161+
pattern = rf"(\(.*{re.escape(name)}.*{re.escape(type(module).__name__)}.*\))"
162+
for q_config_key, q_config_value in self.quantization_config.items():
163+
if re.search(pattern, q_config_key):
164+
if isinstance(q_config_value, dict) and [algo for algo in q_config_value.keys()][0] == "hqq":
165+
return True
166+
160167
def _build_woq_model(self):
161168
"""Build weight-only quantization model."""
169+
if self._is_hqq_model():
170+
return self._build_hqq_model()
171+
162172
from neural_compressor.torch.utils import set_module
163173

164174
from .modules import MulLinear
@@ -228,6 +238,23 @@ def _build_woq_model(self):
228238
woq_model = self.original_model
229239
return woq_model
230240

241+
def _build_hqq_model(self):
242+
"""Replace quantized Linear with HQQLinear."""
243+
from neural_compressor.torch.algorithms.weight_only.hqq.core import HQQLinear
244+
from neural_compressor.torch.utils import set_module
245+
246+
for name, module in self.original_model.named_modules():
247+
if isinstance(module, torch.nn.Linear):
248+
loaded_state_dict_keys_set = set(self.loaded_state_dict_keys)
249+
if name + ".val" not in loaded_state_dict_keys_set:
250+
continue
251+
new_module = HQQLinear(
252+
in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None
253+
)
254+
set_module(self.original_model, name, new_module)
255+
woq_model = self.original_model
256+
return woq_model
257+
231258
def _get_model_class_and_config(self):
232259
from transformers import AutoConfig, AutoModelForCausalLM
233260
from transformers.dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code

neural_compressor/torch/quantization/algorithm_entry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,11 +517,14 @@ def hqq_entry(
517517
**kwargs,
518518
) -> torch.nn.Module:
519519
from neural_compressor.torch.algorithms.weight_only.hqq import HQQuantizer
520+
from neural_compressor.torch.algorithms.weight_only.save_load import save
520521

521522
logger.info("Quantize model with the HQQ algorithm.")
522523

523524
quantizer = get_quantizer(model, quantizer_cls=HQQuantizer, quant_config=configs_mapping)
524525
model = quantizer.execute(model, mode=mode)
526+
model.qconfig = configs_mapping
527+
model.save = MethodType(save, model)
525528
postprocess_model(model, mode, quantizer)
526529
dump_model_op_stats(mode, configs_mapping)
527530

neural_compressor/torch/quantization/load_entry.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
AWQConfig,
2323
FP8Config,
2424
GPTQConfig,
25+
HQQConfig,
2526
RTNConfig,
2627
TEQConfig,
2728
)
@@ -89,7 +90,9 @@ def load(model_name_or_path, original_model=None, format="default", device="cpu"
8990
# select load function
9091
config_object = config_mapping[next(iter(config_mapping))]
9192

92-
if isinstance(config_object, (RTNConfig, GPTQConfig, AWQConfig, TEQConfig, AutoRoundConfig)): # WOQ
93+
if isinstance(
94+
config_object, (RTNConfig, GPTQConfig, AWQConfig, TEQConfig, AutoRoundConfig, HQQConfig)
95+
): # WOQ
9396
from neural_compressor.torch.algorithms import weight_only
9497

9598
return weight_only.load(model_name_or_path, original_model, format=LoadFormat.DEFAULT)

test/3x/torch/quantization/weight_only/test_hqq.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1+
import copy
12
import os
3+
import time
24
from copy import deepcopy
35

46
import pytest
57
import torch
68
import transformers
79
from transformers import AutoModelForCausalLM
810

11+
from neural_compressor.common import options
912
from neural_compressor.common.utils import logger
1013
from neural_compressor.torch.algorithms.weight_only.hqq.config import HQQModuleConfig, QTensorConfig, hqq_global_option
1114
from neural_compressor.torch.algorithms.weight_only.hqq.core import HQQLinear
@@ -93,6 +96,27 @@ def test_hqq_quant(self, force_use_cpu, force_not_half):
9396
q_label_1.eq(q_label_2)
9497
), "The results of calling `convert` + `prepare` and calling `quantize` should be equal."
9598

99+
def test_hqq_load_save(self, force_use_cpu, force_not_half):
100+
101+
hqq_global_option.use_half = False
102+
fp32_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-random-OPTForCausalLM")
103+
example_inputs = torch.tensor([[10, 20, 30, 40, 50, 60]], dtype=torch.long, device="cpu")
104+
# test_default_config
105+
quant_config = get_default_hqq_config()
106+
107+
# prepare + convert API
108+
model = prepare(deepcopy(fp32_model), quant_config)
109+
qmodel = convert(model)
110+
qmodel_out_ref = model(example_inputs)[0]
111+
save_path = options.workspace + f"/_hqq_model_{time.time()}.pth"
112+
qmodel.save(save_path)
113+
from neural_compressor.torch.quantization import load
114+
115+
# loading compressed model
116+
loaded_model = load(save_path, copy.deepcopy(fp32_model))
117+
loaded_model_out = loaded_model(example_inputs)[0]
118+
assert torch.allclose(qmodel_out_ref, loaded_model_out), "Unexpected result. Please double check."
119+
96120
def test_hqq_fallback(self, force_use_cpu, force_not_half):
97121

98122
class ToyModel(torch.nn.Module):
@@ -181,3 +205,57 @@ def test_hqq_module(
181205
scale_quant_group_size=scale_quant_group_size,
182206
device=torch.device(device_name),
183207
)
208+
209+
@pytest.mark.parametrize(
210+
"nbits, group_size, quant_zero, quant_scale, scale_quant_group_size",
211+
[
212+
(4, 64, True, False, 128),
213+
(4, 64, False, False, 128),
214+
(4, 64, True, True, 128),
215+
(4, 64, False, True, 128),
216+
(8, 64, True, False, 128),
217+
],
218+
)
219+
def test_hqq_linear_save_and_load(
220+
self,
221+
nbits,
222+
group_size,
223+
quant_zero,
224+
quant_scale,
225+
scale_quant_group_size,
226+
):
227+
hqq_global_option.use_half = False
228+
# Parse config
229+
weight_qconfig = QTensorConfig(
230+
nbits=nbits,
231+
channel_wise=True,
232+
group_size=group_size,
233+
optimize=True,
234+
round_zero=True if nbits == 4 else False,
235+
)
236+
zero_qconfig = None
237+
if quant_zero:
238+
zero_qconfig = QTensorConfig(nbits=8, channel_wise=False, group_size=None, optimize=False)
239+
scale_qconfig = None
240+
if quant_scale:
241+
scale_qconfig = QTensorConfig(nbits=8, channel_wise=True, group_size=scale_quant_group_size, optimize=False)
242+
hqq_quant_config = HQQModuleConfig(weight=weight_qconfig, scale=scale_qconfig, zero=zero_qconfig)
243+
# Create HQQ Linear
244+
bs = 4
245+
in_features = 64
246+
out_features = 128
247+
float_linear = torch.nn.Linear(in_features=in_features, out_features=out_features)
248+
float_linear.to(device)
249+
float_linear_copy = deepcopy(float_linear)
250+
input = torch.randn(bs, in_features, device=device)
251+
hqq_linear = HQQLinear.from_float(float_linear_copy, quant_config=hqq_quant_config)
252+
out_ref = hqq_linear(input)
253+
state_dict = hqq_linear.state_dict()
254+
hqq_module_path = options.workspace + f"/_hqq_linear_{time.time()}.pth"
255+
torch.save(state_dict, hqq_module_path)
256+
reload_state_dict = torch.load(hqq_module_path)
257+
new_float = torch.nn.Linear(in_features=in_features, out_features=out_features)
258+
new_hqq_linear = HQQLinear.from_float(new_float, quant_config=hqq_quant_config)
259+
new_hqq_linear.load_state_dict(reload_state_dict)
260+
out = new_hqq_linear(input)
261+
assert torch.equal(out_ref, out), f"out_ref: {out_ref}, out: {out}"

0 commit comments

Comments
 (0)