Skip to content

Commit 34d273f

Browse files
authored
Support num_train_epochs (huggingface#1743)
* add a test case for num_train_epochs * fix ci * quick change * disable push to hub * debug windows ci * try another fix * skip subprocess tests on windows
1 parent 3bf9449 commit 34d273f

File tree

8 files changed

+47
-10
lines changed

8 files changed

+47
-10
lines changed

examples/scripts/ppo/ppo.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,5 +104,6 @@ def tokenize(element):
104104
)
105105
trainer.train()
106106
trainer.save_model(config.output_dir)
107-
trainer.push_to_hub()
107+
if config.push_to_hub:
108+
trainer.push_to_hub()
108109
trainer.generate_completions()

examples/scripts/ppo/ppo_tldr.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,5 +115,6 @@ def tokenize(element):
115115
)
116116
trainer.train()
117117
trainer.save_model(config.output_dir)
118-
trainer.push_to_hub()
118+
if config.push_to_hub:
119+
trainer.push_to_hub()
119120
trainer.generate_completions()

examples/scripts/rloo/rloo.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,5 +104,6 @@ def tokenize(element):
104104
)
105105
trainer.train()
106106
trainer.save_model(config.output_dir)
107-
trainer.push_to_hub()
107+
if config.push_to_hub:
108+
trainer.push_to_hub()
108109
trainer.generate_completions()

examples/scripts/rloo/rloo_tldr.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,5 +115,6 @@ def tokenize(element):
115115
)
116116
trainer.train()
117117
trainer.save_model(config.output_dir)
118-
trainer.push_to_hub()
118+
if config.push_to_hub:
119+
trainer.push_to_hub()
119120
trainer.generate_completions()

tests/test_ppov2_trainer.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,49 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import platform
1415
import subprocess
1516

1617

1718
def test():
1819
command = """\
19-
python -i examples/scripts/ppo/ppo.py \
20+
python examples/scripts/ppo/ppo.py \
2021
--learning_rate 3e-6 \
2122
--output_dir models/minimal/ppo \
22-
--per_device_train_batch_size 5 \
23+
--per_device_train_batch_size 4 \
2324
--gradient_accumulation_steps 1 \
2425
--total_episodes 10 \
2526
--model_name_or_path EleutherAI/pythia-14m \
2627
--non_eos_penalty \
2728
--stop_token eos \
2829
"""
30+
if platform.system() == "Windows":
31+
# windows CI does not work with subprocesses for some reason
32+
# e.g., https://github.com/huggingface/trl/actions/runs/9600036224/job/26475286210?pr=1743
33+
return
34+
subprocess.run(
35+
command,
36+
shell=True,
37+
check=True,
38+
)
39+
40+
41+
def test_num_train_epochs():
42+
command = """\
43+
python examples/scripts/ppo/ppo.py \
44+
--learning_rate 3e-6 \
45+
--output_dir models/minimal/ppo \
46+
--per_device_train_batch_size 4 \
47+
--gradient_accumulation_steps 1 \
48+
--num_train_epochs 0.003 \
49+
--model_name_or_path EleutherAI/pythia-14m \
50+
--non_eos_penalty \
51+
--stop_token eos \
52+
"""
53+
if platform.system() == "Windows":
54+
# windows CI does not work with subprocesses for some reason
55+
# e.g., https://github.com/huggingface/trl/actions/runs/9600036224/job/26475286210?pr=1743
56+
return
2957
subprocess.run(
3058
command,
3159
shell=True,

tests/test_rloo_trainer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,28 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import platform
1415
import subprocess
1516

1617
import torch
1718

1819

1920
def test():
2021
command = """\
21-
python -i examples/scripts/rloo/rloo.py \
22+
python examples/scripts/rloo/rloo.py \
2223
--learning_rate 3e-6 \
2324
--output_dir models/minimal/rloo \
24-
--per_device_train_batch_size 5 \
25+
--per_device_train_batch_size 4 \
2526
--gradient_accumulation_steps 1 \
2627
--total_episodes 10 \
2728
--model_name_or_path EleutherAI/pythia-14m \
2829
--non_eos_penalty \
2930
--stop_token eos \
3031
"""
32+
if platform.system() == "Windows":
33+
# windows CI does not work with subprocesses for some reason
34+
# e.g., https://github.com/huggingface/trl/actions/runs/9600036224/job/26475286210?pr=1743
35+
return
3136
subprocess.run(
3237
command,
3338
shell=True,

trl/trainer/ppov2_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def __init__(
101101
# calculate various batch sizes
102102
#########
103103
if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
104-
args.total_episodes = args.num_train_epochs * self.train_dataset_len
104+
args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
105105
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
106106
self.accelerator = accelerator
107107
args.world_size = accelerator.num_processes

trl/trainer/rloo_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(
8383
# calculate various batch sizes
8484
#########
8585
if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
86-
args.total_episodes = args.num_train_epochs * self.train_dataset_len
86+
args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
8787
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
8888
self.accelerator = accelerator
8989
args.world_size = accelerator.num_processes

0 commit comments

Comments
 (0)