77import torch
88import torch ._dynamo as td
99import torch .utils ._pytree as pytree
10- import torch_tensorrt
1110from torch ._dynamo .utils import detect_fake_mode
1211from torch ._functorch .aot_autograd import _aot_export_function
1312from torch ._ops import OpOverload
1413from torch_tensorrt .dynamo import CompilationSettings
1514from torch_tensorrt .dynamo .compile import compile_module
16- from torch_tensorrt .dynamo .lowering . _decompositions import get_decompositions
15+ from torch_tensorrt .dynamo .lowering import apply_lowering_passes , get_decompositions
1716from torch_tensorrt .dynamo .lowering ._pre_aot_lowering import pre_aot_substitutions
1817from torch_tensorrt .dynamo .utils import parse_dynamo_kwargs
1918
20- from packaging import version
21-
22- # Modify import location of utilities based on Torch version
23- if version .parse (torch_tensorrt .sanitized_torch_version ()) < version .parse ("2.1.1" ):
24- from torch ._inductor .freezing import ConstantFolder , replace_node_with_constant
25- else :
26- from torch ._inductor .constant_folding import (
27- ConstantFolder ,
28- replace_node_with_constant ,
29- )
30-
3119logger = logging .getLogger (__name__ )
3220
3321
@@ -87,7 +75,7 @@ def _pretraced_backend(
8775 fake_mode , "allow_non_fake_inputs" , True
8876 ), fake_mode :
8977 # Invoke AOTAutograd to translate operators to aten
90- graph_module = aot_export_for_compile (
78+ gm = aot_export_for_compile (
9179 gm ,
9280 sample_inputs ,
9381 decompositions = get_decompositions (
@@ -97,10 +85,10 @@ def _pretraced_backend(
9785
9886 logger .debug ("Post-AOT Autograd graph:\n " + str (gm .graph ))
9987
100- constant_fold ( graph_module )
88+ gm = apply_lowering_passes ( gm )
10189
10290 trt_compiled = compile_module (
103- graph_module ,
91+ gm ,
10492 sample_inputs ,
10593 settings = settings ,
10694 )
@@ -124,35 +112,6 @@ def _pretraced_backend(
124112 raise
125113
126114
127- @torch .utils ._python_dispatch ._disable_current_modes () # type: ignore
128- def constant_fold (gm : torch .fx .GraphModule ) -> Any :
129- """Adapted from:
130- https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197
131-
132- Folds constants in the graph module, not skipping constructors
133-
134- Modifies the graph in-place and replaces node with constants
135- """
136- cf = ConstantFolder (gm , skip_constructors = False )
137- cf .run ()
138-
139- for node , constant in cf .node_replacements .items ():
140- replace_node_with_constant (gm , node , constant )
141-
142- erased_params = []
143- for node in gm .graph .nodes :
144- if node .op == "get_attr" and len (node .users ) == 0 :
145- delattr (gm , node .target )
146- erased_params .append (node )
147-
148- for node in erased_params :
149- gm .graph .erase_node (node )
150-
151- gm .graph .eliminate_dead_code ()
152- gm .graph .lint ()
153- gm .recompile ()
154-
155-
156115def aot_export_for_compile (
157116 func : torch .fx .GraphModule ,
158117 args : Sequence [torch .Tensor ],
0 commit comments