Skip to content

Conversation

@xmfan
Copy link
Member

@xmfan xmfan commented Nov 18, 2025

Stacked PRs:


Warn that SAC + Compile for MoE models is not yet supported. Behavior should be identical for moe blocks, dense blocks are no longer compiled.

image

This also fixes another issue: CheckpointWrapper is being applied to all submodules in SAC, but only at the block-level for Full AC. That breaks the logic of apply_compile ever since #1895.

xmfan added a commit that referenced this pull request Nov 18, 2025
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 18, 2025

if ac_config.mode == "selective":
logger.warning(
"Selective Activation Checkpointing is not yet supported for MoE models, "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a little bit confusing, SAC works with eager for MoE models

Copy link
Contributor

@wwwjn wwwjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for making this!

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, didn't follow -- what's the issue between compile + SAC + MoE?

CheckpointWrapper is being applied to all submodules in SAC, but only at the block-level for Full AC. That breaks the logic of apply_compile ever since #1895.

What's the problem with full AC at block level? is it because we have full AC (compile)?

Also could you help make a central list on the composability issues among AC, compile, MoE?
I realized that

pytorch/pytorch#167844 fixes SAC around torch.compile region

"Compile + Selective Activation Checkpointing is not yet supported for MoE models, "
"please use Full Activation Checkpointing instead. Turning off Compile."
)
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just error out?

@wwwjn
Copy link
Contributor

wwwjn commented Nov 18, 2025

what's the issue between compile + SAC + MoE?

SAC will wrap each submodule of TransformerBlock separately (_apply_op_sac_to_transformer_block_with_flex), which will make each submodule of TransformerBlock an instance of CheckpointWrapper.

This will make the isinstance() check fail and fall back to else branch, causing a compile error.

So #1895 only works with Full AC, not SAC. AC(compile(moe)) works, but SAC(compile(moe)) doesn't work.

@tianyu-l
Copy link
Contributor

@wwwjn
According to @ezyang

pytorch/pytorch#167844 fixes SAC around torch.compile region

So everything should be fixed now, we just need to remove the hack in _apply_op_sac_to_transformer_block_with_flex and test

@soulitzer
Copy link
Contributor

fixes SAC around torch.compile region

So there are two cases here, depending on whether you care that compiling makes your graph opaque. The fix there primarily addresses one of the cases.
If you're only compiling a single op like FlexAttention, it is fine to not be able to see into the graph.
But for larger graphs, SAC(compile(fn will work, but it might not do exactly what you want. You'll only be able to save/recompute at the granularity of that whole graph.

@wwwjn
Copy link
Contributor

wwwjn commented Nov 18, 2025

To check my understanding:

If you're only compiling a single op like FlexAttention, it is fine to not be able to see into the graph.

So if only FlexAttn is compiled (not each transformer layers / or submodule of transformer layers), SAC works.

But for larger graphs, SAC(compile(fn will work, but it might not do exactly what you want. You'll only be able to save/recompute at the granularity of that whole graph.

Say if we compile each transformer layers, do you mean we can only save / recompute all the ops within the transformer layer, can not specify which ops to save in SAC region?

@tianyu-l
Copy link
Contributor

@soulitzer

But for larger graphs, SAC(compile(fn will work, but it might not do exactly what you want. You'll only be able to save/recompute at the granularity of that whole graph.

Is this full AC behavior? Or do you mean something else? Seems I was aware of this behavior before.

@soulitzer
Copy link
Contributor

@wwwjn @tianyu-l yeah I think your understanding is correct - either save all activations need for backward computed within the compiled region or recompute all ops, just like full AC.

So if only FlexAttn is compiled (not each transformer layers / or submodule of transformer layers), SAC works.

Yes, but existing policy needs to be updated to handle the inductor HOP.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants