-
Notifications
You must be signed in to change notification settings - Fork 609
Description
In short, does PP allow multiple-args input and multiple-args output?
——
Hey, we’ve been stuck for a while on how to properly integrate aux loss for MoE training with PP and compile(full_graph).
For context, both DeepSeek V3 and GLM 4.5 mention that
“We also applied an auxiliary sequence-level balance loss with a 0.0001 weight to avoid extreme imbalance within any single sequence.”
(We could open a PR for the sequence-level balance loss if you’re interested.)
To make this work, we need to compute the extra loss at each block, either by:
-
Caching the per-layer aux_loss loss (which breaks compile, but not PP), or
-
Passing both activations and aux_loss to the next PP stage (which doesn’t affect compile).
The second option basically requires the PP API to support multiple-args input and output. We tried earlier this year to explicitly pass arguments when building PP stages, but it didn’t work. I’m wondering if there have been any updates since then, or if we might have missed something.
Do you have any other suggestions or better solutions? @tianyu-l @H-Huang
CC: @janEbert @garrett361