Skip to content

Commit 211b6bc

Browse files
authored
adding default inductor config settings (#423)
* adding default inductor config settings Summary: making autoquant and quantize apis call a new recommended_inductor_config_setter util to set recommended apis also update groupsize -> groupsize in generate.py Test Plan: sh benchmarks.sh comparison of different config combinations for matmul precision, mixed_mm and coordinate_descent tok/s= 9.14, mem/s= 60.55 GB/s, peak_mem= 8.33 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf, tok/s=147.02, mem/s= 973.53 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, tok/s= 9.23, mem/s= 61.11 GB/s, peak_mem= 8.33 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf, tok/s=139.59, mem/s= 924.33 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, tok/s= 9.10, mem/s= 60.26 GB/s, peak_mem= 8.33 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf, tok/s=146.98, mem/s= 973.23 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, tok/s= 9.28, mem/s= 61.48 GB/s, peak_mem= 8.33 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf, tok/s=146.90, mem/s= 972.73 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, tok/s= 9.08, mem/s= 60.09 GB/s, peak_mem= 8.33 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf, tok/s=137.58, mem/s= 911.00 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, tok/s= 9.19, mem/s= 60.87 GB/s, peak_mem= 8.61 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf, tok/s=166.02, mem/s=1099.30 GB/s, peak_mem= 8.97 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, Reviewers: Subscribers: Tasks: Tags: * fixing tests Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * fix weight only failures Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * fixing new broken test Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * fixing autoquant test Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * testing if inductor config is the issue Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * are inductor configs somehow being set? Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * when is coordinate descent tuning beinng enabled? Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * reset inductor config for tests Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * more test fixes Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * adding warning Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * handling of errors Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * option to supress autoquant errors Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 505edc1 commit 211b6bc

File tree

7 files changed

+92
-36
lines changed

7 files changed

+92
-36
lines changed

test/integration/test_integration.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -98,21 +98,21 @@
9898

9999
def _int8wo_api(mod):
100100
if TORCH_VERSION_AFTER_2_4:
101-
quantize(mod, int8_weight_only())
101+
quantize(mod, int8_weight_only(), set_inductor_config=False)
102102
unwrap_tensor_subclass(mod)
103103
else:
104104
change_linear_weights_to_int8_woqtensors(mod)
105105

106106
def _int8da_int8w_api(mod):
107107
if TORCH_VERSION_AFTER_2_4:
108-
quantize(mod, int8_dynamic_activation_int8_weight())
108+
quantize(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False)
109109
unwrap_tensor_subclass(mod)
110110
else:
111111
change_linear_weights_to_int8_dqtensors(mod)
112112

113113
def _int4wo_api(mod):
114114
if TORCH_VERSION_AFTER_2_4:
115-
quantize(mod, int4_weight_only())
115+
quantize(mod, int4_weight_only(), set_inductor_config=False)
116116
unwrap_tensor_subclass(mod)
117117
else:
118118
change_linear_weights_to_int4_woqtensors(mod)
@@ -124,6 +124,13 @@ def _int4wo_api(mod):
124124
_int4wo_api,
125125
]
126126

127+
def undo_recommended_configs():
128+
torch._inductor.config.coordinate_descent_tuning = False
129+
torch._inductor.config.coordinate_descent_check_all_directions = False
130+
torch._inductor.config.force_fuse_int_mm_with_mul = False
131+
torch._inductor.config.fx_graph_cache = False
132+
torch._inductor.config.triton.unique_kernel_names = False
133+
torch.set_float32_matmul_precision("highest")
127134

128135
def combine_parameters(a, b):
129136
new_tuples = []
@@ -689,6 +696,7 @@ def test_int8_dynamic_quant_subclass(self, device, dtype):
689696

690697
@parameterized.expand(COMMON_DEVICE_DTYPE)
691698
def test_int8_weight_only_quant_subclass(self, device, dtype):
699+
undo_recommended_configs()
692700
self._test_lin_weight_subclass_impl(
693701
Int8WeightOnlyQuantizedLinearWeight.from_float, device, 40, test_dtype=dtype
694702
)
@@ -794,6 +802,7 @@ def test_int8_dynamic_quant_subclass_api(self, device, dtype):
794802
@parameterized.expand(COMMON_DEVICE_DTYPE)
795803
@unittest.skipIf(is_fbcode(), "broken in fbcode")
796804
def test_int8_weight_only_quant_subclass_api(self, device, dtype):
805+
undo_recommended_configs()
797806
self._test_lin_weight_subclass_api_impl(
798807
_int8wo_api, device, 40, test_dtype=dtype
799808
)
@@ -879,6 +888,7 @@ def test_weight_only_quant(self):
879888
@torch.no_grad()
880889
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
881890
def test_weight_only_quant_force_mixed_mm(self, device, dtype):
891+
undo_recommended_configs()
882892
if device != "cuda":
883893
self.skipTest(f"weight_only_quant_force_mixed_mm can't be constructed on {device}")
884894
if dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0):
@@ -907,6 +917,7 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype):
907917
@parameterized.expand(COMMON_DEVICE_DTYPE)
908918
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
909919
def test_weight_only_quant_use_mixed_mm(self, device, dtype):
920+
undo_recommended_configs()
910921
if device != "cuda":
911922
self.skipTest(f"weight_only_quant_force_mixed_mm can't be constructed on {device}")
912923
if dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0):
@@ -1004,6 +1015,7 @@ def test_save_load_dqtensors(self, device, dtype):
10041015
@torch.no_grad()
10051016
@unittest.skipIf(is_fbcode(), "broken in fbcode")
10061017
def test_save_load_int8woqtensors(self, device, dtype):
1018+
undo_recommended_configs()
10071019
self._test_handle_save_load_meta_impl(_int8wo_api, device, test_dtype=dtype)
10081020

10091021
@parameterized.expand(COMMON_DEVICE_DTYPE)
@@ -1153,6 +1165,7 @@ class TestAutoQuant(unittest.TestCase):
11531165
]))
11541166
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.")
11551167
def test_autoquant_one_input(self, device, dtype, m, k, n):
1168+
undo_recommended_configs()
11561169
print("(m, k, n): ", (m, k, n))
11571170
if device != "cuda" or not torch.cuda.is_available():
11581171
self.skipTest(f"autoquant currently does not support {device}")
@@ -1173,7 +1186,7 @@ def test_autoquant_one_input(self, device, dtype, m, k, n):
11731186
torch.nn.ReLU(),
11741187
).to(device).to(dtype)
11751188
out = model(example_input)
1176-
torchao.autoquant(model)
1189+
torchao.autoquant(model, set_inductor_config=False)
11771190
out2 = model(example_input)
11781191
sqnr = SQNR(out, out2)
11791192
self.assertTrue(sqnr >= 30)
@@ -1186,6 +1199,7 @@ def test_autoquant_one_input(self, device, dtype, m, k, n):
11861199
]))
11871200
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.")
11881201
def test_autoquant_compile(self, device, dtype, m1, m2, k, n):
1202+
undo_recommended_configs()
11891203
if device != "cuda" or not torch.cuda.is_available():
11901204
self.skipTest(f"autoquant currently does not support {device}")
11911205
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0):
@@ -1202,7 +1216,7 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n):
12021216
example_input2 = torch.randn(m2, k, device=device, dtype=dtype)
12031217
out = model(example_input)
12041218

1205-
mod = torchao.autoquant(torch.compile(model), manual=True)
1219+
mod = torchao.autoquant(torch.compile(model), manual=True, set_inductor_config=False)
12061220
mod(example_input)
12071221
mod(example_input2)
12081222
mod.finalize_autoquant()
@@ -1214,6 +1228,7 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n):
12141228
@parameterized.expand(COMMON_DEVICE_DTYPE)
12151229
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.")
12161230
def test_autoquant_manual(self, device, dtype):
1231+
undo_recommended_configs()
12171232
if device != "cuda" or not torch.cuda.is_available():
12181233
self.skipTest(f"autoquant currently does not support {device}")
12191234
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0):
@@ -1229,15 +1244,15 @@ def test_autoquant_manual(self, device, dtype):
12291244
example_input2 = torch.randn(m2, k, device=device, dtype=dtype)
12301245
out = model(example_input)
12311246

1232-
mod = torchao.autoquant(torch.compile(model), manual=True)
1247+
mod = torchao.autoquant(torch.compile(model), manual=True, set_inductor_config=False)
12331248
mod(example_input)
12341249
mod(example_input2)
12351250
mod.finalize_autoquant()
12361251
out2 = mod(example_input)
12371252
sqnr = SQNR(out, out2)
12381253
self.assertTrue(sqnr >= 30)
12391254

1240-
mod2 = torchao.autoquant(model, manual=True)
1255+
mod2 = torchao.autoquant(model, manual=True, set_inductor_config=False)
12411256
mod2(example_input)
12421257
mod2(example_input2)
12431258
mod2.finalize_autoquant()
@@ -1254,6 +1269,7 @@ def test_autoquant_manual(self, device, dtype):
12541269
]))
12551270
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.")
12561271
def test_autoquant_kwargs(self, device, dtype, m1, m2, k, n):
1272+
undo_recommended_configs()
12571273
if device != "cuda" or not torch.cuda.is_available():
12581274
self.skipTest(f"autoquant currently does not support {device}")
12591275
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0):
@@ -1280,7 +1296,7 @@ def forward(self, x, y):
12801296
}
12811297
out = model(**example_input)
12821298

1283-
mod = torchao.autoquant(torch.compile(model))
1299+
mod = torchao.autoquant(torch.compile(model), set_inductor_config=False)
12841300
mod(**example_input)
12851301

12861302
out2 = mod(**example_input)
@@ -1293,6 +1309,7 @@ def forward(self, x, y):
12931309
]))
12941310
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.")
12951311
def test_autoquant_double_access(self, device, dtype, m, k, n):
1312+
undo_recommended_configs()
12961313
if device != "cuda" or not torch.cuda.is_available():
12971314
self.skipTest(f"autoquant currently does not support {device}")
12981315
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0):
@@ -1316,7 +1333,7 @@ def forward(self, x):
13161333
x_in = torch.randn(m, k, device=device, dtype=dtype)
13171334
model = DoubleAccess().to(device).to(dtype)
13181335
model(x_in)
1319-
torchao.autoquant(model)
1336+
torchao.autoquant(model, set_inductor_config=False)
13201337
assert not isinstance(model.lin1.weight.weight, AutoQuantizableLinearWeight)
13211338
model(x_in)
13221339

@@ -1443,7 +1460,7 @@ def test_get_model_size_autoquant(self, device, dtype):
14431460
qtensor_class_list = (
14441461
AQWeightOnlyQuantizedLinearWeight2,
14451462
)
1446-
mod = torchao.autoquant(torch.compile(model), qtensor_class_list = qtensor_class_list)
1463+
mod = torchao.autoquant(torch.compile(model), qtensor_class_list = qtensor_class_list, set_inductor_config=False)
14471464
mod(example_input)
14481465
size2 = torchao.utils.get_model_size_in_bytes(mod)
14491466
self.assertTrue(size2 < size)

torchao/_models/llama/eval.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@
2323
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer
2424
from torchao._models.llama.model import prepare_inputs_for_model
2525

26-
torch._inductor.config.fx_graph_cache = True
27-
torch._inductor.config.force_fuse_int_mm_with_mul = True
28-
2926
def run_evaluation(
3027
checkpoint_path: Path,
3128
tasks: List[str],
@@ -41,6 +38,9 @@ def run_evaluation(
4138
pad_calibration_inputs: Optional[bool] = False,
4239
):
4340
"""Runs the evaluation of a model using LM Eval."""
41+
42+
torchao.quantization.utils.recommended_inductor_config_setter()
43+
4444
assert checkpoint_path.is_file(), checkpoint_path
4545
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
4646
assert tokenizer_path.is_file(), str(tokenizer_path)

torchao/_models/llama/generate.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,6 @@ def device_sync(device):
2222
else:
2323
print(f"device={device} is not yet suppported")
2424

25-
26-
torch._inductor.config.coordinate_descent_tuning = True
27-
torch._inductor.config.triton.unique_kernel_names = True
28-
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
29-
torch._inductor.config.force_fuse_int_mm_with_mul = True
30-
# torch._inductor.config.use_mixed_mm = True
31-
3225
default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
3326

3427
# support running without installing as a package
@@ -163,6 +156,9 @@ def main(
163156
) -> None:
164157
"""Generates text samples based on a pre-trained Transformer model and tokenizer.
165158
"""
159+
160+
torchao.quantization.utils.recommended_inductor_config_setter()
161+
166162
assert checkpoint_path.is_file(), checkpoint_path
167163
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
168164
assert tokenizer_path.is_file(), str(tokenizer_path)
@@ -203,7 +199,7 @@ def main(
203199
if "int4wo" in quantization:
204200
groupsize=int(quantization.split("-")[-1])
205201
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
206-
quantize(model, int4_weight_only(groupsize=groupsize))
202+
quantize(model, int4_weight_only(group_size=groupsize))
207203
if "autoquant" == quantization:
208204
model = autoquant(model, manual=True)
209205

@@ -339,8 +335,8 @@ def callback(x):
339335
parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.')
340336
parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
341337
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
342-
parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
343-
parser.add_argument("--quantization", type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, autoquant')
338+
parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
339+
parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, autoquant')
344340
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
345341
parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)')
346342
parser.add_argument('--profile', type=Path, default=None, help='Profile path.')

torchao/quantization/README.md

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,6 @@ of the activations that the different linear layers see, it then benchmarks thes
3030
import torch
3131
import torchao
3232

33-
# inductor settings which improve torch.compile performance for quantized modules
34-
torch._inductor.config.force_fuse_int_mm_with_mul = True
35-
torch._inductor.config.use_mixed_mm = True
36-
3733
# Plug in your model and example input
3834
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
3935
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')
@@ -107,9 +103,6 @@ m_bf16 = torch.compile(m_bf16, mode='max-autotune')
107103
group_size = 32
108104
m = quantize(m, int4_weight_only(group_size=group_size))
109105

110-
torch._inductor.config.force_fuse_int_mm_with_mul = True
111-
torch._inductor.config.use_mixed_mm = True
112-
113106
# temporary workaround for tensor subclass + torch.compile
114107
from torchao.quantization.utils import unwrap_tensor_subclass
115108
m = unwrap_tensor_subclass(m)
@@ -163,6 +156,9 @@ m = torch.export.export(m_unwrapped, example_inputs).module()
163156
torch._export.aot_compile(m_unwrapped, example_inputs)
164157
```
165158

159+
### Automatic Inductor Configuration
160+
The `quantize` and `autoquant` apis now automatically use our recommended inductor configuration setings. You can mimic the same configuration settings for your own experiments by using the `torchao.quantization.utils.recommended_inductor_config_setter` to replicate our recommended configuration settings. Alternatively if you wish to disable these recommended settings, you can use the key word argument `set_inductor_config` and set it to false in the `quantize` or `autoquant` apis to prevent assignment of those configuration settings. You can also overwrite these configuration settings after they are assigned if you so desire, as long as they are overwritten before passing any inputs to the torch.compiled model. This means that previous flows which referenced a variety of inductor configurations that needed to be set are now outdated, though continuing to manually set those same inductor configurations is unlikely to cause any issues.
161+
166162
### Other Available Quantization Techniques
167163
#### A8W8 Dynamic Quantization
168164

torchao/quantization/autoquant.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import torchao
23
from .subclass import ( # noqa
34
Int8DynamicallyQuantizedLinearWeight,
45
Int8WeightOnlyQuantizedLinearWeight,
@@ -90,7 +91,11 @@ def tune_autoquant(self, q_cls, shapes_and_dtype, best_time):
9091
with torch.no_grad():
9192
act_mat = torch.randn(act_shape, dtype=act_dtype, device=self.device)
9293
bias = None if bias_shape is None else torch.randn(bias_shape, dtype=act_dtype, device=self.device)
93-
res = q_cls._autoquant_test(act_mat, self.weight, bias, best_time, self.mode)
94+
try:
95+
res = q_cls._autoquant_test(act_mat, self.weight, bias, best_time, self.mode)
96+
except Exception as e:
97+
print(f"warning: failed to autoquant {q_cls.__name__} for shape: {shapes_and_dtype} due to {e}")
98+
res = torch.inf
9499
update_cache(q_cls, shapes_and_dtype, res)
95100

96101
@torch.no_grad()
@@ -407,16 +412,21 @@ def _change_linears_to_autoquantizable(model, **kwargs):
407412
filter_fn if filter_fn is not None else _is_linear,
408413
)
409414

410-
def _change_autoquantizable_to_quantized(model, **kwargs):
415+
def _change_autoquantizable_to_quantized(model, supress_autoquant_errors=True, **kwargs):
411416
"""
412417
Converts AutoQuantizableLinearWeight tensor subclasses
413418
to various quantized/non-quantized tensor subclasses depending
414419
on benchmark results. Expectation is that these modules are
415420
torch.compiled afterwards.
416421
"""
417-
hold = torch._dynamo.config.automatic_dynamic_shapes
422+
hold_automatic_dynamic_shapes = torch._dynamo.config.automatic_dynamic_shapes
418423
torch._dynamo.config.automatic_dynamic_shapes = False
419424

425+
if supress_autoquant_errors:
426+
hold_supress_errors = torch._dynamo.config.suppress_errors
427+
torch._dynamo.config.suppress_errors = True
428+
import logging
429+
torch._logging.set_logs(inductor=logging.CRITICAL, dynamo=logging.CRITICAL)
420430
filter_fn = kwargs.pop(
421431
"filter_fn",
422432
lambda mod, *args:
@@ -432,7 +442,13 @@ def _change_autoquantizable_to_quantized(model, **kwargs):
432442
),
433443
filter_fn,
434444
)
435-
torch._dynamo.config.automatic_dynamic_shapes = hold
445+
# undo dynamic shape change
446+
torch._dynamo.config.automatic_dynamic_shapes = hold_automatic_dynamic_shapes
447+
448+
# undo error supression
449+
if supress_autoquant_errors:
450+
torch._dynamo.config.suppress_errors = hold_supress_errors
451+
torch._logging.set_logs()
436452
torch._dynamo.reset()
437453

438454
# TODO: example_input seems weird to include in the API
@@ -443,8 +459,11 @@ def autoquant(
443459
model,
444460
example_input=None,
445461
qtensor_class_list=DEFAULT_CLASS_LIST,
446-
filter_fn=None, mode=["interpolate", .85],
462+
filter_fn=None,
463+
mode=["interpolate", .85],
447464
manual=False,
465+
set_inductor_config=True,
466+
supress_autoquant_errors=True,
448467
**aq_kwargs
449468
):
450469
"""
@@ -477,6 +496,8 @@ def autoquant(
477496
and the second element is the mode value (e.g., 0.85). Defaults to ["interpolate", .85].
478497
manual (bool, optional): Whether to stop shape calibration and do autoquant after a single run (default, False) or to wait for
479498
the user to call model.finalize_autoquant (True) so inputs with several shapes/dtypes can be logged.
499+
set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True)
500+
supress_autoquant_errors (bool, optional): Whether to suppress errors during autoquantization. (defaults to True)
480501
**aq_kwargs: Additional keyword arguments for the autoquantization process.
481502
482503
Returns:
@@ -493,6 +514,9 @@ def autoquant(
493514
model(*example_input2)
494515
model.finalize_autoquant()
495516
"""
517+
if set_inductor_config:
518+
torchao.quantization.utils.recommended_inductor_config_setter()
519+
496520

497521
# perform initial swap from linear weights
498522
# to AutoQuantizableLinearWeight
@@ -539,6 +563,7 @@ def autoquant_prehook(module, args, kwargs):
539563
def finalize_autoquant():
540564
_change_autoquantizable_to_quantized(
541565
real_model,
566+
supress_autoquant_errors,
542567
**aq_kwargs,
543568
)
544569
if hasattr(real_model, "old_forward"):

0 commit comments

Comments
 (0)