1+ import inspect
12import logging
23from copy import deepcopy
34from enum import Enum , auto
@@ -41,6 +42,10 @@ def get_state(self) -> RefitFlag:
4142 return self ._state
4243
4344
45+ class DynamicShapeOutOfRangeException (Exception ):
46+ pass
47+
48+
4449class MutableTorchTensorRTModule (object ):
4550 """
4651 Initialize a MutableTorchTensorRTModule to seamlessly manipulate it like a regular PyTorch module.
@@ -65,7 +70,7 @@ def __init__(
6570 Union [torch .dtype , dtype ]
6671 ] = _defaults .ENABLED_PRECISIONS ,
6772 engine_capability : EngineCapability = _defaults .ENGINE_CAPABILITY ,
68- immutable_weights : bool = _defaults . IMMUTABLE_WEIGHTS ,
73+ immutable_weights : bool = False ,
6974 debug : bool = _defaults .DEBUG ,
7075 num_avg_timing_iters : int = _defaults .NUM_AVG_TIMING_ITERS ,
7176 workspace_size : int = _defaults .WORKSPACE_SIZE ,
@@ -189,6 +194,9 @@ def __init__(
189194 "hardware_compatible" : hardware_compatible ,
190195 "timing_cache_path" : timing_cache_path ,
191196 }
197+ self .arg_dynamic_shapes : Optional [tuple [Any ]] = None
198+ self .kwarg_dynamic_shapes : Optional [dict [Any , Any ]] = None
199+ self .total_dynamic_shape : Optional [dict [Any , Any ]] = None
192200
193201 self .settings = CompilationSettings (** compilation_options )
194202 self .run_info : Optional [tuple [Any , ...]] = None
@@ -203,6 +211,26 @@ def __init__(
203211 )
204212 self .init_finished = True
205213
214+ def set_dynamic_shape_hint (
215+ self ,
216+ args_dynamic_shape : tuple [dict [Any , Any ]],
217+ kwargs_dynamic_shape : dict [str , Any ],
218+ ) -> None :
219+ assert isinstance (
220+ args_dynamic_shape , tuple
221+ ), "args dynamic shape has to be a tuple"
222+ assert isinstance (
223+ kwargs_dynamic_shape , dict
224+ ), "args dynamic shape has to be a dictionary"
225+ self .kwarg_dynamic_shapes = kwargs_dynamic_shape
226+ self .arg_dynamic_shapes = args_dynamic_shape
227+ self .total_dynamic_shape = self .kwarg_dynamic_shapes .copy ()
228+ signature = list (
229+ inspect .signature (self .original_model .forward ).parameters .keys ()
230+ )
231+ for i , arg in enumerate (self .arg_dynamic_shapes ):
232+ self .total_dynamic_shape [signature [i ]] = arg
233+
206234 def store_state_dict_metadata (self ) -> None :
207235 for k , v in self .original_model .state_dict ().items ():
208236 self .state_dict_metadata [k ] = v .shape
@@ -295,6 +323,7 @@ def compile(self) -> None:
295323 self .original_model ,
296324 self .arg_inputs ,
297325 kwargs = self .kwarg_inputs ,
326+ dynamic_shapes = self .total_dynamic_shape ,
298327 )
299328 self .gm = dynamo_compile (
300329 self .exp_program ,
@@ -306,14 +335,26 @@ def compile(self) -> None:
306335 torch .cuda .empty_cache ()
307336
308337 def _validate_inputs (self , * args : Any , ** kwargs : Any ) -> None :
309- if (
310- not self .arg_inputs
311- or not MutableTorchTensorRTModule .check_inputs_equal (self .arg_inputs , args )
312- or not MutableTorchTensorRTModule .check_inputs_equal (
313- self .kwarg_inputs , kwargs
314- )
315- ):
338+ try :
339+ if (
340+ not self .arg_inputs
341+ or not MutableTorchTensorRTModule .check_inputs_equal (
342+ self .arg_inputs , args , dynamic_shapes = self .arg_dynamic_shapes
343+ )
344+ or not MutableTorchTensorRTModule .check_inputs_equal (
345+ self .kwarg_inputs , kwargs , dynamic_shapes = self .kwarg_dynamic_shapes
346+ )
347+ ):
348+ logger .info ("Input change detected." )
349+ self .refit_state .set_state (RefitFlag .NEEDS_RECOMPILE )
350+ self .store_inputs (args , kwargs )
351+ except DynamicShapeOutOfRangeException as e :
316352 logger .info ("Input change detected." )
353+ logger .warning (e )
354+ logger .warning ("Recompiling the engine with static shape" )
355+ self .arg_dynamic_shapes = None
356+ self .kwarg_dynamic_shapes = None
357+ self .total_dynamic_shape = None
317358 self .refit_state .set_state (RefitFlag .NEEDS_RECOMPILE )
318359 self .store_inputs (args , kwargs )
319360
@@ -436,33 +477,66 @@ def __setattr__(self, name: str, value: Any) -> None:
436477 def check_inputs_equal (
437478 input1 : Any ,
438479 input2 : Any ,
480+ dynamic_shapes : Any = None ,
439481 ) -> bool :
440- # TODO: Add support for dynamic shape
482+
441483 if isinstance (input1 , (tuple , list )):
442484 if len (input1 ) != len (input2 ):
443485 return False
444- for a , b in zip (input1 , input2 ):
486+ for ( i , a ) , b in zip (enumerate ( input1 ) , input2 ):
445487 if type (a ) != type (b ):
446488 return False
447- if isinstance (a , torch .Tensor ) and a .shape != b .shape :
448- return False
449- elif isinstance (a , bool ) and a != b :
489+ if isinstance (a , bool ) and a != b :
450490 return False
491+ elif isinstance (a , torch .Tensor ) and a .shape != b .shape :
492+ if dynamic_shapes is None :
493+ return False
494+ else :
495+ tensor_dynamic_shape = dynamic_shapes [i ]
496+ if not MutableTorchTensorRTModule .check_tensor_shapes_with_dynamic_shapes (
497+ a , b , tensor_dynamic_shape
498+ ):
499+ return False
451500
452501 elif isinstance (input1 , dict ):
453502 if input1 .keys () != input2 .keys ():
454503 return False
455- for a , b in zip (input1 .values (), input2 .values ()):
456- if type (a ) != type (b ):
504+ for ( ka , va ), vb in zip (input1 .items (), input2 .values ()):
505+ if type (va ) != type (vb ):
457506 return False
458- if isinstance (a , torch .Tensor ) and a .shape != b .shape :
459- return False
460- elif isinstance (a , bool ) and a != b :
507+ if isinstance (va , bool ) and va != vb :
461508 return False
509+ elif isinstance (va , torch .Tensor ) and va .shape != vb .shape :
510+ if dynamic_shapes is None :
511+ return False
512+ else :
513+ tensor_dynamic_shape = dynamic_shapes [ka ]
514+ if not MutableTorchTensorRTModule .check_tensor_shapes_with_dynamic_shapes (
515+ va , vb , tensor_dynamic_shape
516+ ):
517+ return False
462518 elif isinstance (
463- a , (list , tuple , dict )
464- ) and not MutableTorchTensorRTModule .check_inputs_equal (a , b ):
519+ va , (list , tuple , dict )
520+ ) and not MutableTorchTensorRTModule .check_inputs_equal (
521+ va , vb , dynamic_shapes [ka ] if dynamic_shapes else None
522+ ):
523+ return False
524+ return True
525+
526+ @staticmethod
527+ def check_tensor_shapes_with_dynamic_shapes (
528+ t1 : torch .tensor , t2 : torch .tensor , dynamic_shape : dict [int , Any ]
529+ ) -> bool :
530+ for (i , axis_0 ), axis_1 in zip (enumerate (t1 .shape ), t2 .shape ):
531+ if axis_0 != axis_1 :
532+ if i not in dynamic_shape :
465533 return False
534+ dyn = dynamic_shape [i ]
535+ if axis_1 > dyn .max or axis_1 < dyn .min :
536+ raise DynamicShapeOutOfRangeException (
537+ f"The input size ({ axis_1 } ) of dimension ({ i } ) is not in dynamic shape range [{ dyn .max } , { dyn .max } ]!"
538+ )
539+
466540 return True
467541
468542 @staticmethod
0 commit comments