Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions torchtitan/components/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,21 @@ def build_mse_loss(job_config: JobConfig, **kwargs):
logger.info("Compiling the loss function with torch.compile")
loss_fn = torch.compile(loss_fn, backend=job_config.compile.backend)
return loss_fn


def moe_loss(
pred: tuple[torch.Tensor, torch.Tensor] | torch.Tensor,
labels: torch.Tensor,
loss_fn: LossFunction,
) -> torch.Tensor:
"""Sequence-wise auxiliary load balance loss function for MoE
model training.
"""
if isinstance(pred, tuple):
pred, load_balance_loss = pred
loss = loss_fn(pred, labels)
# USE STE to make the magnitude of loss remain the same
loss = loss + (load_balance_loss - load_balance_loss.detach())
else:
loss = loss_fn(pred, labels)
return loss
15 changes: 15 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,18 @@ class Metrics:
"""Whether to log metrics to Weights & Biases"""


@dataclass
class ExtraLosses:
load_balance_loss_type: Literal["sequence_wise", "batch_wise"] = "sequence_wise"
"""Type of load balance loss to use"""

load_balance_loss_weight: float = 0
"""Weight of load balance loss"""

load_balance_coeff: float | None = 1e-3
"""Coefficient of bias update for aux-loss-free load balancing"""


@dataclass
class Model:
name: str = "llama3"
Expand Down Expand Up @@ -130,6 +142,9 @@ class Model:
converters have been applied.
"""

extra_losses: ExtraLosses = field(default_factory=ExtraLosses)
"""Extra losses to use"""


@dataclass
class Optimizer:
Expand Down
5 changes: 5 additions & 0 deletions torchtitan/models/deepseek_v3/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
)
self.max_seq_len = seq_len

losses_config = job_config.model.extra_losses
self.moe_args.load_balance_loss_type = losses_config.load_balance_loss_type
self.moe_args.load_balance_loss_weight = losses_config.load_balance_loss_weight
self.moe_args.load_balance_coeff = losses_config.load_balance_coeff

if self.moe_args.use_grouped_mm and not has_cuda_capability(9, 0):
logger.warning(
"Failed to use grouped mm, which is only supported on SM90 or later",
Expand Down
25 changes: 20 additions & 5 deletions torchtitan/models/deepseek_v3/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def forward(
self,
x: torch.Tensor,
freqs_cis: torch.Tensor,
accumulated_load_balance_loss: torch.Tensor,
attention_masks: AttentionMasksType | None,
):
"""
Expand All @@ -323,10 +324,15 @@ def forward(
"""
x = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks)
if self.moe_enabled:
x = x + self.moe(self.ffn_norm(x))
ffn_moe_output, load_balance_loss = self.moe(self.ffn_norm(x))
accumulated_load_balance_loss = (
accumulated_load_balance_loss + load_balance_loss
)
else:
x = x + self.feed_forward(self.ffn_norm(x))
return x
ffn_moe_output = self.feed_forward(self.ffn_norm(x))

x = x + ffn_moe_output
return x, accumulated_load_balance_loss

def init_weights(self, buffer_device: torch.device):
for norm in (self.attention_norm, self.ffn_norm):
Expand Down Expand Up @@ -410,6 +416,7 @@ def get_attention_masks(
def forward(
self,
tokens: torch.Tensor,
accumulated_load_balance_loss: torch.Tensor | None = None,
attention_masks: AttentionMasksType | None = None,
):
"""
Expand All @@ -427,8 +434,16 @@ def forward(

h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens

accumulated_load_balance_loss = (
torch.zeros((), device=h.device, dtype=torch.float32)
if accumulated_load_balance_loss is None
else accumulated_load_balance_loss
)

for layer in self.layers.values():
h = layer(h, self.freqs_cis, attention_masks)
h, accumulated_load_balance_loss = layer(
h, self.freqs_cis, accumulated_load_balance_loss, attention_masks
)
h = self.norm(h) if self.norm is not None else h
output = self.output(h) if self.output is not None else h
return output
return output, accumulated_load_balance_loss
124 changes: 121 additions & 3 deletions torchtitan/models/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class MoEArgs:
top_k: int = 1
use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation
load_balance_coeff: float | None = 1e-3

load_balance_loss_weight: float = 0
load_balance_loss_type: Literal["sequence_wise", "batch_wise"] = "sequence_wise"
_debug_force_load_balance: bool = False
# if True, we force each experts get same amount of token via round-robin

Expand Down Expand Up @@ -287,7 +288,7 @@ def forward(
max=self.num_experts,
)

return top_scores, selected_experts_indices, num_tokens_per_expert
return top_scores, scores, selected_experts_indices, num_tokens_per_expert

def init_weights(self, init_std: float):
nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std)
Expand Down Expand Up @@ -359,6 +360,8 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int):
super().__init__()

num_experts = moe_args.num_experts
self.topk = moe_args.top_k
self.num_experts = num_experts
self.experts = GroupedExperts(
dim=dim,
hidden_dim=hidden_dim,
Expand Down Expand Up @@ -386,6 +389,8 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int):
# NOTE: tokens_per_expert is accumulated in the model forward pass.
# expert_bias is updated outside the model in an optimizer step pre hook
# to work with gradient accumulation.
self.load_balance_loss_weight = moe_args.load_balance_loss_weight
self.load_balance_loss_type = moe_args.load_balance_loss_type
self.load_balance_coeff = moe_args.load_balance_coeff
if self.load_balance_coeff is not None:
assert self.load_balance_coeff > 0.0
Expand Down Expand Up @@ -418,6 +423,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# num_tokens_per_expert shape (num_experts,)
(
top_scores,
scores,
selected_experts_indices,
num_tokens_per_expert,
) = self.router(x, self.expert_bias)
Expand All @@ -430,6 +436,26 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
self.tokens_per_expert.add_(num_tokens_per_expert)

if self.training:
if self.load_balance_loss_type == "sequence_wise":
load_balance_loss = MoE.sequence_wise_aux_loss(
scores,
selected_experts_indices.long(),
bs,
slen,
self.topk,
self.load_balance_loss_weight,
)
elif self.load_balance_loss_type == "batch_wise":
load_balance_loss = MoE.batch_wise_aux_loss(
scores,
num_tokens_per_expert,
self.topk,
self.load_balance_loss_weight,
)
else:
load_balance_loss = torch.tensor(0.0, device=out.device, dtype=out.dtype)

# top_scores and token_indices_experts_sorted shape (bs*slen*top_k,)
# num_tokens_per_expert shape (num_experts,)
# NOTE: the reason we need to compute num_tokens_per_expert again is:
Expand Down Expand Up @@ -479,7 +505,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
dim=0, index=token_indices_experts_sorted, src=routed_output
)
out = out.reshape(bs, slen, dim)
return out

return out, load_balance_loss

def init_weights(
self,
Expand All @@ -499,3 +526,94 @@ def init_weights(
self.expert_bias = torch.zeros(
self.experts.num_experts, dtype=torch.float32
)

@staticmethod
@torch.compile(fullgraph=True)
def sequence_wise_aux_loss(
scores: torch.Tensor, # Shape: (B*S, N) - Raw Sigmoid Affinities (s_{i,t})
indices: torch.Tensor, # Shape: (B*S, K) - Selected Expert Indices
B: int, # Batch size
S: int, # Sequence length (T in the paper)
top_k: int, # K_r
aux_loss_alpha: float, # Alpha
) -> torch.Tensor:
"""
Computes Sequence-Wise Auxiliary Loss (DeepSeek-V3 Equations 17-20).

Args:
scores: The dense affinity scores (s_{i,t}) for routed experts.
Should be the output of Sigmoid, shape (B*S, N).
indices: The top-k selected expert indices. Shape (B*S, K).
"""
if aux_loss_alpha <= 0:
return torch.tensor(0.0, device=scores.device, dtype=scores.dtype)

# N_r: Total number of routed experts
N = scores.size(-1)

# 1. Reshape inputs to handle each sequence separately: (B, S, N)
# This ensures we calculate P_i and f_i per sequence (Eq 20 & 18).
scores_per_seq = scores.view(B, S, N)
indices_per_seq = indices.view(B, S, top_k)

# 2. Eq 19: Normalize affinity scores s_{i,t} to get s'_{i,t}
# DeepSeek-V3 uses Sigmoid, so scores don't sum to 1.
# Eq 19 explicitly requires dividing by the sum of all affinities.
# denominator shape: (B, S, 1)
denominator = scores_per_seq.sum(dim=-1, keepdim=True) + 1e-20
probs_per_seq = scores_per_seq / denominator # This is s'_{i,t}

# 3. Eq 20: Calculate P_i (Average probability per expert for each sequence)
# P_i = (1/T) * sum_{t=1}^T (s'_{i,t})
# We average over the Sequence dimension (dim=1).
# P_i shape: (B, N)
P_i = probs_per_seq.mean(dim=1)

# 4. Eq 18: Calculate f_i (Fraction of tokens selecting expert i per sequence)
# f_i = (N / (K * T)) * count_i

# Flatten the top-k dimension to count hits per sequence: (B, S*K)
flat_indices_per_seq = indices_per_seq.view(B, -1)
selection_counts = torch.zeros((B, N), device=scores.device, dtype=scores.dtype)
src = torch.ones_like(flat_indices_per_seq, dtype=scores.dtype)
selection_counts.scatter_add_(1, flat_indices_per_seq, src)

# Calculate f_i for each sequence, T (tokens in sequence) is S
f_i = selection_counts * (N / (top_k * S))

# 5. Eq 17: Calculate Balance Loss
loss_per_seq = (f_i * P_i).sum(dim=1) * aux_loss_alpha

return loss_per_seq.mean()

@staticmethod
@torch.compile(fullgraph=True)
def batch_wise_aux_loss(
scores: torch.Tensor,
num_tokens_per_expert: torch.Tensor,
top_k: int,
aux_loss_alpha: float,
) -> torch.Tensor:
"""
Computes Batch-Wise Auxiliary Loss.
Args:
scores: Dense probabilities (BS, N).
num_tokens_per_expert: Token counts (N).
top_k: Number of experts selected per token.
aux_loss_alpha: Scaling factor for the loss.
"""
if aux_loss_alpha <= 0:
return torch.tensor(0.0, device=scores.device, dtype=scores.dtype)

# Total number of routed experts (N)
N = scores.size(1)
# Total number of tokens (T = BS * S)
T = scores.size(0)

P_i = scores.mean(dim=0)

f_i = num_tokens_per_expert.to(scores.dtype) * (N / (top_k * T))

loss = (f_i * P_i).sum() * aux_loss_alpha

return loss
8 changes: 7 additions & 1 deletion torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import functools
import importlib
import os
import time
Expand All @@ -18,7 +19,7 @@
from torchtitan.components.checkpoint import CheckpointManager
from torchtitan.components.dataloader import DataloaderExhaustedError
from torchtitan.components.ft import FTManager, maybe_semi_sync_training
from torchtitan.components.loss import rescale_accumulated_loss
from torchtitan.components.loss import moe_loss, rescale_accumulated_loss
from torchtitan.components.metrics import (
build_metrics_processor,
ensure_pp_loss_visible,
Expand Down Expand Up @@ -184,6 +185,11 @@ def __init__(self, job_config: JobConfig):
job_config, parallel_dims=parallel_dims, ft_manager=self.ft_manager
)

self.loss_fn = functools.partial(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can add a condition here to wrap loss or not for MoE. for now all models in torchtitan only return a single output so its ok for now

moe_loss,
loss_fn=self.loss_fn,
)

# verify batch sizes
global_batch_size = job_config.training.global_batch_size
if global_batch_size < 0:
Expand Down