Skip to content

Commit d2a0328

Browse files
committed
add infernece only roofline
1 parent 596da93 commit d2a0328

File tree

3 files changed

+167
-33
lines changed

3 files changed

+167
-33
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ aten/build/
3434
aten/src/ATen/Config.h
3535
aten/src/ATen/cuda/CUDAConfig.h
3636
benchmarks/.data
37+
benchmarks/data
3738
caffe2/cpp_test/
3839
dist/
3940
docs/build/

benchmarks/float8/float8_roofline.py

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@
6666
from torchao.testing.training.roofline_utils import (
6767
get_float8_mem_sympy,
6868
get_gemm_time_sympy,
69+
get_inference_float8_mem_sympy,
70+
get_inference_gemm_time_sympy,
6971
)
7072
from torchao.utils import is_MI300
7173

@@ -206,21 +208,32 @@ def run(
206208
n_limit: Optional[int] = None,
207209
float8_recipe_name: Optional[str] = None,
208210
mx_recipe_name: Optional[str] = None,
211+
nvfp4_recipe_name: Optional[str] = None,
209212
enable_fusion_modeling: bool = False,
213+
inference_only: bool = False,
210214
):
211215
"""
212216
Args:
213217
* `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked
214218
* `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, or `sweep`
215219
* `gemm_cache_filename (optional)`: file to cache gemm benchmark results
216220
* `n_limit (optional)`: if specified, only runs `n_limit` iterations
221+
* `float8_recipe_name (optional)`: float8 quantization recipe
222+
* `mx_recipe_name (optional)`: MX format recipe
223+
* `nvfp4_recipe_name (optional)`: NVFP4 format recipe
217224
* `enable_fusion_modeling`: if False uses Linear, if True uses LNLinearSigmoid and models the fusion of float8 overhead
225+
* `inference_only`: if True, only models inference (forward pass), not training
218226
"""
219227

220-
assert not ((float8_recipe_name is not None) and (mx_recipe_name is not None)), (
221-
"unsupported"
228+
# Handle recipe specification
229+
recipe_count = sum(
230+
x is not None for x in [float8_recipe_name, mx_recipe_name, nvfp4_recipe_name]
222231
)
223-
if float8_recipe_name is None and mx_recipe_name is None:
232+
233+
# Ensure only one recipe type is specified for single runs
234+
assert recipe_count <= 1, "Only one recipe type can be specified at a time"
235+
236+
if recipe_count == 0:
224237
float8_recipe_name = "tensorwise"
225238

226239
print(f"GPU: {torch.cuda.get_device_name(0)}")
@@ -230,28 +243,48 @@ def run(
230243
print(f"shape_gen_name: {shape_gen_name}")
231244
print(f"float8_recipe_name: {float8_recipe_name}")
232245
print(f"mx_recipe_name: {mx_recipe_name}")
246+
print(f"nvfp4_recipe_name: {nvfp4_recipe_name}")
233247
print(f"enable_fusion_modeling: {enable_fusion_modeling}")
248+
print(f"inference_only: {inference_only}")
234249

235250
M, K, N = sympy.symbols("M K N")
236251

237-
fp8_ovhd_time_sympy = get_float8_mem_sympy(
238-
M,
239-
K,
240-
N,
241-
float8_recipe_name,
242-
mx_recipe_name,
243-
enable_fusion_modeling,
244-
)
245-
bf16_gemm_time_sympy = get_gemm_time_sympy(
246-
M, K, N, torch.bfloat16, None, None, None
247-
)
248-
lowp_input_dtype = torch.float8_e4m3fn
249-
if mx_recipe_name == "mxfp4_cutlass":
250-
lowp_input_dtype = torch.float4_e2m1fn_x2
252+
# Choose functions based on inference_only flag
253+
if inference_only:
254+
fp8_ovhd_time_sympy = get_inference_float8_mem_sympy(
255+
M, K, N, float8_recipe_name, mx_recipe_name, nvfp4_recipe_name
256+
)
257+
bf16_gemm_time_sympy = get_inference_gemm_time_sympy(
258+
M, K, N, torch.bfloat16, None, None
259+
)
260+
if nvfp4_recipe_name is not None:
261+
# Use FP4 for NVFP4 format
262+
fp8_gemm_time_sympy = get_inference_gemm_time_sympy(
263+
M, K, N, torch.float4_e2m1fn_x2, float8_recipe_name, nvfp4_recipe_name
264+
)
265+
else:
266+
fp8_gemm_time_sympy = get_inference_gemm_time_sympy(
267+
M, K, N, torch.float8_e4m3fn, float8_recipe_name, None
268+
)
269+
else:
270+
fp8_ovhd_time_sympy = get_float8_mem_sympy(
271+
M,
272+
K,
273+
N,
274+
float8_recipe_name,
275+
mx_recipe_name,
276+
enable_fusion_modeling,
277+
)
278+
bf16_gemm_time_sympy = get_gemm_time_sympy(
279+
M, K, N, torch.bfloat16, None, None, None
280+
)
281+
lowp_input_dtype = torch.float8_e4m3fn
282+
if mx_recipe_name == "mxfp4_cutlass":
283+
lowp_input_dtype = torch.float4_e2m1fn_x2
251284

252-
fp8_gemm_time_sympy = get_gemm_time_sympy(
253-
M, K, N, lowp_input_dtype, float8_recipe_name, mx_recipe_name, None
254-
)
285+
fp8_gemm_time_sympy = get_gemm_time_sympy(
286+
M, K, N, lowp_input_dtype, float8_recipe_name, mx_recipe_name, None
287+
)
255288
print("bf16_gemm_time_sympy", bf16_gemm_time_sympy)
256289
print("fp8_gemm_time_sympy", fp8_gemm_time_sympy)
257290
print("fp8_ovhd_time_sympy", fp8_ovhd_time_sympy)
@@ -397,6 +430,9 @@ def run(
397430
m_fp8_dyn = torch.compile(m_fp8_dyn)
398431
b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x, grad_output)
399432

433+
# Calculate roofline speedup
434+
roofline_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s)
435+
400436
results.append(
401437
[
402438
M_val,
@@ -409,7 +445,7 @@ def run(
409445
r_fp8_ovhd_time_s,
410446
# roofline - gemm + overhead, and speedup
411447
r_fp8_gemm_time_s + r_fp8_ovhd_time_s,
412-
r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s),
448+
roofline_speedup,
413449
# benchmarks - gemm
414450
b_bf16_gemm_time_s,
415451
b_fp8_gemm_time_s,

torchao/testing/training/roofline_utils.py

Lines changed: 109 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
BYTES_PER_EL_FLOAT4 = 0.5
1313
BYTES_PER_EL_FLOAT8 = 1
1414
BYTES_PER_EL_BF16 = 2
15+
BYTES_PER_EL_FLOAT8_E8M0 = 1
16+
BYTES_PER_EL_FLOAT32 = 4
17+
BYTES_PER_EL_FLOAT4 = 0.5
1518

1619
gpu_name_to_specs = {
1720
"NVIDIA H100": {
@@ -241,7 +244,7 @@ def get_individual_gemm_time_sympy(
241244
elif dtype is torch.float4_e2m1fn_x2:
242245
peak_tops = specs["fp4_peak_tops"]
243246
else:
244-
assert False, "unsupported"
247+
assert False, f"unsupported dtype: {dtype}"
245248
compute_gemm_time_s = gemm_ops / peak_tops / specs["pct_achievable_gemm_tops"]
246249

247250
# memory bound
@@ -274,7 +277,7 @@ def get_individual_gemm_time_sympy(
274277
elif dtype is torch.float4_e2m1fn_x2:
275278
bytes_rw = num_reads * BYTES_PER_EL_FLOAT4 + num_writes * BYTES_PER_EL_BF16
276279
else:
277-
assert False, "unsupported"
280+
assert False, f"unsupported dtype: {dtype}"
278281
mem_gemm_time_s = (
279282
bytes_rw / specs["peak_mem_bw_bytes_sec"] / specs["pct_achievable_mem_bw"]
280283
)
@@ -376,27 +379,56 @@ def get_inference_tensor_memory_traffic_ovhd_s(
376379
dim1,
377380
tensor_role: str,
378381
float8_recipe_name: Optional[str],
382+
mx_recipe_name: Optional[str],
379383
fuse_with_prev=False,
380384
) -> List[Union[sympy.Symbol, float]]:
381385
"""
382386
Inference version of `get_tensor_memory_traffic_ovhd_s`.
383387
The only thing happening here is we quantize the activation.
384388
"""
385-
assert float8_recipe_name == "rowwise", "unsupported"
386389
assert fuse_with_prev is False, "unsupported"
390+
assert tensor_role == "input", "inference only quantizes input activations"
387391

388392
# assumes input bf16, output f8
389393
numel = dim0 * dim1
390394

391395
res_bytes = None
392396

393-
assert tensor_role == "input"
394-
# x_bf16 = ...
395-
# kernel 1: x_bf16 -> x_fp8
396-
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
397-
res_bytes = [
398-
kernel_1_rw,
399-
]
397+
if float8_recipe_name == "tensorwise":
398+
# x_bf16 = ...
399+
# kernel 1: x_bf16 -> max_abs_stage_1 -> tmp
400+
# kernel 2 (mem traffic not modeled): tmp -> max_abs_stage_2 -> max_abs
401+
# kernel 3: x_bf16, max_abs -> to_float8 -> x_fp8
402+
# kernel 1: read numel, write 0 (assume size(tmp) ~ 0)
403+
kernel_1_rw = BYTES_PER_EL_BF16 * numel
404+
# kernel 3: read in bf16, write in float8
405+
kernel_3_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
406+
res_bytes = [kernel_1_rw, kernel_3_rw]
407+
408+
elif float8_recipe_name == "rowwise":
409+
# x_bf16 = ...
410+
# kernel 1: x_bf16 -> x_fp8 (with per-row scaling)
411+
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
412+
# add in the bytes for scale writes
413+
kernel_1_rw += BYTES_PER_EL_FLOAT32 * dim0
414+
res_bytes = [kernel_1_rw]
415+
416+
elif mx_recipe_name in ("mxfp8_emulated", "mxfp8_cublas", "mxfp8_cublas_rceil"):
417+
# x_bf16 = ...
418+
# kernel 1: x_bf16 -> x_mxfp8 (block-wise scaling for inference)
419+
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
420+
# add in the bytes for scale writes
421+
kernel_1_rw += BYTES_PER_EL_FLOAT8_E8M0 * dim0 * (dim1 // 32)
422+
res_bytes = [kernel_1_rw]
423+
424+
else:
425+
# For NVFP4, assume minimal overhead since it's primarily a compute format
426+
# x_bf16 = ...
427+
# kernel 1: x_bf16 -> x_nvfp4 (per-tensor scaling for inference)
428+
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT4 * numel
429+
# add minimal scaling overhead (per-tensor scale)
430+
kernel_1_rw += BYTES_PER_EL_FLOAT32 # single scale factor
431+
res_bytes = [kernel_1_rw]
400432

401433
# convert from bytes to seconds
402434
res_s = [
@@ -410,11 +442,75 @@ def get_inference_tensor_memory_traffic_ovhd_s(
410442
return res_s
411443

412444

445+
# def get_inference_tensor_memory_traffic_ovhd_bytes(
446+
# dim0,
447+
# dim1,
448+
# tensor_role: str,
449+
# float8_recipe_name: Optional[str],
450+
# mx_recipe_name: Optional[str],
451+
# fuse_with_prev=False,
452+
# ) -> int:
453+
# """
454+
# Get total bytes transferred for inference quantization overhead (bytes only, no time conversion).
455+
# """
456+
# assert fuse_with_prev is False, "unsupported"
457+
# assert tensor_role == "input", "inference only quantizes input activations"
458+
459+
# numel = dim0 * dim1
460+
461+
# if float8_recipe_name == "tensorwise":
462+
# # kernel 1: read numel in bf16
463+
# kernel_1_rw = BYTES_PER_EL_BF16 * numel
464+
# # kernel 3: read in bf16, write in float8
465+
# kernel_3_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
466+
# total_bytes = kernel_1_rw + kernel_3_rw
467+
468+
# elif float8_recipe_name == "rowwise":
469+
# # kernel 1: read bf16, write fp8 + scales
470+
# kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
471+
# kernel_1_rw += BYTES_PER_EL_FLOAT32 * dim0
472+
# total_bytes = kernel_1_rw
473+
474+
# elif mx_recipe_name in ("mxfp8_emulated", "mxfp8_cublas", "mxfp8_cublas_rceil"):
475+
# # kernel 1: read bf16, write fp8 + block scales
476+
# kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
477+
# kernel_1_rw += BYTES_PER_EL_FLOAT8_E8M0 * dim0 * (dim1 // 32)
478+
# total_bytes = kernel_1_rw
479+
480+
# else:
481+
# raise ValueError(f"Unsupported recipe for inference roofline: float8={float8_recipe_name}, mx={mx_recipe_name}")
482+
483+
# return total_bytes
484+
485+
486+
# def get_inference_float8_mem_bytes_sympy(
487+
# M,
488+
# K,
489+
# N,
490+
# float8_recipe_name: Optional[str],
491+
# mx_recipe_name: Optional[str] = None,
492+
# ):
493+
# """Get total bytes transferred for inference FP8 quantization overhead."""
494+
# # input @ weight_t = output
495+
# # MxK @ KxN => MxN
496+
# total_bytes = get_inference_tensor_memory_traffic_ovhd_bytes(
497+
# M,
498+
# K,
499+
# tensor_role="input",
500+
# float8_recipe_name=float8_recipe_name,
501+
# mx_recipe_name=mx_recipe_name,
502+
# fuse_with_prev=False,
503+
# )
504+
# return total_bytes
505+
506+
413507
def get_inference_float8_mem_sympy(
414508
M,
415509
K,
416510
N,
417511
float8_recipe_name: Optional[str],
512+
mx_recipe_name: Optional[str] = None,
513+
nvfp4_recipe_name: Optional[str] = None,
418514
gpu_name: Optional[str] = None,
419515
):
420516
specs = get_specs(gpu_name)
@@ -426,6 +522,7 @@ def get_inference_float8_mem_sympy(
426522
K,
427523
tensor_role="input",
428524
float8_recipe_name=float8_recipe_name,
525+
mx_recipe_name=mx_recipe_name,
429526
fuse_with_prev=False,
430527
)
431528
res = sum([*fwd_fp8_input_mem])
@@ -438,9 +535,9 @@ def get_inference_gemm_time_sympy(
438535
N: sympy.Symbol,
439536
dtype,
440537
float8_recipe_name: Optional[str],
441-
gpu_name: Optional[str],
538+
nvfp4_recipe_name: Optional[str] = None,
539+
gpu_name: Optional[str] = None,
442540
):
443-
assert float8_recipe_name == "rowwise" or float8_recipe_name is None, "unsupported"
444541
# note: this function is currently not super accurate for small shapes:
445542
# when M,K,N <= 1k,1k,1k it undercounts by around 2x
446543
gemm_output_time_s = get_individual_gemm_time_sympy(M, K, N, dtype, None, gpu_name)

0 commit comments

Comments
 (0)