@@ -56,14 +56,27 @@ def __init__(
5656
5757 def get_input_iter(self):
5858 def args(m, n, k):
59- a = torch.randn(m, k, device=self.device).to(torch.float8_e4m3fn )
59+ a = torch.randn(m, k, device=self.device).to(torch.float16 )
6060 b = (
6161 torch.randn(k, n, device=self.device)
62- .to(torch.float8_e4m3fn )
62+ .to(torch.float16 )
6363 .T.contiguous()
6464 .T
6565 )
66- return (a, b)
66+
67+ if self.extra_args.scaling_rowwise:
68+ M, N = a.shape[0], b.shape[1]
69+ scale_a = torch.ones((M, 1), dtype=torch.float32, device=a.device)
70+ scale_b = torch.ones((1, N), dtype=torch.float32, device=b.device)
71+ else:
72+ scale_a = torch.tensor(1.0, device=a.device)
73+ scale_b = torch.tensor(1.0, device=a.device)
74+
75+ # Kernels expect dtype=float8_e4m3fn
76+ a = a.to(torch.float8_e4m3fn)
77+ b = b.to(torch.float8_e4m3fn)
78+
79+ return (a, b, scale_a, scale_b)
6780
6881 if (
6982 hasattr(self, "external_shapes") and self.external_shapes
@@ -90,62 +103,49 @@ def args(m, n, k):
90103 yield args(m, n, k)
91104
92105 def get_x_val(self, example_inputs) -> float:
93- a, b = example_inputs
106+ a, b, _, _ = example_inputs
94107 m, k = a.size()
95108 _, n = b.size()
96109 return (m, n, k)
97110
98- @register_benchmark(baseline=True)
99- def torch_fp8_gemm(self, a, b):
111+ def _get_out_dtype(self):
100112 if self.extra_args.scaling_rowwise:
101- M, N = a.shape[0], b.shape[1]
102- scale_a = torch.ones((M, 1), dtype=torch.float32, device=a.device)
103- scale_b = torch.ones((1, N), dtype=torch.float32, device=b.device)
104- out_dtype = torch.bfloat16
113+ return torch.bfloat16
105114 else:
106- scale_a = torch.tensor(1.0, device=a.device)
107- scale_b = torch.tensor(1.0, device=a.device)
108- out_dtype = torch.float16
115+ return torch.float16
109116
117+ @register_benchmark(baseline=True)
118+ def torch_fp8_gemm(self, a, b, scale_a, scale_b):
110119 return lambda: torch._scaled_mm(
111- a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=out_dtype
120+ a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=self._get_out_dtype()
112121 )
113122
114123 @register_benchmark()
115- def pt2_fp8_gemm(self, a, b) -> Callable:
124+ def pt2_fp8_gemm(self, a, b, scale_a, scale_b ) -> Callable:
116125 torch._dynamo.reset()
117126 with inductor_config.patch(
118127 max_autotune=True,
119128 max_autotune_gemm_backends="TRITON",
120129 autotune_fallback_to_aten=False,
121130 ):
122- if self.extra_args.scaling_rowwise:
123- M, N = a.shape[0], b.shape[1]
124- scale_a = torch.ones((M, 1), dtype=torch.float32, device=a.device)
125- scale_b = torch.ones((1, N), dtype=torch.float32, device=b.device)
126- out_dtype = torch.bfloat16
127- else:
128- scale_a = torch.tensor(1.0, device=a.device)
129- scale_b = torch.tensor(1.0, device=b.device)
130- out_dtype = torch.float16
131131 f = lambda a, b: torch._scaled_mm(
132- a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=out_dtype
132+ a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=self._get_out_dtype()
133133 )
134134 compiled = torch.compile(f, dynamic=False)
135135 compiled(a, b)
136136
137137 return lambda: compiled(a, b)
138138
139139 @register_benchmark()
140- def triton_fp8_gemm(self, a, b):
140+ def triton_fp8_gemm(self, a, b, scale_a, scale_b ):
141141 return lambda: tutorial_matmul(a, b)
142142
143143 @register_benchmark(enabled=HAS_TMA)
144- def triton_persistent_fp8_gemm(self, a, b):
144+ def triton_persistent_fp8_gemm(self, a, b, scale_a, scale_b ):
145145 return lambda: matmul_persistent(a, b)
146146
147147 @register_benchmark(enabled=HAS_TMA)
148- def triton_tma_persistent_fp8_gemm(self, a, b):
148+ def triton_tma_persistent_fp8_gemm(self, a, b, scale_a, scale_b ):
149149 b = b.T.contiguous()
150150 c, desc_a, desc_b, desc_c = allocate_matmul_tma(a, b)
151151 return lambda: matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c)
@@ -155,7 +155,7 @@ def gbps(self, fn, example_inputs: Any, metrics: BenchmarkOperatorMetrics) -> fl
155155 def nbytes(t):
156156 return t.numel() * t.element_size()
157157
158- a, b = example_inputs
158+ a, b, _, _ = example_inputs
159159 c = fn()
160160 c = c[0] if isinstance(c, tuple) else c
161161
@@ -168,7 +168,7 @@ def nbytes(t):
168168 def flops(
169169 self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
170170 ) -> float:
171- a, b = example_inputs
171+ a, b, _, _ = example_inputs
172172 m, k = a.size()
173173 _, n = b.size()
174174 flops = 2 * m * n * k
0 commit comments