Skip to content

Commit f18253b

Browse files
authored
intial RPO loss (huggingface#1686)
* intial RPO loss * fix sign * clean up
1 parent 151a452 commit f18253b

File tree

4 files changed

+29
-12
lines changed

4 files changed

+29
-12
lines changed

docs/source/dpo_trainer.mdx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ The [NCA](https://arxiv.org/abs/2402.05369) authors shows that NCA optimizes the
119119

120120
The [TR-DPO](https://arxiv.org/pdf/2404.09656) paper suggests syncing the reference model weights after every `ref_model_sync_steps` steps of SGD with weight `ref_model_mixup_alpha` during DPO training. To toggle this callback use the `sync_ref_model` flag in the `DPOConfig`.
121121

122+
The [RPO](https://arxiv.org/abs/2404.19733) paper implements an iterative preference tuning algorithm using a loss related to the RPO loss in this [paper](https://arxiv.org/abs/2405.16436) that essentially consists of the SFT loss on the chosen preferences together with a weighted DPO loss. To use this loss set the `rpo_alpha` in the `DPOConfig` to an appropriate value.
123+
122124
## Logging
123125

124126
While training and evaluating we record the following reward metrics:

tests/test_dpo_trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def test_dpo_trainer_without_providing_ref_model(self):
160160
eval_strategy="steps",
161161
beta=0.1,
162162
precompute_ref_log_probs=True,
163+
rpo_alpha=0.5,
163164
)
164165

165166
dummy_dataset = self._init_dummy_dataset()

trl/trainer/dpo_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ class DPOConfig(TrainingArguments):
7171
The alpha parameter from the [TR-DPO](https://arxiv.org/pdf/2404.09656) paper.
7272
ref_model_sync_steps ('int', defaults to 2):
7373
The tau parameter from the [TR-DPO](https://arxiv.org/pdf/2404.09656) paper.
74+
rpo_alpha ('float', defaults to `None`):
75+
The alpha parameter from the [RPO](https://arxiv.org/pdf/2404.19733) paper. If None, no weighting is applied and the loss is the same as the DPO loss.
7476
"""
7577

7678
beta: float = 0.1
@@ -98,3 +100,4 @@ class DPOConfig(TrainingArguments):
98100
sync_ref_model: bool = False
99101
ref_model_mixup_alpha: float = 0.9
100102
ref_model_sync_steps: int = 64
103+
rpo_alpha: Optional[float] = None

trl/trainer/dpo_trainer.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -901,13 +901,15 @@ def compute_reference_log_probs(self, padded_batch: Dict) -> Dict:
901901
reference_rejected_logps,
902902
_,
903903
_,
904+
_,
904905
) = self.concatenated_forward(self.model, padded_batch)
905906
else:
906907
(
907908
reference_chosen_logps,
908909
reference_rejected_logps,
909910
_,
910911
_,
912+
_,
911913
) = self.concatenated_forward(self.ref_model, padded_batch)
912914

913915
return reference_chosen_logps, reference_rejected_logps
@@ -1089,21 +1091,19 @@ def dpo_loss(
10891091
def get_batch_logps(
10901092
logits: torch.FloatTensor,
10911093
labels: torch.LongTensor,
1092-
average_log_prob: bool = False,
10931094
label_pad_token_id: int = -100,
10941095
is_encoder_decoder: bool = False,
1095-
) -> torch.FloatTensor:
1096+
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
10961097
"""Compute the log probabilities of the given labels under the given logits.
10971098
10981099
Args:
10991100
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
11001101
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
1101-
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
11021102
label_pad_token_id: The label pad token id.
11031103
is_encoder_decoder: Whether the model is an encoder-decoder model.
11041104
11051105
Returns:
1106-
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
1106+
A Tuple of two tensor of shape ((batch_size,), (batch_size,)) containing the sum of log probabilities of the given labels under the given logits in the first tensor and the number of non-masked tokens in the second tensor.
11071107
"""
11081108
if logits.shape[:-1] != labels.shape:
11091109
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
@@ -1118,10 +1118,7 @@ def get_batch_logps(
11181118

11191119
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
11201120

1121-
if average_log_prob:
1122-
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
1123-
else:
1124-
return (per_token_logps * loss_mask).sum(-1)
1121+
return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)
11251122

11261123
def concatenated_forward(
11271124
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
@@ -1154,21 +1151,25 @@ def concatenated_forward(
11541151
**model_kwargs,
11551152
).logits
11561153

1157-
all_logps = self.get_batch_logps(
1154+
all_logps, size_completion = self.get_batch_logps(
11581155
all_logits,
11591156
concatenated_batch["concatenated_labels"],
1160-
average_log_prob=self.loss_type == "ipo",
1157+
# average_log_prob=self.loss_type == "ipo",
11611158
is_encoder_decoder=self.is_encoder_decoder,
11621159
label_pad_token_id=self.label_pad_token_id,
11631160
)
1161+
chosen_logps_avg = all_logps[:len_chosen] / size_completion[:len_chosen]
1162+
1163+
if self.loss_type == "ipo":
1164+
all_logps = all_logps / size_completion
11641165

11651166
chosen_logps = all_logps[:len_chosen]
11661167
rejected_logps = all_logps[len_chosen:]
11671168

11681169
chosen_logits = all_logits[:len_chosen]
11691170
rejected_logits = all_logits[len_chosen:]
11701171

1171-
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
1172+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps_avg)
11721173

11731174
def get_batch_loss_metrics(
11741175
self,
@@ -1184,10 +1185,15 @@ def get_batch_loss_metrics(
11841185
policy_rejected_logps,
11851186
policy_chosen_logits,
11861187
policy_rejected_logits,
1188+
policy_chosen_logps_avg,
11871189
) = self.concatenated_forward(model, batch)
11881190

11891191
# if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model
1190-
if "reference_chosen_logps" in batch and "reference_rejected_logps" in batch:
1192+
if (
1193+
"reference_chosen_logps" in batch
1194+
and "reference_rejected_logps" in batch
1195+
and self.args.rpo_alpha is not None
1196+
):
11911197
reference_chosen_logps = batch["reference_chosen_logps"]
11921198
reference_rejected_logps = batch["reference_rejected_logps"]
11931199
else:
@@ -1199,13 +1205,15 @@ def get_batch_loss_metrics(
11991205
reference_rejected_logps,
12001206
_,
12011207
_,
1208+
_,
12021209
) = self.concatenated_forward(self.model, batch)
12031210
else:
12041211
(
12051212
reference_chosen_logps,
12061213
reference_rejected_logps,
12071214
_,
12081215
_,
1216+
_,
12091217
) = self.concatenated_forward(self.ref_model, batch)
12101218

12111219
losses, chosen_rewards, rejected_rewards = self.dpo_loss(
@@ -1216,6 +1224,9 @@ def get_batch_loss_metrics(
12161224
)
12171225
reward_accuracies = (chosen_rewards > rejected_rewards).float()
12181226

1227+
if self.args.rpo_alpha is not None:
1228+
losses = losses * self.args.rpo_alpha - policy_chosen_logps_avg
1229+
12191230
prefix = "eval_" if train_eval == "eval" else ""
12201231
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
12211232
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()

0 commit comments

Comments
 (0)