Skip to content

Commit 5bcb8ad

Browse files
authored
RDPO fix nll loss (#1705)
1 parent b8b972f commit 5bcb8ad

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

trl/trainer/dpo_trainer.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,7 +1122,7 @@ def get_batch_logps(
11221122

11231123
def concatenated_forward(
11241124
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
1125-
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1125+
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
11261126
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
11271127
11281128
We do this to avoid doing two forward passes, because it's faster for FSDP.
@@ -1158,7 +1158,23 @@ def concatenated_forward(
11581158
is_encoder_decoder=self.is_encoder_decoder,
11591159
label_pad_token_id=self.label_pad_token_id,
11601160
)
1161-
chosen_logps_avg = all_logps[:len_chosen] / size_completion[:len_chosen]
1161+
1162+
def cross_entropy_loss(logits, labels):
1163+
if not self.is_encoder_decoder:
1164+
# Shift so that tokens < n predict n
1165+
logits = logits[..., :-1, :].contiguous()
1166+
labels = labels[..., 1:].contiguous()
1167+
# Flatten the tokens
1168+
loss_fct = nn.CrossEntropyLoss()
1169+
logits = logits.view(-1, logits.shape[-1])
1170+
labels = labels.view(-1)
1171+
# Enable model parallelism
1172+
labels = labels.to(logits.device)
1173+
loss = loss_fct(logits, labels)
1174+
return loss
1175+
1176+
labels = concatenated_batch["concatenated_labels"].clone()
1177+
nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
11621178

11631179
if self.loss_type == "ipo":
11641180
all_logps = all_logps / size_completion
@@ -1169,7 +1185,7 @@ def concatenated_forward(
11691185
chosen_logits = all_logits[:len_chosen]
11701186
rejected_logits = all_logits[len_chosen:]
11711187

1172-
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps_avg)
1188+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)
11731189

11741190
def get_batch_loss_metrics(
11751191
self,
@@ -1185,7 +1201,7 @@ def get_batch_loss_metrics(
11851201
policy_rejected_logps,
11861202
policy_chosen_logits,
11871203
policy_rejected_logits,
1188-
policy_chosen_logps_avg,
1204+
policy_nll_loss,
11891205
) = self.concatenated_forward(model, batch)
11901206

11911207
# if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model
@@ -1225,7 +1241,7 @@ def get_batch_loss_metrics(
12251241
reward_accuracies = (chosen_rewards > rejected_rewards).float()
12261242

12271243
if self.args.rpo_alpha is not None:
1228-
losses = losses * self.args.rpo_alpha - policy_chosen_logps_avg
1244+
losses = losses * self.args.rpo_alpha + policy_nll_loss
12291245

12301246
prefix = "eval_" if train_eval == "eval" else ""
12311247
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
@@ -1236,6 +1252,8 @@ def get_batch_loss_metrics(
12361252
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
12371253
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
12381254
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
1255+
if self.args.rpo_alpha is not None:
1256+
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu()
12391257

12401258
return losses.mean(), metrics
12411259

0 commit comments

Comments
 (0)