diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 0793820ffd..8c072914a8 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -118,7 +118,7 @@ def parallelize_deepseekv3( ) if model_compile_enabled: - apply_compile(model, job_config.compile) + apply_compile(model, job_config.compile, job_config.activation_checkpoint) dp_mesh: DeviceMesh | None = None if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 9911ecdfd0..7f6eae2e7d 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -21,7 +21,10 @@ SequenceParallel, ) from torchtitan.config import JobConfig, TORCH_DTYPE_MAP -from torchtitan.config.job_config import Compile as CompileConfig +from torchtitan.config.job_config import ( + ActivationCheckpoint as ACConfig, + Compile as CompileConfig, +) from torchtitan.distributed import NoParallel, ParallelDims from torchtitan.distributed.activation_checkpoint import apply_ac @@ -129,7 +132,7 @@ def parallelize_llama( # turn on per-TransformerBlock compile after AC wrapping and before FSDP if model_compile_enabled: - apply_compile(model, job_config.compile) + apply_compile(model, job_config.compile, job_config.activation_checkpoint) dp_mesh: DeviceMesh | None = None if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: @@ -506,11 +509,19 @@ def apply_moe_ep_tp( ) -def apply_compile(model: nn.Module, compile_config: CompileConfig): +def apply_compile(model: nn.Module, compile_config: CompileConfig, ac_config: ACConfig): """ Apply torch.compile to each TransformerBlock, which makes compilation efficient due to repeated structure. Alternatively one can compile the whole model (after applying DP). """ + + if ac_config.mode == "selective": + logger.warning( + "Compile + Selective Activation Checkpointing is not yet supported for MoE models, " + "please use Full Activation Checkpointing instead. Turning off Compile." + ) + return + # NOTE: This flag is needed for torch.compile to avoid graph breaking on dynamic shapes in token-choice MoE # but it is experimental. torch._dynamo.config.capture_scalar_outputs = True diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 6b8dc3d5a6..de22a1ce70 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -119,7 +119,7 @@ def parallelize_qwen3( # turn on per-TransformerBlock compile after AC wrapping and before FSDP if model_compile_enabled: - apply_compile(model, job_config.compile) + apply_compile(model, job_config.compile, job_config.activation_checkpoint) if parallel_dims.fsdp_enabled: # apply FSDP or HSDP, potentially with Context Parallel