Skip to content

Commit 026ab0d

Browse files
committed
[Examples] Add matmul variants with bias support and tests
- Add wrapper functions for tritonbench dispatch in matmul.py and matmul_split_k.py - Implement bias handling in both matmul and matmul_split_k - Add comprehensive tests in test_examples.py for all matmul variants stack-info: PR: #379, branch: yf225/stack/41
1 parent 6c5c4ca commit 026ab0d

File tree

4 files changed

+324
-13
lines changed

4 files changed

+324
-13
lines changed

examples/matmul.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
# static_shapes=True gives a performance boost for matmuls
1111
@helion.kernel(static_shapes=True)
12-
def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
12+
def matmul_no_bias(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1313
m, k = x.size()
1414
k2, n = y.size()
1515
assert k == k2, f"size mismatch {k} != {k2}"
@@ -24,10 +24,48 @@ def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
2424
return out
2525

2626

27+
@helion.kernel(static_shapes=True)
28+
def matmul_with_bias(
29+
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor
30+
) -> torch.Tensor:
31+
m, k = x.size()
32+
k2, n = y.size()
33+
assert k == k2, f"size mismatch {k} != {k2}"
34+
bias_size = bias.size(0)
35+
assert bias_size == n, f"bias size mismatch, expected {n}, got {bias_size}"
36+
out = torch.empty(
37+
[m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device
38+
)
39+
for tile_m, tile_n in hl.tile([m, n]):
40+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
41+
for tile_k in hl.tile(k):
42+
acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
43+
# Add bias
44+
acc = acc + bias[tile_n]
45+
out[tile_m, tile_n] = acc
46+
return out
47+
48+
49+
def matmul(
50+
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
51+
) -> torch.Tensor:
52+
"""Wrapper function for tritonbench that dispatches based on bias presence."""
53+
if bias is None:
54+
return matmul_no_bias(x, y)
55+
return matmul_with_bias(x, y, bias)
56+
57+
2758
def check(m: int, k: int, n: int) -> None:
2859
x = torch.randn([m, k], device="cuda", dtype=torch.float16)
2960
y = torch.randn([k, n], device="cuda", dtype=torch.float16)
30-
run_example(matmul, torch.matmul, (x, y))
61+
62+
# Test without bias
63+
run_example(matmul_no_bias, torch.matmul, (x, y))
64+
65+
# Test with bias
66+
bias = torch.randn([n], device="cuda", dtype=torch.float16)
67+
expected_with_bias = lambda x, y, bias: torch.matmul(x, y) + bias # noqa: E731
68+
run_example(matmul_with_bias, expected_with_bias, (x, y, bias))
3169

3270

3371
def main() -> None:

examples/matmul_split_k.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
# static_shapes=True gives a performance boost for matmuls
1212
@helion.kernel(static_shapes=True)
13-
def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
13+
def matmul_split_k_no_bias(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1414
m, k = x.size()
1515
k2, n = y.size()
1616
assert k == k2, f"size mismatch {k} != {k2}"
@@ -27,10 +27,54 @@ def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
2727
return out
2828

2929

30+
@helion.kernel(static_shapes=True)
31+
def matmul_split_k_with_bias(
32+
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor
33+
) -> torch.Tensor:
34+
m, k = x.size()
35+
k2, n = y.size()
36+
assert k == k2, f"size mismatch {k} != {k2}"
37+
bias_size = bias.size(0)
38+
assert bias_size == n, f"bias size mismatch, expected {n}, got {bias_size}"
39+
40+
# Initialize output with zeros
41+
out = torch.zeros(
42+
[m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device
43+
)
44+
45+
split_k = hl.register_tunable("split_k", PowerOfTwoFragment(1, 256))
46+
k_block = helion.next_power_of_2(helion.cdiv(k, split_k))
47+
for tile_m, tile_n, outer_k in hl.tile([m, n, k], block_size=[None, None, k_block]):
48+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
49+
for inner_k in hl.tile(outer_k.begin, outer_k.end):
50+
acc = torch.addmm(acc, x[tile_m, inner_k], y[inner_k, tile_n])
51+
# Add bias only on the first k-split iteration
52+
if outer_k.begin == 0:
53+
acc = acc + bias[tile_n]
54+
hl.atomic_add(out, [tile_m, tile_n], acc)
55+
return out
56+
57+
58+
def matmul_split_k(
59+
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
60+
) -> torch.Tensor:
61+
"""Wrapper function for tritonbench that dispatches based on bias presence."""
62+
if bias is None:
63+
return matmul_split_k_no_bias(x, y)
64+
return matmul_split_k_with_bias(x, y, bias)
65+
66+
3067
def check(m: int, k: int, n: int) -> None:
3168
x = torch.randn([m, k], device="cuda", dtype=torch.float16)
3269
y = torch.randn([k, n], device="cuda", dtype=torch.float16)
33-
run_example(matmul_split_k, torch.matmul, (x, y), atol=1)
70+
71+
# Test without bias
72+
run_example(matmul_split_k_no_bias, torch.matmul, (x, y), atol=1)
73+
74+
# Test with bias
75+
bias = torch.randn([n], device="cuda", dtype=torch.float16)
76+
expected_with_bias = lambda x, y, bias: torch.matmul(x, y) + bias # noqa: E731
77+
run_example(matmul_split_k_with_bias, expected_with_bias, (x, y, bias), atol=1)
3478

3579

3680
def main() -> None:

test/test_examples.expected

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,6 +1103,45 @@ def matmul_layernorm(x: torch.Tensor, y: torch.Tensor, weight: torch.Tensor, bia
11031103
_launcher(_matmul_layernorm_kernel, (triton.cdiv(128, _BLOCK_SIZE_1),), x, y, weight, bias, out, out.stride(0), _BLOCK_SIZE_1, _RDIM_SIZE_0, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
11041104
return out
11051105

1106+
--- assertExpectedJournal(TestExamples.test_matmul_no_bias)
1107+
from __future__ import annotations
1108+
1109+
import torch
1110+
import triton
1111+
import triton.language as tl
1112+
from helion.runtime import default_launcher as _default_launcher
1113+
1114+
@triton.jit
1115+
def _matmul_no_bias_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
1116+
num_blocks_0 = tl.cdiv(128, _BLOCK_SIZE_0)
1117+
pid_0 = tl.program_id(0) % num_blocks_0
1118+
pid_1 = tl.program_id(0) // num_blocks_0
1119+
offset_0 = pid_0 * _BLOCK_SIZE_0
1120+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1121+
offset_1 = pid_1 * _BLOCK_SIZE_1
1122+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
1123+
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
1124+
for offset_2 in tl.range(0, 256, _BLOCK_SIZE_2):
1125+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
1126+
acc_copy = acc
1127+
acc_copy_0 = acc_copy
1128+
load = tl.load(x + (indices_0[:, None] * 256 + indices_2[None, :] * 1), None)
1129+
load_1 = tl.load(y + (indices_2[:, None] * 512 + indices_1[None, :] * 1), None)
1130+
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
1131+
v_0 = acc.to(tl.float16)
1132+
tl.store(out + (indices_0[:, None] * 512 + indices_1[None, :] * 1), v_0, None)
1133+
1134+
def matmul_no_bias(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
1135+
m, k = x.size()
1136+
k2, n = y.size()
1137+
assert k == k2, f'size mismatch {k} != {k2}'
1138+
out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
1139+
_BLOCK_SIZE_0 = 32
1140+
_BLOCK_SIZE_1 = 32
1141+
_BLOCK_SIZE_2 = 32
1142+
_launcher(_matmul_no_bias_kernel, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
1143+
return out
1144+
11061145
--- assertExpectedJournal(TestExamples.test_matmul_split_k)
11071146
from __future__ import annotations
11081147

@@ -1152,6 +1191,159 @@ def matmul_split_k(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launc
11521191
_launcher(_matmul_split_k_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1) * triton.cdiv(1024, _BLOCK_SIZE_2),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
11531192
return out
11541193

1194+
--- assertExpectedJournal(TestExamples.test_matmul_split_k_no_bias)
1195+
from __future__ import annotations
1196+
1197+
import torch
1198+
import helion
1199+
import triton
1200+
import triton.language as tl
1201+
from helion.runtime import default_launcher as _default_launcher
1202+
1203+
import helion._testing.matmul_split_k as _source_module
1204+
1205+
@triton.jit
1206+
def _matmul_split_k_no_bias_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
1207+
num_blocks_0 = tl.cdiv(64, _BLOCK_SIZE_0)
1208+
num_blocks_1 = tl.cdiv(64, _BLOCK_SIZE_1)
1209+
pid_0 = tl.program_id(0) % num_blocks_0
1210+
pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
1211+
pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1)
1212+
offset_0 = pid_0 * _BLOCK_SIZE_0
1213+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1214+
offset_1 = pid_1 * _BLOCK_SIZE_1
1215+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
1216+
offset_2 = pid_2 * _BLOCK_SIZE_2
1217+
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
1218+
tile_end = tl.minimum(offset_2 + _BLOCK_SIZE_2, 1024)
1219+
for offset_3 in tl.range(offset_2.to(tl.int32), tile_end.to(tl.int32), _BLOCK_SIZE_3):
1220+
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
1221+
mask_3 = indices_3 < tile_end
1222+
acc_copy = acc
1223+
acc_copy_0 = acc_copy
1224+
load = tl.load(x + (indices_0[:, None] * 1024 + indices_3[None, :] * 1), mask_3[None, :], other=0)
1225+
load_1 = tl.load(y + (indices_3[:, None] * 64 + indices_1[None, :] * 1), mask_3[:, None], other=0)
1226+
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
1227+
tl.atomic_add(out + (indices_0[:, None] * 64 + indices_1[None, :] * 1), acc, mask=None, sem='relaxed')
1228+
1229+
def matmul_split_k_no_bias(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
1230+
m, k = x.size()
1231+
k2, n = y.size()
1232+
assert k == k2, f'size mismatch {k} != {k2}'
1233+
out = torch.zeros([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
1234+
split_k = 8
1235+
k_block = helion.next_power_of_2(helion.cdiv(k, split_k))
1236+
_BLOCK_SIZE_0 = 16
1237+
_BLOCK_SIZE_1 = 16
1238+
_BLOCK_SIZE_2 = k_block
1239+
_BLOCK_SIZE_3 = 32
1240+
_launcher(_matmul_split_k_no_bias_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1) * triton.cdiv(1024, _BLOCK_SIZE_2),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
1241+
return out
1242+
1243+
--- assertExpectedJournal(TestExamples.test_matmul_split_k_with_bias)
1244+
from __future__ import annotations
1245+
1246+
import torch
1247+
import helion
1248+
import triton
1249+
import triton.language as tl
1250+
from helion.runtime import default_launcher as _default_launcher
1251+
1252+
import helion._testing.matmul_split_k as _source_module
1253+
1254+
@triton.jit
1255+
def _matmul_split_k_with_bias_kernel(x, y, bias, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
1256+
num_blocks_0 = tl.cdiv(64, _BLOCK_SIZE_0)
1257+
num_blocks_1 = tl.cdiv(64, _BLOCK_SIZE_1)
1258+
pid_0 = tl.program_id(0) % num_blocks_0
1259+
pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
1260+
pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1)
1261+
offset_0 = pid_0 * _BLOCK_SIZE_0
1262+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1263+
offset_1 = pid_1 * _BLOCK_SIZE_1
1264+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
1265+
offset_2 = pid_2 * _BLOCK_SIZE_2
1266+
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
1267+
tile_end = tl.minimum(offset_2 + _BLOCK_SIZE_2, 1024)
1268+
for offset_3 in tl.range(offset_2.to(tl.int32), tile_end.to(tl.int32), _BLOCK_SIZE_3):
1269+
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
1270+
mask_3 = indices_3 < tile_end
1271+
acc_copy = acc
1272+
acc_copy_0 = acc_copy
1273+
load = tl.load(x + (indices_0[:, None] * 1024 + indices_3[None, :] * 1), mask_3[None, :], other=0)
1274+
load_1 = tl.load(y + (indices_3[:, None] * 64 + indices_1[None, :] * 1), mask_3[:, None], other=0)
1275+
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
1276+
eq = offset_2 == 0
1277+
if eq:
1278+
acc_copy_1 = acc
1279+
acc_copy_1_0 = acc_copy_1
1280+
load_2 = tl.load(tl.make_block_ptr(bias, [64], [1], [offset_1], [_BLOCK_SIZE_1], [0]), boundary_check=[0], padding_option='zero')
1281+
v_0 = load_2[None, :]
1282+
v_1 = v_0.to(tl.float32)
1283+
acc = acc_copy_1_0 + v_1
1284+
tl.atomic_add(out + (indices_0[:, None] * 64 + indices_1[None, :] * 1), acc, mask=None, sem='relaxed')
1285+
1286+
def matmul_split_k_with_bias(x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor, *, _launcher=_default_launcher):
1287+
m, k = x.size()
1288+
k2, n = y.size()
1289+
assert k == k2, f'size mismatch {k} != {k2}'
1290+
bias_size = bias.size(0)
1291+
assert bias_size == n, f'bias size mismatch, expected {n}, got {bias_size}'
1292+
out = torch.zeros([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
1293+
split_k = 8
1294+
k_block = helion.next_power_of_2(helion.cdiv(k, split_k))
1295+
_BLOCK_SIZE_0 = 16
1296+
_BLOCK_SIZE_1 = 16
1297+
_BLOCK_SIZE_2 = k_block
1298+
_BLOCK_SIZE_3 = 32
1299+
_launcher(_matmul_split_k_with_bias_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1) * triton.cdiv(1024, _BLOCK_SIZE_2),), x, y, bias, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
1300+
return out
1301+
1302+
--- assertExpectedJournal(TestExamples.test_matmul_with_bias)
1303+
from __future__ import annotations
1304+
1305+
import torch
1306+
import triton
1307+
import triton.language as tl
1308+
from helion.runtime import default_launcher as _default_launcher
1309+
1310+
@triton.jit
1311+
def _matmul_with_bias_kernel(x, y, bias, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
1312+
num_blocks_0 = tl.cdiv(128, _BLOCK_SIZE_0)
1313+
pid_0 = tl.program_id(0) % num_blocks_0
1314+
pid_1 = tl.program_id(0) // num_blocks_0
1315+
offset_0 = pid_0 * _BLOCK_SIZE_0
1316+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1317+
offset_1 = pid_1 * _BLOCK_SIZE_1
1318+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
1319+
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
1320+
for offset_2 in tl.range(0, 256, _BLOCK_SIZE_2):
1321+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
1322+
acc_copy = acc
1323+
acc_copy_0 = acc_copy
1324+
load = tl.load(x + (indices_0[:, None] * 256 + indices_2[None, :] * 1), None)
1325+
load_1 = tl.load(y + (indices_2[:, None] * 512 + indices_1[None, :] * 1), None)
1326+
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
1327+
load_2 = tl.load(bias + indices_1 * 1, None)
1328+
v_0 = load_2[None, :]
1329+
v_1 = v_0.to(tl.float32)
1330+
v_2 = acc + v_1
1331+
v_3 = v_2.to(tl.float16)
1332+
tl.store(out + (indices_0[:, None] * 512 + indices_1[None, :] * 1), v_3, None)
1333+
1334+
def matmul_with_bias(x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor, *, _launcher=_default_launcher):
1335+
m, k = x.size()
1336+
k2, n = y.size()
1337+
assert k == k2, f'size mismatch {k} != {k2}'
1338+
bias_size = bias.size(0)
1339+
assert bias_size == n, f'bias size mismatch, expected {n}, got {bias_size}'
1340+
out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
1341+
_BLOCK_SIZE_0 = 32
1342+
_BLOCK_SIZE_1 = 32
1343+
_BLOCK_SIZE_2 = 32
1344+
_launcher(_matmul_with_bias_kernel, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), x, y, bias, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
1345+
return out
1346+
11551347
--- assertExpectedJournal(TestExamples.test_moe_matmul_ogs)
11561348
from __future__ import annotations
11571349

0 commit comments

Comments
 (0)