Skip to content

question of PP x aux_loss for MoE #1979

@rakkit

Description

@rakkit

In short, does PP allow multiple-args input and multiple-args output?

——

Hey, we’ve been stuck for a while on how to properly integrate aux loss for MoE training with PP and compile(full_graph).

For context, both DeepSeek V3 and GLM 4.5 mention that

“We also applied an auxiliary sequence-level balance loss with a 0.0001 weight to avoid extreme imbalance within any single sequence.”

(We could open a PR for the sequence-level balance loss if you’re interested.)

To make this work, we need to compute the extra loss at each block, either by:

  • Caching the per-layer aux_loss loss (which breaks compile, but not PP), or

  • Passing both activations and aux_loss to the next PP stage (which doesn’t affect compile).

The second option basically requires the PP API to support multiple-args input and output. We tried earlier this year to explicitly pass arguments when building PP stages, but it didn’t work. I’m wondering if there have been any updates since then, or if we might have missed something.

Do you have any other suggestions or better solutions? @tianyu-l @H-Huang
CC: @janEbert @garrett361

Metadata

Metadata

Assignees

Labels

No labels
No labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions