Skip to content

Commit bc8dfbf

Browse files
authored
update eval_strategy (huggingface#1662)
1 parent e4ed7a3 commit bc8dfbf

File tree

3 files changed

+17
-12
lines changed

3 files changed

+17
-12
lines changed

tests/test_dpo_trainer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def _init_dummy_dataset(self):
9494
["t5", "sppo_hard", True],
9595
["gpt2", "nca_pair", False],
9696
["t5", "nca_pair", True],
97+
["gpt2", "robust", True],
9798
]
9899
)
99100
def test_dpo_trainer(self, name, loss_type, pre_compute):
@@ -317,7 +318,7 @@ def test_tr_dpo_trainer(self):
317318
remove_unused_columns=False,
318319
gradient_accumulation_steps=4,
319320
learning_rate=9e-1,
320-
evaluation_strategy="steps",
321+
eval_strategy="steps",
321322
precompute_ref_log_probs=False,
322323
sync_ref_model=True,
323324
ref_model_mixup_alpha=0.5,
@@ -508,6 +509,10 @@ def test_dpo_lora_bf16_autocast_llama(self):
508509
["gpt2", "bco_pair", False, True],
509510
["gpt2", "bco_pair", True, False],
510511
["gpt2", "bco_pair", True, True],
512+
["gpt2", "robust", False, False],
513+
["gpt2", "robust", False, True],
514+
["gpt2", "robust", True, False],
515+
["gpt2", "robust", True, True],
511516
]
512517
)
513518
@require_bitsandbytes

tests/test_kto_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def test_kto_trainer_bco_udm(self):
271271
remove_unused_columns=False,
272272
gradient_accumulation_steps=4,
273273
learning_rate=9e-1,
274-
evaluation_strategy="steps",
274+
eval_strategy="steps",
275275
beta=0.1,
276276
loss_type="bco",
277277
)

tests/test_sft_trainer.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def test_sft_trainer_uncorrect_data(self):
263263
training_args = SFTConfig(
264264
output_dir=tmp_dir,
265265
dataloader_drop_last=True,
266-
evaluation_strategy="steps",
266+
eval_strategy="steps",
267267
max_steps=2,
268268
eval_steps=1,
269269
save_steps=1,
@@ -281,7 +281,7 @@ def test_sft_trainer_uncorrect_data(self):
281281
training_args = SFTConfig(
282282
output_dir=tmp_dir,
283283
dataloader_drop_last=True,
284-
evaluation_strategy="steps",
284+
eval_strategy="steps",
285285
max_steps=2,
286286
eval_steps=1,
287287
save_steps=1,
@@ -298,7 +298,7 @@ def test_sft_trainer_uncorrect_data(self):
298298
training_args = SFTConfig(
299299
output_dir=tmp_dir,
300300
dataloader_drop_last=True,
301-
evaluation_strategy="steps",
301+
eval_strategy="steps",
302302
max_steps=2,
303303
eval_steps=1,
304304
save_steps=1,
@@ -315,7 +315,7 @@ def test_sft_trainer_uncorrect_data(self):
315315
training_args = SFTConfig(
316316
output_dir=tmp_dir,
317317
dataloader_drop_last=True,
318-
evaluation_strategy="steps",
318+
eval_strategy="steps",
319319
max_steps=2,
320320
eval_steps=1,
321321
save_steps=1,
@@ -331,7 +331,7 @@ def test_sft_trainer_uncorrect_data(self):
331331
training_args = SFTConfig(
332332
output_dir=tmp_dir,
333333
dataloader_drop_last=True,
334-
evaluation_strategy="steps",
334+
eval_strategy="steps",
335335
max_steps=2,
336336
eval_steps=1,
337337
save_steps=1,
@@ -352,7 +352,7 @@ def test_sft_trainer_uncorrect_data(self):
352352
training_args = SFTConfig(
353353
output_dir=tmp_dir,
354354
dataloader_drop_last=True,
355-
evaluation_strategy="steps",
355+
eval_strategy="steps",
356356
max_steps=2,
357357
eval_steps=1,
358358
save_steps=1,
@@ -372,7 +372,7 @@ def test_sft_trainer_uncorrect_data(self):
372372
training_args = SFTConfig(
373373
output_dir=tmp_dir,
374374
dataloader_drop_last=True,
375-
evaluation_strategy="steps",
375+
eval_strategy="steps",
376376
max_steps=2,
377377
eval_steps=1,
378378
save_steps=1,
@@ -390,7 +390,7 @@ def test_sft_trainer_uncorrect_data(self):
390390
training_args = SFTConfig(
391391
output_dir=tmp_dir,
392392
dataloader_drop_last=True,
393-
evaluation_strategy="steps",
393+
eval_strategy="steps",
394394
max_steps=2,
395395
eval_steps=1,
396396
save_steps=1,
@@ -1089,7 +1089,7 @@ def test_sft_trainer_eval_packing(self):
10891089
training_args = SFTConfig(
10901090
output_dir=tmp_dir,
10911091
dataloader_drop_last=True,
1092-
evaluation_strategy="steps",
1092+
eval_strategy="steps",
10931093
max_steps=4,
10941094
eval_steps=2,
10951095
save_steps=2,
@@ -1111,7 +1111,7 @@ def test_sft_trainer_eval_packing(self):
11111111
training_args = SFTConfig(
11121112
output_dir=tmp_dir,
11131113
dataloader_drop_last=True,
1114-
evaluation_strategy="steps",
1114+
eval_strategy="steps",
11151115
max_steps=4,
11161116
eval_steps=2,
11171117
save_steps=2,

0 commit comments

Comments
 (0)