@@ -196,12 +196,11 @@ def __init__(
196196 }
197197 self .arg_dynamic_shapes : Optional [tuple [Any ]] = None
198198 self .kwarg_dynamic_shapes : Optional [dict [Any , Any ]] = None
199- self .total_dynamic_shape : Optional [dict [Any , Any ]] = None
200199
201200 self .settings = CompilationSettings (** compilation_options )
202201 self .run_info : Optional [tuple [Any , ...]] = None
203202 self .state_dict_metadata : dict [str , torch .Size ] = {}
204- self .store_state_dict_metadata ()
203+ self ._store_state_dict_metadata ()
205204
206205 cls = self .__class__
207206 self .__class__ = type (
@@ -211,11 +210,31 @@ def __init__(
211210 )
212211 self .init_finished = True
213212
214- def set_dynamic_shape_hint (
213+ def set_expected_dynamic_shape_range (
215214 self ,
216215 args_dynamic_shape : tuple [dict [Any , Any ]],
217216 kwargs_dynamic_shape : dict [str , Any ],
218217 ) -> None :
218+ """
219+ Set the dynamic shape range. The shape hint should EXACTLY follow arg_inputs and kwarg_inputs passed to the forward function
220+ and should not omit any entries. If the dynamic shape is not required for the input, an empty dictionary should be given
221+ as the shape hint for that input.
222+
223+ Example:
224+ def forward(a, b, c=0, d=0):
225+ pass
226+
227+ seq_len = torch.export.Dim("seq_len", min=1, max=10)
228+ args_dynamic_shape = ({0: seq_len}, {}) # b does not have a dynamic shape
229+ kwargs_dynamic_shape = {'c': {0, seq_len}, 'd': {}} # d does not have a dynamic shape
230+ # Later when you call the function
231+ forward(*(a, b), **{c:..., d:...})
232+
233+
234+ Arguments:
235+ args_dynamic_shape (tuple[dict[Any, Any]]): Dynamic shape hint for the arg_inputs,
236+ kwargs_dynamic_shape: (dict[str, Any]): Dynamic shape hint for the kwarg_inputs
237+ """
219238 assert isinstance (
220239 args_dynamic_shape , tuple
221240 ), "args dynamic shape has to be a tuple"
@@ -224,19 +243,31 @@ def set_dynamic_shape_hint(
224243 ), "args dynamic shape has to be a dictionary"
225244 self .kwarg_dynamic_shapes = kwargs_dynamic_shape
226245 self .arg_dynamic_shapes = args_dynamic_shape
227- self .total_dynamic_shape = self .kwarg_dynamic_shapes .copy ()
228- signature = list (
229- inspect .signature (self .original_model .forward ).parameters .keys ()
230- )
231- for i , arg in enumerate (self .arg_dynamic_shapes ):
232- self .total_dynamic_shape [signature [i ]] = arg
233- self .refit_state .set_state (RefitFlag .NEEDS_RECOMPILE )
234246
235247 # Clear cached inputs
236248 self .arg_inputs = tuple ()
237249 self .kwarg_inputs = {}
238250
239- def store_state_dict_metadata (self ) -> None :
251+ self .refit_state .set_state (RefitFlag .NEEDS_RECOMPILE )
252+
253+ def _get_total_dynamic_shapes (self ) -> dict [str , Any ] | None :
254+ if not self .arg_dynamic_shapes and not self .kwarg_dynamic_shapes :
255+ return None
256+ total_dynamic_shape = {}
257+ if self .arg_dynamic_shapes :
258+ signature = list (
259+ inspect .signature (self .original_model .forward ).parameters .keys ()
260+ )
261+ for i , arg in enumerate (self .arg_dynamic_shapes ):
262+ total_dynamic_shape [signature [i ]] = arg
263+
264+ if self .kwarg_dynamic_shapes :
265+ for kwargs , kwargs_dynamic_shape in self .kwarg_dynamic_shapes .items ():
266+ total_dynamic_shape [kwargs ] = kwargs_dynamic_shape
267+
268+ return total_dynamic_shape
269+
270+ def _store_state_dict_metadata (self ) -> None :
240271 for k , v in self .original_model .state_dict ().items ():
241272 self .state_dict_metadata [k ] = v .shape
242273
@@ -328,7 +359,7 @@ def compile(self) -> None:
328359 self .original_model ,
329360 self .arg_inputs ,
330361 kwargs = self .kwarg_inputs ,
331- dynamic_shapes = self .total_dynamic_shape ,
362+ dynamic_shapes = self ._get_total_dynamic_shapes () ,
332363 )
333364 self .gm = dynamo_compile (
334365 self .exp_program ,
@@ -340,40 +371,75 @@ def compile(self) -> None:
340371 torch .cuda .empty_cache ()
341372
342373 def _validate_inputs (self , * args : Any , ** kwargs : Any ) -> None :
374+
375+ if not self .arg_inputs :
376+ logger .info ("First time compilation initiated. This may take some time." )
377+ self .refit_state .set_state (RefitFlag .NEEDS_RECOMPILE )
378+ self ._store_inputs (args , kwargs )
379+ if self .arg_dynamic_shapes or self .kwarg_dynamic_shapes :
380+ if not self ._validates_dynamic_hints ():
381+ logger .warning (
382+ "Invalid dynamic shape hint. Compiling module for the provided input shapes (static)"
383+ )
384+ self .arg_dynamic_shapes = None
385+ self .kwarg_dynamic_shapes = None
386+ return
387+
388+ # If input does not equal or does not fall into dynamic shape range, recompile the engine
343389 try :
344- if (
345- not self .arg_inputs
346- or not MutableTorchTensorRTModule .check_inputs_equal (
347- self .arg_inputs , args , dynamic_shapes = self .arg_dynamic_shapes
348- )
349- or not MutableTorchTensorRTModule .check_inputs_equal (
350- self .kwarg_inputs , kwargs , dynamic_shapes = self .kwarg_dynamic_shapes
351- )
390+ if not MutableTorchTensorRTModule ._check_inputs_shape (
391+ self .arg_inputs , args , dynamic_shapes = self .arg_dynamic_shapes
392+ ) or not MutableTorchTensorRTModule ._check_inputs_shape (
393+ self .kwarg_inputs , kwargs , dynamic_shapes = self .kwarg_dynamic_shapes
352394 ):
353395 logger .info ("Input change detected." )
354396 self .refit_state .set_state (RefitFlag .NEEDS_RECOMPILE )
355- self .store_inputs (args , kwargs )
397+ self ._store_inputs (args , kwargs )
356398 except DynamicShapeOutOfRangeException as e :
357399 logger .info ("Input change detected." )
358400 logger .warning (e )
359- logger .warning ("Recompiling the engine with static shape" )
401+ logger .warning (
402+ "Provided inputs are outside the set expected shape range, recompiling module for the provided input shapes (static)"
403+ )
360404 self .arg_dynamic_shapes = None
361405 self .kwarg_dynamic_shapes = None
362- self .total_dynamic_shape = None
363406 self .refit_state .set_state (RefitFlag .NEEDS_RECOMPILE )
364- self .store_inputs (args , kwargs )
407+ self ._store_inputs (args , kwargs )
365408
366- def store_inputs (self , arg_inputs : Any , kwarg_inputs : Any ) -> None :
409+ def _validates_dynamic_hints (self ) -> bool :
410+ if self .arg_dynamic_shapes is None :
411+ if self .arg_inputs :
412+ logger .warning ("arg_dynamic_shape is not provided!" )
413+ else :
414+ if len (self .arg_dynamic_shapes ) != len (self .arg_inputs ):
415+ logger .warning (
416+ f"Warning: The length of arg_inputs is { len (self .arg_inputs )} but the length of arg_dynamic_shape is { len (self .arg_dynamic_shapes )} !"
417+ )
418+ return False
419+
420+ if self .kwarg_dynamic_shapes is None :
421+ if self .kwarg_inputs :
422+ logger .warning ("kwarg_dynamic_shape is not provided!" )
423+ else :
424+ if self .kwarg_dynamic_shapes .keys () != self .kwarg_inputs .keys ():
425+ logger .warning (
426+ f"kwarg_inputs has { list (self .kwarg_inputs .keys ())} but kwarg_dynamic_shape has { list (self .kwarg_dynamic_shapes .keys ())} !"
427+ )
428+ return False
429+
430+ return True
431+
432+ def _store_inputs (self , arg_inputs : Any , kwarg_inputs : Any ) -> None :
367433 self .arg_inputs = arg_inputs
368434 self .kwarg_inputs = kwarg_inputs
369435
370436 @staticmethod
371- def process_kwarg_inputs (inputs : Any ) -> Any :
437+ def _process_kwarg_inputs (inputs : Any ) -> Any :
372438 # Process kwarg inputs to be acceptable for Torch-TensorRT
373439 if isinstance (inputs , dict ):
374440 # None should be excluded. AOT compile also does not allow dynamic control flow, bool is also excluded.
375441 return {
376- k : MutableTorchTensorRTModule .process_kwarg_inputs (v )
442+ k : MutableTorchTensorRTModule ._process_kwarg_inputs (v )
377443 for k , v in inputs .items ()
378444 if (v is not None and not isinstance (v , bool ))
379445 }
@@ -384,7 +450,10 @@ def process_kwarg_inputs(inputs: Any) -> Any:
384450 elif isinstance (inputs , (list , tuple )):
385451 if None not in inputs :
386452 return type (inputs )(
387- [MutableTorchTensorRTModule .process_kwarg_inputs (v ) for v in inputs ]
453+ [
454+ MutableTorchTensorRTModule ._process_kwarg_inputs (v )
455+ for v in inputs
456+ ]
388457 )
389458
390459 raise ValueError (
@@ -394,7 +463,7 @@ def process_kwarg_inputs(inputs: Any) -> Any:
394463
395464 def forward (self , * args : Any , ** kwargs : Any ) -> Any :
396465 # Step 1: Check whether the input shape has changed
397- kwargs = MutableTorchTensorRTModule .process_kwarg_inputs (kwargs )
466+ kwargs = MutableTorchTensorRTModule ._process_kwarg_inputs (kwargs )
398467 self ._validate_inputs (* args , ** kwargs )
399468
400469 # Step 2: If the flag is unknown, it could be a recompile or refit.
@@ -406,7 +475,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
406475 if self .refit_state .get_state () == RefitFlag .NEEDS_RECOMPILE :
407476 logger .info ("(Re)Compiling the engine..." )
408477 self .compile ()
409- self .store_state_dict_metadata ()
478+ self ._store_state_dict_metadata ()
410479 self .refit_state .set_state (RefitFlag .LIVE )
411480
412481 elif self .refit_state .get_state () == RefitFlag .NEEDS_REFIT :
@@ -417,7 +486,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
417486 logger .error (e )
418487 logger .error ("Model refit failed. Recompiling the graph module." )
419488 self .compile ()
420- self .store_state_dict_metadata ()
489+ self ._store_state_dict_metadata ()
421490 self .refit_state .set_state (RefitFlag .LIVE )
422491
423492 result = self .gm (* args , ** kwargs )
@@ -427,7 +496,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
427496
428497 def to (self , device : str ) -> None :
429498 logger .warning ("Original PyTorch model is moved. CPU offload may failed." )
430- self .orignial_model .to (device )
499+ self .original_model .to (device )
431500
432501 def __deepcopy__ (self , memo : Any ) -> Any :
433502 cls = self .__class__
@@ -479,7 +548,7 @@ def __setattr__(self, name: str, value: Any) -> None:
479548 object .__setattr__ (self , name , value )
480549
481550 @staticmethod
482- def check_inputs_equal (
551+ def _check_inputs_shape (
483552 input1 : Any ,
484553 input2 : Any ,
485554 dynamic_shapes : Any = None ,
@@ -495,10 +564,13 @@ def check_inputs_equal(
495564 return False
496565 elif isinstance (a , torch .Tensor ) and a .shape != b .shape :
497566 if dynamic_shapes is None :
567+ logger .warning (
568+ "Dynamic shape is not properly set but the input shape is changed!"
569+ )
498570 return False
499571 else :
500572 tensor_dynamic_shape = dynamic_shapes [i ]
501- if not MutableTorchTensorRTModule .check_tensor_shapes_with_dynamic_shapes (
573+ if not MutableTorchTensorRTModule ._check_tensor_shapes_with_dynamic_shapes (
502574 a , b , tensor_dynamic_shape
503575 ):
504576 return False
@@ -513,28 +585,34 @@ def check_inputs_equal(
513585 return False
514586 elif isinstance (va , torch .Tensor ) and va .shape != vb .shape :
515587 if dynamic_shapes is None :
588+ logger .warning (
589+ "Dynamic shape is not properly set but the input shape is changed!"
590+ )
516591 return False
517592 else :
518593 tensor_dynamic_shape = dynamic_shapes [ka ]
519- if not MutableTorchTensorRTModule .check_tensor_shapes_with_dynamic_shapes (
594+ if not MutableTorchTensorRTModule ._check_tensor_shapes_with_dynamic_shapes (
520595 va , vb , tensor_dynamic_shape
521596 ):
522597 return False
523598 elif isinstance (
524599 va , (list , tuple , dict )
525- ) and not MutableTorchTensorRTModule .check_inputs_equal (
600+ ) and not MutableTorchTensorRTModule ._check_inputs_shape (
526601 va , vb , dynamic_shapes [ka ] if dynamic_shapes else None
527602 ):
528603 return False
529604 return True
530605
531606 @staticmethod
532- def check_tensor_shapes_with_dynamic_shapes (
607+ def _check_tensor_shapes_with_dynamic_shapes (
533608 t1 : torch .tensor , t2 : torch .tensor , dynamic_shape : dict [int , Any ]
534609 ) -> bool :
535610 for (i , axis_0 ), axis_1 in zip (enumerate (t1 .shape ), t2 .shape ):
536611 if axis_0 != axis_1 :
537612 if i not in dynamic_shape :
613+ logger .warning (
614+ "Dynamic shape does not include the axis on which input changes!"
615+ )
538616 return False
539617 dyn = dynamic_shape [i ]
540618 if axis_1 > dyn .max or axis_1 < dyn .min :
0 commit comments