Skip to content

Commit 05272c4

Browse files
authored
add per_channel_minmax (#1990)
Signed-off-by: yiliu30 <[email protected]>
1 parent 82d8c06 commit 05272c4

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

neural_compressor/torch/algorithms/pt2e_quant/utility.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717

1818
import torch
1919
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
20-
from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver, PlaceholderObserver
20+
from torch.ao.quantization.observer import (
21+
HistogramObserver,
22+
MinMaxObserver,
23+
PerChannelMinMaxObserver,
24+
PlaceholderObserver,
25+
)
2126
from torch.ao.quantization.quantizer import QuantizationSpec
2227
from torch.ao.quantization.quantizer.x86_inductor_quantizer import QuantizationConfig, X86InductorQuantizer
2328

@@ -48,19 +53,23 @@ def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=Fals
4853
"placeholder": PlaceholderObserver,
4954
"minmax": MinMaxObserver,
5055
"kl": HistogramObserver,
56+
"per_channel_minmax": PerChannelMinMaxObserver,
5157
}
5258
# Force to use placeholder observer for dynamic quantization
5359
if is_dynamic:
5460
algo = "placeholder"
55-
# algo
56-
observer_or_fake_quant_ctr = observer_mapping[algo]
61+
if f"{granularity}_{algo}" in observer_mapping:
62+
observer_or_fake_quant_ctr = observer_mapping[f"{granularity}_{algo}"]
63+
else:
64+
observer_or_fake_quant_ctr = observer_mapping[algo]
5765
# qscheme
5866
qscheme = qscheme_mapping[granularity][sym]
5967
quantization_spec = QuantizationSpec(
6068
dtype=select_dtype,
6169
quant_min=min_max_mapping[select_dtype][0],
6270
quant_max=min_max_mapping[select_dtype][1],
6371
observer_or_fake_quant_ctr=observer_or_fake_quant_ctr,
72+
ch_axis=0,
6473
qscheme=qscheme,
6574
is_dynamic=is_dynamic,
6675
)

test/3x/torch/quantization/test_pt2e_quant.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
9898
return exported_model, example_inputs
9999

100100
@pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0")
101-
def test_quantize_simple_model(self, force_not_import_ipex):
101+
@pytest.mark.parametrize("granularity", ["per_tensor", "per_channel"])
102+
def test_quantize_simple_model(self, granularity, force_not_import_ipex):
103+
from neural_compressor.torch.quantization import StaticQuantConfig
104+
102105
model, example_inputs = self.build_simple_torch_model_and_example_inputs()
103106
float_model_output = model(*example_inputs)
104107
quant_config = None
@@ -107,7 +110,7 @@ def calib_fn(model):
107110
for i in range(4):
108111
model(*example_inputs)
109112

110-
quant_config = get_default_static_config()
113+
quant_config = StaticQuantConfig(w_granularity=granularity)
111114
q_model = quantize(model=model, quant_config=quant_config, run_fn=calib_fn)
112115
from torch._inductor import config
113116

0 commit comments

Comments
 (0)