Skip to content

Commit be2cf1f

Browse files
committed
extend inference roofline with real benchmarks
Summary: as titled, hooks up the real linear benchmarks so we can compare to roofline lots of low hanging fruit for mxfp8 and mxfp4 as we haven't really optimized them, nvfp4 is looking better Test Plan: ``` https://www.internalfb.com/phabricator/paste/view/P1995615109 ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 9e2e905 ghstack-comment-id: 3412858656 Pull-Request: #3194
1 parent 2eceb9c commit be2cf1f

File tree

1 file changed

+81
-25
lines changed

1 file changed

+81
-25
lines changed

benchmarks/float8/float8_inference_roofline.py

Lines changed: 81 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@
3838
)
3939

4040
import torchao
41+
from torchao.prototype.mx_formats.config import (
42+
MXGemmKernelChoice,
43+
)
44+
from torchao.prototype.mx_formats.inference_workflow import (
45+
MXFPInferenceConfig,
46+
NVFP4InferenceConfig,
47+
NVFP4MMConfig,
48+
)
4149
from torchao.quantization.quant_api import (
4250
Float8DynamicActivationFloat8WeightConfig,
4351
PerRow,
@@ -80,40 +88,67 @@ def get_gemm_times(
8088
fast_accum: bool,
8189
recipe_name: Optional[str],
8290
):
83-
assert recipe_name in {"rowwise"}, (
84-
"Only support real benchmarks for 'rowwise' recipe for now"
85-
)
8691
device = torch.device("cuda")
8792

8893
# bf16 time
8994
x_bf16 = torch.randn(M, K, dtype=torch.bfloat16, device=device)
90-
# w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device).t().contiguous().t()
9195
w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device)
9296

9397
bf16_time_s = get_gpu_kernel_gemm_time_s(torch.mm, x_bf16, w_bf16)
9498

95-
e4m3_dtype = torch.float8_e4m3fn
96-
if torch.version.hip and torch.cuda.is_available() and is_MI300():
97-
e4m3_dtype = torch.float8_e4m3fnuz
98-
d1, d2, d3 = e4m3_dtype, e4m3_dtype, torch.bfloat16
99-
A = torch.randint(0, 255, (M, K), device=device, dtype=torch.uint8).view(d1)
100-
B = (
101-
torch.randint(0, 255, (K, N), device=device, dtype=torch.uint8)
102-
.view(d2)
103-
.t()
104-
.contiguous()
105-
.t()
106-
)
99+
if recipe_name in ("mxfp4_cutlass", "nvfp4"):
100+
d1, d2, d3 = torch.float4_e2m1fn_x2, torch.float4_e2m1fn_x2, torch.bfloat16
101+
A = torch.randint(0, 255, (M, K // 2), device=device, dtype=torch.uint8).view(
102+
d1
103+
)
104+
B = (
105+
torch.randint(0, 255, (K // 2, N), device=device, dtype=torch.uint8)
106+
.t()
107+
.contiguous()
108+
.t()
109+
.view(d2)
110+
)
111+
else:
112+
e4m3_dtype = torch.float8_e4m3fn
113+
if torch.version.hip and torch.cuda.is_available() and is_MI300():
114+
e4m3_dtype = torch.float8_e4m3fnuz
115+
d1, d2, d3 = e4m3_dtype, e4m3_dtype, torch.bfloat16
116+
A = torch.randint(0, 255, (M, K), device=device, dtype=torch.uint8).view(d1)
117+
B = (
118+
torch.randint(0, 255, (K, N), device=device, dtype=torch.uint8)
119+
.view(d2)
120+
.t()
121+
.contiguous()
122+
.t()
123+
)
124+
107125
if recipe_name == "rowwise":
108126
scale_a = torch.ones(M, 1, device=device)
109127
scale_b = torch.ones(1, N, device=device)
128+
elif recipe_name == "mxfp8_cublas":
129+
scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu)
130+
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
131+
elif recipe_name == "mxfp4_cutlass":
132+
scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu)
133+
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
134+
elif recipe_name == "nvfp4":
135+
scale_a = torch.ones(M, K // 16, device=device, dtype=torch.float8_e4m3fn)
136+
scale_b = torch.ones(N, K // 16, device=device, dtype=torch.float8_e4m3fn)
137+
110138
else:
111139
assert False, "unsupported"
112140

113141
def do_matmul(A, B):
114-
return torch._scaled_mm(
115-
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
116-
)
142+
if recipe_name == "mxfp4_cutlass":
143+
return torchao.ops.mx_fp4_bf16(A, B, scale_a, scale_b)
144+
if recipe_name == "nvfp4":
145+
return torch._scaled_mm(
146+
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=False
147+
)
148+
else:
149+
return torch._scaled_mm(
150+
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
151+
)
117152

118153
f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B)
119154

@@ -259,12 +294,33 @@ def run(
259294
# get the float8 dynamic scaling gpu kernel time
260295
torch._dynamo.reset()
261296

262-
config = Float8DynamicActivationFloat8WeightConfig(
263-
granularity=PerRow(),
264-
# for now, use TORCH. In the future might be interesting
265-
# to benchmark AUTO and FBGEMM.
266-
kernel_preference=KernelPreference.TORCH,
267-
)
297+
if recipe_name == "rowwise":
298+
config = Float8DynamicActivationFloat8WeightConfig(
299+
granularity=PerRow(),
300+
# for now, use TORCH. In the future might be interesting
301+
# to benchmark AUTO and FBGEMM.
302+
kernel_preference=KernelPreference.TORCH,
303+
)
304+
elif recipe_name == "mxfp8_cublas":
305+
config = MXFPInferenceConfig(
306+
activation_dtype=torch.float8_e4m3fn,
307+
weight_dtype=torch.float8_e4m3fn,
308+
gemm_kernel_choice=MXGemmKernelChoice.CUBLAS,
309+
)
310+
elif recipe_name == "mxfp4_cutlass":
311+
config = MXFPInferenceConfig(
312+
activation_dtype=torch.float4_e2m1fn_x2,
313+
weight_dtype=torch.float4_e2m1fn_x2,
314+
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS,
315+
)
316+
elif recipe_name == "nvfp4":
317+
config = NVFP4InferenceConfig(
318+
mm_config=NVFP4MMConfig.DYNAMIC,
319+
use_dynamic_per_tensor_scale=False,
320+
)
321+
else:
322+
assert False, "unsupported"
323+
268324
m_fp8_dyn = copy.deepcopy(m_orig)
269325
quantize_(m_fp8_dyn, config)
270326

0 commit comments

Comments
 (0)