Skip to content

Commit 6db4332

Browse files
authored
[Trainer] use output.loss when using liger-kernel (#42444)
* use output.loss when using liger Handle loss computation for models using Liger-kernel. fixes #42414 * Clarify Liger-kernel loss computation in comments * Both standard transformers and Liger models handle shift_labels correctly via **kwargs * removed unused shift_labels reference in loss computation * Remove unused model unwrapping
1 parent c95d4af commit 6db4332

File tree

1 file changed

+8
-11
lines changed

1 file changed

+8
-11
lines changed

src/transformers/trainer.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3909,7 +3909,7 @@ def compute_loss(
39093909

39103910
def _deepspeed_sp_compute_loss(self, model, inputs, return_outputs, pc):
39113911
"""
3912-
How the loss is computed by Trainer under sequence parallelism with sp_backend=="deepspeed" and sp_size>1.
3912+
How the loss is computed by the Trainer under sequence parallelism with sp_backend=="deepspeed" and sp_size>1.
39133913
Performs weighted loss aggregation across SP ranks, accounting for varying numbers of valid tokens per rank
39143914
(e.g., when some ranks receive only padding or prompt tokens that are masked with -100).
39153915
@@ -3927,23 +3927,20 @@ def _deepspeed_sp_compute_loss(self, model, inputs, return_outputs, pc):
39273927
The loss of the model along with its output if return_outputs was set to True
39283928
"""
39293929

3930-
unwrapped_model = self.accelerator.unwrap_model(model)
3931-
3930+
# DeepSpeed SP automatically injects shift_labels into inputs (pre-shifted labels for SP).
3931+
# The model's forward pass receives shift_labels via **kwargs and passes it to the loss function.
3932+
# Both standard transformer models and Liger-patched models handle shift_labels correctly,
3933+
# so we can directly use the computed loss from the model output.
3934+
# See: https://huggingface.co/docs/accelerate/en/concept_guides/sequence_parallelism
39323935
outputs = model(**inputs)
3933-
shift_labels = inputs["shift_labels"]
3934-
loss = unwrapped_model.loss_function(
3935-
logits=outputs.logits,
3936-
labels=None,
3937-
shift_labels=shift_labels,
3938-
vocab_size=unwrapped_model.config.vocab_size,
3939-
)
3936+
loss = outputs.loss
39403937

39413938
sp_group = self.accelerator.torch_device_mesh["sp"].get_group()
39423939
sp_world_size = pc.sp_size
39433940
# differentiable weighted per-shard-loss aggregation across ranks
39443941
losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=sp_group)
39453942
# special dealing with SFT that has prompt tokens that aren't used in loss computation
3946-
good_tokens = (shift_labels != -100).view(-1).sum()
3943+
good_tokens = (inputs["shift_labels"] != -100).view(-1).sum()
39473944
good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=sp_group)
39483945
# Skip ranks with zero valid tokens
39493946
total_loss = sum(

0 commit comments

Comments
 (0)