@@ -1103,6 +1103,45 @@ def matmul_layernorm(x: torch.Tensor, y: torch.Tensor, weight: torch.Tensor, bia
11031103 _launcher(_matmul_layernorm_kernel, (triton.cdiv(128, _BLOCK_SIZE_1),), x, y, weight, bias, out, out.stride(0), _BLOCK_SIZE_1, _RDIM_SIZE_0, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
11041104 return out
11051105
1106+ --- assertExpectedJournal(TestExamples.test_matmul_no_bias)
1107+ from __future__ import annotations
1108+
1109+ import torch
1110+ import triton
1111+ import triton.language as tl
1112+ from helion.runtime import default_launcher as _default_launcher
1113+
1114+ @triton.jit
1115+ def _matmul_no_bias_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
1116+ num_blocks_0 = tl.cdiv(128, _BLOCK_SIZE_0)
1117+ pid_0 = tl.program_id(0) % num_blocks_0
1118+ pid_1 = tl.program_id(0) // num_blocks_0
1119+ offset_0 = pid_0 * _BLOCK_SIZE_0
1120+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1121+ offset_1 = pid_1 * _BLOCK_SIZE_1
1122+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
1123+ acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
1124+ for offset_2 in tl.range(0, 256, _BLOCK_SIZE_2):
1125+ indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
1126+ acc_copy = acc
1127+ acc_copy_0 = acc_copy
1128+ load = tl.load(x + (indices_0[:, None] * 256 + indices_2[None, :] * 1), None)
1129+ load_1 = tl.load(y + (indices_2[:, None] * 512 + indices_1[None, :] * 1), None)
1130+ acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
1131+ v_0 = acc.to(tl.float16)
1132+ tl.store(out + (indices_0[:, None] * 512 + indices_1[None, :] * 1), v_0, None)
1133+
1134+ def matmul_no_bias(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
1135+ m, k = x.size()
1136+ k2, n = y.size()
1137+ assert k == k2, f'size mismatch {k} != {k2}'
1138+ out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
1139+ _BLOCK_SIZE_0 = 32
1140+ _BLOCK_SIZE_1 = 32
1141+ _BLOCK_SIZE_2 = 32
1142+ _launcher(_matmul_no_bias_kernel, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
1143+ return out
1144+
11061145--- assertExpectedJournal(TestExamples.test_matmul_split_k)
11071146from __future__ import annotations
11081147
@@ -1152,6 +1191,159 @@ def matmul_split_k(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launc
11521191 _launcher(_matmul_split_k_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1) * triton.cdiv(1024, _BLOCK_SIZE_2),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
11531192 return out
11541193
1194+ --- assertExpectedJournal(TestExamples.test_matmul_split_k_no_bias)
1195+ from __future__ import annotations
1196+
1197+ import torch
1198+ import helion
1199+ import triton
1200+ import triton.language as tl
1201+ from helion.runtime import default_launcher as _default_launcher
1202+
1203+ import helion._testing.matmul_split_k as _source_module
1204+
1205+ @triton.jit
1206+ def _matmul_split_k_no_bias_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
1207+ num_blocks_0 = tl.cdiv(64, _BLOCK_SIZE_0)
1208+ num_blocks_1 = tl.cdiv(64, _BLOCK_SIZE_1)
1209+ pid_0 = tl.program_id(0) % num_blocks_0
1210+ pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
1211+ pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1)
1212+ offset_0 = pid_0 * _BLOCK_SIZE_0
1213+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1214+ offset_1 = pid_1 * _BLOCK_SIZE_1
1215+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
1216+ offset_2 = pid_2 * _BLOCK_SIZE_2
1217+ acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
1218+ tile_end = tl.minimum(offset_2 + _BLOCK_SIZE_2, 1024)
1219+ for offset_3 in tl.range(offset_2.to(tl.int32), tile_end.to(tl.int32), _BLOCK_SIZE_3):
1220+ indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
1221+ mask_3 = indices_3 < tile_end
1222+ acc_copy = acc
1223+ acc_copy_0 = acc_copy
1224+ load = tl.load(x + (indices_0[:, None] * 1024 + indices_3[None, :] * 1), mask_3[None, :], other=0)
1225+ load_1 = tl.load(y + (indices_3[:, None] * 64 + indices_1[None, :] * 1), mask_3[:, None], other=0)
1226+ acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
1227+ tl.atomic_add(out + (indices_0[:, None] * 64 + indices_1[None, :] * 1), acc, mask=None, sem='relaxed')
1228+
1229+ def matmul_split_k_no_bias(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
1230+ m, k = x.size()
1231+ k2, n = y.size()
1232+ assert k == k2, f'size mismatch {k} != {k2}'
1233+ out = torch.zeros([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
1234+ split_k = 8
1235+ k_block = helion.next_power_of_2(helion.cdiv(k, split_k))
1236+ _BLOCK_SIZE_0 = 16
1237+ _BLOCK_SIZE_1 = 16
1238+ _BLOCK_SIZE_2 = k_block
1239+ _BLOCK_SIZE_3 = 32
1240+ _launcher(_matmul_split_k_no_bias_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1) * triton.cdiv(1024, _BLOCK_SIZE_2),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
1241+ return out
1242+
1243+ --- assertExpectedJournal(TestExamples.test_matmul_split_k_with_bias)
1244+ from __future__ import annotations
1245+
1246+ import torch
1247+ import helion
1248+ import triton
1249+ import triton.language as tl
1250+ from helion.runtime import default_launcher as _default_launcher
1251+
1252+ import helion._testing.matmul_split_k as _source_module
1253+
1254+ @triton.jit
1255+ def _matmul_split_k_with_bias_kernel(x, y, bias, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
1256+ num_blocks_0 = tl.cdiv(64, _BLOCK_SIZE_0)
1257+ num_blocks_1 = tl.cdiv(64, _BLOCK_SIZE_1)
1258+ pid_0 = tl.program_id(0) % num_blocks_0
1259+ pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
1260+ pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1)
1261+ offset_0 = pid_0 * _BLOCK_SIZE_0
1262+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1263+ offset_1 = pid_1 * _BLOCK_SIZE_1
1264+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
1265+ offset_2 = pid_2 * _BLOCK_SIZE_2
1266+ acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
1267+ tile_end = tl.minimum(offset_2 + _BLOCK_SIZE_2, 1024)
1268+ for offset_3 in tl.range(offset_2.to(tl.int32), tile_end.to(tl.int32), _BLOCK_SIZE_3):
1269+ indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
1270+ mask_3 = indices_3 < tile_end
1271+ acc_copy = acc
1272+ acc_copy_0 = acc_copy
1273+ load = tl.load(x + (indices_0[:, None] * 1024 + indices_3[None, :] * 1), mask_3[None, :], other=0)
1274+ load_1 = tl.load(y + (indices_3[:, None] * 64 + indices_1[None, :] * 1), mask_3[:, None], other=0)
1275+ acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
1276+ eq = offset_2 == 0
1277+ if eq:
1278+ acc_copy_1 = acc
1279+ acc_copy_1_0 = acc_copy_1
1280+ load_2 = tl.load(tl.make_block_ptr(bias, [64], [1], [offset_1], [_BLOCK_SIZE_1], [0]), boundary_check=[0], padding_option='zero')
1281+ v_0 = load_2[None, :]
1282+ v_1 = v_0.to(tl.float32)
1283+ acc = acc_copy_1_0 + v_1
1284+ tl.atomic_add(out + (indices_0[:, None] * 64 + indices_1[None, :] * 1), acc, mask=None, sem='relaxed')
1285+
1286+ def matmul_split_k_with_bias(x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor, *, _launcher=_default_launcher):
1287+ m, k = x.size()
1288+ k2, n = y.size()
1289+ assert k == k2, f'size mismatch {k} != {k2}'
1290+ bias_size = bias.size(0)
1291+ assert bias_size == n, f'bias size mismatch, expected {n}, got {bias_size}'
1292+ out = torch.zeros([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
1293+ split_k = 8
1294+ k_block = helion.next_power_of_2(helion.cdiv(k, split_k))
1295+ _BLOCK_SIZE_0 = 16
1296+ _BLOCK_SIZE_1 = 16
1297+ _BLOCK_SIZE_2 = k_block
1298+ _BLOCK_SIZE_3 = 32
1299+ _launcher(_matmul_split_k_with_bias_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1) * triton.cdiv(1024, _BLOCK_SIZE_2),), x, y, bias, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
1300+ return out
1301+
1302+ --- assertExpectedJournal(TestExamples.test_matmul_with_bias)
1303+ from __future__ import annotations
1304+
1305+ import torch
1306+ import triton
1307+ import triton.language as tl
1308+ from helion.runtime import default_launcher as _default_launcher
1309+
1310+ @triton.jit
1311+ def _matmul_with_bias_kernel(x, y, bias, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
1312+ num_blocks_0 = tl.cdiv(128, _BLOCK_SIZE_0)
1313+ pid_0 = tl.program_id(0) % num_blocks_0
1314+ pid_1 = tl.program_id(0) // num_blocks_0
1315+ offset_0 = pid_0 * _BLOCK_SIZE_0
1316+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1317+ offset_1 = pid_1 * _BLOCK_SIZE_1
1318+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
1319+ acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
1320+ for offset_2 in tl.range(0, 256, _BLOCK_SIZE_2):
1321+ indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
1322+ acc_copy = acc
1323+ acc_copy_0 = acc_copy
1324+ load = tl.load(x + (indices_0[:, None] * 256 + indices_2[None, :] * 1), None)
1325+ load_1 = tl.load(y + (indices_2[:, None] * 512 + indices_1[None, :] * 1), None)
1326+ acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
1327+ load_2 = tl.load(bias + indices_1 * 1, None)
1328+ v_0 = load_2[None, :]
1329+ v_1 = v_0.to(tl.float32)
1330+ v_2 = acc + v_1
1331+ v_3 = v_2.to(tl.float16)
1332+ tl.store(out + (indices_0[:, None] * 512 + indices_1[None, :] * 1), v_3, None)
1333+
1334+ def matmul_with_bias(x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor, *, _launcher=_default_launcher):
1335+ m, k = x.size()
1336+ k2, n = y.size()
1337+ assert k == k2, f'size mismatch {k} != {k2}'
1338+ bias_size = bias.size(0)
1339+ assert bias_size == n, f'bias size mismatch, expected {n}, got {bias_size}'
1340+ out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
1341+ _BLOCK_SIZE_0 = 32
1342+ _BLOCK_SIZE_1 = 32
1343+ _BLOCK_SIZE_2 = 32
1344+ _launcher(_matmul_with_bias_kernel, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), x, y, bias, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
1345+ return out
1346+
11551347--- assertExpectedJournal(TestExamples.test_moe_matmul_ogs)
11561348from __future__ import annotations
11571349
0 commit comments