1010import tensorrt as trt
1111import torch
1212import torch .fx
13- from torch ._ops import OpOverload
1413from torch .fx .node import _get_qualified_name
1514from torch .fx .passes .shape_prop import TensorMetadata
1615
17- from torch_tensorrt .fx import CONVERTERS
16+ from torch_tensorrt .dynamo import DYNAMO_CONVERTERS as CONVERTERS
1817from torch_tensorrt import Input
1918from torch_tensorrt .fx .observer import Observer
2019from torch_tensorrt .fx .utils import (
@@ -69,6 +68,7 @@ def __init__(
6968 self .input_specs = input_specs
7069 self .input_specs_iter = 0
7170 self ._cur_node_name : Optional [str ] = None
71+ self ._cur_node : Optional [torch .fx .Node ] = None
7272 self ._input_names : List [str ] = []
7373 self ._output_names : List [str ] = []
7474 self ._itensor_to_tensor_meta : Dict [
@@ -82,14 +82,14 @@ def validate_conversion(self):
8282 missing_converter = set ()
8383
8484 for node in self .module .graph .nodes :
85- if node .op == "call_function" and not CONVERTERS .get (node . target ):
85+ if node .op == "call_function" and not CONVERTERS .get (node ):
8686 missing_converter .add (f"{ node .op } { _get_qualified_name (node .target )} " )
87- elif node .op == "call_method" and not CONVERTERS .get (node . target ):
87+ elif node .op == "call_method" and not CONVERTERS .get (node ):
8888 missing_converter .add (f"{ node .op } torch.Tensor.{ node .target } " )
8989 elif node .op == "call_module" :
9090 submod = self .fetch_attr (node .target )
9191 submod_type = getattr (submod , "_base_class_origin" , type (submod ))
92- if not CONVERTERS .get (submod_type ):
92+ if not CONVERTERS .get (node ):
9393 missing_converter .add (f"{ node .op } { torch .typename (submod_type )} " )
9494
9595 return missing_converter
@@ -226,6 +226,7 @@ def run(
226226
227227 def run_node (self , n ):
228228 self ._cur_node_name = str (n )
229+ self ._cur_node = n
229230 # add "_itensor_to_tensor_meta"
230231 kwargs = dict (n .kwargs )
231232 kwargs ["_itensor_to_tensor_meta" ] = self ._itensor_to_tensor_meta
@@ -276,7 +277,7 @@ def call_module(self, target, args, kwargs):
276277 assert isinstance (target , str )
277278 submod = self .fetch_attr (target )
278279 submod_type = getattr (submod , "_base_class_origin" , type (submod ))
279- converter = CONVERTERS .get (submod_type )
280+ converter = CONVERTERS .get (self . _cur_node )
280281
281282 if not converter :
282283 raise RuntimeError (
@@ -287,7 +288,7 @@ def call_module(self, target, args, kwargs):
287288 return converter (self .network , submod , args , kwargs , self ._cur_node_name )
288289
289290 def call_function (self , target , args , kwargs ):
290- converter = CONVERTERS .get (target )
291+ converter = CONVERTERS .get (self . _cur_node )
291292 if not converter :
292293 raise RuntimeError (
293294 f"Conversion of function { torch .typename (target )} not currently supported!"
@@ -298,7 +299,7 @@ def call_function(self, target, args, kwargs):
298299
299300 def call_method (self , target , args , kwargs ):
300301 assert isinstance (target , str )
301- converter = CONVERTERS .get (target )
302+ converter = CONVERTERS .get (self . _cur_node )
302303
303304 if not converter :
304305 raise RuntimeError (
0 commit comments