Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 7accfa5

Browse files
committed
bigger sweep
1 parent 017e858 commit 7accfa5

File tree

2 files changed

+37
-24
lines changed

2 files changed

+37
-24
lines changed

benchmarks/bench_padding.py

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
import fire
55

66
import torch
7-
import torch.utils.benchmark as benchmark
7+
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig
88
from float8_experimental.float8_utils import pad_tensor_for_matmul
99
from tabulate import tabulate
10+
from torch._inductor.utils import do_bench_using_profiling
11+
from tqdm import tqdm
1012

1113
# estimating TOPs for matmuls in fp32, fp16, fp8
1214
# assuming A * B = C, with A being M * K, B being K * N, C being M * N
@@ -26,14 +28,9 @@
2628

2729

2830
def benchmark_fn_in_usec(f, *args, **kwargs):
29-
# Manual warmup
30-
for _ in range(4):
31-
f(*args, **kwargs)
32-
t0 = benchmark.Timer(
33-
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
34-
)
35-
measurement = t0.blocked_autorange()
36-
return measurement.mean * 1e6
31+
no_args = lambda: f(*args, **kwargs)
32+
time = do_bench_using_profiling(no_args)
33+
return time * 1e3
3734

3835

3936
def get_tops_info(tops, time, peak_tops):
@@ -44,23 +41,27 @@ def get_tops_info(tops, time, peak_tops):
4441

4542

4643
def do_fp8_matmul(A, B, fp8_dtype, out_dtype):
47-
A_fp8 = A.to(fp8_dtype)
48-
B_fp8 = B.to(fp8_dtype).t() # view
49-
5044
scale_a = torch.tensor([1], device="cuda", dtype=torch.float32)
5145
scale_b = torch.tensor([1], device="cuda", dtype=torch.float32)
5246

53-
A_pad = pad_tensor_for_matmul(A_fp8, dims=1) # mem copy
54-
B_pad = pad_tensor_for_matmul(B_fp8, dims=[0, 1]).contiguous().t() # mem copy
47+
a_config = ScaledMMConfig(
48+
emulate=False, use_fast_accum=True, fp8_output=True, pad_inner_dim=True
49+
)
50+
b_config = ScaledMMConfig(
51+
emulate=False, use_fast_accum=True, fp8_output=True, pad_inner_dim=True
52+
)
5553

56-
return torch._scaled_mm(A_pad, B_pad, scale_a, scale_b, out_dtype=out_dtype)[
57-
: A.shape[0], : B.shape[1]
58-
]
54+
a_fp8 = Float8Tensor.to_float8(A, scale_a, fp8_dtype, mm_config=a_config)
55+
b_fp8 = Float8Tensor.to_float8(B, scale_b, fp8_dtype, mm_config=b_config)
56+
57+
return a_fp8 @ b_fp8
5958

6059

6160
def do_fp8_pad_first_matmul(A, B, fp8_dtype, out_dtype):
61+
# Breaks with compile due to trying to pad on fp8 dtype
62+
# return do_fp8_matmul(A, B, fp8_dtype, out_dtype)
6263
A_pad = pad_tensor_for_matmul(A, dims=1) # mem copy
63-
B_pad = pad_tensor_for_matmul(B, dims=[0, 1]) # mem copy
64+
B_pad = pad_tensor_for_matmul(B, dims=0) # mem copy
6465

6566
scale_a = torch.tensor([1], device="cuda", dtype=torch.float32)
6667
scale_b = torch.tensor([1], device="cuda", dtype=torch.float32)
@@ -70,9 +71,9 @@ def do_fp8_pad_first_matmul(A, B, fp8_dtype, out_dtype):
7071

7172
B_pad = B_pad.t().contiguous().t() # mem copy
7273

73-
return torch._scaled_mm(A_pad, B_pad, scale_a, scale_b, out_dtype=out_dtype)[
74-
: A.shape[0], : B.shape[1]
75-
]
74+
return torch._scaled_mm(
75+
A_pad, B_pad, scale_a, scale_b, out_dtype=out_dtype, use_fast_accum=True
76+
)
7677

7778

7879
def do_hp_matmul(A, B):
@@ -92,7 +93,18 @@ def __iter__(self):
9293

9394

9495
def gen_configs():
95-
shapes = [(8192, 2500, 5000), (64, 255, 4096)]
96+
shapes = shapes = [
97+
(8193, 2501, 5008),
98+
(65, 253, 4096),
99+
(1023, 1029, 2512),
100+
(4095, 511, 10000),
101+
(2047, 3073, 8192),
102+
(511, 769, 7504),
103+
(127, 4097, 12288),
104+
(32769, 15, 15024),
105+
(9217, 8191, 20480),
106+
(16385, 1025, 25008),
107+
]
96108
output_dtype = torch.bfloat16
97109
fp8_dtype = torch.float8_e4m3fn
98110
return [Experiment_config(*shape, output_dtype, fp8_dtype) for shape in shapes]
@@ -112,7 +124,8 @@ def run(compile: bool = False, n_limit: Optional[int] = None):
112124
"Ref % Peak",
113125
"FP8 % Peak",
114126
]
115-
for experiment in experiments:
127+
128+
for experiment in tqdm(experiments):
116129
M, K, N, output_dtype, fp8_dtype = experiment
117130
tops = 2 * M * N * K
118131

test/test_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ class TestScaledMM:
314314
"base_dtype", [torch.float16, torch.bfloat16, torch.float32]
315315
)
316316
@pytest.mark.parametrize("use_fast_accum", [True, False])
317-
def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum, padded):
317+
def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
318318
torch.manual_seed(42)
319319
input_dtype = e4m3_dtype
320320
output_dtype = base_dtype

0 commit comments

Comments
 (0)