Skip to content

Commit b069f70

Browse files
authored
move float8 callsites to torchao.float8 (#492)
Summary: The `float8_experimental` repository moved to `torchao.float8` in pytorch/ao#551 This PR updates `torchtitan` to use float8 from the new location. Test Plan: ``` with-proxy CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 9cf4b2f commit b069f70

File tree

4 files changed

+14
-12
lines changed

4 files changed

+14
-12
lines changed

.github/workflows/integration_test_4gpu.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,6 @@ jobs:
3939
4040
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
4141
python -m pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/
42-
python -m pip install git+https://github.com/pytorch-labs/float8_experimental.git
42+
USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git
4343
mkdir artifacts-to-be-uploaded
4444
python ./test_runner.py artifacts-to-be-uploaded --ngpu 4

torchtitan/config_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,8 +353,8 @@ def __init__(self):
353353
action="store_true",
354354
help="""
355355
If true, swaps `torch.nn.Linear` with `Float8Linear`.
356-
This feature requires you to install 'float8_experimental' which can be found
357-
here: https://github.com/pytorch-labs/float8_experimental
356+
This feature requires you to install 'torchao' which can be found
357+
here: https://github.com/pytorch/ao
358358
""",
359359
)
360360
self.parser.add_argument(

torchtitan/float8_linear.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
# [Note] Getting the 'float8_experimental' package:
8-
# This script requires the 'float8_experimental' package to function correctly.
7+
# [Note] Getting the 'torchao' package:
8+
# This script requires the 'torchao' package to function correctly.
99
# Please ensure you have this package installed from the appropriate repository.
10-
# You can obtain it from https://github.com/pytorch-labs/float8_experimental.
11-
# Either clone and run `pip install .` or run `pip install git+https://github.com/pytorch-labs/float8_experimental.git`
10+
# You can obtain it from https://github.com/pytorch/ao by following the
11+
# installation instructions.
1212

1313
# Note: Performance
1414
# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance
@@ -48,7 +48,7 @@ def maybe_build_fp8_linear(
4848
)
4949
return
5050
try:
51-
from float8_experimental import (
51+
from torchao.float8 import (
5252
CastConfig,
5353
convert_to_float8_training,
5454
Float8LinearConfig,
@@ -83,7 +83,7 @@ def maybe_build_fp8_linear(
8383
)
8484
except ImportError as exc:
8585
raise ImportError(
86-
"float8_experimental is not installed. Please install it to use fp8 linear layers."
86+
"torchao is not installed. Please install it to use fp8 linear layers."
8787
) from exc
8888

8989

@@ -102,7 +102,7 @@ def maybe_precompute_fp8_dynamic_scale_for_fsdp(
102102
"Skipped precomputing fp8 scales because SM90 or later is not available",
103103
)
104104
return
105-
from float8_experimental import precompute_float8_dynamic_scale_for_fsdp
105+
from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp
106106

107107
precompute_float8_dynamic_scale_for_fsdp(model)
108108

@@ -121,7 +121,7 @@ def maybe_sync_float8_amax_and_scale_history(model: nn.Module, job_config: JobCo
121121
):
122122
return
123123

124-
from float8_experimental import sync_float8_amax_and_scale_history
124+
from torchao.float8 import sync_float8_amax_and_scale_history
125125

126126
# TODO(future): see if precalculating the modules to sync over is going to
127127
# meaningfully help performance

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,9 @@ def get_tp_parallel_strategy_for_transformer_block(
129129
# TODO(future PR): once float8 configuration supports delayed
130130
# scaling, add a check here to enforce supported float8 all-gather
131131
# configurations
132-
from float8_experimental.float8_tensor_parallel import (
132+
# TODO(future PR): add the items below to __init__.py of torchao.float8,
133+
# and import from there
134+
from torchao.float8.float8_tensor_parallel import (
133135
Float8ColwiseParallel,
134136
Float8RowwiseParallel,
135137
PrepareFloat8ModuleInput,

0 commit comments

Comments
 (0)