Skip to content

Commit 5902317

Browse files
authored
[Example] int4_gemm kernel example and tritonbench integration (#613)
1 parent 0a33998 commit 5902317

File tree

4 files changed

+306
-0
lines changed

4 files changed

+306
-0
lines changed

benchmarks/run.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,11 @@ class RunResult:
166166
"examples.welford",
167167
"welford",
168168
),
169+
"int4_gemm": (
170+
"tritonbench.operators.int4_gemm.int4_gemm",
171+
"examples.int4_gemm",
172+
"int4_gemm_tritonbench",
173+
),
169174
}
170175

171176

@@ -266,6 +271,14 @@ class RunResult:
266271
"helion_kl_div_tritonbench-speedup": "helion_speedup",
267272
"helion_kl_div_tritonbench-accuracy": "helion_accuracy",
268273
},
274+
"int4_gemm": {
275+
"triton_int4_gemm-speedup": "triton_speedup",
276+
"triton_int4_gemm-accuracy": "triton_accuracy",
277+
"torch_compile_int4_gemm-speedup": "torch_compile_speedup",
278+
"torch_compile_int4_gemm-accuracy": "torch_compile_accuracy",
279+
"helion_int4_gemm_tritonbench-speedup": "helion_speedup",
280+
"helion_int4_gemm_tritonbench-accuracy": "helion_accuracy",
281+
},
269282
}
270283

271284

examples/int4_gemm.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
"""
2+
INT4 General Matrix Multiplication (GEMM) with Helion
3+
=====================================================
4+
This example demonstrates an INT4 GEMM kernel implemented in Helion. The kernel performs
5+
matrix multiplication where the second matrix B is packed with two 4-bit values per byte.
6+
The kernel unpacks the int4 values, converts to bfloat16, and performs matmul with
7+
the bfloat16 matrix A.
8+
"""
9+
10+
# %%
11+
# Imports
12+
# -------
13+
from __future__ import annotations
14+
15+
from typing import Callable
16+
17+
import torch
18+
from torch import Tensor
19+
20+
import helion
21+
import helion.language as hl
22+
23+
24+
# %%
25+
# INT4 GEMM Kernel
26+
# ----------------
27+
@helion.kernel(
28+
use_default_config=True,
29+
static_shapes=False, # Allow dynamic shapes to handle different input sizes
30+
)
31+
def matmul_bf16_int4(A: Tensor, B: Tensor) -> Tensor:
32+
"""
33+
BFloat16 x INT4 General Matrix Multiplication (GEMM).
34+
35+
This kernel performs matrix multiplication where:
36+
- A is a bfloat16 matrix of shape [M, K]
37+
- B is an int8 matrix of shape [K//2, N] containing packed int4 values
38+
(two 4-bit values packed into each int8)
39+
40+
Args:
41+
A (Tensor): Input tensor of shape [M, K] in bfloat16 format.
42+
B (Tensor): Packed int4 tensor of shape [K//2, N] in int8 format.
43+
44+
Returns:
45+
Tensor: Output tensor of shape [M, N] in bfloat16 format.
46+
"""
47+
M, K = A.shape
48+
_, N = B.shape
49+
50+
C = torch.zeros(M, N, dtype=torch.bfloat16, device=A.device)
51+
block_size_k_packed = hl.register_block_size(K // 2)
52+
53+
# Use Helion to tile the computation
54+
for tile_m, tile_n in hl.tile([M, N]):
55+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
56+
57+
for tile_k_packed in hl.tile(K // 2, block_size=block_size_k_packed):
58+
# Load packed int8 data from B
59+
b_tile = B[tile_k_packed, tile_n] # [BLOCK_SIZE_K//2, BLOCK_SIZE_N]
60+
61+
# Extract low and high 4-bit values with sign extension
62+
# Low nibble: sign-extend from 4-bit to 8-bit using left shift then arithmetic right shift
63+
b_lo = ((b_tile << 4) >> 4).to(torch.int8) # Sign-extend low 4 bits
64+
b_hi = (b_tile >> 4).to(torch.int8) # Sign-extend high 4 bits
65+
66+
# Convert to bfloat16
67+
b_lo_bf16 = b_lo.to(torch.bfloat16) # [BLOCK_SIZE_K//2, BLOCK_SIZE_N]
68+
b_hi_bf16 = b_hi.to(torch.bfloat16) # [BLOCK_SIZE_K//2, BLOCK_SIZE_N]
69+
70+
# Stack and reshape to interleave low and high bits
71+
# Stack along a new dimension to get [BLOCK_SIZE_K//2, 2, BLOCK_SIZE_N]
72+
b_stacked = torch.stack([b_lo_bf16, b_hi_bf16], dim=1)
73+
74+
# Reshape to interleave: [BLOCK_SIZE_K//2, 2, BLOCK_SIZE_N] -> [BLOCK_SIZE_K, BLOCK_SIZE_N]
75+
# This will place elements in the order: b_lo[0], b_hi[0], b_lo[1], b_hi[1], ...
76+
b_unpacked = b_stacked.reshape(
77+
tile_k_packed.block_size * 2, tile_n.block_size
78+
)
79+
80+
# Load corresponding tiles from A (need to load twice the packed tile size)
81+
# We need to map tile_k_packed to the corresponding range in A
82+
a_tile_begin = tile_k_packed.begin * 2
83+
a_tile_len = tile_k_packed.block_size * 2
84+
a_tile = A[
85+
tile_m, a_tile_begin : (a_tile_begin + a_tile_len)
86+
] # [BLOCK_SIZE_M, BLOCK_SIZE_K]
87+
88+
acc = acc + hl.dot(a_tile, b_unpacked) # [BLOCK_SIZE_M, BLOCK_SIZE_N]
89+
90+
C[tile_m, tile_n] = acc.to(torch.bfloat16)
91+
92+
return C
93+
94+
95+
# %%
96+
# TritonBench Wrapper
97+
# -------------------
98+
def int4_gemm_tritonbench(tb_op: object, x: torch.Tensor, w: torch.Tensor) -> Callable:
99+
"""
100+
Wrapper for TritonBench compatibility.
101+
102+
Args:
103+
tb_op: TritonBench operator instance
104+
x (torch.Tensor): Left input tensor in bfloat16 format.
105+
w (torch.Tensor): Right input tensor of shape [K, N] containing int4 values.
106+
Will be packed to int4 format.
107+
108+
Returns:
109+
Callable: A function that performs the int4 gemm.
110+
"""
111+
112+
def run_kernel() -> torch.Tensor:
113+
x_2d = x.reshape(-1, x.size(-1))
114+
115+
# Pack w to int4 format (two 4-bit values per int8 byte)
116+
w_int8 = w.to(torch.int8)
117+
w_reshaped = w_int8.reshape(w.shape[0] // 2, 2, w.shape[1]).permute(1, 0, 2)
118+
w_packed = ((w_reshaped[0] & 0xF) | (w_reshaped[1] << 4)).to(torch.int8)
119+
120+
return matmul_bf16_int4(x_2d, w_packed)
121+
122+
return run_kernel
123+
124+
125+
# %%
126+
# Verification Function
127+
# ---------------------
128+
def check(m: int, k: int, n: int) -> None:
129+
"""
130+
Test the INT4 GEMM implementation.
131+
132+
Args:
133+
m (int): Number of rows in the left input matrix.
134+
k (int): Shared dimension (must be even).
135+
n (int): Number of columns in the right input matrix.
136+
"""
137+
# Create test matrices
138+
A = torch.randn(m, k, dtype=torch.bfloat16, device="cuda")
139+
140+
# Create packed int4 matrix B (K//2 x N)
141+
# Generate random int4 values in range [-8, 7] and pack them
142+
B_unpacked = torch.randint(-8, 8, (k, n), dtype=torch.int8, device="cuda")
143+
144+
# Pack using the same format as tritonbench
145+
B_reshaped = B_unpacked.reshape(k // 2, 2, n).permute(1, 0, 2)
146+
B_packed = ((B_reshaped[0] & 0xF) | (B_reshaped[1] << 4)).to(torch.int8)
147+
148+
# Convert unpacked values to bfloat16 for reference
149+
B_unpacked_bf16 = B_unpacked.to(torch.bfloat16)
150+
151+
# Compute reference result
152+
expected = torch.matmul(A, B_unpacked_bf16)
153+
154+
# Run the kernel
155+
result = matmul_bf16_int4(A, B_packed)
156+
157+
# Check accuracy with appropriate tolerance
158+
torch.testing.assert_close(result, expected, rtol=2e-1, atol=1.0)
159+
print(f"Test passed for shapes: M={m}, K={k}, N={n}")
160+
161+
162+
# %%
163+
# Main Function
164+
# -------------
165+
def main() -> None:
166+
"""
167+
Main function to run tests with different matrix sizes.
168+
"""
169+
check(256, 512, 256)
170+
check(512, 512, 512)
171+
check(1024, 1024, 1024)
172+
173+
174+
# %%
175+
# Run Example
176+
# -----------
177+
if __name__ == "__main__":
178+
main()

test/test_examples.expected

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,86 @@ def geglu(a: Tensor, b: Tensor, *, _launcher=_default_launcher):
900900
_launcher(_helion_geglu, (triton.cdiv(total_elements, _BLOCK_SIZE_0),), a_flat, b_flat, out_flat, a_flat.stride(0), b_flat.stride(0), out_flat.stride(0), total_elements, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
901901
return out
902902

903+
--- assertExpectedJournal(TestExamples.test_int4_gemm)
904+
from __future__ import annotations
905+
906+
import torch
907+
import triton
908+
import triton.language as tl
909+
from torch._inductor.runtime import triton_helpers
910+
from helion.runtime import default_launcher as _default_launcher
911+
912+
@triton.jit
913+
def _helion_matmul_bf16_int4(B, A, C, A_stride_0, A_stride_1, B_stride_0, B_stride_1, C_stride_0, C_stride_1, M, N, K, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, mul: tl.constexpr):
914+
num_blocks_0 = tl.cdiv(M, _BLOCK_SIZE_1)
915+
pid_0 = tl.program_id(0) % num_blocks_0
916+
pid_1 = tl.program_id(0) // num_blocks_0
917+
offset_1 = pid_0 * _BLOCK_SIZE_1
918+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
919+
mask_1 = indices_1 < M
920+
offset_2 = pid_1 * _BLOCK_SIZE_2
921+
indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32)
922+
mask_2 = indices_2 < N
923+
acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32)
924+
floordiv = triton_helpers.div_floor_integer(K, 2)
925+
for offset_3 in tl.range(0, floordiv.to(tl.int32), _BLOCK_SIZE_0):
926+
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
927+
mask_0 = indices_3 < floordiv
928+
acc_copy = acc
929+
acc_copy_0 = acc_copy
930+
b_tile = tl.load(B + (indices_3[:, None] * B_stride_0 + indices_2[None, :] * B_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
931+
v_0 = tl.full([], 4, tl.int8)
932+
v_1 = b_tile << v_0
933+
v_2 = tl.full([], 4, tl.int8)
934+
v_3 = v_1 >> v_2
935+
v_4 = tl.full([], 4, tl.int8)
936+
v_5 = b_tile >> v_4
937+
v_6 = tl.cast(v_3, tl.bfloat16)
938+
v_7 = tl.cast(v_5, tl.bfloat16)
939+
stack_idx = tl.arange(0, 2)
940+
broadcast_idx = stack_idx[None, :, None]
941+
expanded_0 = tl.expand_dims(v_6, 1)
942+
expanded_1 = tl.expand_dims(v_7, 1)
943+
stacked_result = tl.zeros_like(expanded_0)
944+
mask_4 = broadcast_idx == 0
945+
stacked_result = tl.where(mask_4, expanded_0, stacked_result)
946+
mask_5 = broadcast_idx == 1
947+
stacked_result = tl.where(mask_5, expanded_1, stacked_result)
948+
b_unpacked = tl.reshape(stacked_result, [2 * _BLOCK_SIZE_0, _BLOCK_SIZE_2])
949+
mul_5 = 2 * offset_3
950+
iota = mul_5 + tl.arange(0, mul)
951+
a_tile = tl.load(A + (indices_1[:, None] * A_stride_0 + iota[None, :] * A_stride_1), mask_1[:, None], other=0)
952+
dot = tl.dot(tl.cast(a_tile, tl.bfloat16), tl.cast(b_unpacked, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32)
953+
acc = acc_copy_0 + dot
954+
v_9 = tl.cast(acc, tl.bfloat16)
955+
tl.store(C + (indices_1[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), v_9, mask_1[:, None] & mask_2[None, :])
956+
957+
def matmul_bf16_int4(A: Tensor, B: Tensor, *, _launcher=_default_launcher):
958+
"""
959+
BFloat16 x INT4 General Matrix Multiplication (GEMM).
960+
961+
This kernel performs matrix multiplication where:
962+
- A is a bfloat16 matrix of shape [M, K]
963+
- B is an int8 matrix of shape [K//2, N] containing packed int4 values
964+
(two 4-bit values packed into each int8)
965+
966+
Args:
967+
A (Tensor): Input tensor of shape [M, K] in bfloat16 format.
968+
B (Tensor): Packed int4 tensor of shape [K//2, N] in int8 format.
969+
970+
Returns:
971+
Tensor: Output tensor of shape [M, N] in bfloat16 format.
972+
"""
973+
M, K = A.shape
974+
_, N = B.shape
975+
C = torch.zeros(M, N, dtype=torch.bfloat16, device=A.device)
976+
_BLOCK_SIZE_1 = 64
977+
_BLOCK_SIZE_2 = 32
978+
_RDIM_SIZE_3 = triton.next_power_of_2(K)
979+
_BLOCK_SIZE_0 = 64
980+
_launcher(_helion_matmul_bf16_int4, (triton.cdiv(M, _BLOCK_SIZE_1) * triton.cdiv(N, _BLOCK_SIZE_2),), B, A, C, A.stride(0), A.stride(1), B.stride(0), B.stride(1), C.stride(0), C.stride(1), M, N, K, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_0, 2 * _BLOCK_SIZE_0, num_warps=4, num_stages=3)
981+
return C
982+
903983
--- assertExpectedJournal(TestExamples.test_jagged_dense_add)
904984
from __future__ import annotations
905985

test/test_examples.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,6 +1112,41 @@ def test_kl_div(self):
11121112
)
11131113
)
11141114

1115+
def test_int4_gemm(self):
1116+
# Matrix dimensions
1117+
M, K, N = 256, 512, 256
1118+
1119+
# Create bfloat16 matrix A
1120+
A = torch.randn(M, K, dtype=torch.bfloat16, device=DEVICE)
1121+
1122+
# Create packed int4 matrix B
1123+
# Generate random int4 values in range [-8, 7]
1124+
B_unpacked = torch.randint(-8, 8, (K, N), dtype=torch.int8, device=DEVICE)
1125+
1126+
# Pack two int4 values per int8
1127+
B_reshaped = B_unpacked.reshape(K // 2, 2, N).permute(1, 0, 2)
1128+
B_packed = ((B_reshaped[0] & 0xF) | (B_reshaped[1] << 4)).to(torch.int8)
1129+
1130+
# Convert unpacked to bfloat16 for expected result
1131+
B_unpacked_bf16 = B_unpacked.to(torch.bfloat16)
1132+
expected = torch.matmul(A, B_unpacked_bf16)
1133+
1134+
args = (A, B_packed)
1135+
1136+
self.assertExpectedJournal(
1137+
check_example(
1138+
"int4_gemm",
1139+
args,
1140+
expected,
1141+
fn_name="matmul_bf16_int4",
1142+
block_sizes=[64, 64, 32],
1143+
num_warps=4,
1144+
num_stages=3,
1145+
rtol=2e-1,
1146+
atol=1.0,
1147+
)
1148+
)
1149+
11151150

11161151
if __name__ == "__main__":
11171152
unittest.main()

0 commit comments

Comments
 (0)