Skip to content

Commit 941fed3

Browse files
authored
Rename RTNWeightOnlyConfig to RTNConfig (#1551)
* Rename RTNWeightOnlyConfig to RTNConfig Signed-off-by: xin3he <[email protected]>
1 parent c565e96 commit 941fed3

File tree

13 files changed

+89
-97
lines changed

13 files changed

+89
-97
lines changed

examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def get_user_model():
230230

231231
# 3.x api
232232
if args.approach == 'weight_only':
233-
from neural_compressor.torch import RTNWeightQuantConfig, GPTQConfig, quantize
233+
from neural_compressor.torch import RTNConfig, GPTQConfig, quantize
234234
from neural_compressor.torch.utils.utility import get_double_quant_config
235235
weight_sym = True if args.woq_scheme == "sym" else False
236236
double_quant_config_dict = get_double_quant_config(args.double_quant_type, weight_sym=weight_sym)
@@ -243,9 +243,9 @@ def get_user_model():
243243
"enable_mse_search": args.woq_enable_mse_search,
244244
}
245245
)
246-
quant_config = RTNWeightQuantConfig.from_dict(double_quant_config_dict)
246+
quant_config = RTNConfig.from_dict(double_quant_config_dict)
247247
else:
248-
quant_config = RTNWeightQuantConfig(
248+
quant_config = RTNConfig(
249249
weight_dtype=args.woq_dtype,
250250
weight_bits=args.woq_bits,
251251
weight_group_size=args.woq_group_size,
@@ -257,7 +257,7 @@ def get_user_model():
257257
double_quant_sym=args.double_quant_sym,
258258
double_quant_group_size=args.double_quant_group_size,
259259
)
260-
quant_config.set_local("lm_head", RTNWeightQuantConfig(weight_dtype="fp32"))
260+
quant_config.set_local("lm_head", RTNConfig(weight_dtype="fp32"))
261261
user_model = quantize(
262262
model=user_model, quant_config=quant_config
263263
)

neural_compressor/common/utility.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
# config name
2828
BASE_CONFIG = "base_config"
2929
COMPOSABLE_CONFIG = "composable_config"
30-
RTN_WEIGHT_ONLY_QUANT = "rtn_weight_only_quant"
30+
RTN = "rtn"
3131
STATIC_QUANT = "static_quant"
3232
GPTQ = "gptq"
3333
FP8_QUANT = "fp8_quant"

neural_compressor/tensorflow/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def register_algo(name):
3535
3636
Usage example:
3737
@register_algo(name=example_algo)
38-
def example_algo(model: torch.nn.Module, quant_config: RTNWeightQuantConfig) -> torch.nn.Module:
38+
def example_algo(model: torch.nn.Module, quant_config: RTNConfig) -> torch.nn.Module:
3939
...
4040
Args:
4141
name (str): The name under which the algorithm function will be registered.

neural_compressor/torch/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717

1818
from neural_compressor.torch.quantization import (
1919
quantize,
20-
RTNWeightQuantConfig,
20+
RTNConfig,
2121
get_default_rtn_config,
2222
GPTQConfig,
2323
get_default_gptq_config,
2424
)
2525

2626
from neural_compressor.common.base_tuning import TuningConfig
27-
from neural_compressor.torch.autotune import autotune, get_default_tune_config
27+
from neural_compressor.torch.quantization.autotune import autotune, get_default_tune_config

neural_compressor/torch/algorithms/weight_only/rtn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -580,10 +580,10 @@ def quant_weight_w_scale(weight, scale, zp, group_size=-1, dtype="int"):
580580
return int_weight
581581

582582

583-
from neural_compressor.torch.quantization.config import RTNWeightQuantConfig
583+
from neural_compressor.torch.quantization.config import RTNConfig
584584

585585

586-
def apply_rtn_on_single_module(module: torch.nn.Module, quant_config: RTNWeightQuantConfig) -> torch.nn.Module:
586+
def apply_rtn_on_single_module(module: torch.nn.Module, quant_config: RTNConfig) -> torch.nn.Module:
587587
# TODO (Yi) remove it
588588
enable_full_range = quant_config.enable_full_range
589589
enable_mse_search = quant_config.enable_mse_search

neural_compressor/torch/algorithms/weight_only_algos.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,17 @@
1818
import torch
1919

2020
from neural_compressor.common.logger import Logger
21-
from neural_compressor.common.utility import GPTQ, RTN_WEIGHT_ONLY_QUANT
22-
from neural_compressor.torch.quantization.config import GPTQConfig, RTNWeightQuantConfig
21+
from neural_compressor.common.utility import GPTQ, RTN
22+
from neural_compressor.torch.quantization.config import GPTQConfig, RTNConfig
2323
from neural_compressor.torch.utils.utility import fetch_module, register_algo, set_module
2424

2525
logger = Logger().get_logger()
2626

2727

2828
###################### RTN Algo Entry ##################################
29-
@register_algo(name=RTN_WEIGHT_ONLY_QUANT)
29+
@register_algo(name=RTN)
3030
def rtn_quantize_entry(
31-
model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], RTNWeightQuantConfig], *args, **kwargs
31+
model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], RTNConfig], *args, **kwargs
3232
) -> torch.nn.Module:
3333
"""The main entry to apply rtn quantization."""
3434
from .weight_only.rtn import apply_rtn_on_single_module

neural_compressor/torch/quantization/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from neural_compressor.torch.quantization.quantize import quantize, quantize_dynamic
1616
from neural_compressor.torch.quantization.config import (
17-
RTNWeightQuantConfig,
17+
RTNConfig,
1818
get_default_rtn_config,
1919
GPTQConfig,
2020
get_default_gptq_config,

neural_compressor/torch/autotune.py renamed to neural_compressor/torch/quantization/autotune.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from neural_compressor.common.base_tuning import TuningConfig, evaluator, init_tuning
2121
from neural_compressor.common.logger import Logger
2222
from neural_compressor.torch import quantize
23-
from neural_compressor.torch.quantization.config import GPTQConfig, RTNWeightQuantConfig
23+
from neural_compressor.torch.quantization.config import GPTQConfig, RTNConfig
2424

2525
logger = Logger().get_logger()
2626

@@ -33,7 +33,7 @@
3333

3434
def get_default_tune_config() -> TuningConfig:
3535
# TODO use the registered default tuning config in the next PR
36-
return TuningConfig(quant_configs=[GPTQConfig(weight_bits=[4, 8]), RTNWeightQuantConfig(weight_bits=[4, 8])])
36+
return TuningConfig(quant_configs=[GPTQConfig(weight_bits=[4, 8]), RTNConfig(weight_bits=[4, 8])])
3737

3838

3939
def autotune(

neural_compressor/torch/quantization/config.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,7 @@
2424
import torch
2525

2626
from neural_compressor.common.base_config import BaseConfig, config_registry, register_config
27-
from neural_compressor.common.utility import (
28-
DEFAULT_WHITE_LIST,
29-
FP8_QUANT,
30-
GPTQ,
31-
OP_NAME_OR_MODULE_TYPE,
32-
RTN_WEIGHT_ONLY_QUANT,
33-
)
27+
from neural_compressor.common.utility import DEFAULT_WHITE_LIST, FP8_QUANT, GPTQ, OP_NAME_OR_MODULE_TYPE, RTN
3428
from neural_compressor.torch.utils.constants import PRIORITY_GPTQ, PRIORITY_RTN
3529
from neural_compressor.torch.utils.utility import is_hpex_avaliable, logger
3630

@@ -60,8 +54,8 @@ class OperatorConfig(NamedTuple):
6054
######################## RNT Config ###############################
6155

6256

63-
@register_config(framework_name=FRAMEWORK_NAME, algo_name=RTN_WEIGHT_ONLY_QUANT, priority=PRIORITY_RTN)
64-
class RTNWeightQuantConfig(BaseConfig):
57+
@register_config(framework_name=FRAMEWORK_NAME, algo_name=RTN, priority=PRIORITY_RTN)
58+
class RTNConfig(BaseConfig):
6559
"""Config class for round-to-nearest weight-only quantization."""
6660

6761
supported_configs: List[OperatorConfig] = []
@@ -80,7 +74,7 @@ class RTNWeightQuantConfig(BaseConfig):
8074
"double_quant_sym",
8175
"double_quant_group_size",
8276
]
83-
name = RTN_WEIGHT_ONLY_QUANT
77+
name = RTN
8478

8579
def __init__(
8680
self,
@@ -137,12 +131,12 @@ def to_dict(self):
137131

138132
@classmethod
139133
def from_dict(cls, config_dict):
140-
return super(RTNWeightQuantConfig, cls).from_dict(config_dict=config_dict, str2operator=str2operator)
134+
return super(RTNConfig, cls).from_dict(config_dict=config_dict, str2operator=str2operator)
141135

142136
@classmethod
143137
def register_supported_configs(cls) -> List[OperatorConfig]:
144138
supported_configs = []
145-
linear_rtn_config = RTNWeightQuantConfig(
139+
linear_rtn_config = RTNConfig(
146140
weight_dtype=["int", "int8", "int4", "nf4", "fp4", "fp4_e2m1_bnb", "fp4_e2m1"],
147141
weight_bits=[4, 1, 2, 3, 5, 6, 7, 8],
148142
weight_group_size=[32, -1, 1, 4, 8, 16, 64, 128, 256, 512, 1024],
@@ -173,16 +167,16 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
173167

174168

175169
# TODO(Yi) run `register_supported_configs` for all registered config.
176-
RTNWeightQuantConfig.register_supported_configs()
170+
RTNConfig.register_supported_configs()
177171

178172

179-
def get_default_rtn_config() -> RTNWeightQuantConfig:
173+
def get_default_rtn_config() -> RTNConfig:
180174
"""Generate the default rtn config.
181175
182176
Returns:
183177
the default rtn config.
184178
"""
185-
return RTNWeightQuantConfig()
179+
return RTNConfig()
186180

187181

188182
######################## GPTQ Config ###############################

neural_compressor/torch/utils/utility.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def register_algo(name):
3333
3434
Usage example:
3535
@register_algo(name=example_algo)
36-
def example_algo(model: torch.nn.Module, quant_config: RTNWeightQuantConfig) -> torch.nn.Module:
36+
def example_algo(model: torch.nn.Module, quant_config: RTNConfig) -> torch.nn.Module:
3737
...
3838
3939
Args:

0 commit comments

Comments
 (0)