11import logging
2- from typing import Any , Union
2+ from typing import Any , Optional , Union
33
44import torch
55import torch_tensorrt
@@ -68,6 +68,7 @@ def __init__(self, compiled_module: torch.nn.Module) -> None:
6868 global _PY_RT_CUDAGRAPHS
6969 self .old_mode = _PY_RT_CUDAGRAPHS
7070 self .compiled_module = compiled_module
71+ self .cudagraphs_module : Optional [CudaGraphsTorchTensorRTModule ] = None
7172
7273 def __enter__ (self ) -> torch .nn .Module :
7374 global _PY_RT_CUDAGRAPHS
@@ -98,7 +99,8 @@ def __enter__(self) -> torch.nn.Module:
9899 logger .debug (
99100 "Found pytorch subgraphs in module, wrapping module in CudaGraphsTorchTensorRTModule"
100101 )
101- return CudaGraphsTorchTensorRTModule (self .compiled_module )
102+ self .cudagraphs_module = CudaGraphsTorchTensorRTModule (self .compiled_module )
103+ return self .cudagraphs_module
102104 else :
103105 if num_trt_module > 0 :
104106 logger .debug ("No graph breaks detected, using runtime cudagraphs mode" )
@@ -113,6 +115,9 @@ def __enter__(self) -> torch.nn.Module:
113115 def __exit__ (self , * args : Any ) -> None :
114116 # Set cudagraphs back to old mode
115117 set_cudagraphs_mode (self .old_mode )
118+ # __del__ is not entirely predictable, so we reset cudagraph here
119+ if self .cudagraphs_module :
120+ self .cudagraphs_module ._reset_captured_graph ()
116121
117122
118123def enable_cudagraphs (
0 commit comments