|
38 | 38 | ) |
39 | 39 |
|
40 | 40 | import torchao |
| 41 | +from torchao.prototype.mx_formats.config import ( |
| 42 | + MXGemmKernelChoice, |
| 43 | +) |
| 44 | +from torchao.prototype.mx_formats.inference_workflow import ( |
| 45 | + MXFPInferenceConfig, |
| 46 | + NVFP4InferenceConfig, |
| 47 | + NVFP4MMConfig, |
| 48 | +) |
41 | 49 | from torchao.quantization.quant_api import ( |
42 | 50 | Float8DynamicActivationFloat8WeightConfig, |
43 | 51 | PerRow, |
@@ -80,40 +88,67 @@ def get_gemm_times( |
80 | 88 | fast_accum: bool, |
81 | 89 | recipe_name: Optional[str], |
82 | 90 | ): |
83 | | - assert recipe_name in {"rowwise"}, ( |
84 | | - "Only support real benchmarks for 'rowwise' recipe for now" |
85 | | - ) |
86 | 91 | device = torch.device("cuda") |
87 | 92 |
|
88 | 93 | # bf16 time |
89 | 94 | x_bf16 = torch.randn(M, K, dtype=torch.bfloat16, device=device) |
90 | | - # w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device).t().contiguous().t() |
91 | 95 | w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device) |
92 | 96 |
|
93 | 97 | bf16_time_s = get_gpu_kernel_gemm_time_s(torch.mm, x_bf16, w_bf16) |
94 | 98 |
|
95 | | - e4m3_dtype = torch.float8_e4m3fn |
96 | | - if torch.version.hip and torch.cuda.is_available() and is_MI300(): |
97 | | - e4m3_dtype = torch.float8_e4m3fnuz |
98 | | - d1, d2, d3 = e4m3_dtype, e4m3_dtype, torch.bfloat16 |
99 | | - A = torch.randint(0, 255, (M, K), device=device, dtype=torch.uint8).view(d1) |
100 | | - B = ( |
101 | | - torch.randint(0, 255, (K, N), device=device, dtype=torch.uint8) |
102 | | - .view(d2) |
103 | | - .t() |
104 | | - .contiguous() |
105 | | - .t() |
106 | | - ) |
| 99 | + if recipe_name in ("mxfp4_cutlass", "nvfp4"): |
| 100 | + d1, d2, d3 = torch.float4_e2m1fn_x2, torch.float4_e2m1fn_x2, torch.bfloat16 |
| 101 | + A = torch.randint(0, 255, (M, K // 2), device=device, dtype=torch.uint8).view( |
| 102 | + d1 |
| 103 | + ) |
| 104 | + B = ( |
| 105 | + torch.randint(0, 255, (K // 2, N), device=device, dtype=torch.uint8) |
| 106 | + .t() |
| 107 | + .contiguous() |
| 108 | + .t() |
| 109 | + .view(d2) |
| 110 | + ) |
| 111 | + else: |
| 112 | + e4m3_dtype = torch.float8_e4m3fn |
| 113 | + if torch.version.hip and torch.cuda.is_available() and is_MI300(): |
| 114 | + e4m3_dtype = torch.float8_e4m3fnuz |
| 115 | + d1, d2, d3 = e4m3_dtype, e4m3_dtype, torch.bfloat16 |
| 116 | + A = torch.randint(0, 255, (M, K), device=device, dtype=torch.uint8).view(d1) |
| 117 | + B = ( |
| 118 | + torch.randint(0, 255, (K, N), device=device, dtype=torch.uint8) |
| 119 | + .view(d2) |
| 120 | + .t() |
| 121 | + .contiguous() |
| 122 | + .t() |
| 123 | + ) |
| 124 | + |
107 | 125 | if recipe_name == "rowwise": |
108 | 126 | scale_a = torch.ones(M, 1, device=device) |
109 | 127 | scale_b = torch.ones(1, N, device=device) |
| 128 | + elif recipe_name == "mxfp8_cublas": |
| 129 | + scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu) |
| 130 | + scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu) |
| 131 | + elif recipe_name == "mxfp4_cutlass": |
| 132 | + scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu) |
| 133 | + scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu) |
| 134 | + elif recipe_name == "nvfp4": |
| 135 | + scale_a = torch.ones(M, K // 16, device=device, dtype=torch.float8_e4m3fn) |
| 136 | + scale_b = torch.ones(N, K // 16, device=device, dtype=torch.float8_e4m3fn) |
| 137 | + |
110 | 138 | else: |
111 | 139 | assert False, "unsupported" |
112 | 140 |
|
113 | 141 | def do_matmul(A, B): |
114 | | - return torch._scaled_mm( |
115 | | - A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum |
116 | | - ) |
| 142 | + if recipe_name == "mxfp4_cutlass": |
| 143 | + return torchao.ops.mx_fp4_bf16(A, B, scale_a, scale_b) |
| 144 | + if recipe_name == "nvfp4": |
| 145 | + return torch._scaled_mm( |
| 146 | + A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=False |
| 147 | + ) |
| 148 | + else: |
| 149 | + return torch._scaled_mm( |
| 150 | + A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum |
| 151 | + ) |
117 | 152 |
|
118 | 153 | f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B) |
119 | 154 |
|
@@ -259,12 +294,33 @@ def run( |
259 | 294 | # get the float8 dynamic scaling gpu kernel time |
260 | 295 | torch._dynamo.reset() |
261 | 296 |
|
262 | | - config = Float8DynamicActivationFloat8WeightConfig( |
263 | | - granularity=PerRow(), |
264 | | - # for now, use TORCH. In the future might be interesting |
265 | | - # to benchmark AUTO and FBGEMM. |
266 | | - kernel_preference=KernelPreference.TORCH, |
267 | | - ) |
| 297 | + if recipe_name == "rowwise": |
| 298 | + config = Float8DynamicActivationFloat8WeightConfig( |
| 299 | + granularity=PerRow(), |
| 300 | + # for now, use TORCH. In the future might be interesting |
| 301 | + # to benchmark AUTO and FBGEMM. |
| 302 | + kernel_preference=KernelPreference.TORCH, |
| 303 | + ) |
| 304 | + elif recipe_name == "mxfp8_cublas": |
| 305 | + config = MXFPInferenceConfig( |
| 306 | + activation_dtype=torch.float8_e4m3fn, |
| 307 | + weight_dtype=torch.float8_e4m3fn, |
| 308 | + gemm_kernel_choice=MXGemmKernelChoice.CUBLAS, |
| 309 | + ) |
| 310 | + elif recipe_name == "mxfp4_cutlass": |
| 311 | + config = MXFPInferenceConfig( |
| 312 | + activation_dtype=torch.float4_e2m1fn_x2, |
| 313 | + weight_dtype=torch.float4_e2m1fn_x2, |
| 314 | + gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, |
| 315 | + ) |
| 316 | + elif recipe_name == "nvfp4": |
| 317 | + config = NVFP4InferenceConfig( |
| 318 | + mm_config=NVFP4MMConfig.DYNAMIC, |
| 319 | + use_dynamic_per_tensor_scale=False, |
| 320 | + ) |
| 321 | + else: |
| 322 | + assert False, "unsupported" |
| 323 | + |
268 | 324 | m_fp8_dyn = copy.deepcopy(m_orig) |
269 | 325 | quantize_(m_fp8_dyn, config) |
270 | 326 |
|
|
0 commit comments