Skip to content

Commit e976595

Browse files
authored
Support woq Autotune (#1921)
Signed-off-by: Kaihui-intel <[email protected]>
1 parent d56075c commit e976595

File tree

3 files changed

+18
-6
lines changed

3 files changed

+18
-6
lines changed

neural_compressor/torch/algorithms/weight_only/utility.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1105,7 +1105,11 @@ def __iter__(self):
11051105
if not args:
11061106
yield kwargs
11071107
elif not kwargs:
1108-
yield args
1108+
# case: tensor
1109+
if len(args) == 1:
1110+
yield args[0]
1111+
else:
1112+
yield args
11091113
else:
11101114
yield args, kwargs
11111115

neural_compressor/torch/quantization/config.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,7 @@ def __init__(
740740
minmax_lr: float = None,
741741
low_gpu_mem_usage: bool = True,
742742
iters: int = 200,
743-
seqlen: int = 2048,
743+
seqlen: int = 512,
744744
n_samples: int = 512,
745745
sampler: str = "rand",
746746
seed: int = 42,
@@ -1507,8 +1507,7 @@ def get_woq_tuning_config() -> list:
15071507
the list of WOQ quant config.
15081508
"""
15091509
RTN_G32ASYM = RTNConfig(use_sym=False, group_size=32)
1510+
AUTO_ROUND_CONFIG = AutoRoundConfig(use_sym=False, group_size=32)
15101511
GPTQ_G32ASYM = GPTQConfig(use_sym=False, group_size=32)
1511-
GPTQ_G32ASYM_DISABLE_LAST_LINEAR = GPTQConfig(use_sym=False).set_local("*.lm_head", GPTQConfig(dtype="fp32"))
1512-
GPTQ_G128ASYM = GPTQConfig(group_size=128, use_sym=False)
15131512
AWQ_G32ASYM = AWQConfig(use_sym=False, group_size=32)
1514-
return [RTN_G32ASYM, GPTQ_G32ASYM, GPTQ_G32ASYM_DISABLE_LAST_LINEAR, GPTQ_G128ASYM, AWQ_G32ASYM]
1513+
return [RTN_G32ASYM, AUTO_ROUND_CONFIG, GPTQ_G32ASYM, AWQ_G32ASYM]

test/3x/torch/quantization/weight_only/test_woq_utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,16 @@ def test_captured_dataloader_iteration(self):
169169

170170
result = list(dataloader)
171171

172-
assert result == [(1,), (2,), (3,)]
172+
assert result == [1, 2, 3]
173+
174+
# Test case when kwargs is empty
175+
args_list = [(1, 2), (2, 3), (3, 4)]
176+
kwargs_list = [{}, {}, {}]
177+
dataloader = CapturedDataloader(args_list, kwargs_list)
178+
179+
result = list(dataloader)
180+
181+
assert result == [(1, 2), (2, 3), (3, 4)]
173182

174183
# Test case when both args and kwargs are present
175184
args_list = [(1,), (2,), (3,)]

0 commit comments

Comments
 (0)