|
13 | 13 | # limitations under the License. |
14 | 14 | import inspect |
15 | 15 | import warnings |
| 16 | +from collections import defaultdict |
16 | 17 | from dataclasses import FrozenInstanceError, replace |
17 | 18 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
18 | 19 |
|
| 20 | +import pandas as pd |
19 | 21 | import torch |
20 | 22 | import torch.nn as nn |
| 23 | +from accelerate.utils import gather_object |
21 | 24 | from datasets import Dataset |
22 | 25 | from transformers import DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainingArguments |
23 | 26 | from transformers.trainer_callback import TrainerCallback |
|
26 | 29 |
|
27 | 30 | from ..import_utils import is_peft_available |
28 | 31 | from .reward_config import RewardConfig |
29 | | -from .utils import RewardDataCollatorWithPadding, compute_accuracy |
| 32 | +from .utils import RewardDataCollatorWithPadding, compute_accuracy, print_rich_table |
30 | 33 |
|
31 | 34 |
|
32 | 35 | if is_peft_available(): |
@@ -279,3 +282,39 @@ def prediction_step( |
279 | 282 | labels = self._prepare_inputs(labels) |
280 | 283 |
|
281 | 284 | return loss, logits, labels |
| 285 | + |
| 286 | + def evaluate(self, *args, **kwargs): |
| 287 | + num_print_samples = kwargs.pop("num_print_samples", 4) |
| 288 | + self.visualize_samples(num_print_samples) |
| 289 | + return super().evaluate(*args, **kwargs) |
| 290 | + |
| 291 | + def visualize_samples(self, num_print_samples: int): |
| 292 | + """ |
| 293 | + Visualize the reward model logits prediction |
| 294 | +
|
| 295 | + Args: |
| 296 | + num_print_samples (`int`, defaults to `4`): |
| 297 | + The number of samples to print. Set to `-1` to print all samples. |
| 298 | + """ |
| 299 | + eval_dataloader = self.get_eval_dataloader() |
| 300 | + table = defaultdict(list) |
| 301 | + for _, inputs in enumerate(eval_dataloader): |
| 302 | + _, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False) |
| 303 | + chosen_text = self.tokenizer.batch_decode(inputs["input_ids_chosen"], skip_special_tokens=True) |
| 304 | + rejected_text = self.tokenizer.batch_decode(inputs["input_ids_rejected"], skip_special_tokens=True) |
| 305 | + table["chosen_text"].extend(gather_object(chosen_text)) |
| 306 | + table["rejected_text"].extend(gather_object(rejected_text)) |
| 307 | + table["logits"].extend( |
| 308 | + gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()]) |
| 309 | + ) |
| 310 | + if num_print_samples >= 0 and len(table["chosen_text"]) >= num_print_samples: |
| 311 | + break |
| 312 | + df = pd.DataFrame(table) |
| 313 | + print_rich_table(pd.DataFrame(table)) |
| 314 | + if self.accelerator.process_index == 0: |
| 315 | + print_rich_table(df[:num_print_samples]) |
| 316 | + if "wandb" in self.args.report_to: |
| 317 | + import wandb |
| 318 | + |
| 319 | + if wandb.run is not None: |
| 320 | + wandb.log({"completions": wandb.Table(dataframe=df)}) |
0 commit comments