Skip to content

Commit 13454d2

Browse files
authored
PPO / Reinforce Trainers (huggingface#1540)
* Add ppov2 trainer * make eos trick optional, remove unused args * quick fix * precommit * update debugging script * fix out of bound `drop_last=True`; use built-in scheduler * Add PPO examples * push changes * quick change * quick change * various bug fixes * remove unnecessary grad accumulation setting * push new changes * fix DS3 model saving * update ppo.py * refactor * quick change * refactor * update ppo trainer * refactor * quick test * add ds2 /ds3 7 processes config * add vllm trainer * quick change * experiment with reward normalization * push changes * quick push * push changes * push various changes * refactor to use ModelConfig * quick change * refactor * refactor * Simplify DS logic * quick update * remove unnecessary files * precommit * deepspeed fix; handle edge case when eos_token_id = 0 * add PPO tldr example * add TL;DR example * fix undefined var * utilize all samples in rloo * quick setting * remove the unnecessary `value_model` * use exact_div * allow saving the deepspeed model * refactor * remove dead code * Use some shared utilities * add some end-to-end test cases * add PPOv2 docs and RLOO docs / tests * update docs * quikc push * fix ci * fix type annotation for ci * quick update * update trainer docs
1 parent 99f2c94 commit 13454d2

23 files changed

+3114
-11
lines changed

docs/source/_toctree.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727
title: Supervised Fine-Tuning
2828
- local: ppo_trainer
2929
title: PPO Trainer
30+
- local: ppov2_trainer
31+
title: PPOv2 Trainer
32+
- local: rloo_trainer
33+
title: RLOO Trainer
3034
- local: best_of_n
3135
title: Best of N Sampling
3236
- local: dpo_trainer

docs/source/ppov2_trainer.md

Lines changed: 257 additions & 0 deletions
Large diffs are not rendered by default.

docs/source/rloo_trainer.md

Lines changed: 301 additions & 0 deletions
Large diffs are not rendered by default.

examples/accelerate_configs/deepspeed_zero2.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ compute_environment: LOCAL_MACHINE
22
debug: false
33
deepspeed_config:
44
deepspeed_multinode_launcher: standard
5-
gradient_accumulation_steps: 1
65
offload_optimizer_device: none
76
offload_param_device: none
87
zero3_init_flag: false

examples/accelerate_configs/deepspeed_zero3.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ compute_environment: LOCAL_MACHINE
22
debug: false
33
deepspeed_config:
44
deepspeed_multinode_launcher: standard
5-
gradient_accumulation_steps: 1
65
offload_optimizer_device: none
76
offload_param_device: none
87
zero3_init_flag: true
@@ -12,7 +11,7 @@ distributed_type: DEEPSPEED
1211
downcast_bf16: 'no'
1312
machine_rank: 0
1413
main_training_function: main
15-
mixed_precision: 'bf16'
14+
mixed_precision: bf16
1615
num_machines: 1
1716
num_processes: 8
1817
rdzv_backend: static

examples/datasets/tldr_preference.py

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ class ScriptArguments:
2727
hf_repo_id: Optional[str] = field(
2828
default="tldr-preference-trl-style", metadata={"help": "The Hugging Face repository ID"}
2929
)
30+
sft_hf_repo_id: Optional[str] = field(
31+
default="tldr-preference-sft-trl-style", metadata={"help": "The Hugging Face repository ID"}
32+
)
3033
revision: Optional[str] = field(default="0.1.0", metadata={"help": "The revision of the repository"})
3134
update_main_revision: Optional[bool] = field(
3235
default=True, metadata={"help": "Update the main revision of the repository"}
@@ -39,7 +42,11 @@ class ScriptArguments:
3942
if args.hf_entity is None:
4043
args.hf_entity = api.whoami()["name"]
4144
full_repo_id = f"{args.hf_entity}/{args.hf_repo_id}"
45+
full_sft_repo_id = f"{args.hf_entity}/{args.sft_hf_repo_id}"
4246

47+
################
48+
# Preference dataset
49+
################
4350
ds = load_dataset("openai/summarize_from_feedback", "comparisons")
4451
if args.debug:
4552
for key in ds:
@@ -92,22 +99,88 @@ def process(row):
9299
repo_type="dataset",
93100
)
94101

95-
sft_card = RepoCard.load(
102+
preference_card = RepoCard.load(
96103
full_repo_id,
97104
repo_type="dataset",
98105
)
99-
sft_card.text = f"""\
106+
preference_card.text = f"""\
100107
# TRL's TL;DR Preference Dataset
101108
102109
We preprocess the dataset using our standard `prompt, chosen, rejected` format.
103110
111+
## Source of the dataset
112+
113+
We take the dataset from https://huggingface.co/datasets/openai/summarize_from_feedback.
104114
105115
## Reproduce this dataset
106116
107117
1. Download the `{file_name}` from the {repo_full_url}.
108118
2. Run `{run_command}`
109119
"""
110-
sft_card.push_to_hub(
120+
preference_card.push_to_hub(
111121
full_repo_id,
112122
repo_type="dataset",
113123
)
124+
125+
################
126+
# SFT dataset
127+
################
128+
sft_ds = load_dataset("vwxyzjn/summarize_from_feedback_tldr_3_filtered")
129+
if args.debug:
130+
for key in sft_ds:
131+
sft_ds[key] = sft_ds[key].select(range(50))
132+
133+
def sft_process(row):
134+
row["prompt"] = tldr_format_str.format(**row)
135+
row["messages"] = [
136+
{"role": "user", "content": row["prompt"]},
137+
{"role": "assistant", "content": row["summary"]},
138+
]
139+
return row
140+
141+
sft_ds = sft_ds.map(
142+
sft_process,
143+
num_proc=1 if args.debug else multiprocessing.cpu_count(),
144+
load_from_cache_file=False,
145+
)
146+
for key in sft_ds: # reorder columns
147+
sft_ds[key] = sft_ds[key].select_columns(["prompt", "messages", "id", "subreddit", "title", "post", "summary"])
148+
if args.push_to_hub:
149+
revisions = ["main"] if args.update_main_revision else []
150+
revisions.append(args.revision)
151+
152+
# get the commnad used to run the script
153+
run_command = " ".join(["python"] + sys.argv)
154+
155+
for revision in revisions:
156+
sft_ds.push_to_hub(full_sft_repo_id, revision=revision)
157+
repo_full_url = f"https://huggingface.co/datasets/{full_sft_repo_id}/tree/{revision}"
158+
159+
# get the name of the current file
160+
file_name = __file__.split("/")[-1]
161+
api.upload_file(
162+
path_or_fileobj=__file__,
163+
path_in_repo=file_name,
164+
revision=revision,
165+
repo_id=full_sft_repo_id,
166+
repo_type="dataset",
167+
)
168+
169+
sft_card = RepoCard.load(
170+
full_sft_repo_id,
171+
repo_type="dataset",
172+
)
173+
sft_card.text = f"""\
174+
# TRL's TL;DR SFT Dataset
175+
176+
We preprocess the dataset using our standard `prompt, messages` format.
177+
178+
## Source of the dataset
179+
180+
We take the dataset from https://huggingface.co/datasets/vwxyzjn/summarize_from_feedback_tldr_3_filtered.
181+
182+
## Reproduce this dataset
183+
184+
1. Download the `{file_name}` from the {repo_full_url}.
185+
2. Run `{run_command}`
186+
"""
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import shlex
2+
import subprocess
3+
import sys
4+
from collections import defaultdict
5+
from dataclasses import dataclass
6+
7+
import pandas as pd
8+
from datasets import load_dataset
9+
from gpt_tldr_judge import LLMJudgeConfig, llm_judge
10+
from transformers import AutoTokenizer, HfArgumentParser
11+
from vllm import SamplingParams, SingleGPULLM
12+
13+
14+
"""
15+
python -i examples/scripts/evals/generate_tldr.py \
16+
--model_name_or_path vwxyzjn/rloo_tldr \
17+
--output_path examples/scripts/minimal/evals/rloo_tldr.csv \
18+
--n 1000
19+
python -i examples/scripts/evals/generate_tldr.py \
20+
--model_name_or_path vwxyzjn/ppo_tldr \
21+
--output_path examples/scripts/minimal/evals/ppo_tldr.csv \
22+
--n 1000
23+
"""
24+
25+
26+
@dataclass
27+
class Args:
28+
output_path: str
29+
model_name_or_path: str
30+
model_revision: str = "main"
31+
n: int = 1000
32+
33+
34+
def run_command(command: str):
35+
command_list = shlex.split(command)
36+
print(f"running {command}")
37+
subprocess.run(command_list, stderr=sys.stderr, stdout=sys.stdout)
38+
39+
40+
MAX_TOKENS = 200 # a very generous max token length
41+
parser = HfArgumentParser(Args)
42+
args = parser.parse_args_into_dataclasses()[0]
43+
tokenizer = AutoTokenizer.from_pretrained(
44+
args.model_name_or_path,
45+
revision=args.model_revision,
46+
)
47+
raw_datasets = load_dataset("trl-internal-testing/tldr-preference-sft-trl-style")
48+
prompts = raw_datasets["test"]["prompt"]
49+
if args.n is not None:
50+
prompts = prompts[: args.n]
51+
reference_summaries = [message[-1]["content"] for message in raw_datasets["test"]["messages"]]
52+
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=MAX_TOKENS)
53+
llm = SingleGPULLM(
54+
model=args.model_name_or_path,
55+
revision=args.model_revision,
56+
tensor_parallel_size=1,
57+
device="cuda:0",
58+
)
59+
outputs = llm.generate(prompts, sampling_params)
60+
table = defaultdict(list)
61+
62+
# Print the outputs.
63+
for output, reference_response in zip(outputs, reference_summaries):
64+
prompt = output.prompt
65+
generated_text = output.outputs[0].text
66+
table["prompt"].append(prompt)
67+
table["model_response"].append(generated_text.strip()) # need `strip()` because of the leading space
68+
table["model_response_len"].append(len(output.outputs[0].token_ids))
69+
table["reference_response"].append(reference_response)
70+
table["reference_response_len"].append(
71+
len(tokenizer(f" {reference_response}")["input_ids"])
72+
) # prepend leading space
73+
74+
df = pd.DataFrame(table)
75+
df.to_csv(args.output_path)
76+
77+
#####
78+
# GPT as a judge
79+
####
80+
df["response0"] = df["model_response"]
81+
df["response1"] = df["reference_response"]
82+
judged_df = llm_judge(
83+
LLMJudgeConfig(
84+
n=args.n,
85+
model="gpt-3.5-turbo-0125",
86+
),
87+
df,
88+
)
89+
judged_df.to_csv(args.output_path.replace(".csv", "_judged.csv"))
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# you can download the CSV from https://wandb.ai/costa-huang/tldr_summarize/runs/gb2dian5
2+
3+
import asyncio
4+
import random
5+
import time
6+
from dataclasses import dataclass
7+
from typing import Optional
8+
9+
import pandas as pd
10+
from openai import AsyncOpenAI
11+
from tqdm.asyncio import tqdm_asyncio
12+
from transformers import HfArgumentParser
13+
14+
15+
@dataclass
16+
class LLMJudgeConfig:
17+
n: int = 64
18+
model: str = "gpt-3.5-turbo-0125"
19+
max_parallel_requests: Optional[int] = None
20+
21+
def __post_init__(self):
22+
if "gpt-3.5" in self.model:
23+
# gpt-3.5 generates so fast that it will exceeds the
24+
# token limit per minute
25+
self.max_parallel_requests = 11
26+
elif "gpt-4" in self.model:
27+
self.max_parallel_requests = 13
28+
29+
30+
@dataclass
31+
class Args:
32+
csv: str = "trained_response.csv"
33+
output_path: Optional[str] = None
34+
num_trails: int = 1
35+
36+
37+
TEMPLATE = r"""
38+
Which of the following summaries does a better job of summarizing the most important points in the given forum post, without including unimportant or irrelevant details? Judge based on accuracy, coverage, and coherence.
39+
40+
### Post:
41+
{{post}}
42+
43+
### Summary A:
44+
{{response0}}
45+
46+
### Summary B:
47+
{{response1}}
48+
49+
### Instructions:
50+
FIRST provide a one-sentence comparison of the two summaries, explaining which \
51+
you prefer and why. SECOND, on a new line, state only "A" or "B" to indicate your choice. Your response should use the format:
52+
Comparison: <one-sentence comparison and explanation>
53+
Preferred: <"A" or "B">
54+
"""
55+
56+
57+
def llm_judge(ljc: LLMJudgeConfig, df: pd.DataFrame):
58+
limiter = asyncio.Semaphore(ljc.max_parallel_requests)
59+
async_client = AsyncOpenAI()
60+
61+
async def process_text(post: str, response0: str, response1: str, i: int):
62+
text = TEMPLATE.replace("{{post}}", post)
63+
text = text.replace("{{response0}}", response0)
64+
text = text.replace("{{response1}}", response1) # Ensure this split logic is correct for your data
65+
66+
async with limiter:
67+
response = None
68+
while response is None:
69+
try:
70+
response = await async_client.chat.completions.create(
71+
model=ljc.model,
72+
messages=[
73+
{"role": "system", "content": "You are a helpful assistant."},
74+
{"role": "user", "content": text},
75+
],
76+
)
77+
r = response.choices[0].message.content
78+
except Exception as e:
79+
print(f"error in {i}: {e}")
80+
time.sleep(30) # deal with rate limit
81+
continue
82+
83+
try:
84+
comparison = r.split("Comparison:")[1].split("Preferred:")[0].strip()
85+
preferred = r.split("Preferred:")[1].strip()
86+
return comparison, preferred, i, text + r
87+
except Exception as e:
88+
print(f"error in {i} {e}")
89+
return "", random.choice(["A", "B"]), i, text + r
90+
91+
async def main(ljc: LLMJudgeConfig, df: pd.DataFrame):
92+
"""`df` should have columns: `prompt`, `response0`, `response1`"""
93+
tasks = []
94+
df["explanation"] = [None for _ in range(len(df))]
95+
df["preferred"] = [None for _ in range(len(df))]
96+
df["shuffled_index"] = [None for _ in range(len(df))]
97+
df["entire_conversation"] = [None for _ in range(len(df))]
98+
r = range(min(ljc.n, len(df)))
99+
if ljc.n == -1:
100+
r = range(len(df))
101+
for i in r:
102+
post = df["prompt"].iloc[i].strip()
103+
# shuffled the index to avoid GPT4's preference bias in the content's order
104+
shuffled_index = random.randint(0, 1)
105+
df.at[i, "shuffled_index"] = shuffled_index
106+
responses = [
107+
df["response0"].iloc[i].strip(),
108+
df["response1"].iloc[i].strip(),
109+
]
110+
response0 = responses[shuffled_index]
111+
response1 = responses[1 - shuffled_index]
112+
task = asyncio.create_task(process_text(post, response0, response1, i))
113+
tasks.append(task)
114+
115+
results = await tqdm_asyncio.gather(*tasks)
116+
117+
for _, (comparison, preferred, i, entire_conversation) in enumerate(results):
118+
df.at[i, "explanation"] = comparison
119+
df.at[i, "entire_conversation"] = entire_conversation
120+
preferred_label = (
121+
"response0"
122+
if (df.at[i, "shuffled_index"] == 0 and preferred == "A")
123+
or (df.at[i, "shuffled_index"] == 1 and preferred == "B")
124+
else "response1"
125+
)
126+
df.at[i, "preferred"] = preferred_label
127+
print(df["preferred"].value_counts())
128+
return df
129+
130+
return asyncio.run(main(ljc, df))
131+
132+
133+
if __name__ == "__main__":
134+
args, ljc = HfArgumentParser((Args, LLMJudgeConfig)).parse_args_into_dataclasses()
135+
df = pd.read_csv(args.csv)
136+
df["reference_response"] = df["reference_response"].map(lambda x: x.split("<|endoftext|>")[0].strip())
137+
df["prompt"] = df["query"].map(lambda x: x.strip())
138+
df["response0"] = df["model_response"].map(lambda x: x.strip())
139+
df["response1"] = df["reference_response"].map(lambda x: x.strip())
140+
judge_df = llm_judge(ljc, df)
141+
judge_df.to_csv(args.output_path)

0 commit comments

Comments
 (0)