Skip to content

Commit a89914c

Browse files
committed
move float8_experimental to torchao/float8
Summary: This PR moves https://github.com/pytorch-labs/float8_experimental to torchao/float8. There are no logic changes here. Here is how to reproduce this PR: * copy float8_experimental/float8_experimental/* to torchao/float8 * copy float8_experimental/test/* to test/float8 * copy float8_experimental/benchmarks/* to benchmarks/float8 * copy the README over and delete sections which no longer apply (license, installation) * replace `float8_experimental` with `torchao.float8` everywhere Test Plan: ``` // run local tests, they pass ./test/float8/test_everything.sh // run every benchmark in `benchmarks/float8`, they still work ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 8fa11a6 commit a89914c

35 files changed

+7628
-0
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,12 @@ In some cases we rewrote popular GenAI models to be significantly faster in nati
8585

8686
### Training
8787

88+
#### Float8
89+
90+
[torchao.float8](torchao/float8) implements training recipes with the scaled float8 dtypes, as laid out in https://arxiv.org/abs/2209.05433.
91+
92+
#### Sparsity
93+
8894
We've added support for semi-structured 2:4 sparsity with 6% end to end speedups on ViT-L
8995

9096
The code change is a 1 liner with the full example available [here](torchao/sparsity/training/)
Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import argparse
7+
import copy
8+
from dataclasses import dataclass
9+
from itertools import product
10+
from pathlib import Path
11+
from typing import Callable, List, Optional, Tuple
12+
13+
import pandas as pd
14+
15+
import torch
16+
import torch.utils.benchmark as benchmark
17+
from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
18+
from torchao.float8.float8_linear import Float8Linear
19+
from torchao.float8.float8_linear_utils import (
20+
linear_requires_sync,
21+
sync_float8_amax_and_scale_history,
22+
)
23+
from torchao.float8.float8_tensor import ScaledMMConfig
24+
from tqdm import tqdm
25+
26+
# estimating TOPs for matmuls in fp32, fp16, fp8
27+
# assuming A * B = C, with A being M * K, B being K * N, C being M * N
28+
29+
# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/
30+
h100_peak_flops_float32 = 67e12
31+
h100_peak_flops_fp16_tc = 1979e12
32+
h100_peak_tops_float8_tc = 3958e12
33+
34+
dtype_to_peak_tops = {
35+
torch.float32: h100_peak_flops_float32,
36+
torch.float16: h100_peak_flops_fp16_tc,
37+
torch.bfloat16: h100_peak_flops_fp16_tc,
38+
torch.float8_e4m3fn: h100_peak_tops_float8_tc,
39+
torch.float8_e5m2: h100_peak_tops_float8_tc,
40+
}
41+
42+
# prevent splitting columns when printing a data frame
43+
pd.set_option("display.expand_frame_repr", False)
44+
# print the entire data frame
45+
pd_print_full_ctx = pd.option_context(
46+
"display.max_rows", None, "display.max_columns", None
47+
)
48+
49+
50+
def benchmark_torch_function_in_microseconds(
51+
func: Callable,
52+
*args,
53+
**kwargs,
54+
) -> float:
55+
t0 = benchmark.Timer(
56+
stmt="func(*args, **kwargs)",
57+
globals={"args": args, "kwargs": kwargs, "func": func},
58+
)
59+
return t0.blocked_autorange().median * 1e6
60+
61+
62+
@dataclass
63+
class Experiment:
64+
name: str
65+
shape: Tuple[int, int, int]
66+
ref_time_sec: float
67+
float8_time_sec: float
68+
dtype: torch.dtype
69+
compiled: bool
70+
use_fast_accum: bool
71+
scaling_repr: str
72+
73+
# 3 Times since we are calculating forward backward
74+
@property
75+
def ref_tops_sec(self):
76+
M, K, N = self.shape
77+
return float(3 * (2 * M * K * N)) / self.ref_time_sec
78+
79+
@property
80+
def ref_pct_top_peak(self):
81+
return self.ref_tops_sec / dtype_to_peak_tops[self.dtype]
82+
83+
@property
84+
def float8_tops_sec(self):
85+
M, K, N = self.shape
86+
return float(3 * (2 * M * K * N)) / self.float8_time_sec
87+
88+
@property
89+
def float8_pct_top_peak(self):
90+
return self.float8_tops_sec / dtype_to_peak_tops[torch.float8_e4m3fn]
91+
92+
93+
def main(
94+
sweep_path: Optional[Path] = None,
95+
compile: bool = True,
96+
n_limit: Optional[int] = None,
97+
fast_accum_filter: Optional[bool] = None,
98+
shape_name_filter: Optional[str] = None,
99+
scaling_type_input: str = "dynamic",
100+
scaling_type_weight: str = "dynamic",
101+
scaling_type_grad_output: str = "dynamic",
102+
):
103+
device = "cuda"
104+
print(f"Compile is set to | {compile}")
105+
106+
scaling_type_input = ScalingType(scaling_type_input)
107+
scaling_type_weight = ScalingType(scaling_type_weight)
108+
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
109+
config = Float8LinearConfig(
110+
cast_config_input=CastConfig(scaling_type=scaling_type_input),
111+
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
112+
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
113+
)
114+
115+
# LLaMa 2 70B single-node weight shapes
116+
# assumes fused attn.wqkv and ffn.w13
117+
name_to_shapes_70b = {
118+
"attn.wqkv": (8192, 1280),
119+
"attn.w0": (1024, 8192),
120+
"ffn.w13": (8192, 7168),
121+
"ffn.w2": (3584, 8192),
122+
}
123+
input_bias = False
124+
if fast_accum_filter is not None:
125+
use_fast_accum = [fast_accum_filter]
126+
else:
127+
use_fast_accum = [True, False]
128+
if shape_name_filter is not None:
129+
k = shape_name_filter
130+
name_to_shapes_70b = {k: name_to_shapes_70b[k]}
131+
experiment_list: List[Experiment] = []
132+
dtype = torch.bfloat16
133+
for idx, (fast_accum, (name, (K, N))) in enumerate(
134+
tqdm(list(product(use_fast_accum, name_to_shapes_70b.items())))
135+
):
136+
if n_limit is not None and idx >= n_limit:
137+
break
138+
linear_ref = torch.nn.Linear(K, N, bias=input_bias).to(
139+
device=device, dtype=dtype
140+
)
141+
142+
linear_float8 = Float8Linear.from_float(
143+
copy.deepcopy(linear_ref),
144+
config=config,
145+
)
146+
scaling_repr = linear_float8.scaling_repr()
147+
148+
if fast_accum:
149+
linear_float8.forward_config = ScaledMMConfig(False, True, False)
150+
else:
151+
linear_float8.forward_config = ScaledMMConfig(False, False, False)
152+
153+
bsz, seq_len = 4, 4096
154+
M = bsz * seq_len
155+
input_tensor = torch.randn(M, K, device=device, dtype=dtype, requires_grad=True)
156+
ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward()
157+
158+
def float8_forw_backward():
159+
if linear_requires_sync(config):
160+
sync_float8_amax_and_scale_history(linear_float8)
161+
linear_float8(input_tensor).sum().backward()
162+
163+
def n_times(n, fn, *args, **kwargs):
164+
def wrapper(*args, **kwargs):
165+
for _ in range(n):
166+
fn(*args, **kwargs)
167+
168+
return wrapper
169+
170+
REPEAT_N = 100
171+
172+
ref_forw_backward = n_times(REPEAT_N, ref_forw_backward)
173+
float8_forw_backward = n_times(REPEAT_N, float8_forw_backward)
174+
175+
if compile:
176+
ref_forw_backward = torch.compile(ref_forw_backward)
177+
float8_forw_backward = torch.compile(float8_forw_backward)
178+
179+
for _ in range(5):
180+
ref_forw_backward()
181+
float8_forw_backward()
182+
183+
ref_time = (
184+
benchmark_torch_function_in_microseconds(ref_forw_backward)
185+
* 1e-6
186+
/ REPEAT_N
187+
)
188+
float8_time = (
189+
benchmark_torch_function_in_microseconds(float8_forw_backward)
190+
* 1e-6
191+
/ REPEAT_N
192+
)
193+
experiment = Experiment(
194+
name,
195+
(M, K, N),
196+
ref_time,
197+
float8_time,
198+
dtype,
199+
compile,
200+
use_fast_accum=fast_accum,
201+
scaling_repr=scaling_repr,
202+
)
203+
print(experiment)
204+
print("float8 speedup", experiment.ref_time_sec / experiment.float8_time_sec)
205+
experiment_list.append(experiment)
206+
torch._dynamo.reset()
207+
208+
headers = [
209+
"name",
210+
"M",
211+
"K",
212+
"N",
213+
"scaling_repr",
214+
"ref_dtype",
215+
"compiled",
216+
"use_fast_accum",
217+
"ref_time_sec",
218+
"pt_fp8_time_sec",
219+
"ref_tops_sec",
220+
"ref_pct_top_peak",
221+
"pt_fp8_tops_sec",
222+
"pt_fp8_pct_top_peak",
223+
]
224+
data = []
225+
for experiment in experiment_list:
226+
data.append(
227+
[
228+
experiment.name,
229+
experiment.shape[0],
230+
experiment.shape[1],
231+
experiment.shape[2],
232+
experiment.scaling_repr,
233+
experiment.dtype,
234+
experiment.compiled,
235+
experiment.use_fast_accum,
236+
experiment.ref_time_sec,
237+
experiment.float8_time_sec,
238+
experiment.ref_tops_sec,
239+
experiment.ref_pct_top_peak,
240+
experiment.float8_tops_sec,
241+
experiment.float8_pct_top_peak,
242+
]
243+
)
244+
245+
data_pd = pd.DataFrame(data, columns=headers)
246+
data_pd["pt_fp8_speedup"] = data_pd["ref_time_sec"] / data_pd["pt_fp8_time_sec"]
247+
data_pd["shape"] = (
248+
"("
249+
+ data_pd["M"].astype(str)
250+
+ ", "
251+
+ data_pd["K"].astype(str)
252+
+ ", "
253+
+ data_pd["N"].astype(str)
254+
+ ")"
255+
)
256+
257+
data_pd_simple = data_pd[
258+
[
259+
"name",
260+
"shape",
261+
"scaling_repr",
262+
"compiled",
263+
"use_fast_accum",
264+
"ref_time_sec",
265+
"pt_fp8_time_sec",
266+
"pt_fp8_speedup",
267+
]
268+
]
269+
with pd_print_full_ctx:
270+
print(data_pd_simple)
271+
272+
if sweep_path is not None:
273+
sweep_path = sweep_path.with_suffix(".csv")
274+
data_pd.to_csv(sweep_path)
275+
276+
277+
def invoke_main() -> None:
278+
parser = argparse.ArgumentParser()
279+
parser.add_argument("-o", "--output_path", type=str, required=False)
280+
parser.add_argument("--disable_compile", action="store_true")
281+
parser.add_argument("-n", "--n_limit", type=int, required=False)
282+
parser.add_argument("--fast_accum_filter", type=bool, required=False)
283+
parser.add_argument("--shape_name_filter", type=str, required=False)
284+
parser.add_argument("--scaling_type_input", type=str, required=False)
285+
parser.add_argument("--scaling_type_weight", type=str, required=False)
286+
parser.add_argument("--scaling_type_grad_output", type=str, required=False)
287+
args = parser.parse_args()
288+
output_path = Path(args.output_path) if args.output_path is not None else None
289+
kwargs = {}
290+
if args.scaling_type_input is not None:
291+
kwargs["scaling_type_input"] = args.scaling_type_input
292+
if args.scaling_type_weight is not None:
293+
kwargs["scaling_type_weight"] = args.scaling_type_weight
294+
if args.scaling_type_grad_output is not None:
295+
kwargs["scaling_type_grad_output"] = args.scaling_type_grad_output
296+
main(
297+
output_path,
298+
not args.disable_compile,
299+
args.n_limit,
300+
args.fast_accum_filter,
301+
args.shape_name_filter,
302+
**kwargs,
303+
)
304+
305+
306+
if __name__ == "__main__":
307+
invoke_main() # pragma: no cover

0 commit comments

Comments
 (0)