|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | +import argparse |
| 7 | +import copy |
| 8 | +from dataclasses import dataclass |
| 9 | +from itertools import product |
| 10 | +from pathlib import Path |
| 11 | +from typing import Callable, List, Optional, Tuple |
| 12 | + |
| 13 | +import pandas as pd |
| 14 | + |
| 15 | +import torch |
| 16 | +import torch.utils.benchmark as benchmark |
| 17 | +from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType |
| 18 | +from torchao.float8.float8_linear import Float8Linear |
| 19 | +from torchao.float8.float8_linear_utils import ( |
| 20 | + linear_requires_sync, |
| 21 | + sync_float8_amax_and_scale_history, |
| 22 | +) |
| 23 | +from torchao.float8.float8_tensor import ScaledMMConfig |
| 24 | +from tqdm import tqdm |
| 25 | + |
| 26 | +# estimating TOPs for matmuls in fp32, fp16, fp8 |
| 27 | +# assuming A * B = C, with A being M * K, B being K * N, C being M * N |
| 28 | + |
| 29 | +# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/ |
| 30 | +h100_peak_flops_float32 = 67e12 |
| 31 | +h100_peak_flops_fp16_tc = 1979e12 |
| 32 | +h100_peak_tops_float8_tc = 3958e12 |
| 33 | + |
| 34 | +dtype_to_peak_tops = { |
| 35 | + torch.float32: h100_peak_flops_float32, |
| 36 | + torch.float16: h100_peak_flops_fp16_tc, |
| 37 | + torch.bfloat16: h100_peak_flops_fp16_tc, |
| 38 | + torch.float8_e4m3fn: h100_peak_tops_float8_tc, |
| 39 | + torch.float8_e5m2: h100_peak_tops_float8_tc, |
| 40 | +} |
| 41 | + |
| 42 | +# prevent splitting columns when printing a data frame |
| 43 | +pd.set_option("display.expand_frame_repr", False) |
| 44 | +# print the entire data frame |
| 45 | +pd_print_full_ctx = pd.option_context( |
| 46 | + "display.max_rows", None, "display.max_columns", None |
| 47 | +) |
| 48 | + |
| 49 | + |
| 50 | +def benchmark_torch_function_in_microseconds( |
| 51 | + func: Callable, |
| 52 | + *args, |
| 53 | + **kwargs, |
| 54 | +) -> float: |
| 55 | + t0 = benchmark.Timer( |
| 56 | + stmt="func(*args, **kwargs)", |
| 57 | + globals={"args": args, "kwargs": kwargs, "func": func}, |
| 58 | + ) |
| 59 | + return t0.blocked_autorange().median * 1e6 |
| 60 | + |
| 61 | + |
| 62 | +@dataclass |
| 63 | +class Experiment: |
| 64 | + name: str |
| 65 | + shape: Tuple[int, int, int] |
| 66 | + ref_time_sec: float |
| 67 | + float8_time_sec: float |
| 68 | + dtype: torch.dtype |
| 69 | + compiled: bool |
| 70 | + use_fast_accum: bool |
| 71 | + scaling_repr: str |
| 72 | + |
| 73 | + # 3 Times since we are calculating forward backward |
| 74 | + @property |
| 75 | + def ref_tops_sec(self): |
| 76 | + M, K, N = self.shape |
| 77 | + return float(3 * (2 * M * K * N)) / self.ref_time_sec |
| 78 | + |
| 79 | + @property |
| 80 | + def ref_pct_top_peak(self): |
| 81 | + return self.ref_tops_sec / dtype_to_peak_tops[self.dtype] |
| 82 | + |
| 83 | + @property |
| 84 | + def float8_tops_sec(self): |
| 85 | + M, K, N = self.shape |
| 86 | + return float(3 * (2 * M * K * N)) / self.float8_time_sec |
| 87 | + |
| 88 | + @property |
| 89 | + def float8_pct_top_peak(self): |
| 90 | + return self.float8_tops_sec / dtype_to_peak_tops[torch.float8_e4m3fn] |
| 91 | + |
| 92 | + |
| 93 | +def main( |
| 94 | + sweep_path: Optional[Path] = None, |
| 95 | + compile: bool = True, |
| 96 | + n_limit: Optional[int] = None, |
| 97 | + fast_accum_filter: Optional[bool] = None, |
| 98 | + shape_name_filter: Optional[str] = None, |
| 99 | + scaling_type_input: str = "dynamic", |
| 100 | + scaling_type_weight: str = "dynamic", |
| 101 | + scaling_type_grad_output: str = "dynamic", |
| 102 | +): |
| 103 | + device = "cuda" |
| 104 | + print(f"Compile is set to | {compile}") |
| 105 | + |
| 106 | + scaling_type_input = ScalingType(scaling_type_input) |
| 107 | + scaling_type_weight = ScalingType(scaling_type_weight) |
| 108 | + scaling_type_grad_output = ScalingType(scaling_type_grad_output) |
| 109 | + config = Float8LinearConfig( |
| 110 | + cast_config_input=CastConfig(scaling_type=scaling_type_input), |
| 111 | + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), |
| 112 | + cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), |
| 113 | + ) |
| 114 | + |
| 115 | + # LLaMa 2 70B single-node weight shapes |
| 116 | + # assumes fused attn.wqkv and ffn.w13 |
| 117 | + name_to_shapes_70b = { |
| 118 | + "attn.wqkv": (8192, 1280), |
| 119 | + "attn.w0": (1024, 8192), |
| 120 | + "ffn.w13": (8192, 7168), |
| 121 | + "ffn.w2": (3584, 8192), |
| 122 | + } |
| 123 | + input_bias = False |
| 124 | + if fast_accum_filter is not None: |
| 125 | + use_fast_accum = [fast_accum_filter] |
| 126 | + else: |
| 127 | + use_fast_accum = [True, False] |
| 128 | + if shape_name_filter is not None: |
| 129 | + k = shape_name_filter |
| 130 | + name_to_shapes_70b = {k: name_to_shapes_70b[k]} |
| 131 | + experiment_list: List[Experiment] = [] |
| 132 | + dtype = torch.bfloat16 |
| 133 | + for idx, (fast_accum, (name, (K, N))) in enumerate( |
| 134 | + tqdm(list(product(use_fast_accum, name_to_shapes_70b.items()))) |
| 135 | + ): |
| 136 | + if n_limit is not None and idx >= n_limit: |
| 137 | + break |
| 138 | + linear_ref = torch.nn.Linear(K, N, bias=input_bias).to( |
| 139 | + device=device, dtype=dtype |
| 140 | + ) |
| 141 | + |
| 142 | + linear_float8 = Float8Linear.from_float( |
| 143 | + copy.deepcopy(linear_ref), |
| 144 | + config=config, |
| 145 | + ) |
| 146 | + scaling_repr = linear_float8.scaling_repr() |
| 147 | + |
| 148 | + if fast_accum: |
| 149 | + linear_float8.forward_config = ScaledMMConfig(False, True, False) |
| 150 | + else: |
| 151 | + linear_float8.forward_config = ScaledMMConfig(False, False, False) |
| 152 | + |
| 153 | + bsz, seq_len = 4, 4096 |
| 154 | + M = bsz * seq_len |
| 155 | + input_tensor = torch.randn(M, K, device=device, dtype=dtype, requires_grad=True) |
| 156 | + ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward() |
| 157 | + |
| 158 | + def float8_forw_backward(): |
| 159 | + if linear_requires_sync(config): |
| 160 | + sync_float8_amax_and_scale_history(linear_float8) |
| 161 | + linear_float8(input_tensor).sum().backward() |
| 162 | + |
| 163 | + def n_times(n, fn, *args, **kwargs): |
| 164 | + def wrapper(*args, **kwargs): |
| 165 | + for _ in range(n): |
| 166 | + fn(*args, **kwargs) |
| 167 | + |
| 168 | + return wrapper |
| 169 | + |
| 170 | + REPEAT_N = 100 |
| 171 | + |
| 172 | + ref_forw_backward = n_times(REPEAT_N, ref_forw_backward) |
| 173 | + float8_forw_backward = n_times(REPEAT_N, float8_forw_backward) |
| 174 | + |
| 175 | + if compile: |
| 176 | + ref_forw_backward = torch.compile(ref_forw_backward) |
| 177 | + float8_forw_backward = torch.compile(float8_forw_backward) |
| 178 | + |
| 179 | + for _ in range(5): |
| 180 | + ref_forw_backward() |
| 181 | + float8_forw_backward() |
| 182 | + |
| 183 | + ref_time = ( |
| 184 | + benchmark_torch_function_in_microseconds(ref_forw_backward) |
| 185 | + * 1e-6 |
| 186 | + / REPEAT_N |
| 187 | + ) |
| 188 | + float8_time = ( |
| 189 | + benchmark_torch_function_in_microseconds(float8_forw_backward) |
| 190 | + * 1e-6 |
| 191 | + / REPEAT_N |
| 192 | + ) |
| 193 | + experiment = Experiment( |
| 194 | + name, |
| 195 | + (M, K, N), |
| 196 | + ref_time, |
| 197 | + float8_time, |
| 198 | + dtype, |
| 199 | + compile, |
| 200 | + use_fast_accum=fast_accum, |
| 201 | + scaling_repr=scaling_repr, |
| 202 | + ) |
| 203 | + print(experiment) |
| 204 | + print("float8 speedup", experiment.ref_time_sec / experiment.float8_time_sec) |
| 205 | + experiment_list.append(experiment) |
| 206 | + torch._dynamo.reset() |
| 207 | + |
| 208 | + headers = [ |
| 209 | + "name", |
| 210 | + "M", |
| 211 | + "K", |
| 212 | + "N", |
| 213 | + "scaling_repr", |
| 214 | + "ref_dtype", |
| 215 | + "compiled", |
| 216 | + "use_fast_accum", |
| 217 | + "ref_time_sec", |
| 218 | + "pt_fp8_time_sec", |
| 219 | + "ref_tops_sec", |
| 220 | + "ref_pct_top_peak", |
| 221 | + "pt_fp8_tops_sec", |
| 222 | + "pt_fp8_pct_top_peak", |
| 223 | + ] |
| 224 | + data = [] |
| 225 | + for experiment in experiment_list: |
| 226 | + data.append( |
| 227 | + [ |
| 228 | + experiment.name, |
| 229 | + experiment.shape[0], |
| 230 | + experiment.shape[1], |
| 231 | + experiment.shape[2], |
| 232 | + experiment.scaling_repr, |
| 233 | + experiment.dtype, |
| 234 | + experiment.compiled, |
| 235 | + experiment.use_fast_accum, |
| 236 | + experiment.ref_time_sec, |
| 237 | + experiment.float8_time_sec, |
| 238 | + experiment.ref_tops_sec, |
| 239 | + experiment.ref_pct_top_peak, |
| 240 | + experiment.float8_tops_sec, |
| 241 | + experiment.float8_pct_top_peak, |
| 242 | + ] |
| 243 | + ) |
| 244 | + |
| 245 | + data_pd = pd.DataFrame(data, columns=headers) |
| 246 | + data_pd["pt_fp8_speedup"] = data_pd["ref_time_sec"] / data_pd["pt_fp8_time_sec"] |
| 247 | + data_pd["shape"] = ( |
| 248 | + "(" |
| 249 | + + data_pd["M"].astype(str) |
| 250 | + + ", " |
| 251 | + + data_pd["K"].astype(str) |
| 252 | + + ", " |
| 253 | + + data_pd["N"].astype(str) |
| 254 | + + ")" |
| 255 | + ) |
| 256 | + |
| 257 | + data_pd_simple = data_pd[ |
| 258 | + [ |
| 259 | + "name", |
| 260 | + "shape", |
| 261 | + "scaling_repr", |
| 262 | + "compiled", |
| 263 | + "use_fast_accum", |
| 264 | + "ref_time_sec", |
| 265 | + "pt_fp8_time_sec", |
| 266 | + "pt_fp8_speedup", |
| 267 | + ] |
| 268 | + ] |
| 269 | + with pd_print_full_ctx: |
| 270 | + print(data_pd_simple) |
| 271 | + |
| 272 | + if sweep_path is not None: |
| 273 | + sweep_path = sweep_path.with_suffix(".csv") |
| 274 | + data_pd.to_csv(sweep_path) |
| 275 | + |
| 276 | + |
| 277 | +def invoke_main() -> None: |
| 278 | + parser = argparse.ArgumentParser() |
| 279 | + parser.add_argument("-o", "--output_path", type=str, required=False) |
| 280 | + parser.add_argument("--disable_compile", action="store_true") |
| 281 | + parser.add_argument("-n", "--n_limit", type=int, required=False) |
| 282 | + parser.add_argument("--fast_accum_filter", type=bool, required=False) |
| 283 | + parser.add_argument("--shape_name_filter", type=str, required=False) |
| 284 | + parser.add_argument("--scaling_type_input", type=str, required=False) |
| 285 | + parser.add_argument("--scaling_type_weight", type=str, required=False) |
| 286 | + parser.add_argument("--scaling_type_grad_output", type=str, required=False) |
| 287 | + args = parser.parse_args() |
| 288 | + output_path = Path(args.output_path) if args.output_path is not None else None |
| 289 | + kwargs = {} |
| 290 | + if args.scaling_type_input is not None: |
| 291 | + kwargs["scaling_type_input"] = args.scaling_type_input |
| 292 | + if args.scaling_type_weight is not None: |
| 293 | + kwargs["scaling_type_weight"] = args.scaling_type_weight |
| 294 | + if args.scaling_type_grad_output is not None: |
| 295 | + kwargs["scaling_type_grad_output"] = args.scaling_type_grad_output |
| 296 | + main( |
| 297 | + output_path, |
| 298 | + not args.disable_compile, |
| 299 | + args.n_limit, |
| 300 | + args.fast_accum_filter, |
| 301 | + args.shape_name_filter, |
| 302 | + **kwargs, |
| 303 | + ) |
| 304 | + |
| 305 | + |
| 306 | +if __name__ == "__main__": |
| 307 | + invoke_main() # pragma: no cover |
0 commit comments