Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ def _step_safety_checker(
scores (List[`torch.FloatTensor`]):
List of tensors containing the scores.
masks (List[`torch.LongTensor`], *optional*):
list of optional tensors containing the masks of shape (`query_length` + `response_length`)
list of optional tensors containing the masks of shape (`response_length`)
Returns:
`tuple`: The input processed data.
"""
Expand Down Expand Up @@ -1033,15 +1033,11 @@ def batched_forward_pass(
if attention_mask[j, 0] == 0: # offset left padding
start += attention_mask[j, :].nonzero()[0]
end = start + len(response_batch[j])
if response_masks is not None:
response_masks_batch[j] = torch.cat(
(torch.zeros_like(query_batch[j]), response_masks_batch[j])
)[1:]

masks[j, :start] = 0
masks[j, end:] = 0
if response_masks is not None:
masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
masks[j, start:end] = masks[j, start:end] * response_masks_batch[j]

if return_logits:
all_logits.append(logits)
Expand Down