4
4
import fire
5
5
6
6
import torch
7
- import torch . utils . benchmark as benchmark
7
+ from float8_experimental . float8_tensor import Float8Tensor , ScaledMMConfig
8
8
from float8_experimental .float8_utils import pad_tensor_for_matmul
9
9
from tabulate import tabulate
10
+ from torch ._inductor .utils import do_bench_using_profiling
11
+ from tqdm import tqdm
10
12
11
13
# estimating TOPs for matmuls in fp32, fp16, fp8
12
14
# assuming A * B = C, with A being M * K, B being K * N, C being M * N
26
28
27
29
28
30
def benchmark_fn_in_usec (f , * args , ** kwargs ):
29
- # Manual warmup
30
- for _ in range (4 ):
31
- f (* args , ** kwargs )
32
- t0 = benchmark .Timer (
33
- stmt = "f(*args, **kwargs)" , globals = {"args" : args , "kwargs" : kwargs , "f" : f }
34
- )
35
- measurement = t0 .blocked_autorange ()
36
- return measurement .mean * 1e6
31
+ no_args = lambda : f (* args , ** kwargs )
32
+ time = do_bench_using_profiling (no_args )
33
+ return time * 1e3
37
34
38
35
39
36
def get_tops_info (tops , time , peak_tops ):
@@ -44,23 +41,27 @@ def get_tops_info(tops, time, peak_tops):
44
41
45
42
46
43
def do_fp8_matmul (A , B , fp8_dtype , out_dtype ):
47
- A_fp8 = A .to (fp8_dtype )
48
- B_fp8 = B .to (fp8_dtype ).t () # view
49
-
50
44
scale_a = torch .tensor ([1 ], device = "cuda" , dtype = torch .float32 )
51
45
scale_b = torch .tensor ([1 ], device = "cuda" , dtype = torch .float32 )
52
46
53
- A_pad = pad_tensor_for_matmul (A_fp8 , dims = 1 ) # mem copy
54
- B_pad = pad_tensor_for_matmul (B_fp8 , dims = [0 , 1 ]).contiguous ().t () # mem copy
47
+ a_config = ScaledMMConfig (
48
+ emulate = False , use_fast_accum = True , fp8_output = True , pad_inner_dim = True
49
+ )
50
+ b_config = ScaledMMConfig (
51
+ emulate = False , use_fast_accum = True , fp8_output = True , pad_inner_dim = True
52
+ )
55
53
56
- return torch ._scaled_mm (A_pad , B_pad , scale_a , scale_b , out_dtype = out_dtype )[
57
- : A .shape [0 ], : B .shape [1 ]
58
- ]
54
+ a_fp8 = Float8Tensor .to_float8 (A , scale_a , fp8_dtype , mm_config = a_config )
55
+ b_fp8 = Float8Tensor .to_float8 (B , scale_b , fp8_dtype , mm_config = b_config )
56
+
57
+ return a_fp8 @ b_fp8
59
58
60
59
61
60
def do_fp8_pad_first_matmul (A , B , fp8_dtype , out_dtype ):
61
+ # Breaks with compile due to trying to pad on fp8 dtype
62
+ # return do_fp8_matmul(A, B, fp8_dtype, out_dtype)
62
63
A_pad = pad_tensor_for_matmul (A , dims = 1 ) # mem copy
63
- B_pad = pad_tensor_for_matmul (B , dims = [ 0 , 1 ] ) # mem copy
64
+ B_pad = pad_tensor_for_matmul (B , dims = 0 ) # mem copy
64
65
65
66
scale_a = torch .tensor ([1 ], device = "cuda" , dtype = torch .float32 )
66
67
scale_b = torch .tensor ([1 ], device = "cuda" , dtype = torch .float32 )
@@ -70,9 +71,9 @@ def do_fp8_pad_first_matmul(A, B, fp8_dtype, out_dtype):
70
71
71
72
B_pad = B_pad .t ().contiguous ().t () # mem copy
72
73
73
- return torch ._scaled_mm (A_pad , B_pad , scale_a , scale_b , out_dtype = out_dtype )[
74
- : A . shape [ 0 ], : B . shape [ 1 ]
75
- ]
74
+ return torch ._scaled_mm (
75
+ A_pad , B_pad , scale_a , scale_b , out_dtype = out_dtype , use_fast_accum = True
76
+ )
76
77
77
78
78
79
def do_hp_matmul (A , B ):
@@ -92,7 +93,18 @@ def __iter__(self):
92
93
93
94
94
95
def gen_configs ():
95
- shapes = [(8192 , 2500 , 5000 ), (64 , 255 , 4096 )]
96
+ shapes = shapes = [
97
+ (8193 , 2501 , 5008 ),
98
+ (65 , 253 , 4096 ),
99
+ (1023 , 1029 , 2512 ),
100
+ (4095 , 511 , 10000 ),
101
+ (2047 , 3073 , 8192 ),
102
+ (511 , 769 , 7504 ),
103
+ (127 , 4097 , 12288 ),
104
+ (32769 , 15 , 15024 ),
105
+ (9217 , 8191 , 20480 ),
106
+ (16385 , 1025 , 25008 ),
107
+ ]
96
108
output_dtype = torch .bfloat16
97
109
fp8_dtype = torch .float8_e4m3fn
98
110
return [Experiment_config (* shape , output_dtype , fp8_dtype ) for shape in shapes ]
@@ -112,7 +124,8 @@ def run(compile: bool = False, n_limit: Optional[int] = None):
112
124
"Ref % Peak" ,
113
125
"FP8 % Peak" ,
114
126
]
115
- for experiment in experiments :
127
+
128
+ for experiment in tqdm (experiments ):
116
129
M , K , N , output_dtype , fp8_dtype = experiment
117
130
tops = 2 * M * N * K
118
131
0 commit comments