diff --git a/torchtitan/components/loss.py b/torchtitan/components/loss.py index 6fb80f39cb..b8b52993e1 100644 --- a/torchtitan/components/loss.py +++ b/torchtitan/components/loss.py @@ -78,3 +78,21 @@ def build_mse_loss(job_config: JobConfig, **kwargs): logger.info("Compiling the loss function with torch.compile") loss_fn = torch.compile(loss_fn, backend=job_config.compile.backend) return loss_fn + + +def moe_loss( + pred: tuple[torch.Tensor, torch.Tensor] | torch.Tensor, + labels: torch.Tensor, + loss_fn: LossFunction, +) -> torch.Tensor: + """Sequence-wise auxiliary load balance loss function for MoE + model training. + """ + if isinstance(pred, tuple): + pred, load_balance_loss = pred + loss = loss_fn(pred, labels) + # USE STE to make the magnitude of loss remain the same + loss = loss + (load_balance_loss - load_balance_loss.detach()) + else: + loss = loss_fn(pred, labels) + return loss diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 95588d2c3b..c4d40108c8 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -97,6 +97,18 @@ class Metrics: """Whether to log metrics to Weights & Biases""" +@dataclass +class ExtraLosses: + load_balance_loss_type: Literal["sequence_wise", "batch_wise"] = "sequence_wise" + """Type of load balance loss to use""" + + load_balance_loss_weight: float = 0 + """Weight of load balance loss""" + + load_balance_coeff: float | None = 1e-3 + """Coefficient of bias update for aux-loss-free load balancing""" + + @dataclass class Model: name: str = "llama3" @@ -130,6 +142,9 @@ class Model: converters have been applied. """ + extra_losses: ExtraLosses = field(default_factory=ExtraLosses) + """Extra losses to use""" + @dataclass class Optimizer: diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index 48d4b5ece1..fab9862b91 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -95,6 +95,11 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.max_seq_len = seq_len + losses_config = job_config.model.extra_losses + self.moe_args.load_balance_loss_type = losses_config.load_balance_loss_type + self.moe_args.load_balance_loss_weight = losses_config.load_balance_loss_weight + self.moe_args.load_balance_coeff = losses_config.load_balance_coeff + if self.moe_args.use_grouped_mm and not has_cuda_capability(9, 0): logger.warning( "Failed to use grouped mm, which is only supported on SM90 or later", diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 3cf56eb1b2..25b168806a 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -309,6 +309,7 @@ def forward( self, x: torch.Tensor, freqs_cis: torch.Tensor, + accumulated_load_balance_loss: torch.Tensor, attention_masks: AttentionMasksType | None, ): """ @@ -323,10 +324,15 @@ def forward( """ x = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks) if self.moe_enabled: - x = x + self.moe(self.ffn_norm(x)) + ffn_moe_output, load_balance_loss = self.moe(self.ffn_norm(x)) + accumulated_load_balance_loss = ( + accumulated_load_balance_loss + load_balance_loss + ) else: - x = x + self.feed_forward(self.ffn_norm(x)) - return x + ffn_moe_output = self.feed_forward(self.ffn_norm(x)) + + x = x + ffn_moe_output + return x, accumulated_load_balance_loss def init_weights(self, buffer_device: torch.device): for norm in (self.attention_norm, self.ffn_norm): @@ -410,6 +416,7 @@ def get_attention_masks( def forward( self, tokens: torch.Tensor, + accumulated_load_balance_loss: torch.Tensor | None = None, attention_masks: AttentionMasksType | None = None, ): """ @@ -427,8 +434,16 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens + accumulated_load_balance_loss = ( + torch.zeros((), device=h.device, dtype=torch.float32) + if accumulated_load_balance_loss is None + else accumulated_load_balance_loss + ) + for layer in self.layers.values(): - h = layer(h, self.freqs_cis, attention_masks) + h, accumulated_load_balance_loss = layer( + h, self.freqs_cis, accumulated_load_balance_loss, attention_masks + ) h = self.norm(h) if self.norm is not None else h output = self.output(h) if self.output is not None else h - return output + return output, accumulated_load_balance_loss diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py index 295e2193a5..471fc7c076 100644 --- a/torchtitan/models/moe/moe.py +++ b/torchtitan/models/moe/moe.py @@ -30,7 +30,8 @@ class MoEArgs: top_k: int = 1 use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation load_balance_coeff: float | None = 1e-3 - + load_balance_loss_weight: float = 0 + load_balance_loss_type: Literal["sequence_wise", "batch_wise"] = "sequence_wise" _debug_force_load_balance: bool = False # if True, we force each experts get same amount of token via round-robin @@ -287,7 +288,7 @@ def forward( max=self.num_experts, ) - return top_scores, selected_experts_indices, num_tokens_per_expert + return top_scores, scores, selected_experts_indices, num_tokens_per_expert def init_weights(self, init_std: float): nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std) @@ -359,6 +360,8 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): super().__init__() num_experts = moe_args.num_experts + self.topk = moe_args.top_k + self.num_experts = num_experts self.experts = GroupedExperts( dim=dim, hidden_dim=hidden_dim, @@ -386,6 +389,8 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): # NOTE: tokens_per_expert is accumulated in the model forward pass. # expert_bias is updated outside the model in an optimizer step pre hook # to work with gradient accumulation. + self.load_balance_loss_weight = moe_args.load_balance_loss_weight + self.load_balance_loss_type = moe_args.load_balance_loss_type self.load_balance_coeff = moe_args.load_balance_coeff if self.load_balance_coeff is not None: assert self.load_balance_coeff > 0.0 @@ -418,6 +423,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # num_tokens_per_expert shape (num_experts,) ( top_scores, + scores, selected_experts_indices, num_tokens_per_expert, ) = self.router(x, self.expert_bias) @@ -430,6 +436,26 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: with torch.no_grad(): self.tokens_per_expert.add_(num_tokens_per_expert) + if self.training: + if self.load_balance_loss_type == "sequence_wise": + load_balance_loss = MoE.sequence_wise_aux_loss( + scores, + selected_experts_indices.long(), + bs, + slen, + self.topk, + self.load_balance_loss_weight, + ) + elif self.load_balance_loss_type == "batch_wise": + load_balance_loss = MoE.batch_wise_aux_loss( + scores, + num_tokens_per_expert, + self.topk, + self.load_balance_loss_weight, + ) + else: + load_balance_loss = torch.tensor(0.0, device=out.device, dtype=out.dtype) + # top_scores and token_indices_experts_sorted shape (bs*slen*top_k,) # num_tokens_per_expert shape (num_experts,) # NOTE: the reason we need to compute num_tokens_per_expert again is: @@ -479,7 +505,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: dim=0, index=token_indices_experts_sorted, src=routed_output ) out = out.reshape(bs, slen, dim) - return out + + return out, load_balance_loss def init_weights( self, @@ -499,3 +526,94 @@ def init_weights( self.expert_bias = torch.zeros( self.experts.num_experts, dtype=torch.float32 ) + + @staticmethod + @torch.compile(fullgraph=True) + def sequence_wise_aux_loss( + scores: torch.Tensor, # Shape: (B*S, N) - Raw Sigmoid Affinities (s_{i,t}) + indices: torch.Tensor, # Shape: (B*S, K) - Selected Expert Indices + B: int, # Batch size + S: int, # Sequence length (T in the paper) + top_k: int, # K_r + aux_loss_alpha: float, # Alpha + ) -> torch.Tensor: + """ + Computes Sequence-Wise Auxiliary Loss (DeepSeek-V3 Equations 17-20). + + Args: + scores: The dense affinity scores (s_{i,t}) for routed experts. + Should be the output of Sigmoid, shape (B*S, N). + indices: The top-k selected expert indices. Shape (B*S, K). + """ + if aux_loss_alpha <= 0: + return torch.tensor(0.0, device=scores.device, dtype=scores.dtype) + + # N_r: Total number of routed experts + N = scores.size(-1) + + # 1. Reshape inputs to handle each sequence separately: (B, S, N) + # This ensures we calculate P_i and f_i per sequence (Eq 20 & 18). + scores_per_seq = scores.view(B, S, N) + indices_per_seq = indices.view(B, S, top_k) + + # 2. Eq 19: Normalize affinity scores s_{i,t} to get s'_{i,t} + # DeepSeek-V3 uses Sigmoid, so scores don't sum to 1. + # Eq 19 explicitly requires dividing by the sum of all affinities. + # denominator shape: (B, S, 1) + denominator = scores_per_seq.sum(dim=-1, keepdim=True) + 1e-20 + probs_per_seq = scores_per_seq / denominator # This is s'_{i,t} + + # 3. Eq 20: Calculate P_i (Average probability per expert for each sequence) + # P_i = (1/T) * sum_{t=1}^T (s'_{i,t}) + # We average over the Sequence dimension (dim=1). + # P_i shape: (B, N) + P_i = probs_per_seq.mean(dim=1) + + # 4. Eq 18: Calculate f_i (Fraction of tokens selecting expert i per sequence) + # f_i = (N / (K * T)) * count_i + + # Flatten the top-k dimension to count hits per sequence: (B, S*K) + flat_indices_per_seq = indices_per_seq.view(B, -1) + selection_counts = torch.zeros((B, N), device=scores.device, dtype=scores.dtype) + src = torch.ones_like(flat_indices_per_seq, dtype=scores.dtype) + selection_counts.scatter_add_(1, flat_indices_per_seq, src) + + # Calculate f_i for each sequence, T (tokens in sequence) is S + f_i = selection_counts * (N / (top_k * S)) + + # 5. Eq 17: Calculate Balance Loss + loss_per_seq = (f_i * P_i).sum(dim=1) * aux_loss_alpha + + return loss_per_seq.mean() + + @staticmethod + @torch.compile(fullgraph=True) + def batch_wise_aux_loss( + scores: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + top_k: int, + aux_loss_alpha: float, + ) -> torch.Tensor: + """ + Computes Batch-Wise Auxiliary Loss. + Args: + scores: Dense probabilities (BS, N). + num_tokens_per_expert: Token counts (N). + top_k: Number of experts selected per token. + aux_loss_alpha: Scaling factor for the loss. + """ + if aux_loss_alpha <= 0: + return torch.tensor(0.0, device=scores.device, dtype=scores.dtype) + + # Total number of routed experts (N) + N = scores.size(1) + # Total number of tokens (T = BS * S) + T = scores.size(0) + + P_i = scores.mean(dim=0) + + f_i = num_tokens_per_expert.to(scores.dtype) * (N / (top_k * T)) + + loss = (f_i * P_i).sum() * aux_loss_alpha + + return loss diff --git a/torchtitan/train.py b/torchtitan/train.py index 5cfab998b2..bab206cb00 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import functools import importlib import os import time @@ -18,7 +19,7 @@ from torchtitan.components.checkpoint import CheckpointManager from torchtitan.components.dataloader import DataloaderExhaustedError from torchtitan.components.ft import FTManager, maybe_semi_sync_training -from torchtitan.components.loss import rescale_accumulated_loss +from torchtitan.components.loss import moe_loss, rescale_accumulated_loss from torchtitan.components.metrics import ( build_metrics_processor, ensure_pp_loss_visible, @@ -184,6 +185,11 @@ def __init__(self, job_config: JobConfig): job_config, parallel_dims=parallel_dims, ft_manager=self.ft_manager ) + self.loss_fn = functools.partial( + moe_loss, + loss_fn=self.loss_fn, + ) + # verify batch sizes global_batch_size = job_config.training.global_batch_size if global_batch_size < 0: