-
Notifications
You must be signed in to change notification settings - Fork 607
Warn that SAC + Compile for MoE models is not yet supported #2052
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
stack-info: PR: #2052, branch: xmfan/stack/4
|
|
||
| if ac_config.mode == "selective": | ||
| logger.warning( | ||
| "Selective Activation Checkpointing is not yet supported for MoE models, " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a little bit confusing, SAC works with eager for MoE models
stack-info: PR: #2052, branch: xmfan/stack/4
wwwjn
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for making this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry, didn't follow -- what's the issue between compile + SAC + MoE?
CheckpointWrapper is being applied to all submodules in SAC, but only at the block-level for Full AC. That breaks the logic of apply_compile ever since #1895.
What's the problem with full AC at block level? is it because we have full AC (compile)?
Also could you help make a central list on the composability issues among AC, compile, MoE?
I realized that
pytorch/pytorch#167844 fixes SAC around torch.compile region
| "Compile + Selective Activation Checkpointing is not yet supported for MoE models, " | ||
| "please use Full Activation Checkpointing instead. Turning off Compile." | ||
| ) | ||
| return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we just error out?
SAC will wrap each submodule of TransformerBlock separately (_apply_op_sac_to_transformer_block_with_flex), which will make each submodule of TransformerBlock an instance of CheckpointWrapper. This will make the So #1895 only works with Full AC, not SAC. AC(compile(moe)) works, but SAC(compile(moe)) doesn't work. |
So everything should be fixed now, we just need to remove the hack in _apply_op_sac_to_transformer_block_with_flex and test |
So there are two cases here, depending on whether you care that compiling makes your graph opaque. The fix there primarily addresses one of the cases. |
|
To check my understanding:
So if only FlexAttn is compiled (not each transformer layers / or submodule of transformer layers), SAC works.
Say if we compile each transformer layers, do you mean we can only save / recompute all the ops within the transformer layer, can not specify which ops to save in SAC region? |
Is this full AC behavior? Or do you mean something else? Seems I was aware of this behavior before. |
|
@wwwjn @tianyu-l yeah I think your understanding is correct - either save all activations need for backward computed within the compiled region or recompute all ops, just like full AC.
Yes, but existing policy needs to be updated to handle the inductor HOP. |
Stacked PRs:
Warn that SAC + Compile for MoE models is not yet supported. Behavior should be identical for moe blocks, dense blocks are no longer compiled.
This also fixes another issue: CheckpointWrapper is being applied to all submodules in SAC, but only at the block-level for Full AC. That breaks the logic of
apply_compileever since #1895.