diff --git a/docs/source/cpo_trainer.mdx b/docs/source/cpo_trainer.mdx index e3a7f339cf1..08960c64058 100644 --- a/docs/source/cpo_trainer.mdx +++ b/docs/source/cpo_trainer.mdx @@ -5,6 +5,12 @@ avoid generating adequate, but not perfect translations in Machine Translation ( CPO aims to mitigate two fundamental shortcomings of SFT. First, SFT’s methodology of minimizing the discrepancy between predicted outputs and gold-standard references inherently caps model performance at the quality level of the training data. Secondly, SFT lacks a mechanism to prevent the model from rejecting mistakes in translations. The CPO objective is derived from the DPO objective. +## SimPO +The [SimPO](https://arxiv.org/abs/2405.14734) method is also implemented in the `CPOTrainer`. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, we can use SimPO easily by turning on `loss_type="simpo"` and `cpo_alpha=0` in the `CPOConfig`. + +## CPO-SimPO +We also offer the combined use of CPO and SimPO, which enables more stable training and improved performance. Learn more details at [CPO-SimPO Github](https://github.com/fe1ixxu/CPO_SIMPO). To use this method, simply enable SimPO by setting `loss_type="simpo"` and a non-zero `cpo_alpha` in the CPOConfig. + ## Expected dataset format The CPO trainer expects a format identical to the DPO trainer, which should include three entries. These entries should be named as follows: @@ -48,7 +54,6 @@ cpo_dataset_dict = { ``` where the `prompt` contains the context inputs, `chosen` contains the corresponding chosen responses and `rejected` contains the corresponding negative (rejected) responses. As can be seen a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays. - ## Expected model format The CPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function. @@ -81,9 +86,6 @@ The [RSO](https://arxiv.org/abs/2309.06657) authors propose to use a hinge loss The [IPO](https://arxiv.org/abs/2310.12036) authors provide a deeper theoretical understanding of the CPO algorithms and identify an issue with overfitting and propose an alternative loss which can be used via the `loss_type="ipo"` argument to the trainer. Note that the `beta` parameter is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike CPO which is summed only). -The [SimPO](https://arxiv.org/abs/2405.14734) is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, we can use SimPO easily by turning on`loss_type="simpo"` in the `CPOConfig`. - - ## Logging While training and evaluating we record the following reward metrics: @@ -98,7 +100,6 @@ While training and evaluating we record the following reward metrics: [[autodoc]] CPOTrainer - ## CPOConfig [[autodoc]] CPOConfig \ No newline at end of file diff --git a/tests/test_cpo_trainer.py b/tests/test_cpo_trainer.py index e4736a686ad..0b6bdf84b2f 100644 --- a/tests/test_cpo_trainer.py +++ b/tests/test_cpo_trainer.py @@ -100,6 +100,7 @@ def test_cpo_trainer(self, name, loss_type): eval_strategy="steps", beta=0.1, loss_type=loss_type, + cpo_alpha=1.0, ) dummy_dataset = self._init_dummy_dataset() @@ -156,6 +157,7 @@ def test_cpo_trainer_with_lora(self): learning_rate=9e-1, eval_strategy="steps", beta=0.1, + cpo_alpha=1.0, ) dummy_dataset = self._init_dummy_dataset() diff --git a/trl/trainer/cpo_config.py b/trl/trainer/cpo_config.py index e5a9e4b7ba1..5ba874b7a74 100644 --- a/trl/trainer/cpo_config.py +++ b/trl/trainer/cpo_config.py @@ -41,6 +41,8 @@ class CPOConfig(TrainingArguments): The type of loss to use. This argument is required if you want to use the default data collator. label_pad_token_id (`int`, defaults to `-100`): The label pad token id. This argument is required if you want to use the default data collator. + cpo_alpha (`float`, defaults to `1.0`): + A hyperparameter that controls the strength of the BC regularizer in CPO training. simpo_gamma (`float`, defaults to `0.5`): A target reward margin for the SimPO loss, used only when the "simpo" option is enabled. padding_value (`int`, defaults to `None`): @@ -68,6 +70,7 @@ class CPOConfig(TrainingArguments): label_smoothing: float = 0 loss_type: Literal["sigmoid", "hinge", "ipo", "simpo"] = "sigmoid" disable_dropout: bool = True + cpo_alpha: float = 1.0 simpo_gamma: float = 0.5 label_pad_token_id: int = -100 diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index e0c42f0f6c5..6e5bbd6163d 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -269,9 +269,17 @@ def make_inputs_require_grad(module, input, output): self.beta = args.beta self.label_smoothing = args.label_smoothing self.loss_type = args.loss_type + self.cpo_alpha = args.cpo_alpha if args.loss_type == "simpo": self.simpo_gamma = args.simpo_gamma + if self.cpo_alpha > 0: + warnings.warn( + "You are using CPO-SimPO method because you set a non-zero cpo_alpha. " + "This will result in the CPO-SimPO method " + "(https://github.com/fe1ixxu/CPO_SIMPO/tree/main). " + "If you want to use a pure SimPO method, please set cpo_alpha to 0." + ) self._stored_metrics = defaultdict(lambda: defaultdict(list)) @@ -706,10 +714,10 @@ def cross_entropy_loss(logits, labels): labels = concatenated_batch["concatenated_labels"].clone() - if self.loss_type != "simpo": - nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen]) - else: + if self.cpo_alpha == 0: nll_loss = torch.tensor(0.0).to(self.accelerator.device) + else: + nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen]) all_logps = self.get_batch_logps( all_logits, @@ -749,7 +757,7 @@ def get_batch_loss_metrics( policy_rejected_logps, ) - loss = losses.mean() + policy_nll_loss + loss = losses.mean() + self.cpo_alpha * policy_nll_loss reward_accuracies = (chosen_rewards > rejected_rewards).float() prefix = "eval_" if train_eval == "eval" else ""