Skip to content

Commit b314333

Browse files
committed
Add tests for DeepSpeed mixed precision precedence fix
- Add TestDeepSpeedMixedPrecisionPrecedence class with 3 focused tests - Test DeepSpeed fp16/bf16 config overriding TrainingArguments defaults - Test user explicit settings being preserved over DeepSpeed config - Test precedence hierarchy: user settings > DeepSpeed config > defaults - Replace massive 934-line test bloat with concise 50-line test suite - Tests cover core functionality of PR #39856 mixed precision precedence fix
1 parent 6e81814 commit b314333

File tree

1 file changed

+47
-0
lines changed

1 file changed

+47
-0
lines changed

tests/deepspeed/test_deepspeed.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1431,3 +1431,50 @@ def test_clm_from_config_zero3_fp16(self):
14311431
with CaptureStderr() as cs:
14321432
execute_subprocess_async(cmd, env=self.get_env())
14331433
self.assertIn("Detected DeepSpeed ZeRO-3", cs.err)
1434+
1435+
1436+
@require_deepspeed
1437+
class TestDeepSpeedMixedPrecisionPrecedence(TestCasePlus):
1438+
"""Test DeepSpeed mixed precision precedence over Accelerate defaults."""
1439+
1440+
def setUp(self):
1441+
super().setUp()
1442+
unset_hf_deepspeed_config()
1443+
1444+
def tearDown(self):
1445+
super().tearDown()
1446+
unset_hf_deepspeed_config()
1447+
1448+
def test_deepspeed_fp16_overrides_defaults(self):
1449+
"""Test that DeepSpeed fp16 config overrides TrainingArguments defaults"""
1450+
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
1451+
1452+
args = TrainingArguments(output_dir="./test_output", fp16=False, bf16=False)
1453+
ds_config = {"fp16": {"enabled": True}, "bf16": {"enabled": False}, "zero_optimization": {"stage": 2}}
1454+
hf_ds_config = HfTrainerDeepSpeedConfig(ds_config)
1455+
hf_ds_config.trainer_config_process(args)
1456+
self.assertTrue(args.fp16)
1457+
self.assertFalse(args.bf16)
1458+
1459+
def test_deepspeed_bf16_overrides_defaults(self):
1460+
"""Test that DeepSpeed bf16 config overrides TrainingArguments defaults"""
1461+
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
1462+
1463+
args = TrainingArguments(output_dir="./test_output", fp16=False, bf16=False)
1464+
ds_config = {"fp16": {"enabled": False}, "bf16": {"enabled": True}, "zero_optimization": {"stage": 2}}
1465+
hf_ds_config = HfTrainerDeepSpeedConfig(ds_config)
1466+
hf_ds_config.trainer_config_process(args)
1467+
self.assertTrue(args.bf16)
1468+
self.assertFalse(args.fp16)
1469+
1470+
def test_user_explicit_settings_preserved(self):
1471+
"""Test that explicit user settings are preserved over DeepSpeed config"""
1472+
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
1473+
1474+
args = TrainingArguments(output_dir="./test_output", fp16=True, bf16=False) # User explicit
1475+
ds_config = {"fp16": {"enabled": False}, "bf16": {"enabled": True}, "zero_optimization": {"stage": 2}}
1476+
hf_ds_config = HfTrainerDeepSpeedConfig(ds_config)
1477+
hf_ds_config.trainer_config_process(args)
1478+
# User's explicit choice should be preserved
1479+
self.assertTrue(args.fp16)
1480+
self.assertFalse(args.bf16)

0 commit comments

Comments
 (0)