Skip to content

Conversation

@notkisk
Copy link
Contributor

@notkisk notkisk commented Aug 1, 2025

Summary

Fixes issue [#39849] where Accelerate would default to bf16 mixed precision even when a DeepSpeed config specifies fp16, causing the following error:

ValueError: --mixed_precision arg cannot be set to bf16 when fp16 is set in the DeepSpeed config file.

This PR ensures that DeepSpeed configuration takes precedence over TrainingArguments defaults while preserving explicit user settings.


Root Cause

The issue was caused by the initialization order in TrainingArguments.__post_init__(). The ACCELERATE_MIXED_PRECISION environment variable was being set before the DeepSpeed config was processed, preventing it from overriding Accelerate’s defaults.


Changes Made

1. Added DeepSpeed Config Override Logic

  • Added override_training_args_from_deepspeed() method to HfTrainerDeepSpeedConfig class.
  • This method checks DeepSpeed config for fp16/bf16 settings and overrides TrainingArguments defaults accordingly.
  • Explicit user choices are preserved, but DeepSpeed config can override defaults if no user input is provided.

2. Fixed Initialization Order

  • Moved the mixed precision environment variable setting in TrainingArguments.__post_init__() to occur after DeepSpeed config processing.
  • Ensures DeepSpeed config overrides are applied before environment variables are set.

Behavior

The fix enforces the following precedence hierarchy:

  1. Explicit user settings – Highest priority
    E.g., fp16=True or bf16=True passed by user.
  2. DeepSpeed config – Medium priority
    E.g., "fp16": {"enabled": true} or "bf16": {"enabled": true} in config file.
  3. TrainingArguments defaults – Lowest priority

Test Plan

  • ✅ Verified the original reproduction case no longer fails.
  • ✅ Tested that DeepSpeed fp16 config overrides default correctly.
  • ✅ Tested that DeepSpeed bf16 config overrides default correctly.
  • ✅ Confirmed explicit user settings take precedence over DeepSpeed config.
  • ✅ Ensured environment variables are set correctly in all scenarios.
  • ✅ Ran existing DeepSpeed test suite to check for regressions.
  • ✅ Rebased on latest main and verified fix still works.

Files Modified

  • src/transformers/integrations/deepspeed.py – Added override logic and method call.
  • src/transformers/training_args.py – Reordered mixed precision env var setup.

Branch Info

  • PR Branch: fix-deepspeed-mixed-precision-precedence (rebased on latest main)
  • Base Branch: main

@notkisk
Copy link
Contributor Author

notkisk commented Aug 1, 2025

@zucchini-nlp

@zucchini-nlp zucchini-nlp requested a review from SunMarc August 2, 2025 09:29
@notkisk notkisk force-pushed the fix-deepspeed-mixed-precision branch from bfea358 to 158389a Compare August 2, 2025 21:14
@Rocketknight1
Copy link
Member

cc @SunMarc

@notkisk notkisk force-pushed the fix-deepspeed-mixed-precision branch from 158389a to abcfc42 Compare August 6, 2025 16:55
@SunMarc SunMarc requested a review from S1ro1 August 6, 2025 17:20
@notkisk notkisk force-pushed the fix-deepspeed-mixed-precision branch from abcfc42 to ad97cfe Compare August 10, 2025 10:49
@notkisk
Copy link
Contributor Author

notkisk commented Aug 19, 2025

@S1ro1
Copy link
Contributor

S1ro1 commented Aug 19, 2025

Hey, can you add some tests for this behaviour?

@notkisk notkisk force-pushed the fix-deepspeed-mixed-precision branch from ad97cfe to 6e81814 Compare August 19, 2025 15:01
notkisk added a commit to notkisk/transformers that referenced this pull request Aug 19, 2025
- Add TestDeepSpeedMixedPrecisionPrecedence class with 3 focused tests
- Test DeepSpeed fp16/bf16 config overriding TrainingArguments defaults
- Test user explicit settings being preserved over DeepSpeed config
- Test precedence hierarchy: user settings > DeepSpeed config > defaults
- Replace massive 934-line test bloat with concise 50-line test suite
- Tests cover core functionality of PR huggingface#39856 mixed precision precedence fix
notkisk added a commit to notkisk/transformers that referenced this pull request Aug 19, 2025
- Add TestDeepSpeedMixedPrecisionPrecedence class with 3 focused tests
- Test DeepSpeed fp16/bf16 config overriding TrainingArguments defaults
- Test user explicit settings being preserved over DeepSpeed config
- Test precedence hierarchy: user settings > DeepSpeed config > defaults
- Replace massive 934-line test bloat with concise 50-line test suite
- Tests cover core functionality of PR huggingface#39856 mixed precision precedence fix
@notkisk notkisk force-pushed the fix-deepspeed-mixed-precision branch from b314333 to 55e9838 Compare August 19, 2025 17:06
@notkisk
Copy link
Contributor Author

notkisk commented Aug 19, 2025

cc @S1ro1

Resolves issue where Accelerate would default to bf16 mixed precision
when a DeepSpeed config specifies fp16, causing a ValueError. The fix
ensures DeepSpeed config takes precedence over TrainingArguments defaults
while preserving explicit user settings.

Changes:
- Add override_training_args_from_deepspeed() method to handle config precedence
- Reorder mixed precision environment variable setting in TrainingArguments
- Ensure DeepSpeed fp16/bf16 settings override defaults but not explicit choices

Fixes huggingface#39849
- Add TestDeepSpeedMixedPrecisionPrecedence class with 3 focused tests
- Test DeepSpeed fp16/bf16 config overriding TrainingArguments defaults
- Test user explicit settings being preserved over DeepSpeed config
- Test precedence hierarchy: user settings > DeepSpeed config > defaults
- Replace massive 934-line test bloat with concise 50-line test suite
- Tests cover core functionality of PR huggingface#39856 mixed precision precedence fix
@notkisk notkisk force-pushed the fix-deepspeed-mixed-precision branch from 55e9838 to aed6ba5 Compare August 19, 2025 19:20
Copy link
Contributor

@S1ro1 S1ro1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. CC @SunMarc to check the trainer side of things

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@notkisk
Copy link
Contributor Author

notkisk commented Aug 22, 2025

@SunMarc

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM Thanks for the PR

@ArthurZucker ArthurZucker merged commit df67cd3 into huggingface:main Sep 11, 2025
24 checks passed
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. Did you manage to reproduce the initial error from the PR @notkisk ?
cc @S1ro1 for visibility

Comment on lines +147 to +175
user_set_fp16 = args.fp16 is True
user_set_bf16 = args.bf16 is True

if self.is_true("fp16.enabled"):
# DeepSpeed config explicitly enables fp16
if not user_set_fp16 and not user_set_bf16:
# User didn't explicitly set either, so apply DeepSpeed config
args.fp16 = True
args.bf16 = False
elif user_set_bf16 and not user_set_fp16:
# User explicitly chose bf16, but DeepSpeed config wants fp16
# This is a potential conflict - let user choice win but log a warning
pass # Keep user's bf16=True, fp16=False
elif self.is_true("bf16.enabled"):
# DeepSpeed config explicitly enables bf16
if not user_set_fp16 and not user_set_bf16:
# User didn't explicitly set either, so apply DeepSpeed config
args.bf16 = True
args.fp16 = False
elif user_set_fp16 and not user_set_bf16:
# User explicitly chose fp16, but DeepSpeed config wants bf16
# This is a potential conflict - let user choice win but log a warning
pass # Keep user's fp16=True, bf16=False
elif self.is_false("fp16.enabled") and self.is_false("bf16.enabled"):
# Both are explicitly disabled in DeepSpeed config
if not user_set_fp16 and not user_set_bf16:
# User didn't explicitly set either, so apply DeepSpeed config (fp32)
args.fp16 = False
args.bf16 = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this could have been simpler

Comment on lines +1436 to +1458
@require_deepspeed
class TestDeepSpeedMixedPrecisionPrecedence(TestCasePlus):
"""Test DeepSpeed mixed precision precedence over Accelerate defaults."""

def setUp(self):
super().setUp()
unset_hf_deepspeed_config()

def tearDown(self):
super().tearDown()
unset_hf_deepspeed_config()

def test_deepspeed_fp16_overrides_defaults(self):
"""Test that DeepSpeed fp16 config overrides TrainingArguments defaults"""
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig

args = TrainingArguments(output_dir="./test_output", fp16=False, bf16=False)
ds_config = {"fp16": {"enabled": True}, "bf16": {"enabled": False}, "zero_optimization": {"stage": 2}}
hf_ds_config = HfTrainerDeepSpeedConfig(ds_config)
hf_ds_config.trainer_config_process(args)
self.assertTrue(args.fp16)
self.assertFalse(args.bf16)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be nice to add a test that reproduce the initial error

@notkisk
Copy link
Contributor Author

notkisk commented Sep 11, 2025

Thanks for the PR. Did you manage to reproduce the initial error from the PR @notkisk ? cc @S1ro1 for visibility

as far as i remember, yes i did reproduce the issue and the pr fixed it

@SunMarc
Copy link
Member

SunMarc commented Sep 24, 2025

Well, the issue is that we shouldn't overwrite training_args default. Hence i'm reverting this PR. The issue with the original issue was that mixed_precision was set to bf16 somehow

vijayabhaskar-ev pushed a commit to vijayabhaskar-ev/transformers that referenced this pull request Oct 2, 2025
…ggingface#39856)

* Fix DeepSpeed mixed precision precedence over Accelerate defaults

Resolves issue where Accelerate would default to bf16 mixed precision
when a DeepSpeed config specifies fp16, causing a ValueError. The fix
ensures DeepSpeed config takes precedence over TrainingArguments defaults
while preserving explicit user settings.

Changes:
- Add override_training_args_from_deepspeed() method to handle config precedence
- Reorder mixed precision environment variable setting in TrainingArguments
- Ensure DeepSpeed fp16/bf16 settings override defaults but not explicit choices

Fixes huggingface#39849

* Add tests for DeepSpeed mixed precision precedence fix

- Add TestDeepSpeedMixedPrecisionPrecedence class with 3 focused tests
- Test DeepSpeed fp16/bf16 config overriding TrainingArguments defaults
- Test user explicit settings being preserved over DeepSpeed config
- Test precedence hierarchy: user settings > DeepSpeed config > defaults
- Replace massive 934-line test bloat with concise 50-line test suite
- Tests cover core functionality of PR huggingface#39856 mixed precision precedence fix
yuchenxie4645 pushed a commit to yuchenxie4645/transformers that referenced this pull request Oct 4, 2025
…ggingface#39856)

* Fix DeepSpeed mixed precision precedence over Accelerate defaults

Resolves issue where Accelerate would default to bf16 mixed precision
when a DeepSpeed config specifies fp16, causing a ValueError. The fix
ensures DeepSpeed config takes precedence over TrainingArguments defaults
while preserving explicit user settings.

Changes:
- Add override_training_args_from_deepspeed() method to handle config precedence
- Reorder mixed precision environment variable setting in TrainingArguments
- Ensure DeepSpeed fp16/bf16 settings override defaults but not explicit choices

Fixes huggingface#39849

* Add tests for DeepSpeed mixed precision precedence fix

- Add TestDeepSpeedMixedPrecisionPrecedence class with 3 focused tests
- Test DeepSpeed fp16/bf16 config overriding TrainingArguments defaults
- Test user explicit settings being preserved over DeepSpeed config
- Test precedence hierarchy: user settings > DeepSpeed config > defaults
- Replace massive 934-line test bloat with concise 50-line test suite
- Tests cover core functionality of PR huggingface#39856 mixed precision precedence fix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants