Skip to content

Commit 177b050

Browse files
authored
config compile backend (#1768)
Stacked PRs: * __->__#1768 --- --- --- ### config compile backend
1 parent fa21894 commit 177b050

File tree

10 files changed

+23
-14
lines changed

10 files changed

+23
-14
lines changed

torchtitan/components/loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def build_cross_entropy_loss(job_config: JobConfig):
2727
loss_fn = cross_entropy_loss
2828
if job_config.compile.enable and "loss" in job_config.compile.components:
2929
logger.info("Compiling the loss function with torch.compile")
30-
loss_fn = torch.compile(loss_fn)
30+
loss_fn = torch.compile(loss_fn, backend=job_config.compile.backend)
3131
return loss_fn
3232

3333

torchtitan/config/job_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,7 @@ class Compile:
566566
default_factory=lambda: ["model", "loss"]
567567
)
568568
"""Which components to compile"""
569+
backend: str = "inductor"
569570

570571

571572
@dataclass

torchtitan/experiments/flux/loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,5 @@ def build_mse_loss(job_config: JobConfig):
2323
loss_fn = mse_loss
2424
if job_config.compile.enable and "loss" in job_config.compile.components:
2525
logger.info("Compiling the loss function with torch.compile")
26-
loss_fn = torch.compile(loss_fn)
26+
loss_fn = torch.compile(loss_fn, backend=job_config.compile.backend)
2727
return loss_fn

torchtitan/experiments/llama4/infra/parallelize.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
SequenceParallel,
2020
)
2121
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
22+
from torchtitan.config.job_config import Compile as CompileConfig
2223
from torchtitan.distributed import NoParallel, ParallelDims
2324
from torchtitan.distributed.activation_checkpoint import apply_ac
2425
from torchtitan.distributed.expert_parallel import (
@@ -123,7 +124,7 @@ def parallelize_llama(
123124

124125
# turn on per-TransformerBlock compile after AC wrapping and before FSDP
125126
if model_compile_enabled:
126-
apply_compile(model)
127+
apply_compile(model, job_config.compile)
127128

128129
dp_mesh: DeviceMesh | None = None
129130
if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled:
@@ -502,7 +503,7 @@ def apply_moe_ep_tp(
502503
)
503504

504505

505-
def apply_compile(model: nn.Module):
506+
def apply_compile(model: nn.Module, compile_config: CompileConfig):
506507
"""
507508
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
508509
repeated structure. Alternatively one can compile the whole model (after applying DP).
@@ -515,7 +516,11 @@ def apply_compile(model: nn.Module):
515516
fullgraph = True
516517
if transformer_block.moe_enabled:
517518
fullgraph = False
518-
transformer_block = torch.compile(transformer_block, fullgraph=fullgraph)
519+
transformer_block = torch.compile(
520+
transformer_block,
521+
backend=compile_config.backend,
522+
fullgraph=fullgraph,
523+
)
519524
model.layers.register_module(layer_id, transformer_block)
520525

521526
logger.info("Compiling each TransformerBlock with torch.compile")

torchtitan/experiments/qwen3/infra/parallelize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def parallelize_qwen3(
118118

119119
# turn on per-TransformerBlock compile after AC wrapping and before FSDP
120120
if model_compile_enabled:
121-
apply_compile(model)
121+
apply_compile(model, job_config.compile)
122122

123123
if parallel_dims.fsdp_enabled:
124124
# apply FSDP or HSDP, potentially with Context Parallel

torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,6 @@ def parallelize_deepseekv3(
153153
if job_config.compile.enable:
154154
torch._inductor.config.reorder_for_peak_memory = False
155155
torch._dynamo.config.capture_scalar_outputs = True
156-
model = torch.compile(model, fullgraph=True)
156+
model = torch.compile(model, backend=job_config.compile.backend, fullgraph=True)
157157

158158
return model

torchtitan/experiments/simple_fsdp/llama3/parallelize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,6 @@ def parallelize_llama(
122122

123123
if job_config.compile.enable and "model" in job_config.compile.components:
124124
torch._inductor.config.reorder_for_peak_memory = False
125-
model = torch.compile(model, fullgraph=True)
125+
model = torch.compile(model, backend=job_config.compile.backend, fullgraph=True)
126126

127127
return model

torchtitan/experiments/vlm/infra/parallelize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def parallelize_vlm(
7070

7171
# turn on per-TransformerBlock compile after AC wrapping and before FSDP
7272
if job_config.compile.enable:
73-
apply_compile(model)
74-
apply_compile(model.encoder)
73+
apply_compile(model, job_config.compile)
74+
apply_compile(model.encoder, job_config.compile)
7575

7676
if parallel_dims.fsdp_enabled:
7777
# apply FSDP or HSDP, potentially with Context Parallel

torchtitan/models/deepseek_v3/infra/parallelize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def parallelize_deepseekv3(
116116
)
117117

118118
if model_compile_enabled:
119-
apply_compile(model)
119+
apply_compile(model, job_config.compile)
120120

121121
dp_mesh: DeviceMesh | None = None
122122
if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled:

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
)
2424

2525
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
26+
from torchtitan.config.job_config import Compile as CompileConfig
2627
from torchtitan.distributed import ParallelDims
2728
from torchtitan.distributed.activation_checkpoint import apply_ac
2829
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
@@ -105,7 +106,7 @@ def parallelize_llama(
105106

106107
# turn on per-TransformerBlock compile after AC wrapping and before FSDP
107108
if model_compile_enabled:
108-
apply_compile(model)
109+
apply_compile(model, job_config.compile)
109110

110111
if parallel_dims.fsdp_enabled:
111112
# apply FSDP or HSDP, potentially with Context Parallel
@@ -234,13 +235,15 @@ def apply_tp(
234235
)
235236

236237

237-
def apply_compile(model: nn.Module):
238+
def apply_compile(model: nn.Module, compile_config: CompileConfig):
238239
"""
239240
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
240241
repeated structure. Alternatively one can compile the whole model (after applying DP).
241242
"""
242243
for layer_id, transformer_block in model.layers.named_children():
243-
transformer_block = torch.compile(transformer_block, fullgraph=True)
244+
transformer_block = torch.compile(
245+
transformer_block, backend=compile_config.backend, fullgraph=True
246+
)
244247
model.layers.register_module(layer_id, transformer_block)
245248

246249
logger.info("Compiling each TransformerBlock with torch.compile")

0 commit comments

Comments
 (0)