Skip to content

Commit 8799952

Browse files
authored
visualize rm prediction (huggingface#1636)
* visualize rm prediction * quick update * quick check * quick fix * update eval steps
1 parent 3b4c249 commit 8799952

File tree

3 files changed

+69
-10
lines changed

3 files changed

+69
-10
lines changed

examples/scripts/reward_modeling.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,17 @@
1515
python examples/scripts/reward_modeling.py \
1616
--model_name_or_path=facebook/opt-350m \
1717
--output_dir="reward_modeling_anthropic_hh" \
18-
--per_device_train_batch_size=64 \
18+
--per_device_train_batch_size=16 \
1919
--num_train_epochs=1 \
20-
--gradient_accumulation_steps=16 \
20+
--gradient_accumulation_steps=2 \
2121
--gradient_checkpointing=True \
2222
--learning_rate=1.41e-5 \
2323
--report_to="wandb" \
2424
--remove_unused_columns=False \
2525
--optim="adamw_torch" \
2626
--logging_steps=10 \
2727
--evaluation_strategy="steps" \
28+
--eval_steps=500 \
2829
--max_length=512 \
2930
"""
3031
import warnings
@@ -42,8 +43,8 @@
4243

4344
if __name__ == "__main__":
4445
parser = HfArgumentParser((RewardConfig, ModelConfig))
45-
reward_config, model_config = parser.parse_args_into_dataclasses()
46-
reward_config.gradient_checkpointing_kwargs = dict(use_reentrant=False)
46+
config, model_config = parser.parse_args_into_dataclasses()
47+
config.gradient_checkpointing_kwargs = dict(use_reentrant=False)
4748

4849
################
4950
# Model & Tokenizer
@@ -103,8 +104,7 @@ def preprocess_function(examples):
103104
num_proc=4,
104105
)
105106
raw_datasets = raw_datasets.filter(
106-
lambda x: len(x["input_ids_chosen"]) <= reward_config.max_length
107-
and len(x["input_ids_rejected"]) <= reward_config.max_length
107+
lambda x: len(x["input_ids_chosen"]) <= config.max_length and len(x["input_ids_rejected"]) <= config.max_length
108108
)
109109
train_dataset = raw_datasets["train"]
110110
eval_dataset = raw_datasets["test"]
@@ -115,10 +115,14 @@ def preprocess_function(examples):
115115
trainer = RewardTrainer(
116116
model=model,
117117
tokenizer=tokenizer,
118-
args=reward_config,
118+
args=config,
119119
train_dataset=train_dataset,
120120
eval_dataset=eval_dataset,
121121
peft_config=get_peft_config(model_config),
122122
)
123123
trainer.train()
124-
trainer.save_model(reward_config.output_dir)
124+
trainer.save_model(config.output_dir)
125+
trainer.push_to_hub()
126+
metrics = trainer.evaluate()
127+
trainer.log_metrics("eval", metrics)
128+
print(metrics)

trl/trainer/reward_trainer.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@
1313
# limitations under the License.
1414
import inspect
1515
import warnings
16+
from collections import defaultdict
1617
from dataclasses import FrozenInstanceError, replace
1718
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1819

20+
import pandas as pd
1921
import torch
2022
import torch.nn as nn
23+
from accelerate.utils import gather_object
2124
from datasets import Dataset
2225
from transformers import DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainingArguments
2326
from transformers.trainer_callback import TrainerCallback
@@ -26,7 +29,7 @@
2629

2730
from ..import_utils import is_peft_available
2831
from .reward_config import RewardConfig
29-
from .utils import RewardDataCollatorWithPadding, compute_accuracy
32+
from .utils import RewardDataCollatorWithPadding, compute_accuracy, print_rich_table
3033

3134

3235
if is_peft_available():
@@ -279,3 +282,39 @@ def prediction_step(
279282
labels = self._prepare_inputs(labels)
280283

281284
return loss, logits, labels
285+
286+
def evaluate(self, *args, **kwargs):
287+
num_print_samples = kwargs.pop("num_print_samples", 4)
288+
self.visualize_samples(num_print_samples)
289+
return super().evaluate(*args, **kwargs)
290+
291+
def visualize_samples(self, num_print_samples: int):
292+
"""
293+
Visualize the reward model logits prediction
294+
295+
Args:
296+
num_print_samples (`int`, defaults to `4`):
297+
The number of samples to print. Set to `-1` to print all samples.
298+
"""
299+
eval_dataloader = self.get_eval_dataloader()
300+
table = defaultdict(list)
301+
for _, inputs in enumerate(eval_dataloader):
302+
_, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False)
303+
chosen_text = self.tokenizer.batch_decode(inputs["input_ids_chosen"], skip_special_tokens=True)
304+
rejected_text = self.tokenizer.batch_decode(inputs["input_ids_rejected"], skip_special_tokens=True)
305+
table["chosen_text"].extend(gather_object(chosen_text))
306+
table["rejected_text"].extend(gather_object(rejected_text))
307+
table["logits"].extend(
308+
gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()])
309+
)
310+
if num_print_samples >= 0 and len(table["chosen_text"]) >= num_print_samples:
311+
break
312+
df = pd.DataFrame(table)
313+
print_rich_table(pd.DataFrame(table))
314+
if self.accelerator.process_index == 0:
315+
print_rich_table(df[:num_print_samples])
316+
if "wandb" in self.args.report_to:
317+
import wandb
318+
319+
if wandb.run is not None:
320+
wandb.log({"completions": wandb.Table(dataframe=df)})

trl/trainer/utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,21 @@
1818
from typing import Any, Dict, List, Optional, Tuple, Union
1919

2020
import numpy as np
21+
import pandas as pd
2122
import torch
2223
from accelerate import PartialState
2324
from rich.console import Console, Group
2425
from rich.live import Live
2526
from rich.panel import Panel
2627
from rich.progress import Progress
28+
from rich.table import Table
2729
from torch.nn.utils.rnn import pad_sequence
2830
from torch.utils.data import IterableDataset
29-
from transformers import BitsAndBytesConfig, DataCollatorForLanguageModeling, PreTrainedTokenizerBase
31+
from transformers import (
32+
BitsAndBytesConfig,
33+
DataCollatorForLanguageModeling,
34+
PreTrainedTokenizerBase,
35+
)
3036
from transformers.trainer import TrainerCallback
3137
from transformers.trainer_utils import has_length
3238

@@ -815,3 +821,13 @@ def on_train_end(self, args, state, control, **kwargs):
815821
self.rich_console = None
816822
self.training_status = None
817823
self.current_step = None
824+
825+
826+
def print_rich_table(df: pd.DataFrame) -> Table:
827+
console = Console()
828+
table = Table(show_lines=True)
829+
for column in df.columns:
830+
table.add_column(column)
831+
for _, row in df.iterrows():
832+
table.add_row(*row.astype(str).tolist())
833+
console.print(table)

0 commit comments

Comments
 (0)