Skip to content

Commit 257ab4a

Browse files
committed
Fixed issues in comments
1 parent b1cc096 commit 257ab4a

File tree

2 files changed

+25
-29
lines changed

2 files changed

+25
-29
lines changed

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def refit_module_weights(
207207
new_weight_module: ExportedProgram,
208208
inputs: Tuple[Any, ...],
209209
verify_output: bool = False,
210-
fast_refit: bool = True,
210+
use_weight_map_cache: bool = True,
211211
in_place: bool = False,
212212
) -> torch.fx.GraphModule:
213213
"""
@@ -231,11 +231,11 @@ def refit_module_weights(
231231
inline_module = True
232232

233233
if not in_place:
234-
if inline_module:
235-
logger.warning(
236-
"Inplace has no effect on exported program. Please use the returned module as the updated module."
237-
)
238234
compiled_module = copy.deepcopy(compiled_module)
235+
elif inline_module:
236+
raise AssertionError(
237+
"Exported program does not support modifying in place. Please set inplace to false and use the returned graph module."
238+
)
239239

240240
# Get the settings and check the setting to be uniform
241241
settings: CompilationSettings = None
@@ -253,7 +253,7 @@ def refit_module_weights(
253253
]
254254
assert (
255255
encoded_metadata != ""
256-
), "Settings are not saved in the engine. Please recompile the engine with make_refitable=True."
256+
), "The engine provided is either not refittable or was built with a version of Torch-TensorRT that is too old, please recompile using the latest version with make_refitable=True"
257257
settings = TorchTensorRTModule.decode_metadata(encoded_metadata)["settings"]
258258
# Handle torch modules
259259
compiled_submodules_map = dict(compiled_submodules)
@@ -362,36 +362,32 @@ def refit_module_weights(
362362
engine = get_engine_from_encoded_engine(
363363
engine_info[ENGINE_IDX], runtime
364364
)
365-
if fast_refit:
365+
if use_weight_map_cache:
366366
encoded_metadata = compiled_submodule.__getstate__()[0][
367367
SERIALIZED_METADATA_IDX
368368
]
369-
assert (
370-
encoded_metadata != ""
371-
), "Metadata are not saved in the engine. Please recompile the engine with make_refitable=True."
372369
weight_name_map = TorchTensorRTModule.decode_metadata(
373370
encoded_metadata
374371
)["weight_name_map"]
375372
if not weight_name_map:
376-
fast_refit = False
373+
use_weight_map_cache = False
377374
logger.warning(
378375
"Fast refitting is not supported in this module. Use regular refitting."
379376
)
380377
else:
381378
compiled_submodule = getattr(compiled_module, name)
382379
weight_name_map = None
383-
if fast_refit:
380+
if use_weight_map_cache:
384381
try:
385382
weight_name_map = compiled_submodule.weight_name_map
386383
except AttributeError:
387-
fast_refit = False
388384
logger.warning(
389-
"You are using a old version of Torch-TensorRT. Please re-compile the engine to avoid failures."
385+
"The module was compiled wit an old version of Torch-TensorRT. Rebuilding the weight map."
390386
)
391387
if not weight_name_map:
392-
fast_refit = False
388+
use_weight_map_cache = False
393389
logger.warning(
394-
"Fast refitting is not supported in this module. Use regular refitting."
390+
"This engine does not have a weight map cache. Rebuilding the weight map"
395391
)
396392
if isinstance(compiled_submodule, PythonTorchTensorRTModule):
397393
engine = compiled_submodule.engine
@@ -440,7 +436,7 @@ def refit_module_weights(
440436
except AssertionError as e:
441437
# If fast_refit is used and failed, we fall back to regular refit
442438
logger.warning(e)
443-
if fast_refit and weight_name_map:
439+
if use_weight_map_cache and weight_name_map:
444440
_refit_single_trt_engine_with_gm(
445441
new_gm=new_submodule,
446442
old_engine=engine,

tests/py/dynamo/models/test_model_refit.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def test_fast_refit_one_engine():
109109
compiled_module=trt_gm,
110110
new_weight_module=exp_program2,
111111
inputs=inputs,
112-
fast_refit=True,
112+
use_weight_map_cache=True,
113113
)
114114

115115
# Check the output
@@ -156,7 +156,7 @@ def test_fast_refit_one_engine_no_map():
156156
compiled_module=trt_gm,
157157
new_weight_module=exp_program2,
158158
inputs=inputs,
159-
fast_refit=True,
159+
use_weight_map_cache=True,
160160
)
161161

162162
# Check the output
@@ -207,7 +207,7 @@ def test_fast_refit_one_engine_wrong_map():
207207
compiled_module=trt_gm,
208208
new_weight_module=exp_program2,
209209
inputs=inputs,
210-
fast_refit=True,
210+
use_weight_map_cache=True,
211211
)
212212

213213
# Check the output
@@ -254,7 +254,7 @@ def test_fast_refit_one_engine_bert():
254254
compiled_module=trt_gm,
255255
new_weight_module=exp_program2,
256256
inputs=inputs,
257-
fast_refit=True,
257+
use_weight_map_cache=True,
258258
)
259259

260260
# Check the output
@@ -304,7 +304,7 @@ def test_fast_refit_one_engine_inline_runtime():
304304
compiled_module=trt_gm,
305305
new_weight_module=exp_program2,
306306
inputs=inputs,
307-
fast_refit=True,
307+
use_weight_map_cache=True,
308308
)
309309

310310
# Check the output
@@ -349,7 +349,7 @@ def test_fast_refit_one_engine_python_runtime():
349349
compiled_module=trt_gm,
350350
new_weight_module=exp_program2,
351351
inputs=inputs,
352-
fast_refit=True,
352+
use_weight_map_cache=True,
353353
)
354354

355355
# Check the output
@@ -416,7 +416,7 @@ def forward(self, x):
416416
compiled_module=trt_gm,
417417
new_weight_module=exp_program2,
418418
inputs=inputs,
419-
fast_refit=True,
419+
use_weight_map_cache=True,
420420
)
421421

422422
# Check the output
@@ -461,7 +461,7 @@ def test_refit_one_engine():
461461
compiled_module=trt_gm,
462462
new_weight_module=exp_program2,
463463
inputs=inputs,
464-
fast_refit=False,
464+
use_weight_map_cache=False,
465465
)
466466

467467
# Check the output
@@ -508,7 +508,7 @@ def test_refit_one_engine_bert():
508508
compiled_module=trt_gm,
509509
new_weight_module=exp_program2,
510510
inputs=inputs,
511-
fast_refit=False,
511+
use_weight_map_cache=False,
512512
)
513513

514514
# Check the output
@@ -558,7 +558,7 @@ def test_refit_one_engine_inline_runtime():
558558
compiled_module=trt_gm,
559559
new_weight_module=exp_program2,
560560
inputs=inputs,
561-
fast_refit=False,
561+
use_weight_map_cache=False,
562562
)
563563

564564
# Check the output
@@ -603,7 +603,7 @@ def test_refit_one_engine_python_runtime():
603603
compiled_module=trt_gm,
604604
new_weight_module=exp_program2,
605605
inputs=inputs,
606-
fast_refit=False,
606+
use_weight_map_cache=False,
607607
)
608608

609609
# Check the output
@@ -670,7 +670,7 @@ def forward(self, x):
670670
compiled_module=trt_gm,
671671
new_weight_module=exp_program2,
672672
inputs=inputs,
673-
fast_refit=False,
673+
use_weight_map_cache=False,
674674
)
675675

676676
# Check the output

0 commit comments

Comments
 (0)