1212BYTES_PER_EL_FLOAT4 = 0.5
1313BYTES_PER_EL_FLOAT8 = 1
1414BYTES_PER_EL_BF16 = 2
15+ BYTES_PER_EL_FLOAT8_E8M0 = 1
16+ BYTES_PER_EL_FLOAT32 = 4
17+ BYTES_PER_EL_FLOAT4 = 0.5
1518
1619gpu_name_to_specs = {
1720 "NVIDIA H100" : {
@@ -241,7 +244,7 @@ def get_individual_gemm_time_sympy(
241244 elif dtype is torch .float4_e2m1fn_x2 :
242245 peak_tops = specs ["fp4_peak_tops" ]
243246 else :
244- assert False , "unsupported"
247+ assert False , f "unsupported dtype: { dtype } "
245248 compute_gemm_time_s = gemm_ops / peak_tops / specs ["pct_achievable_gemm_tops" ]
246249
247250 # memory bound
@@ -274,7 +277,7 @@ def get_individual_gemm_time_sympy(
274277 elif dtype is torch .float4_e2m1fn_x2 :
275278 bytes_rw = num_reads * BYTES_PER_EL_FLOAT4 + num_writes * BYTES_PER_EL_BF16
276279 else :
277- assert False , "unsupported"
280+ assert False , f "unsupported dtype: { dtype } "
278281 mem_gemm_time_s = (
279282 bytes_rw / specs ["peak_mem_bw_bytes_sec" ] / specs ["pct_achievable_mem_bw" ]
280283 )
@@ -376,27 +379,56 @@ def get_inference_tensor_memory_traffic_ovhd_s(
376379 dim1 ,
377380 tensor_role : str ,
378381 float8_recipe_name : Optional [str ],
382+ mx_recipe_name : Optional [str ],
379383 fuse_with_prev = False ,
380384) -> List [Union [sympy .Symbol , float ]]:
381385 """
382386 Inference version of `get_tensor_memory_traffic_ovhd_s`.
383387 The only thing happening here is we quantize the activation.
384388 """
385- assert float8_recipe_name == "rowwise" , "unsupported"
386389 assert fuse_with_prev is False , "unsupported"
390+ assert tensor_role == "input" , "inference only quantizes input activations"
387391
388392 # assumes input bf16, output f8
389393 numel = dim0 * dim1
390394
391395 res_bytes = None
392396
393- assert tensor_role == "input"
394- # x_bf16 = ...
395- # kernel 1: x_bf16 -> x_fp8
396- kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
397- res_bytes = [
398- kernel_1_rw ,
399- ]
397+ if float8_recipe_name == "tensorwise" :
398+ # x_bf16 = ...
399+ # kernel 1: x_bf16 -> max_abs_stage_1 -> tmp
400+ # kernel 2 (mem traffic not modeled): tmp -> max_abs_stage_2 -> max_abs
401+ # kernel 3: x_bf16, max_abs -> to_float8 -> x_fp8
402+ # kernel 1: read numel, write 0 (assume size(tmp) ~ 0)
403+ kernel_1_rw = BYTES_PER_EL_BF16 * numel
404+ # kernel 3: read in bf16, write in float8
405+ kernel_3_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
406+ res_bytes = [kernel_1_rw , kernel_3_rw ]
407+
408+ elif float8_recipe_name == "rowwise" :
409+ # x_bf16 = ...
410+ # kernel 1: x_bf16 -> x_fp8 (with per-row scaling)
411+ kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
412+ # add in the bytes for scale writes
413+ kernel_1_rw += BYTES_PER_EL_FLOAT32 * dim0
414+ res_bytes = [kernel_1_rw ]
415+
416+ elif mx_recipe_name in ("mxfp8_emulated" , "mxfp8_cublas" , "mxfp8_cublas_rceil" ):
417+ # x_bf16 = ...
418+ # kernel 1: x_bf16 -> x_mxfp8 (block-wise scaling for inference)
419+ kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
420+ # add in the bytes for scale writes
421+ kernel_1_rw += BYTES_PER_EL_FLOAT8_E8M0 * dim0 * (dim1 // 32 )
422+ res_bytes = [kernel_1_rw ]
423+
424+ else :
425+ # For NVFP4, assume minimal overhead since it's primarily a compute format
426+ # x_bf16 = ...
427+ # kernel 1: x_bf16 -> x_nvfp4 (per-tensor scaling for inference)
428+ kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT4 * numel
429+ # add minimal scaling overhead (per-tensor scale)
430+ kernel_1_rw += BYTES_PER_EL_FLOAT32 # single scale factor
431+ res_bytes = [kernel_1_rw ]
400432
401433 # convert from bytes to seconds
402434 res_s = [
@@ -410,11 +442,75 @@ def get_inference_tensor_memory_traffic_ovhd_s(
410442 return res_s
411443
412444
445+ # def get_inference_tensor_memory_traffic_ovhd_bytes(
446+ # dim0,
447+ # dim1,
448+ # tensor_role: str,
449+ # float8_recipe_name: Optional[str],
450+ # mx_recipe_name: Optional[str],
451+ # fuse_with_prev=False,
452+ # ) -> int:
453+ # """
454+ # Get total bytes transferred for inference quantization overhead (bytes only, no time conversion).
455+ # """
456+ # assert fuse_with_prev is False, "unsupported"
457+ # assert tensor_role == "input", "inference only quantizes input activations"
458+
459+ # numel = dim0 * dim1
460+
461+ # if float8_recipe_name == "tensorwise":
462+ # # kernel 1: read numel in bf16
463+ # kernel_1_rw = BYTES_PER_EL_BF16 * numel
464+ # # kernel 3: read in bf16, write in float8
465+ # kernel_3_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
466+ # total_bytes = kernel_1_rw + kernel_3_rw
467+
468+ # elif float8_recipe_name == "rowwise":
469+ # # kernel 1: read bf16, write fp8 + scales
470+ # kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
471+ # kernel_1_rw += BYTES_PER_EL_FLOAT32 * dim0
472+ # total_bytes = kernel_1_rw
473+
474+ # elif mx_recipe_name in ("mxfp8_emulated", "mxfp8_cublas", "mxfp8_cublas_rceil"):
475+ # # kernel 1: read bf16, write fp8 + block scales
476+ # kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
477+ # kernel_1_rw += BYTES_PER_EL_FLOAT8_E8M0 * dim0 * (dim1 // 32)
478+ # total_bytes = kernel_1_rw
479+
480+ # else:
481+ # raise ValueError(f"Unsupported recipe for inference roofline: float8={float8_recipe_name}, mx={mx_recipe_name}")
482+
483+ # return total_bytes
484+
485+
486+ # def get_inference_float8_mem_bytes_sympy(
487+ # M,
488+ # K,
489+ # N,
490+ # float8_recipe_name: Optional[str],
491+ # mx_recipe_name: Optional[str] = None,
492+ # ):
493+ # """Get total bytes transferred for inference FP8 quantization overhead."""
494+ # # input @ weight_t = output
495+ # # MxK @ KxN => MxN
496+ # total_bytes = get_inference_tensor_memory_traffic_ovhd_bytes(
497+ # M,
498+ # K,
499+ # tensor_role="input",
500+ # float8_recipe_name=float8_recipe_name,
501+ # mx_recipe_name=mx_recipe_name,
502+ # fuse_with_prev=False,
503+ # )
504+ # return total_bytes
505+
506+
413507def get_inference_float8_mem_sympy (
414508 M ,
415509 K ,
416510 N ,
417511 float8_recipe_name : Optional [str ],
512+ mx_recipe_name : Optional [str ] = None ,
513+ nvfp4_recipe_name : Optional [str ] = None ,
418514 gpu_name : Optional [str ] = None ,
419515):
420516 specs = get_specs (gpu_name )
@@ -426,6 +522,7 @@ def get_inference_float8_mem_sympy(
426522 K ,
427523 tensor_role = "input" ,
428524 float8_recipe_name = float8_recipe_name ,
525+ mx_recipe_name = mx_recipe_name ,
429526 fuse_with_prev = False ,
430527 )
431528 res = sum ([* fwd_fp8_input_mem ])
@@ -438,9 +535,9 @@ def get_inference_gemm_time_sympy(
438535 N : sympy .Symbol ,
439536 dtype ,
440537 float8_recipe_name : Optional [str ],
441- gpu_name : Optional [str ],
538+ nvfp4_recipe_name : Optional [str ] = None ,
539+ gpu_name : Optional [str ] = None ,
442540):
443- assert float8_recipe_name == "rowwise" or float8_recipe_name is None , "unsupported"
444541 # note: this function is currently not super accurate for small shapes:
445542 # when M,K,N <= 1k,1k,1k it undercounts by around 2x
446543 gemm_output_time_s = get_individual_gemm_time_sympy (M , K , N , dtype , None , gpu_name )
0 commit comments