Skip to content

Commit a529cf4

Browse files
Asaf Karnieliulivne
authored andcommitted
[ALGO-801] Add Fake Quant option in linear and matmul layers
Change-Id: I9888c92ffc33035f75d434044f4ef41b58f51e62
1 parent 09c6312 commit a529cf4

File tree

8 files changed

+165
-24
lines changed

8 files changed

+165
-24
lines changed

neural_compressor/torch/algorithms/fp8_quant/_core/common.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,6 @@ def __init__(self, num_inputs, param_names, num_outputs, required_output):
6868
"softmax": ModuleType(1, [], 1, True),
6969
"fused_sdpa": ModuleType(3, [], 2, True),
7070
}
71-
descale_fcn = lambda x, scale: torch.mul(x, scale)
72-
scale_fcn = lambda x, scale: torch.div(x, scale)
73-
cast_fcn = lambda x, dtype: x.to(dtype=dtype)
74-
cast_to_fp8_fcn = lambda x, dtype, scale_inv=None: torch.ops.hpu.cast_to_fp8_v2(x, scale_inv, False, False, dtype)[0]
75-
cast_from_fp8_fcn = lambda x, dtype, scale=None: torch.ops.hpu.cast_from_fp8(x, scale, dtype)
7671

7772

7873
class ShapeList:

neural_compressor/torch/algorithms/fp8_quant/_core/fp_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import torch
1516
import habana_frameworks.torch.core as htcore
1617
import habana_frameworks.torch.utils.experimental as htexp
17-
import torch
18-
19-
from .common import *
18+
from .common import ModuleConfig
19+
from .quant_dequant import cast_to_fp8_fcn, cast_fcn, descale_fcn, scale_fcn
2020

2121
GAUDI2 = htexp.synDeviceType.synDeviceGaudi2
2222
GAUDI3 = htexp.synDeviceType.synDeviceGaudi3

neural_compressor/torch/algorithms/fp8_quant/_core/quant_dequant.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,17 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import torch.nn as nn
16+
import torch
1517
from abc import abstractmethod
18+
import habana_frameworks.torch.core as htcore
1619

17-
import torch.nn as nn
1820

19-
from .common import *
21+
descale_fcn = lambda x, scale: torch.mul(x, scale)
22+
scale_fcn = lambda x, scale: torch.div(x, scale)
23+
cast_fcn = lambda x, dtype: x.to(dtype=dtype)
24+
cast_to_fp8_fcn = lambda x, dtype, scale_inv=None: torch.ops.hpu.cast_to_fp8_v2(x, scale_inv, False, False, dtype)[0]
25+
cast_from_fp8_fcn = lambda x, dtype, scale=None: torch.ops.hpu.cast_from_fp8(x, scale, dtype)
2026

2127

2228
class QuantDequantBase(nn.Module):
@@ -69,3 +75,22 @@ def forward(self, x):
6975
def extra_repr(self) -> str:
7076
repr = super(DequantOutput, self).extra_repr()
7177
return f"{repr}, scale dtype={self.scale.dtype}"
78+
79+
80+
class QuantDequant(QuantDequantBase):
81+
def __init__(self, scale_inv, lp_dtype, hp_dtype, *args, **kwargs):
82+
super(QuantDequant, self).__init__(lp_dtype, hp_dtype, *args, **kwargs)
83+
self.scale_inv = nn.Parameter(scale_inv)
84+
self.scale = nn.Parameter(1 / scale_inv)
85+
86+
def forward(self, x, *args, **kwargs):
87+
y = cast_to_fp8_fcn(x, self.lp_dtype, self.scale_inv)
88+
# mark_step is needed so fuser won't remove 2 consecutive casts.
89+
# will be removed once SW-196431 is implemented
90+
htcore.mark_step()
91+
z = cast_from_fp8_fcn(y, self.hp_dtype, self.scale)
92+
return z
93+
94+
def extra_repr(self) -> str:
95+
repr = super(QuantDequant, self).extra_repr()
96+
return f"{repr}, Quantize, and then dequantize"

neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,10 @@ def prepare_model(model, qconfig, mod_list, hp_dtype=torch.float):
9797
apply_hf_hook(mod)
9898
if name in mod_list:
9999
mod_extra_config = qconfig[name]
100-
quantize_params(mod, mod_extra_config)
100+
101+
if config.cfg["fake_quant"] == False:
102+
quantize_params(mod, mod_extra_config)
103+
101104
patch_module(mod, mod_extra_config, mod_default_dict)
102105
patched_modules.append(name)
103106
patched_module_types.add(type(mod))

neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import torch.nn as nn
1717

1818
from .quant_config import QuantMode, get_hqt_config
19+
from .._core.quant_dequant import QuantDequant as qdq
1920

2021
try: # backwards compatibility for 1.16
2122
from habana_frameworks.torch.hpex.kernels import fp8_fused_sdpa
@@ -122,6 +123,7 @@ def set_attrs_from_orig_model(cls_instance, mod, mod_extra_config, *func_names):
122123
cls_instance.class_name_org = mod.__class__.__name__
123124
cls_instance._mod_extra_config = mod_extra_config
124125
cls_instance.quantization_mode = config.cfg["mode"]
126+
cls_instance.fake_quant = config.cfg["fake_quant"]
125127
# store original module in order to invoke its functions during measurements.
126128
# this may be omitted of torch remove the related validation from dynamo. see SW-187731.
127129
cls_instance.__dict__["orig_mod"] = mod
@@ -160,14 +162,25 @@ def __init__(self, mod, mod_extra_config, *args, **kwargs):
160162
super().__init__()
161163
set_attrs_from_orig_model(self, mod, mod_extra_config)
162164
if self.quantization_mode == QuantMode.QUANTIZE:
163-
self.quant_input_0 = self._mod_extra_config.inputs[0]
164-
self.quant_input_1 = self._mod_extra_config.inputs[1]
165-
self.scale_input = nn.Parameter(mod_extra_config.scale.inputs[0])
166-
self.scale_other = nn.Parameter(mod_extra_config.scale.inputs[1])
165+
if self.fake_quant == False:
166+
self.forward = self.forward_quant
167+
self.quant_input_0 = self._mod_extra_config.inputs[0]
168+
self.quant_input_1 = self._mod_extra_config.inputs[1]
169+
self.scale_input = nn.Parameter(mod_extra_config.scale.inputs[0])
170+
self.scale_other = nn.Parameter(mod_extra_config.scale.inputs[1])
171+
else:
172+
self.forward = self.forward_fakequant
173+
174+
# override quantization to quant-dequant
175+
mec = self._mod_extra_config.inputs[0]
176+
self.quant_input_0 = qdq(mec.scale_inv, mec.lp_dtype, mec.hp_dtype)
177+
mec = self._mod_extra_config.inputs[1]
178+
self.quant_input_1 = qdq(mec.scale_inv, mec.lp_dtype, mec.hp_dtype)
179+
167180
elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE):
168181
self.forward = self.forward_measure
169182

170-
def forward(self, input, other):
183+
def forward_quant(self, input, other):
171184
qinput = self.quant_input_0(input)
172185
qother = self.quant_input_1(other)
173186
output = matmul_fp8(
@@ -179,6 +192,12 @@ def forward(self, input, other):
179192
)
180193
return output
181194

195+
def forward_fakequant(self, input, other):
196+
qinput = self.quant_input_0(input)
197+
qother = self.quant_input_1(other)
198+
output = torch.matmul(qinput, qother)
199+
return output
200+
182201
def forward_measure(self, input, other):
183202
measure_input((input, other), observer=self._mod_extra_config.inputs)
184203
output = self.orig_mod(input, other)
@@ -198,21 +217,40 @@ def __init__(self, mod, mod_extra_config, *args, **kwargs):
198217
super().__init__()
199218
set_attrs_from_orig_model(self, mod, mod_extra_config)
200219
if self.quantization_mode == QuantMode.QUANTIZE:
201-
# When offloading weights to disk using device_map, the module forward is overridden.
202-
# __dict__.update call again overrides the PatchedLinear forward with the forward that device_map planted.
203-
# So need to set PatchedLinear forawrd to be the right forward.
204-
self.forward = self.forward_quant
205-
self.quant_input = self._mod_extra_config.inputs[0]
206220
self.weight = nn.Parameter(self.weight.t().contiguous())
207221
self.scale_input = nn.Parameter(mod_extra_config.scale.inputs[0])
208222
if isinstance(mod_extra_config.scale.params["weight"], (torch.Tensor, float)):
209223
self.scale_weight = nn.Parameter(mod_extra_config.scale.params["weight"])
210224
elif isinstance(mod_extra_config.scale.params["weight"], dict):
211225
# PCQ weight is calculated with actual weight [0] and ones [1]
212226
self.scale_weight = nn.Parameter(mod_extra_config.scale.params["weight"][0])
227+
228+
if self.fake_quant == False:
229+
# When offloading weights to disk using device_map, the module forward is overridden.
230+
# __dict__.update call again overrides the PatchedLinear forward with the forward that device_map planted.
231+
# So need to set PatchedLinear forawrd to be the right forward.
232+
self.forward = self.forward_quant
233+
self.quant_input = self._mod_extra_config.inputs[0]
234+
235+
else:
236+
self.forward = self.forward_fakequant
237+
# override quantization to quant-dequant
238+
mec = self._mod_extra_config.inputs[0]
239+
self.quant_input = qdq(mec.scale_inv, mec.lp_dtype, mec.hp_dtype)
240+
mec = self._mod_extra_config.params['weight']
241+
self.quant_weights = qdq(mec.scale_inv, mec.lp_dtype, mec.hp_dtype)
242+
243+
213244
elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE):
214245
self.forward = self.forward_measure
215246

247+
def forward_fakequant(self, input):
248+
qweight = self.quant_weights(self.weight, )
249+
qinput = self.quant_input(input)
250+
y = torch.matmul(qinput, qweight)
251+
output = y + self.bias if (self.bias is not None) else y
252+
return output
253+
216254
def forward_quant(self, input):
217255
qinput = self.quant_input(input)
218256
y = matmul_fp8(

neural_compressor/torch/algorithms/fp8_quant/_quant_common/quant_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ class QuantMode(Enum):
3636
MEASURE = 2
3737
SHAPE = 3
3838

39-
4039
class MeasureExclude(Flag):
4140
NONE = auto()
4241
INPUT = auto()
@@ -68,7 +67,6 @@ class ScaleMethod(Enum):
6867
MAXABS_HW_OPT_WEIGHT = 12
6968
MAXABS_POW2_OPT_WEIGHT = 13
7069

71-
7270
class TrueFalse(Enum):
7371
TRUE = True
7472
FALSE = False
@@ -82,10 +80,11 @@ class TrueFalse(Enum):
8280
"scale_method": ScaleMethod,
8381
"recalc_scales": TrueFalse,
8482
"ignore_modules_wo_measures": TrueFalse,
83+
"fake_quant": TrueFalse
8584
}
8685

8786

88-
_configs_that_use_enum_value = ["fp8_config", "hp_dtype", "ignore_modules_wo_measures", "recalc_scales"]
87+
_configs_that_use_enum_value = ["fp8_config", "hp_dtype", "ignore_modules_wo_measures", "recalc_scales", "fake_quant"]
8988

9089

9190
def get_hqt_config(mod) -> Fp8cfg:
@@ -121,6 +120,7 @@ def parse(custom_config: Mapping[str, str]) -> Fp8cfg:
121120
"types": (),
122121
}, # types and names to be quantized. Allowlist by names is not yet implemented
123122
"mode": QuantMode.QUANTIZE, # Quantize or Measure
123+
"fake_quant": False, # Fake or Real Quant
124124
"scale_method": ScaleMethod.UNIT_SCALE, # Method to quantize with
125125
"scale_params": {}, # scaling parameters that are different then the default ones
126126
"observer": "maxabs", # Supported ['shape', 'maxabs', 'maxabs_per_channel', 'save']

neural_compressor/torch/quantization/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,6 +1256,7 @@ def __init__(
12561256
observer: str = "maxabs",
12571257
mod_dict: dict = {},
12581258
measure_exclude: str = "OUTPUT",
1259+
fake_quant: bool = False,
12591260
**kwargs,
12601261
):
12611262
"""Init FP8 config."""
@@ -1271,6 +1272,7 @@ def __init__(
12711272
self.observer = observer
12721273
self.mod_dict = mod_dict
12731274
self._json_file = None
1275+
self.fake_quant = fake_quant
12741276

12751277
@property
12761278
def measure(self):
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import typing
2+
import pytest
3+
import copy
4+
import torch
5+
6+
import habana_frameworks.torch.core as htcore
7+
8+
htcore.hpu_set_env()
9+
10+
from neural_compressor.torch.quantization import FP8Config, convert, finalize_calibration, prepare
11+
from neural_compressor.torch.algorithms.fp8_quant._quant_common.helper_modules import Matmul
12+
13+
torch.manual_seed(1)
14+
15+
class M(torch.nn.Module):
16+
def __init__(self) -> None:
17+
super().__init__()
18+
self.fc1 = torch.nn.Linear(10, 200, bias=False)
19+
self.fc2 = torch.nn.Linear(10, 200, bias=True)
20+
self.matmul = Matmul()
21+
22+
def forward(self, inp):
23+
x1 = self.fc1(inp)
24+
x2 = self.fc2(inp)
25+
x3 = self.matmul(x1, x2.t())
26+
return x3
27+
28+
29+
def test_fakequant():
30+
# Run both real and fake quantization, and compare
31+
32+
model = M().eval().to("hpu").to(torch.bfloat16)
33+
model_fake = copy.deepcopy(model)
34+
htcore.hpu_initialize()
35+
36+
config_dict_fake = {
37+
"mode": "AUTO",
38+
"observer": "maxabs",
39+
"scale_method": "maxabs_hw",
40+
"allowlist": {"types": [], "names": []},
41+
"blocklist": {"types": [], "names": []},
42+
"dump_stats_path": "./inc_output/measure_fake",
43+
"fake_quant": "True",
44+
}
45+
46+
config_dict = {
47+
"mode": "AUTO",
48+
"observer": "maxabs",
49+
"scale_method": "maxabs_hw",
50+
"allowlist": {"types": [], "names": []},
51+
"blocklist": {"types": [], "names": []},
52+
"dump_stats_path": "./inc_output/measure",
53+
"fake_quant": "False",
54+
}
55+
56+
config = FP8Config.from_dict(config_dict)
57+
config_fake = FP8Config.from_dict(config_dict_fake)
58+
59+
model = prepare(model, config)
60+
model_fake = prepare(model_fake, config_fake)
61+
inp_calib = torch.arange(0, 100, 0.1, dtype=torch.bfloat16).to("hpu").reshape(-1, 10)
62+
inp_test = torch.rand(10000, dtype=torch.bfloat16).reshape(-1, 10).to("hpu") * 100
63+
64+
# for calibration
65+
with torch.no_grad():
66+
a = model(inp_calib)
67+
b = model_fake(inp_calib)
68+
69+
model = convert(model)
70+
model_fake = convert(model_fake)
71+
72+
# for benchmark
73+
with torch.no_grad():
74+
output = model(inp_test).cpu()
75+
output_fake = model_fake(inp_test).cpu()
76+
assert torch.allclose(output, output_fake, rtol=0.01), f"FakeQuant failed"
77+
78+

0 commit comments

Comments
 (0)