Skip to content

Commit b51b0f4

Browse files
committed
Fixed issues in comments
1 parent 227deee commit b51b0f4

File tree

2 files changed

+25
-29
lines changed

2 files changed

+25
-29
lines changed

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def refit_module_weights(
207207
arg_inputs: Optional[Tuple[Any, ...]] = None,
208208
kwarg_inputs: Optional[dict[str, Any]] = None,
209209
verify_output: bool = False,
210-
fast_refit: bool = True,
210+
use_weight_map_cache: bool = True,
211211
in_place: bool = False,
212212
) -> torch.fx.GraphModule:
213213
"""
@@ -232,11 +232,11 @@ def refit_module_weights(
232232
inline_module = True
233233

234234
if not in_place:
235-
if inline_module:
236-
logger.warning(
237-
"Inplace has no effect on exported program. Please use the returned module as the updated module."
238-
)
239235
compiled_module = copy.deepcopy(compiled_module)
236+
elif inline_module:
237+
raise AssertionError(
238+
"Exported program does not support modifying in place. Please set inplace to false and use the returned graph module."
239+
)
240240

241241
# Get the settings and check the setting to be uniform
242242
settings: CompilationSettings = None
@@ -254,7 +254,7 @@ def refit_module_weights(
254254
]
255255
assert (
256256
encoded_metadata != ""
257-
), "Settings are not saved in the engine. Please recompile the engine with make_refitable=True."
257+
), "The engine provided is either not refittable or was built with a version of Torch-TensorRT that is too old, please recompile using the latest version with make_refitable=True"
258258
settings = TorchTensorRTModule.decode_metadata(encoded_metadata)["settings"]
259259
# Handle torch modules
260260
compiled_submodules_map = dict(compiled_submodules)
@@ -365,36 +365,32 @@ def refit_module_weights(
365365
engine = get_engine_from_encoded_engine(
366366
engine_info[ENGINE_IDX], runtime
367367
)
368-
if fast_refit:
368+
if use_weight_map_cache:
369369
encoded_metadata = compiled_submodule.__getstate__()[0][
370370
SERIALIZED_METADATA_IDX
371371
]
372-
assert (
373-
encoded_metadata != ""
374-
), "Metadata are not saved in the engine. Please recompile the engine with make_refitable=True."
375372
weight_name_map = TorchTensorRTModule.decode_metadata(
376373
encoded_metadata
377374
)["weight_name_map"]
378375
if not weight_name_map:
379-
fast_refit = False
376+
use_weight_map_cache = False
380377
logger.warning(
381378
"Fast refitting is not supported in this module. Use regular refitting."
382379
)
383380
else:
384381
compiled_submodule = getattr(compiled_module, name)
385382
weight_name_map = None
386-
if fast_refit:
383+
if use_weight_map_cache:
387384
try:
388385
weight_name_map = compiled_submodule.weight_name_map
389386
except AttributeError:
390-
fast_refit = False
391387
logger.warning(
392-
"You are using a old version of Torch-TensorRT. Please re-compile the engine to avoid failures."
388+
"The module was compiled wit an old version of Torch-TensorRT. Rebuilding the weight map."
393389
)
394390
if not weight_name_map:
395-
fast_refit = False
391+
use_weight_map_cache = False
396392
logger.warning(
397-
"Fast refitting is not supported in this module. Use regular refitting."
393+
"This engine does not have a weight map cache. Rebuilding the weight map"
398394
)
399395
if isinstance(compiled_submodule, PythonTorchTensorRTModule):
400396
engine = compiled_submodule.engine
@@ -443,7 +439,7 @@ def refit_module_weights(
443439
except AssertionError as e:
444440
# If fast_refit is used and failed, we fall back to regular refit
445441
logger.warning(e)
446-
if fast_refit and weight_name_map:
442+
if use_weight_map_cache and weight_name_map:
447443
_refit_single_trt_engine_with_gm(
448444
new_gm=new_submodule,
449445
old_engine=engine,

tests/py/dynamo/models/test_model_refit.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def test_fast_refit_one_engine():
109109
compiled_module=trt_gm,
110110
new_weight_module=exp_program2,
111111
inputs=inputs,
112-
fast_refit=True,
112+
use_weight_map_cache=True,
113113
)
114114

115115
# Check the output
@@ -156,7 +156,7 @@ def test_fast_refit_one_engine_no_map():
156156
compiled_module=trt_gm,
157157
new_weight_module=exp_program2,
158158
inputs=inputs,
159-
fast_refit=True,
159+
use_weight_map_cache=True,
160160
)
161161

162162
# Check the output
@@ -207,7 +207,7 @@ def test_fast_refit_one_engine_wrong_map():
207207
compiled_module=trt_gm,
208208
new_weight_module=exp_program2,
209209
inputs=inputs,
210-
fast_refit=True,
210+
use_weight_map_cache=True,
211211
)
212212

213213
# Check the output
@@ -254,7 +254,7 @@ def test_fast_refit_one_engine_bert():
254254
compiled_module=trt_gm,
255255
new_weight_module=exp_program2,
256256
inputs=inputs,
257-
fast_refit=True,
257+
use_weight_map_cache=True,
258258
)
259259

260260
# Check the output
@@ -304,7 +304,7 @@ def test_fast_refit_one_engine_inline_runtime():
304304
compiled_module=trt_gm,
305305
new_weight_module=exp_program2,
306306
inputs=inputs,
307-
fast_refit=True,
307+
use_weight_map_cache=True,
308308
)
309309

310310
# Check the output
@@ -349,7 +349,7 @@ def test_fast_refit_one_engine_python_runtime():
349349
compiled_module=trt_gm,
350350
new_weight_module=exp_program2,
351351
inputs=inputs,
352-
fast_refit=True,
352+
use_weight_map_cache=True,
353353
)
354354

355355
# Check the output
@@ -416,7 +416,7 @@ def forward(self, x):
416416
compiled_module=trt_gm,
417417
new_weight_module=exp_program2,
418418
inputs=inputs,
419-
fast_refit=True,
419+
use_weight_map_cache=True,
420420
)
421421

422422
# Check the output
@@ -461,7 +461,7 @@ def test_refit_one_engine():
461461
compiled_module=trt_gm,
462462
new_weight_module=exp_program2,
463463
inputs=inputs,
464-
fast_refit=False,
464+
use_weight_map_cache=False,
465465
)
466466

467467
# Check the output
@@ -508,7 +508,7 @@ def test_refit_one_engine_bert():
508508
compiled_module=trt_gm,
509509
new_weight_module=exp_program2,
510510
inputs=inputs,
511-
fast_refit=False,
511+
use_weight_map_cache=False,
512512
)
513513

514514
# Check the output
@@ -558,7 +558,7 @@ def test_refit_one_engine_inline_runtime():
558558
compiled_module=trt_gm,
559559
new_weight_module=exp_program2,
560560
inputs=inputs,
561-
fast_refit=False,
561+
use_weight_map_cache=False,
562562
)
563563

564564
# Check the output
@@ -603,7 +603,7 @@ def test_refit_one_engine_python_runtime():
603603
compiled_module=trt_gm,
604604
new_weight_module=exp_program2,
605605
inputs=inputs,
606-
fast_refit=False,
606+
use_weight_map_cache=False,
607607
)
608608

609609
# Check the output
@@ -670,7 +670,7 @@ def forward(self, x):
670670
compiled_module=trt_gm,
671671
new_weight_module=exp_program2,
672672
inputs=inputs,
673-
fast_refit=False,
673+
use_weight_map_cache=False,
674674
)
675675

676676
# Check the output

0 commit comments

Comments
 (0)