@@ -187,13 +187,52 @@ def get_tensor_memory_traffic_ovhd_s(
187187 else :
188188 assert False , "unsupported"
189189
190+ elif mx_recipe_name == "mxfp8_32x32_flexible_gemm_layout" :
191+ # modeling the following:
192+ # 1. mxfp8 scaling with 32x32 everywhere, so the format makes sense
193+ # across dim0 and dim1
194+ # 2. mxfp8 gemm with TN, NT, TT, NN formats supported (not in
195+ # PyTorch right now)
196+ # x_bf16 = ...
197+ # kernel 1: x_bf16 -> x_mxfp8_dim0
198+ if fuse_with_prev :
199+ kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel
200+ else :
201+ kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
202+ res_bytes = [kernel_1_rw ]
203+
204+ elif mx_recipe_name == "mxfp8_32x32_weight" :
205+ # modeling the following:
206+ # 1. mxfp8 scaling with 32x32 weights, so the format makes sense
207+ # across dim0 and dim1. input and grad_output still 1x32.
208+
209+ if tensor_role in ("input" , "grad_output" ):
210+ # kernel 1: x_bf16 -> x_mxfp8_dim0
211+ # kernel 2: x_bf16 -> x_mxfp8_dim1
212+ if fuse_with_prev :
213+ kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel
214+ else :
215+ kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
216+ kernel_2_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
217+
218+ elif tensor_role == "weight" :
219+ # kernel 1: x_bf16 -> x_mxfp8_dim0
220+ # kernel 2: x_mxfp8_dim0 -> x_mxfp8_dim1
221+ kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
222+ kernel_2_rw = BYTES_PER_EL_FLOAT8 * numel * 2
223+
224+ else :
225+ assert False , "unsupported"
226+
227+ res_bytes = [kernel_1_rw , kernel_2_rw ]
228+
190229 else :
191230 assert mx_recipe_name in (
192231 "mxfp8_emulated" ,
193232 "mxfp8_cublas" ,
194233 "mxfp8_cublas_rceil" ,
195234 "mxfp4_cutlass" ,
196- ), "unsupported"
235+ ), f "unsupported { mx_recipe_name = } "
197236 # For now, assume that we can't profitably fuse kernel 1 and kernel 2
198237 # x_bf16 = ...
199238 # kernel 1: x_bf16 -> x_mxfp8_dim0
0 commit comments