1+ import inspect
12import logging
23from copy import deepcopy
34from enum import Enum , auto
@@ -41,6 +42,10 @@ def get_state(self) -> RefitFlag:
4142 return self ._state
4243
4344
45+ class DynamicShapeOutOfRangeException (Exception ):
46+ pass
47+
48+
4449class MutableTorchTensorRTModule (object ):
4550 """
4651 Initialize a MutableTorchTensorRTModule to seamlessly manipulate it like a regular PyTorch module.
@@ -65,7 +70,7 @@ def __init__(
6570 Union [torch .dtype , dtype ]
6671 ] = _defaults .ENABLED_PRECISIONS ,
6772 engine_capability : EngineCapability = _defaults .ENGINE_CAPABILITY ,
68- immutable_weights : bool = _defaults . IMMUTABLE_WEIGHTS ,
73+ immutable_weights : bool = False ,
6974 debug : bool = _defaults .DEBUG ,
7075 num_avg_timing_iters : int = _defaults .NUM_AVG_TIMING_ITERS ,
7176 workspace_size : int = _defaults .WORKSPACE_SIZE ,
@@ -189,6 +194,9 @@ def __init__(
189194 "hardware_compatible" : hardware_compatible ,
190195 "timing_cache_path" : timing_cache_path ,
191196 }
197+ self .arg_dynamic_shapes : Optional [tuple [Any ]] = None
198+ self .kwarg_dynamic_shapes : Optional [dict [Any , Any ]] = None
199+ self .total_dynamic_shape : Optional [dict [Any , Any ]] = None
192200
193201 self .settings = CompilationSettings (** compilation_options )
194202 self .run_info : Optional [tuple [Any , ...]] = None
@@ -203,6 +211,27 @@ def __init__(
203211 )
204212 self .init_finished = True
205213
214+ def set_dynamic_shape_hint (
215+ self ,
216+ args_dynamic_shape : tuple [dict [Any , Any ]],
217+ kwargs_dynamic_shape : dict [str , Any ],
218+ ) -> None :
219+ assert isinstance (
220+ args_dynamic_shape , tuple
221+ ), "args dynamic shape has to be a tuple"
222+ assert isinstance (
223+ kwargs_dynamic_shape , dict
224+ ), "args dynamic shape has to be a dictionary"
225+ self .kwarg_dynamic_shapes = kwargs_dynamic_shape
226+ self .arg_dynamic_shapes = args_dynamic_shape
227+ self .total_dynamic_shape = self .kwarg_dynamic_shapes .copy ()
228+ signature = list (
229+ inspect .signature (self .original_model .forward ).parameters .keys ()
230+ )
231+ for i , arg in enumerate (self .arg_dynamic_shapes ):
232+ self .total_dynamic_shape [signature [i ]] = arg
233+ self .refit_state .set_state (RefitFlag .NEEDS_RECOMPILE )
234+
206235 def store_state_dict_metadata (self ) -> None :
207236 for k , v in self .original_model .state_dict ().items ():
208237 self .state_dict_metadata [k ] = v .shape
@@ -295,6 +324,7 @@ def compile(self) -> None:
295324 self .original_model ,
296325 self .arg_inputs ,
297326 kwargs = self .kwarg_inputs ,
327+ dynamic_shapes = self .total_dynamic_shape ,
298328 )
299329 self .gm = dynamo_compile (
300330 self .exp_program ,
@@ -306,14 +336,26 @@ def compile(self) -> None:
306336 torch .cuda .empty_cache ()
307337
308338 def _validate_inputs (self , * args : Any , ** kwargs : Any ) -> None :
309- if (
310- not self .arg_inputs
311- or not MutableTorchTensorRTModule .check_inputs_equal (self .arg_inputs , args )
312- or not MutableTorchTensorRTModule .check_inputs_equal (
313- self .kwarg_inputs , kwargs
314- )
315- ):
339+ try :
340+ if (
341+ not self .arg_inputs
342+ or not MutableTorchTensorRTModule .check_inputs_equal (
343+ self .arg_inputs , args , dynamic_shapes = self .arg_dynamic_shapes
344+ )
345+ or not MutableTorchTensorRTModule .check_inputs_equal (
346+ self .kwarg_inputs , kwargs , dynamic_shapes = self .kwarg_dynamic_shapes
347+ )
348+ ):
349+ logger .info ("Input change detected." )
350+ self .refit_state .set_state (RefitFlag .NEEDS_RECOMPILE )
351+ self .store_inputs (args , kwargs )
352+ except DynamicShapeOutOfRangeException as e :
316353 logger .info ("Input change detected." )
354+ logger .warning (e )
355+ logger .warning ("Recompiling the engine with static shape" )
356+ self .arg_dynamic_shapes = None
357+ self .kwarg_dynamic_shapes = None
358+ self .total_dynamic_shape = None
317359 self .refit_state .set_state (RefitFlag .NEEDS_RECOMPILE )
318360 self .store_inputs (args , kwargs )
319361
@@ -436,33 +478,66 @@ def __setattr__(self, name: str, value: Any) -> None:
436478 def check_inputs_equal (
437479 input1 : Any ,
438480 input2 : Any ,
481+ dynamic_shapes : Any = None ,
439482 ) -> bool :
440- # TODO: Add support for dynamic shape
483+
441484 if isinstance (input1 , (tuple , list )):
442485 if len (input1 ) != len (input2 ):
443486 return False
444- for a , b in zip (input1 , input2 ):
487+ for ( i , a ) , b in zip (enumerate ( input1 ) , input2 ):
445488 if type (a ) != type (b ):
446489 return False
447- if isinstance (a , torch .Tensor ) and a .shape != b .shape :
448- return False
449- elif isinstance (a , bool ) and a != b :
490+ if isinstance (a , bool ) and a != b :
450491 return False
492+ elif isinstance (a , torch .Tensor ) and a .shape != b .shape :
493+ if dynamic_shapes is None :
494+ return False
495+ else :
496+ tensor_dynamic_shape = dynamic_shapes [i ]
497+ if not MutableTorchTensorRTModule .check_tensor_shapes_with_dynamic_shapes (
498+ a , b , tensor_dynamic_shape
499+ ):
500+ return False
451501
452502 elif isinstance (input1 , dict ):
453503 if input1 .keys () != input2 .keys ():
454504 return False
455- for a , b in zip (input1 .values (), input2 .values ()):
456- if type (a ) != type (b ):
505+ for ( ka , va ), vb in zip (input1 .items (), input2 .values ()):
506+ if type (va ) != type (vb ):
457507 return False
458- if isinstance (a , torch .Tensor ) and a .shape != b .shape :
459- return False
460- elif isinstance (a , bool ) and a != b :
508+ if isinstance (va , bool ) and va != vb :
461509 return False
510+ elif isinstance (va , torch .Tensor ) and va .shape != vb .shape :
511+ if dynamic_shapes is None :
512+ return False
513+ else :
514+ tensor_dynamic_shape = dynamic_shapes [ka ]
515+ if not MutableTorchTensorRTModule .check_tensor_shapes_with_dynamic_shapes (
516+ va , vb , tensor_dynamic_shape
517+ ):
518+ return False
462519 elif isinstance (
463- a , (list , tuple , dict )
464- ) and not MutableTorchTensorRTModule .check_inputs_equal (a , b ):
520+ va , (list , tuple , dict )
521+ ) and not MutableTorchTensorRTModule .check_inputs_equal (
522+ va , vb , dynamic_shapes [ka ] if dynamic_shapes else None
523+ ):
524+ return False
525+ return True
526+
527+ @staticmethod
528+ def check_tensor_shapes_with_dynamic_shapes (
529+ t1 : torch .tensor , t2 : torch .tensor , dynamic_shape : dict [int , Any ]
530+ ) -> bool :
531+ for (i , axis_0 ), axis_1 in zip (enumerate (t1 .shape ), t2 .shape ):
532+ if axis_0 != axis_1 :
533+ if i not in dynamic_shape :
465534 return False
535+ dyn = dynamic_shape [i ]
536+ if axis_1 > dyn .max or axis_1 < dyn .min :
537+ raise DynamicShapeOutOfRangeException (
538+ f"The input size ({ axis_1 } ) of dimension ({ i } ) is not in dynamic shape range [{ dyn .max } , { dyn .max } ]!"
539+ )
540+
466541 return True
467542
468543 @staticmethod
0 commit comments