Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions docs/source/cpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ avoid generating adequate, but not perfect translations in Machine Translation (

CPO aims to mitigate two fundamental shortcomings of SFT. First, SFT’s methodology of minimizing the discrepancy between predicted outputs and gold-standard references inherently caps model performance at the quality level of the training data. Secondly, SFT lacks a mechanism to prevent the model from rejecting mistakes in translations. The CPO objective is derived from the DPO objective.

## SimPO
The [SimPO](https://arxiv.org/abs/2405.14734) method is also implemented in the `CPOTrainer`. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, we can use SimPO easily by turning on `loss_type="simpo"` and `cpo_alpha=0` in the `CPOConfig`.

## CPO-SimPO
We also offer the combined use of CPO and SimPO, which enables more stable training and improved performance. Learn more details at [CPO-SimPO Github](https://github.com/fe1ixxu/CPO_SIMPO). To use this method, simply enable SimPO by setting `loss_type="simpo"` and a non-zero `cpo_alpha` in the CPOConfig.

## Expected dataset format

The CPO trainer expects a format identical to the DPO trainer, which should include three entries. These entries should be named as follows:
Expand Down Expand Up @@ -48,7 +54,6 @@ cpo_dataset_dict = {
```
where the `prompt` contains the context inputs, `chosen` contains the corresponding chosen responses and `rejected` contains the corresponding negative (rejected) responses. As can be seen a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.


## Expected model format
The CPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.

Expand Down Expand Up @@ -81,9 +86,6 @@ The [RSO](https://arxiv.org/abs/2309.06657) authors propose to use a hinge loss

The [IPO](https://arxiv.org/abs/2310.12036) authors provide a deeper theoretical understanding of the CPO algorithms and identify an issue with overfitting and propose an alternative loss which can be used via the `loss_type="ipo"` argument to the trainer. Note that the `beta` parameter is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike CPO which is summed only).

The [SimPO](https://arxiv.org/abs/2405.14734) is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, we can use SimPO easily by turning on`loss_type="simpo"` in the `CPOConfig`.


## Logging

While training and evaluating we record the following reward metrics:
Expand All @@ -98,7 +100,6 @@ While training and evaluating we record the following reward metrics:

[[autodoc]] CPOTrainer


## CPOConfig

[[autodoc]] CPOConfig
2 changes: 2 additions & 0 deletions tests/test_cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def test_cpo_trainer(self, name, loss_type):
eval_strategy="steps",
beta=0.1,
loss_type=loss_type,
cpo_alpha=1.0,
)

dummy_dataset = self._init_dummy_dataset()
Expand Down Expand Up @@ -156,6 +157,7 @@ def test_cpo_trainer_with_lora(self):
learning_rate=9e-1,
eval_strategy="steps",
beta=0.1,
cpo_alpha=1.0,
)

dummy_dataset = self._init_dummy_dataset()
Expand Down
3 changes: 3 additions & 0 deletions trl/trainer/cpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class CPOConfig(TrainingArguments):
The type of loss to use. This argument is required if you want to use the default data collator.
label_pad_token_id (`int`, defaults to `-100`):
The label pad token id. This argument is required if you want to use the default data collator.
cpo_alpha (`float`, defaults to `1.0`):
A hyperparameter that controls the strength of the BC regularizer in CPO training.
simpo_gamma (`float`, defaults to `0.5`):
A target reward margin for the SimPO loss, used only when the "simpo" option is enabled.
padding_value (`int`, defaults to `None`):
Expand Down Expand Up @@ -68,6 +70,7 @@ class CPOConfig(TrainingArguments):
label_smoothing: float = 0
loss_type: Literal["sigmoid", "hinge", "ipo", "simpo"] = "sigmoid"
disable_dropout: bool = True
cpo_alpha: float = 1.0
simpo_gamma: float = 0.5

label_pad_token_id: int = -100
Expand Down
16 changes: 12 additions & 4 deletions trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,9 +269,17 @@ def make_inputs_require_grad(module, input, output):
self.beta = args.beta
self.label_smoothing = args.label_smoothing
self.loss_type = args.loss_type
self.cpo_alpha = args.cpo_alpha

if args.loss_type == "simpo":
self.simpo_gamma = args.simpo_gamma
if self.cpo_alpha > 0:
warnings.warn(
"You are using CPO-SimPO method because you set a non-zero cpo_alpha. "
"This will result in the CPO-SimPO method "
"(https://github.com/fe1ixxu/CPO_SIMPO/tree/main). "
"If you want to use a pure SimPO method, please set cpo_alpha to 0."
)

self._stored_metrics = defaultdict(lambda: defaultdict(list))

Expand Down Expand Up @@ -706,10 +714,10 @@ def cross_entropy_loss(logits, labels):

labels = concatenated_batch["concatenated_labels"].clone()

if self.loss_type != "simpo":
nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
else:
if self.cpo_alpha == 0:
nll_loss = torch.tensor(0.0).to(self.accelerator.device)
else:
nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])

all_logps = self.get_batch_logps(
all_logits,
Expand Down Expand Up @@ -749,7 +757,7 @@ def get_batch_loss_metrics(
policy_rejected_logps,
)

loss = losses.mean() + policy_nll_loss
loss = losses.mean() + self.cpo_alpha * policy_nll_loss
reward_accuracies = (chosen_rewards > rejected_rewards).float()

prefix = "eval_" if train_eval == "eval" else ""
Expand Down