1919 SequenceParallel ,
2020)
2121from torchtitan .config import JobConfig , TORCH_DTYPE_MAP
22+ from torchtitan .config .job_config import Compile as CompileConfig
2223from torchtitan .distributed import NoParallel , ParallelDims
2324from torchtitan .distributed .activation_checkpoint import apply_ac
2425from torchtitan .distributed .expert_parallel import (
@@ -123,7 +124,7 @@ def parallelize_llama(
123124
124125 # turn on per-TransformerBlock compile after AC wrapping and before FSDP
125126 if model_compile_enabled :
126- apply_compile (model )
127+ apply_compile (model , job_config . compile )
127128
128129 dp_mesh : DeviceMesh | None = None
129130 if parallel_dims .fsdp_enabled or parallel_dims .ep_enabled :
@@ -502,7 +503,7 @@ def apply_moe_ep_tp(
502503 )
503504
504505
505- def apply_compile (model : nn .Module ):
506+ def apply_compile (model : nn .Module , compile_config : CompileConfig ):
506507 """
507508 Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
508509 repeated structure. Alternatively one can compile the whole model (after applying DP).
@@ -515,7 +516,11 @@ def apply_compile(model: nn.Module):
515516 fullgraph = True
516517 if transformer_block .moe_enabled :
517518 fullgraph = False
518- transformer_block = torch .compile (transformer_block , fullgraph = fullgraph )
519+ transformer_block = torch .compile (
520+ transformer_block ,
521+ backend = compile_config .backend ,
522+ fullgraph = fullgraph ,
523+ )
519524 model .layers .register_module (layer_id , transformer_block )
520525
521526 logger .info ("Compiling each TransformerBlock with torch.compile" )
0 commit comments