11from __future__ import annotations
22
3+ import functools
34import logging
45from dataclasses import dataclass , field
56from enum import Enum , auto
1718 cast ,
1819)
1920
21+ import tensorrt as trt
22+ import torch
23+ from torch import SymBool , SymFloat , SymInt
2024from torch ._ops import OpOverloadPacket
2125from torch .fx .node import Argument , Node , Target , _get_qualified_name
2226from torch_tensorrt .dynamo .conversion ._ConversionContext import ConversionContext
2327from torch_tensorrt .fx .converter_registry import CONVERTERS as FX_CONVERTERS
2428
25- import tensorrt as trt
26-
2729logger = logging .getLogger (__name__ )
2830
2931LegacyConverterImplSignature = Callable [
@@ -76,22 +78,119 @@ class ConverterSupport:
7678 capability_validator: Function which takes in a Node and returns a bool indicating
7779 whether that node can be supported by its companion converter. Note that
7880 this function must not modify the node or its graph
81+ supports_dynamic_shapes: Boolean flag indicating if the converter has support for dynamic inputs.
7982 """
8083
8184 converter_implementation : ConverterImplSignature
8285 capability_validator : Callable [[Node ], bool ] = field (default = lambda node : True )
86+ supports_dynamic_shapes : bool = False
8387
8488
8589# Dictionary representing Dynamo aten-only converters
8690# Each converter maps to a sequence of at least one ConverterSupport object(s)
8791DYNAMO_ATEN_CONVERTERS : Dict [Target , Sequence [ConverterSupport ]] = {}
8892
8993
94+ def has_static_shapes (node : torch .fx .Node ) -> bool :
95+ """Returns True if a node has static args, kwargs, or outputs"""
96+ return not _has_dynamic_shapes (node = node )
97+
98+
99+ def has_dynamic_shapes (node : torch .fx .Node ) -> bool :
100+ """Returns True if a node has dynamic args, kwargs, or outputs"""
101+ return _has_dynamic_shapes (node = node )
102+
103+
104+ def has_dynamic_shapes_in_args (
105+ arg_positions_to_check : Optional [List [int ]] = None ,
106+ ) -> Callable [[torch .fx .Node ], bool ]:
107+ """Returns True if a node has dynamic inputs in node.args at specified positions"""
108+ return functools .partial (
109+ _has_dynamic_shapes , arg_positions_to_check = arg_positions_to_check
110+ )
111+
112+
113+ def has_static_shapes_in_args (
114+ arg_positions_to_check : Optional [List [int ]] = None ,
115+ ) -> Callable [[torch .fx .Node ], bool ]:
116+ """Returns True if a node has static inputs in node.args at specified positions"""
117+ _has_static_shapes = lambda node , arg_positions_to_check : not _has_dynamic_shapes (
118+ node , arg_positions_to_check
119+ )
120+ return functools .partial (
121+ _has_static_shapes , arg_positions_to_check = arg_positions_to_check
122+ )
123+
124+
125+ def _has_dynamic_shapes (
126+ node : torch .fx .Node , arg_positions_to_check : Optional [List [int ]] = None
127+ ) -> bool :
128+ # Validate that none of the inputs to the node have Dynamic shapes
129+ assert isinstance (
130+ node , torch .fx .Node
131+ ), "Inputs to validator functions must be FX Nodes"
132+
133+ def _is_subnode_dynamic (subnode : torch .fx .Node ) -> bool :
134+ """Checks if a node itself has Dynamic properties"""
135+ _has_symbolic_sizes_strides , is_shape_dynamic = False , False
136+ if "val" in subnode .meta :
137+ _has_symbolic_sizes_strides = getattr (
138+ subnode .meta ["val" ], "_has_symbolic_sizes_strides" , False
139+ )
140+ meta_val = subnode .meta ["val" ]
141+ if isinstance (meta_val , (list , tuple )):
142+ for val in meta_val :
143+ shape = val .size ()
144+ if any (
145+ isinstance (dim , (SymFloat , SymInt , SymBool )) for dim in shape
146+ ):
147+ is_shape_dynamic = True
148+ break
149+ elif isinstance (meta_val , (SymFloat , SymInt , SymBool )):
150+ is_shape_dynamic = True
151+ else :
152+ shape = subnode .meta ["val" ].size ()
153+ is_shape_dynamic = any (
154+ isinstance (dim , (SymFloat , SymInt , SymBool )) for dim in shape
155+ )
156+
157+ return _has_symbolic_sizes_strides or is_shape_dynamic
158+
159+ # Check node value itself
160+ if arg_positions_to_check is None and _is_subnode_dynamic (node ):
161+ return True
162+
163+ # Check node arguments individually
164+ if arg_positions_to_check is None and any (
165+ _is_subnode_dynamic (arg ) for arg in node .args if isinstance (arg , torch .fx .Node )
166+ ):
167+ return True
168+ # Check specific arg positions if the caller has specified positions to check
169+ elif arg_positions_to_check is not None and any (
170+ _is_subnode_dynamic (node .args [i ])
171+ for i in arg_positions_to_check
172+ if isinstance (node .args [i ], torch .fx .Node )
173+ ):
174+ return True
175+
176+ # Check node keyword arguments individually
177+ if arg_positions_to_check is None and any (
178+ _is_subnode_dynamic (kwarg )
179+ for kwarg in node .kwargs .values ()
180+ if isinstance (kwarg , torch .fx .Node )
181+ ):
182+ return True
183+
184+ return False
185+
186+
90187def dynamo_tensorrt_converter (
91188 key : Target ,
189+ * ,
92190 enabled : bool = True ,
93191 capability_validator : Optional [Callable [[Node ], bool ]] = None ,
94192 priority : ConverterPriority = ConverterPriority .STANDARD ,
193+ supports_dynamic_shapes : bool = False ,
95194) -> Callable [[ConverterImplSignature ], ConverterImplSignature ]:
96195 """Decorator for Dynamo TensorRT Converter
97196
@@ -117,14 +216,18 @@ def register_converter(converter: ConverterImplSignature) -> ConverterImplSignat
117216
118217 # If no capability_validator function is specified, use the default function - always return true
119218 if capability_validator is None :
120- converter_support = ConverterSupport (converter_implementation = converter )
219+ converter_support = ConverterSupport (
220+ converter_implementation = converter ,
221+ supports_dynamic_shapes = supports_dynamic_shapes ,
222+ )
121223 else :
122224 assert callable (
123225 capability_validator
124226 ), "Argument checking function must be callable"
125227 converter_support = ConverterSupport (
126228 converter_implementation = converter ,
127229 capability_validator = capability_validator ,
230+ supports_dynamic_shapes = supports_dynamic_shapes ,
128231 )
129232
130233 # OpOverloadPackets are only valid if they have a single overload, or
@@ -194,6 +297,7 @@ def __init__(
194297 ],
195298 registry_names : Optional [Sequence [str ]] = None ,
196299 registry_calling_conventions : Optional [Sequence [CallingConvention ]] = None ,
300+ assume_dynamic_shape_support : bool = False ,
197301 ):
198302 # Copy reference to each dictionary object into attribute list
199303 self .registries = list (registries )
@@ -215,9 +319,12 @@ def __init__(
215319 ]
216320
217321 self .disallowed_targets : Collection [Target ] = set ()
218-
322+ self . assume_dynamic_shape_support = assume_dynamic_shape_support
219323 self .validate_invariants ()
220324
325+ def set_dynamic_shape_support (self , assume_dynamic_shape_support : bool ) -> None :
326+ self .assume_dynamic_shape_support = assume_dynamic_shape_support
327+
221328 def set_disallowed_targets (self , torch_executed_ops : Collection [Target ]) -> None :
222329 self .disallowed_targets = torch_executed_ops
223330
@@ -324,13 +431,24 @@ def __getitem__(
324431
325432 if isinstance (converters , (list , tuple )):
326433 for candidate in converters :
327- if candidate .capability_validator (node ):
434+ # We enable the converter under 4 conditions
435+ # 1) capability validator is True
436+ # 2) Assume dynamic_shape support is True
437+ # 3) Node only has static shaped inputs
438+ # 4) Node has dynamic inputs and the converter has supports_dynamic_shapes=True
439+ if candidate .capability_validator (node ) and (
440+ self .assume_dynamic_shape_support
441+ or not has_dynamic_shapes (node )
442+ or candidate .supports_dynamic_shapes
443+ ):
328444 return (
329445 candidate .converter_implementation ,
330446 calling_convention ,
331447 )
332448 else :
333- return converters , calling_convention
449+ # Assuming FX converters don't have dynamic shapes supported
450+ if not has_dynamic_shapes (node ):
451+ return converters , calling_convention
334452
335453 raise KeyError (
336454 f"None of the converter registries have a validated entry for { key } , with node { node } "
0 commit comments