File tree Expand file tree Collapse file tree 8 files changed +47
-10
lines changed Expand file tree Collapse file tree 8 files changed +47
-10
lines changed Original file line number Diff line number Diff line change @@ -104,5 +104,6 @@ def tokenize(element):
104104 )
105105 trainer .train ()
106106 trainer .save_model (config .output_dir )
107- trainer .push_to_hub ()
107+ if config .push_to_hub :
108+ trainer .push_to_hub ()
108109 trainer .generate_completions ()
Original file line number Diff line number Diff line change @@ -115,5 +115,6 @@ def tokenize(element):
115115 )
116116 trainer .train ()
117117 trainer .save_model (config .output_dir )
118- trainer .push_to_hub ()
118+ if config .push_to_hub :
119+ trainer .push_to_hub ()
119120 trainer .generate_completions ()
Original file line number Diff line number Diff line change @@ -104,5 +104,6 @@ def tokenize(element):
104104 )
105105 trainer .train ()
106106 trainer .save_model (config .output_dir )
107- trainer .push_to_hub ()
107+ if config .push_to_hub :
108+ trainer .push_to_hub ()
108109 trainer .generate_completions ()
Original file line number Diff line number Diff line change @@ -115,5 +115,6 @@ def tokenize(element):
115115 )
116116 trainer .train ()
117117 trainer .save_model (config .output_dir )
118- trainer .push_to_hub ()
118+ if config .push_to_hub :
119+ trainer .push_to_hub ()
119120 trainer .generate_completions ()
Original file line number Diff line number Diff line change 1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import platform
1415import subprocess
1516
1617
1718def test ():
1819 command = """\
19- python -i examples/scripts/ppo/ppo.py \
20+ python examples/scripts/ppo/ppo.py \
2021 --learning_rate 3e-6 \
2122 --output_dir models/minimal/ppo \
22- --per_device_train_batch_size 5 \
23+ --per_device_train_batch_size 4 \
2324 --gradient_accumulation_steps 1 \
2425 --total_episodes 10 \
2526 --model_name_or_path EleutherAI/pythia-14m \
2627 --non_eos_penalty \
2728 --stop_token eos \
2829 """
30+ if platform .system () == "Windows" :
31+ # windows CI does not work with subprocesses for some reason
32+ # e.g., https://github.com/huggingface/trl/actions/runs/9600036224/job/26475286210?pr=1743
33+ return
34+ subprocess .run (
35+ command ,
36+ shell = True ,
37+ check = True ,
38+ )
39+
40+
41+ def test_num_train_epochs ():
42+ command = """\
43+ python examples/scripts/ppo/ppo.py \
44+ --learning_rate 3e-6 \
45+ --output_dir models/minimal/ppo \
46+ --per_device_train_batch_size 4 \
47+ --gradient_accumulation_steps 1 \
48+ --num_train_epochs 0.003 \
49+ --model_name_or_path EleutherAI/pythia-14m \
50+ --non_eos_penalty \
51+ --stop_token eos \
52+ """
53+ if platform .system () == "Windows" :
54+ # windows CI does not work with subprocesses for some reason
55+ # e.g., https://github.com/huggingface/trl/actions/runs/9600036224/job/26475286210?pr=1743
56+ return
2957 subprocess .run (
3058 command ,
3159 shell = True ,
Original file line number Diff line number Diff line change 1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import platform
1415import subprocess
1516
1617import torch
1718
1819
1920def test ():
2021 command = """\
21- python -i examples/scripts/rloo/rloo.py \
22+ python examples/scripts/rloo/rloo.py \
2223 --learning_rate 3e-6 \
2324 --output_dir models/minimal/rloo \
24- --per_device_train_batch_size 5 \
25+ --per_device_train_batch_size 4 \
2526 --gradient_accumulation_steps 1 \
2627 --total_episodes 10 \
2728 --model_name_or_path EleutherAI/pythia-14m \
2829 --non_eos_penalty \
2930 --stop_token eos \
3031 """
32+ if platform .system () == "Windows" :
33+ # windows CI does not work with subprocesses for some reason
34+ # e.g., https://github.com/huggingface/trl/actions/runs/9600036224/job/26475286210?pr=1743
35+ return
3136 subprocess .run (
3237 command ,
3338 shell = True ,
Original file line number Diff line number Diff line change @@ -101,7 +101,7 @@ def __init__(
101101 # calculate various batch sizes
102102 #########
103103 if args .total_episodes is None : # allow the users to define episodes in terms of epochs.
104- args .total_episodes = args .num_train_epochs * self .train_dataset_len
104+ args .total_episodes = int ( args .num_train_epochs * self .train_dataset_len )
105105 accelerator = Accelerator (gradient_accumulation_steps = args .gradient_accumulation_steps )
106106 self .accelerator = accelerator
107107 args .world_size = accelerator .num_processes
Original file line number Diff line number Diff line change @@ -83,7 +83,7 @@ def __init__(
8383 # calculate various batch sizes
8484 #########
8585 if args .total_episodes is None : # allow the users to define episodes in terms of epochs.
86- args .total_episodes = args .num_train_epochs * self .train_dataset_len
86+ args .total_episodes = int ( args .num_train_epochs * self .train_dataset_len )
8787 accelerator = Accelerator (gradient_accumulation_steps = args .gradient_accumulation_steps )
8888 self .accelerator = accelerator
8989 args .world_size = accelerator .num_processes
You can’t perform that action at this time.
0 commit comments