Skip to content

Commit b3a12be

Browse files
committed
[Benchmark] bf16 x int16 helion kernel
1 parent e5320e8 commit b3a12be

File tree

4 files changed

+295
-0
lines changed

4 files changed

+295
-0
lines changed

benchmarks/run.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,11 @@ class RunResult:
280280
"examples.jagged_sum",
281281
"jagged_sum_tritonbench",
282282
),
283+
"bf16xint16_gemm": (
284+
"tritonbench.operators.bf16xint16_gemm.bf16xint16_gemm",
285+
"examples.bf16xint16_gemm",
286+
"bf16xint16_gemm_tritonbench",
287+
),
283288
}
284289

285290

@@ -538,6 +543,16 @@ class RunResult:
538543
"helion_fp8_gemm_tritonbench-speedup": "helion_speedup",
539544
"helion_fp8_gemm_tritonbench-accuracy": "helion_accuracy",
540545
},
546+
547+
"bf16xint16_gemm": {
548+
"bf16xbf16": "baseline",
549+
"bf16xint16-speedup": "triton_speedup",
550+
"bf16xint16-accuracy": "triton_accuracy",
551+
"torch_compile_bf16xbf16-speedup": "torch_compile_speedup",
552+
"torch_compile_bf16xbf16-accuracy": "torch_compile_accuracy",
553+
"helion_bf16xint16_gemm_tritonbench-speedup": "helion_speedup",
554+
"helion_bf16xint16_gemm_tritonbench-accuracy": "helion_accuracy",
555+
},
541556
}
542557

543558

examples/bf16xint16_gemm.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
"""
2+
BF16 x INT16 GEMM with Helion
3+
============================================================
4+
The kernel performs matrix multiplication where one matrix is in bfloat16 format and the other is in int16 format.
5+
The int16 values are converted to bfloat16 before performing the matrix multiplication.
6+
"""
7+
8+
# %%
9+
from __future__ import annotations
10+
11+
import os
12+
from typing import Callable
13+
14+
import torch
15+
from torch import Tensor
16+
17+
import helion
18+
import helion.language as hl
19+
20+
# %%
21+
@helion.kernel(static_shapes=True)
22+
def _bf16xint16_gemm(x: Tensor, w: Tensor) -> Tensor:
23+
"""
24+
x is bf16, w is int16.
25+
"""
26+
M, K = x.shape
27+
K2, N = w.shape
28+
assert K == K2, f"size mismatch {K} != {K2}"
29+
30+
out = torch.empty([M, N], dtype=torch.bfloat16, device=x.device)
31+
32+
for tile_m, tile_n in hl.tile([M, N]):
33+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
34+
for tile_k in hl.tile(K):
35+
x_tile = x[tile_m, tile_k]
36+
w_tile = w[tile_k, tile_n].to(torch.bfloat16)
37+
acc = hl.dot(x_tile, w_tile, acc=acc)
38+
out[tile_m, tile_n] = acc.to(torch.bfloat16)
39+
40+
return out
41+
42+
# %%
43+
@helion.kernel(static_shapes=True)
44+
def _int16xbf16_gemm(x: Tensor, w: Tensor) -> Tensor:
45+
"""
46+
x is int16, w is bf16.
47+
"""
48+
M, K = x.shape
49+
K2, N = w.shape
50+
assert K == K2, f"size mismatch {K} != {K2}"
51+
52+
out = torch.empty([M, N], dtype=torch.bfloat16, device=x.device)
53+
54+
for tile_m, tile_n in hl.tile([M, N]):
55+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
56+
for tile_k in hl.tile(K):
57+
x_tile = x[tile_m, tile_k].to(torch.bfloat16)
58+
w_tile = w[tile_k, tile_n]
59+
acc = hl.dot(x_tile, w_tile, acc=acc)
60+
out[tile_m, tile_n] = acc.to(torch.bfloat16)
61+
62+
return out
63+
64+
# %%
65+
def bf16xint16_gemm(x: Tensor, w: Tensor, transpose: bool = False) -> Tensor:
66+
"""
67+
This function dispatches to the appropriate kernel based on the transpose flag.
68+
69+
Args:
70+
x (Tensor): Input tensor.
71+
w (Tensor): Weight tensor.
72+
transpose (bool): If True, assumes x is int16 and w is bf16. Default: False.
73+
74+
Returns:
75+
Tensor: Output tensor in bfloat16 format.
76+
"""
77+
if transpose:
78+
return _int16xbf16_gemm(x, w)
79+
else:
80+
return _bf16xint16_gemm(x, w)
81+
82+
83+
# %%
84+
def bf16xint16_gemm_tritonbench(
85+
tb_op: object, x: torch.Tensor, w: torch.Tensor
86+
) -> Callable[[], torch.Tensor]:
87+
"""
88+
Wrapper for TritonBench compatibility.
89+
90+
Args:
91+
tb_op: TritonBench operator instance
92+
x (torch.Tensor): Input tensor in bfloat16 format.
93+
w (torch.Tensor): Weight tensor in int16 format.
94+
95+
Returns:
96+
Callable that returns output tensor in bfloat16 format.
97+
"""
98+
# Check if transpose mode based on tritonbench operator
99+
transpose = getattr(tb_op, 'transpose', False)
100+
101+
def run_kernel() -> torch.Tensor:
102+
return bf16xint16_gemm(x, w, transpose=transpose)
103+
104+
return run_kernel
105+
106+
# %%
107+
def reference_bf16xint16_pytorch(x: torch.Tensor, w: torch.Tensor, transpose: bool = False) -> torch.Tensor:
108+
"""
109+
Reference implementation using PyTorch operations.
110+
111+
Args:
112+
x (torch.Tensor): Input tensor.
113+
w (torch.Tensor): Weight tensor.
114+
transpose (bool): Transpose mode flag.
115+
116+
Returns:
117+
torch.Tensor: Output tensor in bfloat16 format.
118+
"""
119+
if transpose:
120+
x_bf16 = x.to(torch.bfloat16)
121+
return torch.matmul(x_bf16, w)
122+
else:
123+
w_bf16 = w.to(torch.bfloat16)
124+
return torch.matmul(x, w_bf16)
125+
126+
127+
# %%
128+
def check(m: int, k: int, n: int) -> None:
129+
"""
130+
Test the bf16 x int16 GEMM implementation against the PyTorch reference.
131+
132+
Args:
133+
m (int): Number of rows.
134+
k (int): Shared dimension.
135+
n (int): Number of cols.
136+
"""
137+
x = torch.randn([m, k], device="cuda", dtype=torch.bfloat16)
138+
w = torch.randint(-(2**15), 2**15 - 1, (k, n), device="cuda", dtype=torch.int16)
139+
140+
141+
result = bf16xint16_gemm(x, w, transpose=False)
142+
expected = reference_bf16xint16_pytorch(x, w, transpose=False)
143+
torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2)
144+
145+
x_int16 = torch.randint(-(2**15), 2**15 - 1, (m, k), device="cuda", dtype=torch.int16)
146+
w_bf16 = torch.randn([k, n], device="cuda", dtype=torch.bfloat16)
147+
148+
result = bf16xint16_gemm(x_int16, w_bf16, transpose=True)
149+
expected = reference_bf16xint16_pytorch(x_int16, w_bf16, transpose=True)
150+
torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2)
151+
152+
153+
# %%
154+
def main() -> None:
155+
"""
156+
Main entry point that runs the bf16xint16 kernel verification with different tensor sizes.
157+
"""
158+
check(256, 256, 256)
159+
check(512, 512, 512)
160+
check(65536, 1024, 1280)
161+
162+
163+
# %%
164+
if __name__ == "__main__":
165+
main()

test/test_examples.expected

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,92 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _la
402402
_launcher(_helion_attention, (32 * triton.cdiv(512, _BLOCK_SIZE_1),), q_view, k_view, v_view, out, _BLOCK_SIZE_1, _RDIM_SIZE_2, 1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
403403
return out.view(q_in.size())
404404

405+
--- assertExpectedJournal(TestExamples.test_bf16xint16)
406+
from __future__ import annotations
407+
408+
import torch
409+
import triton
410+
import triton.language as tl
411+
from helion.runtime import default_launcher as _default_launcher
412+
413+
@triton.jit
414+
def _helion__bf16xint16_gemm(x, w, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
415+
num_blocks_0 = tl.cdiv(65536, _BLOCK_SIZE_0)
416+
pid_0 = tl.program_id(0) % num_blocks_0
417+
pid_1 = tl.program_id(0) // num_blocks_0
418+
offset_0 = pid_0 * _BLOCK_SIZE_0
419+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
420+
offset_1 = pid_1 * _BLOCK_SIZE_1
421+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
422+
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
423+
for offset_2 in tl.range(0, 1024, _BLOCK_SIZE_2):
424+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
425+
acc_copy = acc
426+
acc_copy_0 = acc_copy
427+
x_tile = tl.load(x + (indices_0[:, None] * 1024 + indices_2[None, :] * 1), None)
428+
load_1 = tl.load(w + (indices_2[:, None] * 1280 + indices_1[None, :] * 1), None)
429+
v_0 = tl.cast(load_1, tl.bfloat16)
430+
acc = tl.dot(tl.cast(x_tile, tl.bfloat16), tl.cast(v_0, tl.bfloat16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32)
431+
v_1 = tl.cast(acc, tl.bfloat16)
432+
tl.store(out + (indices_0[:, None] * 1280 + indices_1[None, :] * 1), v_1, None)
433+
434+
def _bf16xint16_gemm(x: Tensor, w: Tensor, *, _launcher=_default_launcher):
435+
"""
436+
x is bf16, w is int16.
437+
"""
438+
M, K = x.shape
439+
K2, N = w.shape
440+
assert K == K2, f'size mismatch {K} != {K2}'
441+
out = torch.empty([M, N], dtype=torch.bfloat16, device=x.device)
442+
_BLOCK_SIZE_0 = 16
443+
_BLOCK_SIZE_1 = 16
444+
_BLOCK_SIZE_2 = 16
445+
_launcher(_helion__bf16xint16_gemm, (triton.cdiv(65536, _BLOCK_SIZE_0) * triton.cdiv(1280, _BLOCK_SIZE_1),), x, w, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
446+
return out
447+
448+
--- assertExpectedJournal(TestExamples.test_bf16xint16)
449+
from __future__ import annotations
450+
451+
import torch
452+
import triton
453+
import triton.language as tl
454+
from helion.runtime import default_launcher as _default_launcher
455+
456+
@triton.jit
457+
def _helion__int16xbf16_gemm(x, w, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
458+
num_blocks_0 = tl.cdiv(65536, _BLOCK_SIZE_0)
459+
pid_0 = tl.program_id(0) % num_blocks_0
460+
pid_1 = tl.program_id(0) // num_blocks_0
461+
offset_0 = pid_0 * _BLOCK_SIZE_0
462+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
463+
offset_1 = pid_1 * _BLOCK_SIZE_1
464+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
465+
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
466+
for offset_2 in tl.range(0, 1024, _BLOCK_SIZE_2):
467+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
468+
acc_copy = acc
469+
acc_copy_0 = acc_copy
470+
load = tl.load(x + (indices_0[:, None] * 1024 + indices_2[None, :] * 1), None)
471+
v_0 = tl.cast(load, tl.bfloat16)
472+
w_tile = tl.load(w + (indices_2[:, None] * 1280 + indices_1[None, :] * 1), None)
473+
acc = tl.dot(tl.cast(v_0, tl.bfloat16), tl.cast(w_tile, tl.bfloat16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32)
474+
v_1 = tl.cast(acc, tl.bfloat16)
475+
tl.store(out + (indices_0[:, None] * 1280 + indices_1[None, :] * 1), v_1, None)
476+
477+
def _int16xbf16_gemm(x: Tensor, w: Tensor, *, _launcher=_default_launcher):
478+
"""
479+
x is int16, w is bf16.
480+
"""
481+
M, K = x.shape
482+
K2, N = w.shape
483+
assert K == K2, f'size mismatch {K} != {K2}'
484+
out = torch.empty([M, N], dtype=torch.bfloat16, device=x.device)
485+
_BLOCK_SIZE_0 = 16
486+
_BLOCK_SIZE_1 = 16
487+
_BLOCK_SIZE_2 = 16
488+
_launcher(_helion__int16xbf16_gemm, (triton.cdiv(65536, _BLOCK_SIZE_0) * triton.cdiv(1280, _BLOCK_SIZE_1),), x, w, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
489+
return out
490+
405491
--- assertExpectedJournal(TestExamples.test_bmm)
406492
from __future__ import annotations
407493

test/test_examples.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,35 @@ def test_welford(self):
308308
)
309309
)
310310

311+
def test_bf16xint16(self):
312+
from examples.bf16xint16_gemm import reference_bf16xint16_pytorch
313+
314+
m, k, n = 65536, 1024, 1280
315+
316+
x = torch.randn([m, k], device=DEVICE, dtype=torch.bfloat16)
317+
w = torch.randint(-(2**15), 2**15 - 1, (k, n), device=DEVICE, dtype=torch.int16)
318+
319+
self.assertExpectedJournal(
320+
check_example(
321+
"bf16xint16_gemm",
322+
(x, w),
323+
reference_bf16xint16_pytorch(x, w, False),
324+
fn_name="_bf16xint16_gemm",
325+
)
326+
)
327+
328+
x_int16 = torch.randint(-(2**15), 2**15 - 1, (m, k), device=DEVICE, dtype=torch.int16)
329+
w_bf16 = torch.randn([k, n], device=DEVICE, dtype=torch.bfloat16)
330+
331+
self.assertExpectedJournal(
332+
check_example(
333+
"bf16xint16_gemm",
334+
(x_int16, w_bf16),
335+
reference_bf16xint16_pytorch(x_int16, w_bf16, True),
336+
fn_name="_int16xbf16_gemm",
337+
)
338+
)
339+
311340
def test_rms_norm_fwd(self):
312341
args = (
313342
torch.randn([128, 256], device=DEVICE, dtype=torch.float16),

0 commit comments

Comments
 (0)