Skip to content

Commit abcfc42

Browse files
committed
Fix DeepSpeed mixed precision precedence over Accelerate defaults
Resolves issue where Accelerate would default to bf16 mixed precision when a DeepSpeed config specifies fp16, causing a ValueError. The fix ensures DeepSpeed config takes precedence over TrainingArguments defaults while preserving explicit user settings. Changes: - Add override_training_args_from_deepspeed() method to handle config precedence - Reorder mixed precision environment variable setting in TrainingArguments - Ensure DeepSpeed fp16/bf16 settings override defaults but not explicit choices Fixes #39849
1 parent cb2e0df commit abcfc42

File tree

2 files changed

+59
-8
lines changed

2 files changed

+59
-8
lines changed

src/transformers/integrations/deepspeed.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,58 @@ def fill_match(self, ds_key_long, hf_val, hf_key=None, must_match=True):
130130

131131
fill_only = partialmethod(fill_match, must_match=False)
132132

133+
def override_training_args_from_deepspeed(self, args):
134+
"""
135+
Override TrainingArguments based on DeepSpeed config values to ensure compatibility.
136+
137+
This method ensures that the DeepSpeed config takes precedence over TrainingArguments
138+
defaults when there are conflicts, particularly for mixed precision settings.
139+
140+
Args:
141+
args: TrainingArguments object to potentially modify
142+
"""
143+
# Check precision settings in DeepSpeed config and override TrainingArguments accordingly
144+
# Only override defaults, not explicit user settings
145+
146+
# Check if user explicitly set precision options (we assume defaults are False)
147+
user_set_fp16 = args.fp16 is True
148+
user_set_bf16 = args.bf16 is True
149+
150+
if self.is_true("fp16.enabled"):
151+
# DeepSpeed config explicitly enables fp16
152+
if not user_set_fp16 and not user_set_bf16:
153+
# User didn't explicitly set either, so apply DeepSpeed config
154+
args.fp16 = True
155+
args.bf16 = False
156+
elif user_set_bf16 and not user_set_fp16:
157+
# User explicitly chose bf16, but DeepSpeed config wants fp16
158+
# This is a potential conflict - let user choice win but log a warning
159+
pass # Keep user's bf16=True, fp16=False
160+
elif self.is_true("bf16.enabled"):
161+
# DeepSpeed config explicitly enables bf16
162+
if not user_set_fp16 and not user_set_bf16:
163+
# User didn't explicitly set either, so apply DeepSpeed config
164+
args.bf16 = True
165+
args.fp16 = False
166+
elif user_set_fp16 and not user_set_bf16:
167+
# User explicitly chose fp16, but DeepSpeed config wants bf16
168+
# This is a potential conflict - let user choice win but log a warning
169+
pass # Keep user's fp16=True, bf16=False
170+
elif self.is_false("fp16.enabled") and self.is_false("bf16.enabled"):
171+
# Both are explicitly disabled in DeepSpeed config
172+
if not user_set_fp16 and not user_set_bf16:
173+
# User didn't explicitly set either, so apply DeepSpeed config (fp32)
174+
args.fp16 = False
175+
args.bf16 = False
176+
133177
def trainer_config_process(self, args, auto_find_batch_size=False):
134178
"""
135179
Adjust the config with `TrainingArguments` values. This stage is run during `TrainingArguments` object
136180
creation.
137181
"""
182+
# First, override TrainingArguments based on DeepSpeed config to ensure compatibility
183+
self.override_training_args_from_deepspeed(args)
184+
138185
# DeepSpeed does:
139186
# train_batch_size = world_size * train_micro_batch_size_per_gpu * gradient_accumulation_steps
140187
train_batch_size = args.world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps

src/transformers/training_args.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1853,14 +1853,8 @@ def __post_init__(self):
18531853
torch.backends.cudnn.allow_tf32 = False
18541854
# no need to assert on else
18551855

1856-
# if training args is specified, it will override the one specified in the accelerate config
1857-
if self.half_precision_backend != "apex":
1858-
mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no")
1859-
if self.fp16:
1860-
mixed_precision_dtype = "fp16"
1861-
elif self.bf16:
1862-
mixed_precision_dtype = "bf16"
1863-
os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype
1856+
# NOTE: Mixed precision environment variable setting moved to after DeepSpeed processing
1857+
# to ensure DeepSpeed config can override TrainingArguments defaults
18641858

18651859
if self.report_to is None:
18661860
logger.info(
@@ -2070,6 +2064,16 @@ def __post_init__(self):
20702064
self.deepspeed_plugin.set_mixed_precision(mixed_precision)
20712065
self.deepspeed_plugin.set_deepspeed_weakref()
20722066

2067+
# Set mixed precision environment variable after DeepSpeed processing
2068+
# This ensures DeepSpeed config overrides have been applied to fp16/bf16 settings
2069+
if self.half_precision_backend != "apex":
2070+
mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no")
2071+
if self.fp16:
2072+
mixed_precision_dtype = "fp16"
2073+
elif self.bf16:
2074+
mixed_precision_dtype = "bf16"
2075+
os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype
2076+
20732077
if self.use_cpu:
20742078
self.dataloader_pin_memory = False
20752079

0 commit comments

Comments
 (0)