Skip to content

Commit b012237

Browse files
authored
make float8 scaling type configurable (#489)
Summary: Adds config options to configure float8 scaling type for input, weight, grad_output. Performance is not ideal yet, but that's because we have not optimized it. Test Plan: ``` // repeat for input, weight, grad_out with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.enable_float8_linear --training.float8_scaling_type_weight delayed --training.compile ``` Reviewers: Subscribers: Tasks: Tags:
1 parent f13fe3f commit b012237

File tree

3 files changed

+67
-4
lines changed

3 files changed

+67
-4
lines changed

torchtitan/config_manager.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,7 @@ def __init__(self):
352352
"--training.enable_float8_linear",
353353
action="store_true",
354354
help="""
355-
If true, swaps `torch.nn.Linear` with `Float8Linear` with
356-
default settings (dynamic scaling).
355+
If true, swaps `torch.nn.Linear` with `Float8Linear`.
357356
This feature requires you to install 'float8_experimental' which can be found
358357
here: https://github.com/pytorch-labs/float8_experimental
359358
""",
@@ -370,6 +369,25 @@ def __init__(self):
370369
default=False,
371370
help="Whether precompute float8 scales dynamically for FSDP",
372371
)
372+
self.parser.add_argument(
373+
"--training.float8_scaling_type_input",
374+
type=str,
375+
default="dynamic",
376+
help="float8 scaling for input, dynamic (default) or delayed",
377+
choices=["dynamic", "delayed"],
378+
)
379+
self.parser.add_argument(
380+
"--training.float8_scaling_type_weight",
381+
type=str,
382+
default="dynamic",
383+
help="float8 scaling for input, dynamic (default) or delayed",
384+
)
385+
self.parser.add_argument(
386+
"--training.float8_scaling_type_grad_output",
387+
type=str,
388+
default="dynamic",
389+
help="float8 scaling for input, dynamic (default) or delayed",
390+
)
373391
self.parser.add_argument(
374392
"--training.gc_freq",
375393
type=int,

torchtitan/float8_linear.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,19 @@ def maybe_build_fp8_linear(
5959
enable_fsdp_float8_all_gather = (
6060
job_config.training.enable_fsdp_float8_all_gather and dp_enabled
6161
)
62+
scaling_type_input = ScalingType(job_config.training.float8_scaling_type_input)
63+
scaling_type_weight = ScalingType(
64+
job_config.training.float8_scaling_type_weight
65+
)
66+
scaling_type_grad_output = ScalingType(
67+
job_config.training.float8_scaling_type_grad_output
68+
)
6269
float8_config = Float8LinearConfig(
6370
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
64-
cast_config_weight=CastConfig(scaling_type=ScalingType.DYNAMIC),
71+
cast_config_input=CastConfig(scaling_type=scaling_type_input),
72+
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
73+
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
74+
enable_pre_and_post_forward=False,
6575
)
6676
convert_to_float8_training(
6777
model,
@@ -95,3 +105,34 @@ def maybe_precompute_fp8_dynamic_scale_for_fsdp(
95105
from float8_experimental import precompute_float8_dynamic_scale_for_fsdp
96106

97107
precompute_float8_dynamic_scale_for_fsdp(model)
108+
109+
110+
_sync_float8_amax_and_scale_history = None
111+
112+
113+
def maybe_sync_float8_amax_and_scale_history(model: nn.Module, job_config: JobConfig):
114+
if not (
115+
job_config.training.enable_float8_linear
116+
and (
117+
job_config.training.float8_scaling_type_input == "delayed"
118+
or job_config.training.float8_scaling_type_weight == "delayed"
119+
or job_config.training.float8_scaling_type_grad_output == "delayed"
120+
)
121+
):
122+
return
123+
124+
from float8_experimental import sync_float8_amax_and_scale_history
125+
126+
# TODO(future): see if precalculating the modules to sync over is going to
127+
# meaningfully help performance
128+
129+
global _sync_float8_amax_and_scale_history
130+
if _sync_float8_amax_and_scale_history is None:
131+
if job_config.training.compile:
132+
_sync_float8_amax_and_scale_history = torch.compile(
133+
sync_float8_amax_and_scale_history
134+
)
135+
else:
136+
_sync_float8_amax_and_scale_history = sync_float8_amax_and_scale_history
137+
138+
sync_float8_amax_and_scale_history(model)

train.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from torchtitan.float8_linear import (
3131
maybe_build_fp8_linear,
3232
maybe_precompute_fp8_dynamic_scale_for_fsdp,
33+
maybe_sync_float8_amax_and_scale_history,
3334
)
3435
from torchtitan.logging_utils import init_logger, logger
3536
from torchtitan.lr_scheduling import get_lr_schedulers
@@ -417,12 +418,15 @@ def loss_fn(pred, labels):
417418
model.parameters(), job_config.training.max_norm, foreach=True
418419
)
419420

421+
# if float8 is enabled, sync float8 amaxes and scales
422+
maybe_sync_float8_amax_and_scale_history(model, job_config)
423+
420424
# optimizer step
421425
checkpoint.wait_for_staging()
422426
optimizers.step()
423427
lr_schedulers.step()
424428

425-
# when fp8 config is on,
429+
# when float8 config is on,
426430
# calculate float8 dynamic amax/scale for all-parameter for FSDP2
427431
# it issues a single all-reduce for all parameters at once for better performance
428432
maybe_precompute_fp8_dynamic_scale_for_fsdp(model, job_config)

0 commit comments

Comments
 (0)