Skip to content

Commit 821bd2b

Browse files
committed
Update
[ghstack-poisoned]
1 parent a9694a5 commit 821bd2b

File tree

2 files changed

+60
-3
lines changed

2 files changed

+60
-3
lines changed

benchmarks/float8/float8_roofline.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def get_gemm_times(
180180
scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu)
181181
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
182182
else:
183-
assert False, "TODO add cutlass mx gemm here"
183+
assert False, f"unsupported {float8_recipe_name=} {mx_recipe_name=}"
184184

185185
def do_matmul(A, B):
186186
return torch._scaled_mm(
@@ -233,6 +233,20 @@ def run(
233233
print(f"mx_recipe_name: {mx_recipe_name}")
234234
print(f"enable_fusion_modeling: {enable_fusion_modeling}")
235235

236+
assert mx_recipe_name in (
237+
# real mxfp8_cublas recipe
238+
"mxfp8_cublas",
239+
# real mxfp8_cublas_rceil recipe
240+
"mxfp8_cublas_rceil",
241+
# modeling of what mxfp8 with 32x32 block size and without gemm
242+
# operand layout restrictions would look like
243+
"mxfp8_32x32_flexible_gemm_layout",
244+
# modeling of what mxfp8 with 32x32 block size for weight
245+
"mxfp8_32x32_weight",
246+
# real mxfp4_cutlass recipe
247+
"mxfp4_cutlass",
248+
), f"unsupported {mx_recipe_name=}"
249+
236250
M, K, N = sympy.symbols("M K N")
237251

238252
fp8_ovhd_time_sympy = get_float8_mem_sympy(
@@ -309,7 +323,11 @@ def run(
309323
rb_fp8_gemm_ratio = -1
310324

311325
if do_benchmarks:
312-
assert mx_recipe_name != "mxfp4_cutlass", "unsupported"
326+
assert mx_recipe_name not in (
327+
"mxfp4_cutlass",
328+
"mxfp8_32x32_flexible_gemm_layout",
329+
"mxfp8_32x32_weight",
330+
), f"do_benchmarks unsupported with {mx_recipe_name=}"
313331

314332
# TODO(future): make the bf16 gemm times exactly match the e2e
315333
# benchmarks, there is a slight deviation, probably related to gemm

torchao/testing/training/roofline_utils.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,13 +187,52 @@ def get_tensor_memory_traffic_ovhd_s(
187187
else:
188188
assert False, "unsupported"
189189

190+
elif mx_recipe_name == "mxfp8_32x32_flexible_gemm_layout":
191+
# modeling the following:
192+
# 1. mxfp8 scaling with 32x32 everywhere, so the format makes sense
193+
# across dim0 and dim1
194+
# 2. mxfp8 gemm with TN, NT, TT, NN formats supported (not in
195+
# PyTorch right now)
196+
# x_bf16 = ...
197+
# kernel 1: x_bf16 -> x_mxfp8_dim0
198+
if fuse_with_prev:
199+
kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel
200+
else:
201+
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
202+
res_bytes = [kernel_1_rw]
203+
204+
elif mx_recipe_name == "mxfp8_32x32_weight":
205+
# modeling the following:
206+
# 1. mxfp8 scaling with 32x32 weights, so the format makes sense
207+
# across dim0 and dim1. input and grad_output still 1x32.
208+
209+
if tensor_role in ("input", "grad_output"):
210+
# kernel 1: x_bf16 -> x_mxfp8_dim0
211+
# kernel 2: x_bf16 -> x_mxfp8_dim1
212+
if fuse_with_prev:
213+
kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel
214+
else:
215+
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
216+
kernel_2_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
217+
218+
elif tensor_role == "weight":
219+
# kernel 1: x_bf16 -> x_mxfp8_dim0
220+
# kernel 2: x_mxfp8_dim0 -> x_mxfp8_dim1
221+
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
222+
kernel_2_rw = BYTES_PER_EL_FLOAT8 * numel * 2
223+
224+
else:
225+
assert False, "unsupported"
226+
227+
res_bytes = [kernel_1_rw, kernel_2_rw]
228+
190229
else:
191230
assert mx_recipe_name in (
192231
"mxfp8_emulated",
193232
"mxfp8_cublas",
194233
"mxfp8_cublas_rceil",
195234
"mxfp4_cutlass",
196-
), "unsupported"
235+
), f"unsupported {mx_recipe_name=}"
197236
# For now, assume that we can't profitably fuse kernel 1 and kernel 2
198237
# x_bf16 = ...
199238
# kernel 1: x_bf16 -> x_mxfp8_dim0

0 commit comments

Comments
 (0)