99import torch .utils ._pytree as pytree
1010from torch ._dynamo .utils import detect_fake_mode
1111from torch ._functorch .aot_autograd import _aot_export_function
12- from torch ._inductor .constant_folding import ConstantFolder , replace_node_with_constant
1312from torch ._ops import OpOverload
1413from torch_tensorrt .dynamo import CompilationSettings
1514from torch_tensorrt .dynamo .compile import compile_module
16- from torch_tensorrt .dynamo .lowering . _decompositions import get_decompositions
15+ from torch_tensorrt .dynamo .lowering import apply_lowering_passes , get_decompositions
1716from torch_tensorrt .dynamo .lowering ._pre_aot_lowering import pre_aot_substitutions
1817from torch_tensorrt .dynamo .utils import parse_dynamo_kwargs
1918
@@ -75,7 +74,7 @@ def _pretraced_backend(
7574 fake_mode , "allow_non_fake_inputs" , True
7675 ), fake_mode :
7776 # Invoke AOTAutograd to translate operators to aten
78- graph_module = aot_export_for_compile (
77+ gm = aot_export_for_compile (
7978 gm ,
8079 sample_inputs ,
8180 decompositions = get_decompositions (
@@ -85,10 +84,10 @@ def _pretraced_backend(
8584
8685 logger .debug ("Post-AOT Autograd graph:\n " + str (gm .graph ))
8786
88- constant_fold ( graph_module )
87+ gm = apply_lowering_passes ( gm )
8988
9089 trt_compiled = compile_module (
91- graph_module ,
90+ gm ,
9291 sample_inputs ,
9392 settings = settings ,
9493 )
@@ -112,35 +111,6 @@ def _pretraced_backend(
112111 raise
113112
114113
115- @torch .utils ._python_dispatch ._disable_current_modes () # type: ignore
116- def constant_fold (gm : torch .fx .GraphModule ) -> Any :
117- """Adapted from:
118- https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197
119-
120- Folds constants in the graph module, not skipping constructors
121-
122- Modifies the graph in-place and replaces node with constants
123- """
124- cf = ConstantFolder (gm , skip_constructors = False )
125- cf .run ()
126-
127- for node , constant in cf .node_replacements .items ():
128- replace_node_with_constant (gm , node , constant )
129-
130- erased_params = []
131- for node in gm .graph .nodes :
132- if node .op == "get_attr" and len (node .users ) == 0 :
133- delattr (gm , node .target )
134- erased_params .append (node )
135-
136- for node in erased_params :
137- gm .graph .erase_node (node )
138-
139- gm .graph .eliminate_dead_code ()
140- gm .graph .lint ()
141- gm .recompile ()
142-
143-
144114def aot_export_for_compile (
145115 func : torch .fx .GraphModule ,
146116 args : Sequence [torch .Tensor ],
0 commit comments