@@ -187,13 +187,53 @@ def get_tensor_memory_traffic_ovhd_s(
187187 else :
188188 assert False , "unsupported"
189189
190+ elif mx_recipe_name == "mxfp8_32x32_flexible_gemm_layout" :
191+ # modeling the following:
192+ # 1. mxfp8 scaling with 32x32 everywhere, so the format makes sense
193+ # across dim0 and dim1
194+ # 2. mxfp8 gemm with TN, NT, TT, NN formats supported (not in
195+ # PyTorch right now)
196+ # x_bf16 = ...
197+ # kernel 1: x_bf16 -> x_mxfp8_dim0
198+ if fuse_with_prev :
199+ kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel
200+ else :
201+ kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
202+ res_bytes = [kernel_1_rw ]
203+
204+ elif mx_recipe_name == "mxfp8_32x32_weight" :
205+ # modeling the following:
206+ # 1. mxfp8 scaling with 32x32 weights, so the format makes sense
207+ # across dim0 and dim1. input and grad_output still 1x32.
208+
209+ if tensor_role in ("input" , "grad_output" ):
210+ # TODO(future): update all of the mx rooflines to just read once
211+ # kernel 1: x_bf16 -> x_mxfp8_dim0
212+ # kernel 2: x_bf16 -> x_mxfp8_dim1
213+ if fuse_with_prev :
214+ kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel
215+ else :
216+ kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
217+ kernel_2_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
218+
219+ elif tensor_role == "weight" :
220+ # kernel 1: x_bf16 -> x_mxfp8_dim0
221+ # kernel 2: x_mxfp8_dim0 -> x_mxfp8_dim1
222+ kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
223+ kernel_2_rw = BYTES_PER_EL_FLOAT8 * numel * 2
224+
225+ else :
226+ assert False , "unsupported"
227+
228+ res_bytes = [kernel_1_rw , kernel_2_rw ]
229+
190230 else :
191231 assert mx_recipe_name in (
192232 "mxfp8_emulated" ,
193233 "mxfp8_cublas" ,
194234 "mxfp8_cublas_rceil" ,
195235 "mxfp4_cutlass" ,
196- ), "unsupported"
236+ ), f "unsupported { mx_recipe_name = } "
197237 # For now, assume that we can't profitably fuse kernel 1 and kernel 2
198238 # x_bf16 = ...
199239 # kernel 1: x_bf16 -> x_mxfp8_dim0
0 commit comments