@@ -1122,7 +1122,7 @@ def get_batch_logps(
11221122
11231123 def concatenated_forward (
11241124 self , model : nn .Module , batch : Dict [str , Union [List , torch .LongTensor ]]
1125- ) -> Tuple [torch .FloatTensor , torch .FloatTensor , torch .FloatTensor , torch .FloatTensor ]:
1125+ ) -> Tuple [torch .FloatTensor , torch .FloatTensor , torch .FloatTensor , torch .FloatTensor , torch . FloatTensor ]:
11261126 """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
11271127
11281128 We do this to avoid doing two forward passes, because it's faster for FSDP.
@@ -1158,7 +1158,23 @@ def concatenated_forward(
11581158 is_encoder_decoder = self .is_encoder_decoder ,
11591159 label_pad_token_id = self .label_pad_token_id ,
11601160 )
1161- chosen_logps_avg = all_logps [:len_chosen ] / size_completion [:len_chosen ]
1161+
1162+ def cross_entropy_loss (logits , labels ):
1163+ if not self .is_encoder_decoder :
1164+ # Shift so that tokens < n predict n
1165+ logits = logits [..., :- 1 , :].contiguous ()
1166+ labels = labels [..., 1 :].contiguous ()
1167+ # Flatten the tokens
1168+ loss_fct = nn .CrossEntropyLoss ()
1169+ logits = logits .view (- 1 , logits .shape [- 1 ])
1170+ labels = labels .view (- 1 )
1171+ # Enable model parallelism
1172+ labels = labels .to (logits .device )
1173+ loss = loss_fct (logits , labels )
1174+ return loss
1175+
1176+ labels = concatenated_batch ["concatenated_labels" ].clone ()
1177+ nll_loss = cross_entropy_loss (all_logits [:len_chosen ], labels [:len_chosen ])
11621178
11631179 if self .loss_type == "ipo" :
11641180 all_logps = all_logps / size_completion
@@ -1169,7 +1185,7 @@ def concatenated_forward(
11691185 chosen_logits = all_logits [:len_chosen ]
11701186 rejected_logits = all_logits [len_chosen :]
11711187
1172- return (chosen_logps , rejected_logps , chosen_logits , rejected_logits , chosen_logps_avg )
1188+ return (chosen_logps , rejected_logps , chosen_logits , rejected_logits , nll_loss )
11731189
11741190 def get_batch_loss_metrics (
11751191 self ,
@@ -1185,7 +1201,7 @@ def get_batch_loss_metrics(
11851201 policy_rejected_logps ,
11861202 policy_chosen_logits ,
11871203 policy_rejected_logits ,
1188- policy_chosen_logps_avg ,
1204+ policy_nll_loss ,
11891205 ) = self .concatenated_forward (model , batch )
11901206
11911207 # if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model
@@ -1225,7 +1241,7 @@ def get_batch_loss_metrics(
12251241 reward_accuracies = (chosen_rewards > rejected_rewards ).float ()
12261242
12271243 if self .args .rpo_alpha is not None :
1228- losses = losses * self .args .rpo_alpha - policy_chosen_logps_avg
1244+ losses = losses * self .args .rpo_alpha + policy_nll_loss
12291245
12301246 prefix = "eval_" if train_eval == "eval" else ""
12311247 metrics [f"{ prefix } rewards/chosen" ] = chosen_rewards .mean ().cpu ()
@@ -1236,6 +1252,8 @@ def get_batch_loss_metrics(
12361252 metrics [f"{ prefix } logps/chosen" ] = policy_chosen_logps .detach ().mean ().cpu ()
12371253 metrics [f"{ prefix } logits/rejected" ] = policy_rejected_logits .detach ().mean ().cpu ()
12381254 metrics [f"{ prefix } logits/chosen" ] = policy_chosen_logits .detach ().mean ().cpu ()
1255+ if self .args .rpo_alpha is not None :
1256+ metrics [f"{ prefix } nll_loss" ] = policy_nll_loss .detach ().mean ().cpu ()
12391257
12401258 return losses .mean (), metrics
12411259
0 commit comments