Skip to content

Commit 2e1cdc5

Browse files
authored
Support Autotune FP16 Mix-precision on torch 3.0 new API (#1793)
Signed-off-by: zehao-intel <[email protected]>
1 parent bacc164 commit 2e1cdc5

File tree

8 files changed

+316
-1
lines changed

8 files changed

+316
-1
lines changed

neural_compressor/common/utils/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
TEQ = "teq" # pragma: no cover
3737
AUTOROUND = "autoround"
3838
FP8_QUANT = "fp8_quant"
39+
MIX_PRECISION = "mix_precision"
3940

4041
# options
4142
import datetime
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2024 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
from neural_compressor.torch.algorithms.mix_precision.half_precision_convert import HalfPrecisionConverter
19+
from neural_compressor.torch.algorithms.mix_precision.module_wrappers import HalfPrecisionModuleWrapper
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
#
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2024 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
"""Half-precision Convert for Torch Modules."""
18+
19+
from typing import Dict, Tuple
20+
21+
import torch
22+
23+
from neural_compressor.common import logger
24+
from neural_compressor.torch.algorithms.mix_precision.module_wrappers import HalfPrecisionModuleWrapper
25+
from neural_compressor.torch.utils import get_device
26+
27+
28+
class HalfPrecisionConverter:
29+
"""Converter Class for FP16 and BF16."""
30+
31+
dtype_mapping = {
32+
"fp16": torch.float16,
33+
"bf16": torch.bfloat16,
34+
}
35+
36+
def __init__(self, configs_mapping: Dict[Tuple[str], object], *args, **kwargs):
37+
"""Initialize the Half-precision Converter with config.
38+
39+
Args:
40+
configs_mapping (Dict): config class for mix-precision.
41+
"""
42+
self.configs_mapping = configs_mapping
43+
self.device = get_device()
44+
45+
def convert(self, model: torch.nn.Module):
46+
"""Convert to FP16 or BF16 model.
47+
48+
Args:
49+
model (torch.nn.Module): the input model.
50+
51+
Returns:
52+
mix_precision_model (torch.nn.Module): model with mix-precision.
53+
"""
54+
if len(self.configs_mapping) > 0:
55+
logger.info("Convert operators to half-precision")
56+
57+
if next(model.parameters()).is_cuda:
58+
self.device = "cuda"
59+
elif next(model.parameters()).is_cpu:
60+
self.device = "cpu"
61+
62+
mix_precision_model = self._wrap_half_precision_model(model)
63+
mix_precision_model.to(self.device)
64+
65+
return mix_precision_model
66+
67+
def _wrap_half_precision_model(self, model: torch.nn.Module, prefix=""):
68+
"""Wrap and replace half-precision target modules.
69+
70+
Args:
71+
model (torch.nn.Module): the input module.
72+
prefix (str): the name prefix for named children.
73+
74+
Returns:
75+
model (torch.nn.Module): the model whose target modules have been wrapped.
76+
"""
77+
for name, child in model.named_children():
78+
op_name = prefix + "." + name if prefix != "" else name
79+
for op_info, config in self.configs_mapping.items():
80+
if op_name == op_info[0] and config.dtype in ("fp16", "bf16"):
81+
child = HalfPrecisionModuleWrapper(
82+
module=child, device=self.device, dtype=self.dtype_mapping[config.dtype]
83+
)
84+
else:
85+
self._wrap_half_precision_model(child, op_name)
86+
setattr(model, name, child)
87+
88+
return model
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2024 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
"""Half-precision Wrapper for Torch Modules."""
18+
19+
import torch
20+
21+
22+
class HalfPrecisionModuleWrapper(torch.nn.Module):
23+
"""FP16 or BF16 Module Wrapper Class."""
24+
25+
def __init__(self, module, device="cpu", dtype=torch.float16):
26+
"""Init a HalfPrecisionModuleWrapper object."""
27+
super(HalfPrecisionModuleWrapper, self).__init__()
28+
self.add_module("module", module)
29+
self.device = device
30+
self.dtype = dtype
31+
self.weight = self.module.weight if hasattr(self.module, "weight") else None
32+
self.bias = self.module.bias if hasattr(self.module, "bias") else None
33+
34+
def forward(self, X):
35+
"""Convert dtype."""
36+
with torch.autocast(device_type=self.device, dtype=self.dtype):
37+
X = self.module(X)
38+
return X.float()

neural_compressor/torch/quantization/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
FP8Config,
3535
get_default_fp8_config,
3636
get_default_fp8_config_set,
37+
MixPrecisionConfig,
38+
get_default_mix_precision_config,
39+
get_default_mix_precision_config_set,
3740
get_woq_tuning_config,
3841
DynamicQuantConfig,
3942
get_default_dynamic_config,

neural_compressor/torch/quantization/algorithm_entry.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
FP8_QUANT,
2525
GPTQ,
2626
HQQ,
27+
MIX_PRECISION,
2728
RTN,
2829
SMOOTH_QUANT,
2930
STATIC_QUANT,
@@ -36,6 +37,7 @@
3637
FP8Config,
3738
GPTQConfig,
3839
HQQConfig,
40+
MixPrecisionConfig,
3941
RTNConfig,
4042
SmoothQuantConfig,
4143
StaticQuantConfig,
@@ -528,3 +530,17 @@ def fp8_quant_entry(
528530
model.qconfig = configs_mapping
529531
model.save = MethodType(save, model)
530532
return model
533+
534+
535+
###################### Mixed Precision Algo Entry ##################################
536+
@register_algo(MIX_PRECISION)
537+
def mix_precision_entry(
538+
model: torch.nn.Module, configs_mapping: Dict[Tuple[str], MixPrecisionConfig], *args, **kwargs
539+
) -> torch.nn.Module:
540+
# only support fp16 and bf16 now, more types might be added later
541+
from neural_compressor.torch.algorithms.mix_precision import HalfPrecisionConverter
542+
543+
half_precision_converter = HalfPrecisionConverter(configs_mapping, *args, **kwargs)
544+
mix_precision_model = half_precision_converter.convert(model)
545+
546+
return mix_precision_model

neural_compressor/torch/quantization/config.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
FP8_QUANT,
3737
GPTQ,
3838
HQQ,
39+
MIX_PRECISION,
3940
OP_NAME_OR_MODULE_TYPE,
4041
RTN,
4142
SMOOTH_QUANT,
@@ -1196,6 +1197,81 @@ def get_default_fp8_config_set() -> FP8Config:
11961197
return FP8Config.get_config_set_for_tuning()
11971198

11981199

1200+
######################## MixPrecision Config ###############################
1201+
@register_config(framework_name=FRAMEWORK_NAME, algo_name=MIX_PRECISION)
1202+
class MixPrecisionConfig(BaseConfig):
1203+
"""Config class for mix-precision."""
1204+
1205+
name = MIX_PRECISION
1206+
supported_configs: List[OperatorConfig] = []
1207+
params_list = [
1208+
"dtype",
1209+
]
1210+
supported_half_precision_ops = (
1211+
torch.nn.Linear,
1212+
torch.nn.Conv1d,
1213+
torch.nn.Conv2d,
1214+
torch.nn.Conv3d,
1215+
)
1216+
1217+
def __init__(
1218+
self,
1219+
dtype: Union[str, List[str]] = "fp16",
1220+
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
1221+
):
1222+
"""Init MixPrecision config.
1223+
1224+
Args:
1225+
"""
1226+
super().__init__(white_list=white_list)
1227+
self.dtype = dtype
1228+
self._post_init()
1229+
1230+
@classmethod
1231+
def register_supported_configs(cls) -> List[OperatorConfig]:
1232+
supported_configs = []
1233+
mix_precision_config = MixPrecisionConfig(
1234+
dtype=["fp16", "bf16", "fp32"],
1235+
)
1236+
operators = cls.supported_half_precision_ops
1237+
supported_configs.append(OperatorConfig(config=mix_precision_config, operators=operators))
1238+
cls.supported_configs = supported_configs
1239+
1240+
@staticmethod
1241+
def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
1242+
white_list = tuple(MixPrecisionConfig.supported_half_precision_ops)
1243+
filter_result = []
1244+
for op_name, module in model.named_modules():
1245+
if isinstance(module, white_list):
1246+
pair = (op_name, type(module).__name__)
1247+
filter_result.append(pair)
1248+
logger.debug(f"Get model info: {filter_result}")
1249+
return filter_result
1250+
1251+
@classmethod
1252+
def get_config_set_for_tuning(cls) -> Union[None, "MixPrecisionConfig", List["MixPrecisionConfig"]]:
1253+
# TODO fwk owner needs to update it.
1254+
return MixPrecisionConfig(dtype=["fp16", "bf16", "fp32"])
1255+
1256+
1257+
def get_default_mix_precision_config() -> MixPrecisionConfig:
1258+
"""Generate the default mix-precision config.
1259+
1260+
Returns:
1261+
the default mix-precision config.
1262+
"""
1263+
return MixPrecisionConfig()
1264+
1265+
1266+
def get_default_mix_precision_config_set() -> MixPrecisionConfig:
1267+
"""Generate the default mix-precision config set.
1268+
1269+
Returns:
1270+
the default mix-precision config.
1271+
"""
1272+
return MixPrecisionConfig.get_config_set_for_tuning()
1273+
1274+
11991275
##################### Algo Configs End ###################################
12001276

12011277

test/3x/torch/test_autotune.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,13 @@
77
import transformers
88

99
from neural_compressor.common import logger
10-
from neural_compressor.torch.quantization import RTNConfig, TuningConfig, autotune, get_all_config_set
10+
from neural_compressor.torch.quantization import (
11+
MixPrecisionConfig,
12+
RTNConfig,
13+
TuningConfig,
14+
autotune,
15+
get_all_config_set,
16+
)
1117
from neural_compressor.torch.utils import constants
1218

1319
FAKE_DOUBLE_QUANT_CONFIGS = {
@@ -332,6 +338,74 @@ def eval_acc_fn(model):
332338
)
333339
self.assertIsNone(best_model)
334340

341+
@reset_tuning_target
342+
def test_autotune_mix_precision_default(self):
343+
from neural_compressor.torch.algorithms.mix_precision import HalfPrecisionModuleWrapper
344+
345+
baseline = [1]
346+
acc_res_lst = baseline + [0.9, 0.99, 1]
347+
348+
def eval_acc_fn(model):
349+
res = acc_res_lst.pop(0)
350+
return res
351+
352+
custom_tune_config = TuningConfig(config_set=[MixPrecisionConfig(dtype=["fp16", "bf16", "fp32"])], max_trials=3)
353+
best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fn=eval_acc_fn)
354+
355+
self.assertIsNotNone(best_model)
356+
self.assertTrue(isinstance(best_model.fc1, HalfPrecisionModuleWrapper))
357+
self.assertTrue(isinstance(best_model.fc2, HalfPrecisionModuleWrapper))
358+
self.assertTrue(isinstance(best_model.fc3, HalfPrecisionModuleWrapper))
359+
360+
@reset_tuning_target
361+
def test_autotune_mix_precision_set_op_name(self):
362+
from neural_compressor.common.base_config import ComposableConfig, config_registry
363+
from neural_compressor.torch.algorithms.mix_precision import HalfPrecisionModuleWrapper
364+
365+
baseline = [1]
366+
acc_res_lst = baseline + [0.9, 1.1]
367+
368+
def eval_acc_fn(model):
369+
res = acc_res_lst.pop(0)
370+
return res
371+
372+
config1 = {
373+
"mix_precision": {
374+
"global": {
375+
"dtype": "bf16",
376+
},
377+
"local": {
378+
"fc2": {
379+
"dtype": "fp32",
380+
}
381+
},
382+
}
383+
}
384+
config2 = {
385+
"mix_precision": {
386+
"global": {
387+
"dtype": "fp16",
388+
},
389+
"local": {
390+
"fc1": {
391+
"dtype": "fp32",
392+
}
393+
},
394+
}
395+
}
396+
397+
registered_configs = config_registry.get_cls_configs()
398+
config1 = ComposableConfig.from_dict(config1, config_registry=registered_configs["torch"])
399+
config2 = ComposableConfig.from_dict(config2, config_registry=registered_configs["torch"])
400+
401+
custom_tune_config = TuningConfig(config_set=[config1, config2], max_trials=2)
402+
best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fn=eval_acc_fn)
403+
404+
self.assertIsNotNone(best_model)
405+
self.assertTrue(isinstance(best_model.fc1, torch.nn.Linear))
406+
self.assertTrue(isinstance(best_model.fc2, HalfPrecisionModuleWrapper))
407+
self.assertTrue(isinstance(best_model.fc3, HalfPrecisionModuleWrapper))
408+
335409

336410
if __name__ == "__main__":
337411
unittest.main()

0 commit comments

Comments
 (0)