- 
          
 - 
                Notifications
    
You must be signed in to change notification settings  - Fork 11k
 
[torch.compile] CUDAGraph Inductor partition integration #24281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 22 commits
4f6e1b4
              1c1b600
              50d1dda
              7218e2b
              71209e2
              202b6f3
              0b1e18a
              b66568b
              87c74dd
              c0bd3fb
              e16e23a
              892ab46
              eabb1b6
              04e9801
              6cf5bd5
              70f45da
              7eb5d57
              3a6abd8
              4cce30c
              d3809fb
              d7a73db
              289a60e
              29ae5f0
              b5972fa
              7570f4b
              c7ff7c4
              4a38b36
              d4269d9
              20b9ef1
              e055458
              45b7588
              91c03a4
              19787d3
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| 
          
            
          
           | 
    @@ -326,6 +326,40 @@ def call_module(self, target: torch.fx.node.Target, | |
| i for i, x in enumerate(args) if isinstance(x, torch.SymInt) | ||
| ] | ||
| global compilation_start_time | ||
| 
     | 
||
| if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE | ||
| and self.compilation_config.use_inductor_graph_partition): | ||
| # If we're using Inductor-based graph partitioning, we currently | ||
| # have the whole `fx.Graph` before Inductor lowering and | ||
| # and the piecewise splitting happens after all graph | ||
| # passes and fusions. Here, we add a custom hook for Inductor | ||
| # to wrap each partition with our static graph wrapper class to | ||
| # maintain more control over static graph capture and replay. | ||
| 
     | 
||
| from torch._inductor.utils import CUDAGraphWrapperMetadata | ||
| 
     | 
||
| from .cuda_graph import CUDAGraphOptions | ||
| 
     | 
||
| static_graph_wrapper_class = resolve_obj_by_qualname( | ||
| current_platform.get_static_graph_wrapper_cls()) | ||
| 
     | 
||
| def customized_cudagraph_wrapper( | ||
| f, metadata: CUDAGraphWrapperMetadata): | ||
| partition_id = metadata.partition_index | ||
| num_partitions = metadata.num_partitions | ||
| return static_graph_wrapper_class( | ||
| runnable=f, | ||
| vllm_config=self.vllm_config, | ||
| runtime_mode=CUDAGraphMode.PIECEWISE, | ||
| cudagraph_options=CUDAGraphOptions( | ||
| debug_log_enable=partition_id == 0, | ||
| gc_disable=partition_id != 0, | ||
| weak_ref_output=partition_id == num_partitions - 1, | ||
| )) | ||
| 
     | 
||
| torch._inductor.utils.set_customized_partition_wrappers( | ||
                
       | 
||
| customized_cudagraph_wrapper) | ||
| 
     | 
||
| compiled_graph_for_dynamic_shape = self.vllm_backend.\ | ||
| compiler_manager.compile( | ||
| submod, | ||
| 
        
          
        
         | 
    @@ -336,15 +370,20 @@ def call_module(self, target: torch.fx.node.Target, | |
| num_graphs=len(self.compile_submod_names), | ||
| runtime_shape=None) | ||
| # Lazy import here to avoid circular import | ||
| from .cuda_graph import CUDAGraphOptions | ||
| from .cuda_piecewise_backend import PiecewiseBackend | ||
| 
     | 
||
| piecewise_backend = PiecewiseBackend( | ||
| submod, self.vllm_config, index, | ||
| len(self.compile_submod_names), sym_shape_indices, | ||
| compiled_graph_for_dynamic_shape, self.vllm_backend) | ||
| 
     | 
||
| if self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: | ||
| if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE | ||
| and | ||
| not self.compilation_config.use_inductor_graph_partition): | ||
| # We're using Dynamo-based piecewise splitting, so we wrap | ||
| # the whole subgraph with a static graph wrapper. | ||
| from .cuda_graph import CUDAGraphOptions | ||
                
      
                  BoyuanFeng marked this conversation as resolved.
               
          
            Show resolved
            Hide resolved
         | 
||
| 
     | 
||
| # resolve the static graph wrapper class (e.g. CUDAGraphWrapper | ||
| # class) as platform dependent. | ||
| static_graph_wrapper_class = resolve_obj_by_qualname( | ||
| 
          
            
          
           | 
    ||
Uh oh!
There was an error while loading. Please reload this page.