From 821bd2b7985f26743ef7644a60e7380cb16e8c26 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 16 Oct 2025 07:41:27 -0700 Subject: [PATCH 1/4] Update [ghstack-poisoned] --- benchmarks/float8/float8_roofline.py | 22 ++++++++++-- torchao/testing/training/roofline_utils.py | 41 +++++++++++++++++++++- 2 files changed, 60 insertions(+), 3 deletions(-) diff --git a/benchmarks/float8/float8_roofline.py b/benchmarks/float8/float8_roofline.py index 4bf54538df..547b0a40e4 100644 --- a/benchmarks/float8/float8_roofline.py +++ b/benchmarks/float8/float8_roofline.py @@ -180,7 +180,7 @@ def get_gemm_times( scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu) scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu) else: - assert False, "TODO add cutlass mx gemm here" + assert False, f"unsupported {float8_recipe_name=} {mx_recipe_name=}" def do_matmul(A, B): return torch._scaled_mm( @@ -233,6 +233,20 @@ def run( print(f"mx_recipe_name: {mx_recipe_name}") print(f"enable_fusion_modeling: {enable_fusion_modeling}") + assert mx_recipe_name in ( + # real mxfp8_cublas recipe + "mxfp8_cublas", + # real mxfp8_cublas_rceil recipe + "mxfp8_cublas_rceil", + # modeling of what mxfp8 with 32x32 block size and without gemm + # operand layout restrictions would look like + "mxfp8_32x32_flexible_gemm_layout", + # modeling of what mxfp8 with 32x32 block size for weight + "mxfp8_32x32_weight", + # real mxfp4_cutlass recipe + "mxfp4_cutlass", + ), f"unsupported {mx_recipe_name=}" + M, K, N = sympy.symbols("M K N") fp8_ovhd_time_sympy = get_float8_mem_sympy( @@ -309,7 +323,11 @@ def run( rb_fp8_gemm_ratio = -1 if do_benchmarks: - assert mx_recipe_name != "mxfp4_cutlass", "unsupported" + assert mx_recipe_name not in ( + "mxfp4_cutlass", + "mxfp8_32x32_flexible_gemm_layout", + "mxfp8_32x32_weight", + ), f"do_benchmarks unsupported with {mx_recipe_name=}" # TODO(future): make the bf16 gemm times exactly match the e2e # benchmarks, there is a slight deviation, probably related to gemm diff --git a/torchao/testing/training/roofline_utils.py b/torchao/testing/training/roofline_utils.py index f57705333a..6610654bf1 100644 --- a/torchao/testing/training/roofline_utils.py +++ b/torchao/testing/training/roofline_utils.py @@ -187,13 +187,52 @@ def get_tensor_memory_traffic_ovhd_s( else: assert False, "unsupported" + elif mx_recipe_name == "mxfp8_32x32_flexible_gemm_layout": + # modeling the following: + # 1. mxfp8 scaling with 32x32 everywhere, so the format makes sense + # across dim0 and dim1 + # 2. mxfp8 gemm with TN, NT, TT, NN formats supported (not in + # PyTorch right now) + # x_bf16 = ... + # kernel 1: x_bf16 -> x_mxfp8_dim0 + if fuse_with_prev: + kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel + else: + kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + res_bytes = [kernel_1_rw] + + elif mx_recipe_name == "mxfp8_32x32_weight": + # modeling the following: + # 1. mxfp8 scaling with 32x32 weights, so the format makes sense + # across dim0 and dim1. input and grad_output still 1x32. + + if tensor_role in ("input", "grad_output"): + # kernel 1: x_bf16 -> x_mxfp8_dim0 + # kernel 2: x_bf16 -> x_mxfp8_dim1 + if fuse_with_prev: + kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel + else: + kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + kernel_2_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + + elif tensor_role == "weight": + # kernel 1: x_bf16 -> x_mxfp8_dim0 + # kernel 2: x_mxfp8_dim0 -> x_mxfp8_dim1 + kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + kernel_2_rw = BYTES_PER_EL_FLOAT8 * numel * 2 + + else: + assert False, "unsupported" + + res_bytes = [kernel_1_rw, kernel_2_rw] + else: assert mx_recipe_name in ( "mxfp8_emulated", "mxfp8_cublas", "mxfp8_cublas_rceil", "mxfp4_cutlass", - ), "unsupported" + ), f"unsupported {mx_recipe_name=}" # For now, assume that we can't profitably fuse kernel 1 and kernel 2 # x_bf16 = ... # kernel 1: x_bf16 -> x_mxfp8_dim0 From 5bd4e3b4ff6617d6bb7eec8b13f6be99b1aeb40d Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 16 Oct 2025 13:32:59 -0700 Subject: [PATCH 2/4] Update [ghstack-poisoned] --- torchao/testing/training/roofline_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchao/testing/training/roofline_utils.py b/torchao/testing/training/roofline_utils.py index 6610654bf1..e391a4d44b 100644 --- a/torchao/testing/training/roofline_utils.py +++ b/torchao/testing/training/roofline_utils.py @@ -207,6 +207,7 @@ def get_tensor_memory_traffic_ovhd_s( # across dim0 and dim1. input and grad_output still 1x32. if tensor_role in ("input", "grad_output"): + # TODO(future): update all of the mx rooflines to just read once # kernel 1: x_bf16 -> x_mxfp8_dim0 # kernel 2: x_bf16 -> x_mxfp8_dim1 if fuse_with_prev: From ea2d54f578ef0fb39d0556699429598419ce8927 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 16 Oct 2025 14:09:19 -0700 Subject: [PATCH 3/4] Update [ghstack-poisoned] --- .../float8/float8_inference_roofline.py | 106 +++++++++++++----- 1 file changed, 81 insertions(+), 25 deletions(-) diff --git a/benchmarks/float8/float8_inference_roofline.py b/benchmarks/float8/float8_inference_roofline.py index fbfead161a..6c8113e8cb 100644 --- a/benchmarks/float8/float8_inference_roofline.py +++ b/benchmarks/float8/float8_inference_roofline.py @@ -38,6 +38,14 @@ ) import torchao +from torchao.prototype.mx_formats.config import ( + MXGemmKernelChoice, +) +from torchao.prototype.mx_formats.inference_workflow import ( + MXFPInferenceConfig, + NVFP4InferenceConfig, + NVFP4MMConfig, +) from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, PerRow, @@ -80,40 +88,67 @@ def get_gemm_times( fast_accum: bool, recipe_name: Optional[str], ): - assert recipe_name in {"rowwise"}, ( - "Only support real benchmarks for 'rowwise' recipe for now" - ) device = torch.device("cuda") # bf16 time x_bf16 = torch.randn(M, K, dtype=torch.bfloat16, device=device) - # w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device).t().contiguous().t() w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device) bf16_time_s = get_gpu_kernel_gemm_time_s(torch.mm, x_bf16, w_bf16) - e4m3_dtype = torch.float8_e4m3fn - if torch.version.hip and torch.cuda.is_available() and is_MI300(): - e4m3_dtype = torch.float8_e4m3fnuz - d1, d2, d3 = e4m3_dtype, e4m3_dtype, torch.bfloat16 - A = torch.randint(0, 255, (M, K), device=device, dtype=torch.uint8).view(d1) - B = ( - torch.randint(0, 255, (K, N), device=device, dtype=torch.uint8) - .view(d2) - .t() - .contiguous() - .t() - ) + if recipe_name in ("mxfp4_cutlass", "nvfp4"): + d1, d2, d3 = torch.float4_e2m1fn_x2, torch.float4_e2m1fn_x2, torch.bfloat16 + A = torch.randint(0, 255, (M, K // 2), device=device, dtype=torch.uint8).view( + d1 + ) + B = ( + torch.randint(0, 255, (K // 2, N), device=device, dtype=torch.uint8) + .t() + .contiguous() + .t() + .view(d2) + ) + else: + e4m3_dtype = torch.float8_e4m3fn + if torch.version.hip and torch.cuda.is_available() and is_MI300(): + e4m3_dtype = torch.float8_e4m3fnuz + d1, d2, d3 = e4m3_dtype, e4m3_dtype, torch.bfloat16 + A = torch.randint(0, 255, (M, K), device=device, dtype=torch.uint8).view(d1) + B = ( + torch.randint(0, 255, (K, N), device=device, dtype=torch.uint8) + .view(d2) + .t() + .contiguous() + .t() + ) + if recipe_name == "rowwise": scale_a = torch.ones(M, 1, device=device) scale_b = torch.ones(1, N, device=device) + elif recipe_name == "mxfp8_cublas": + scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu) + scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu) + elif recipe_name == "mxfp4_cutlass": + scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu) + scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu) + elif recipe_name == "nvfp4": + scale_a = torch.ones(M, K // 16, device=device, dtype=torch.float8_e4m3fn) + scale_b = torch.ones(N, K // 16, device=device, dtype=torch.float8_e4m3fn) + else: assert False, "unsupported" def do_matmul(A, B): - return torch._scaled_mm( - A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum - ) + if recipe_name == "mxfp4_cutlass": + return torchao.ops.mx_fp4_bf16(A, B, scale_a, scale_b) + if recipe_name == "nvfp4": + return torch._scaled_mm( + A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=False + ) + else: + return torch._scaled_mm( + A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum + ) f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B) @@ -259,12 +294,33 @@ def run( # get the float8 dynamic scaling gpu kernel time torch._dynamo.reset() - config = Float8DynamicActivationFloat8WeightConfig( - granularity=PerRow(), - # for now, use TORCH. In the future might be interesting - # to benchmark AUTO and FBGEMM. - kernel_preference=KernelPreference.TORCH, - ) + if recipe_name == "rowwise": + config = Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + # for now, use TORCH. In the future might be interesting + # to benchmark AUTO and FBGEMM. + kernel_preference=KernelPreference.TORCH, + ) + elif recipe_name == "mxfp8_cublas": + config = MXFPInferenceConfig( + activation_dtype=torch.float8_e4m3fn, + weight_dtype=torch.float8_e4m3fn, + gemm_kernel_choice=MXGemmKernelChoice.CUBLAS, + ) + elif recipe_name == "mxfp4_cutlass": + config = MXFPInferenceConfig( + activation_dtype=torch.float4_e2m1fn_x2, + weight_dtype=torch.float4_e2m1fn_x2, + gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, + ) + elif recipe_name == "nvfp4": + config = NVFP4InferenceConfig( + mm_config=NVFP4MMConfig.DYNAMIC, + use_dynamic_per_tensor_scale=False, + ) + else: + assert False, "unsupported" + m_fp8_dyn = copy.deepcopy(m_orig) quantize_(m_fp8_dyn, config) From b88850f0d83a7cac38b83868da00ddfaf2f9ab26 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 16 Oct 2025 17:44:34 -0700 Subject: [PATCH 4/4] Update [ghstack-poisoned] --- .../float8/float8_inference_roofline.py | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/benchmarks/float8/float8_inference_roofline.py b/benchmarks/float8/float8_inference_roofline.py index 6c8113e8cb..3365fba923 100644 --- a/benchmarks/float8/float8_inference_roofline.py +++ b/benchmarks/float8/float8_inference_roofline.py @@ -60,7 +60,7 @@ @torch.no_grad() -def get_gpu_kernel_time(m, x): +def get_gpu_kernel_time(m, x, trace_filename=None): # warm up for _ in range(2): __ = m(x) @@ -72,6 +72,12 @@ def get_gpu_kernel_time(m, x): for _ in range(n_iter): __ = m(x) torch.cuda.synchronize() + + # save a trace, if requested + if trace_filename is not None: + print(f"exporting trace to {trace_filename}") + prof.export_chrome_trace(trace_filename) + # get the gpu kernel time and aggregate it num_leaf_tensors = 1 + len(list(m.parameters())) ref_times = profiler_output_to_filtered_time_by_kernel_name( @@ -161,6 +167,7 @@ def run( do_benchmarks: bool = True, shape_gen_name: str = "pow2", n_limit: Optional[int] = None, + save_profile_traces: bool = False, ): """ Args: @@ -168,6 +175,7 @@ def run( * `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked * `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, or `sweep` * `n_limit (optional)`: if specified, only runs `n_limit` iterations + # `save_profile_traces (optional)`: if True, saves profiling traces """ config_table = [ ["GPU", torch.cuda.get_device_name(0)], @@ -289,7 +297,11 @@ def run( # get the bf16 gpu kernel time torch._dynamo.reset() m_bf16 = torch.compile(copy.deepcopy(m_orig)) - b_bf16_e2e_time_s = get_gpu_kernel_time(m_bf16, x) + + bf16_trace_filename = None + if save_profile_traces: + bf16_trace_filename = f"{outfile}_{M_val}_{K_val}_{N_val}_bf16.json" + b_bf16_e2e_time_s = get_gpu_kernel_time(m_bf16, x, bf16_trace_filename) # get the float8 dynamic scaling gpu kernel time torch._dynamo.reset() @@ -325,7 +337,11 @@ def run( quantize_(m_fp8_dyn, config) m_fp8_dyn = torch.compile(m_fp8_dyn) - b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x) + + fp8_trace_filename = None + if save_profile_traces: + fp8_trace_filename = f"{outfile}_{M_val}_{K_val}_{N_val}_fp8.json" + b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x, fp8_trace_filename) r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s)