Skip to content

Commit 5b19983

Browse files
committed
[Examples] Add matmul variants with bias support and tests
- Add wrapper functions for tritonbench dispatch in matmul.py and matmul_split_k.py - Implement bias handling in both matmul and matmul_split_k - Add comprehensive tests in test_examples.py for all matmul variants stack-info: PR: #379, branch: yf225/stack/41
1 parent 6c5c4ca commit 5b19983

File tree

10 files changed

+414
-67
lines changed

10 files changed

+414
-67
lines changed

examples/matmul.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,56 @@
11
from __future__ import annotations
22

3+
from typing import TYPE_CHECKING
4+
35
import torch
46

57
import helion
68
from helion._testing import run_example
79
import helion.language as hl
810

11+
if TYPE_CHECKING:
12+
from collections.abc import Callable
13+
914

1015
# static_shapes=True gives a performance boost for matmuls
1116
@helion.kernel(static_shapes=True)
12-
def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
17+
def matmul_with_epilogue(
18+
x: torch.Tensor,
19+
y: torch.Tensor,
20+
epilogue: Callable[[torch.Tensor, list[torch.Tensor]], torch.Tensor],
21+
) -> torch.Tensor:
22+
m, k = x.size()
23+
k2, n = y.size()
24+
assert k == k2, f"size mismatch {k} != {k2}"
25+
out = torch.empty(
26+
[m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device
27+
)
28+
for tile_m, tile_n in hl.tile([m, n]):
29+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
30+
for tile_k in hl.tile(k):
31+
acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
32+
out[tile_m, tile_n] = epilogue(acc, [tile_m, tile_n])
33+
return out
34+
35+
36+
def matmul(
37+
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
38+
) -> torch.Tensor:
39+
"""Wrapper function for tritonbench that dispatches based on bias presence."""
40+
if bias is None:
41+
# No epilogue, just return the accumulated value
42+
return matmul_with_epilogue(x, y, lambda acc, tile: acc)
43+
# Create a closure that captures the bias
44+
45+
def epilogue_with_bias(acc: torch.Tensor, tile: list[torch.Tensor]) -> torch.Tensor:
46+
# Use tile_n to index into the bias
47+
return acc + bias[tile[1]]
48+
49+
return matmul_with_epilogue(x, y, epilogue_with_bias)
50+
51+
52+
@helion.kernel(static_shapes=True)
53+
def matmul_no_bias(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1354
m, k = x.size()
1455
k2, n = y.size()
1556
assert k == k2, f"size mismatch {k} != {k2}"
@@ -24,10 +65,45 @@ def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
2465
return out
2566

2667

68+
@helion.kernel(static_shapes=True)
69+
def matmul_with_bias(
70+
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor
71+
) -> torch.Tensor:
72+
m, k = x.size()
73+
k2, n = y.size()
74+
assert k == k2, f"size mismatch {k} != {k2}"
75+
out = torch.empty(
76+
[m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device
77+
)
78+
for tile_m, tile_n in hl.tile([m, n]):
79+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
80+
for tile_k in hl.tile(k):
81+
acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
82+
out[tile_m, tile_n] = acc + bias[tile_n]
83+
return out
84+
85+
2786
def check(m: int, k: int, n: int) -> None:
2887
x = torch.randn([m, k], device="cuda", dtype=torch.float16)
2988
y = torch.randn([k, n], device="cuda", dtype=torch.float16)
30-
run_example(matmul, torch.matmul, (x, y))
89+
90+
# Test without bias using closure approach
91+
def kernel_no_bias(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
92+
return matmul_with_epilogue(x, y, lambda acc, tile: acc)
93+
94+
run_example(kernel_no_bias, torch.matmul, (x, y))
95+
96+
# Test with bias using closure approach
97+
bias = torch.randn([n], device="cuda", dtype=torch.float16)
98+
99+
def kernel_with_bias(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
100+
def epilogue(acc: torch.Tensor, tile: list[torch.Tensor]) -> torch.Tensor:
101+
return acc + bias[tile[1]]
102+
103+
return matmul_with_epilogue(x, y, epilogue)
104+
105+
expected_with_bias = lambda x, y: torch.matmul(x, y) + bias # noqa: E731
106+
run_example(kernel_with_bias, expected_with_bias, (x, y))
31107

32108

33109
def main() -> None:

examples/matmul_split_k.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
# static_shapes=True gives a performance boost for matmuls
1212
@helion.kernel(static_shapes=True)
13-
def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
13+
def matmul_split_k_no_bias(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1414
m, k = x.size()
1515
k2, n = y.size()
1616
assert k == k2, f"size mismatch {k} != {k2}"
@@ -27,10 +27,54 @@ def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
2727
return out
2828

2929

30+
@helion.kernel(static_shapes=True)
31+
def matmul_split_k_with_bias(
32+
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor
33+
) -> torch.Tensor:
34+
m, k = x.size()
35+
k2, n = y.size()
36+
assert k == k2, f"size mismatch {k} != {k2}"
37+
bias_size = bias.size(0)
38+
assert bias_size == n, f"bias size mismatch, expected {n}, got {bias_size}"
39+
40+
# Initialize output with zeros
41+
out = torch.zeros(
42+
[m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device
43+
)
44+
45+
split_k = hl.register_tunable("split_k", PowerOfTwoFragment(1, 256))
46+
k_block = helion.next_power_of_2(helion.cdiv(k, split_k))
47+
for tile_m, tile_n, outer_k in hl.tile([m, n, k], block_size=[None, None, k_block]):
48+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
49+
for inner_k in hl.tile(outer_k.begin, outer_k.end):
50+
acc = torch.addmm(acc, x[tile_m, inner_k], y[inner_k, tile_n])
51+
# Add bias only on the first k-split iteration
52+
if outer_k.begin == 0:
53+
acc = acc + bias[tile_n]
54+
hl.atomic_add(out, [tile_m, tile_n], acc)
55+
return out
56+
57+
58+
def matmul_split_k(
59+
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
60+
) -> torch.Tensor:
61+
"""Wrapper function for tritonbench that dispatches based on bias presence."""
62+
if bias is None:
63+
return matmul_split_k_no_bias(x, y)
64+
return matmul_split_k_with_bias(x, y, bias)
65+
66+
3067
def check(m: int, k: int, n: int) -> None:
3168
x = torch.randn([m, k], device="cuda", dtype=torch.float16)
3269
y = torch.randn([k, n], device="cuda", dtype=torch.float16)
33-
run_example(matmul_split_k, torch.matmul, (x, y), atol=1)
70+
71+
# Test without bias
72+
run_example(matmul_split_k_no_bias, torch.matmul, (x, y), atol=1)
73+
74+
# Test with bias
75+
bias = torch.randn([n], device="cuda", dtype=torch.float16)
76+
expected_with_bias = lambda x, y, bias: torch.matmul(x, y) + bias # noqa: E731
77+
run_example(matmul_split_k_with_bias, expected_with_bias, (x, y, bias), atol=1)
3478

3579

3680
def main() -> None:

test/test_autotuner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
datadir = Path(__file__).parent / "data"
2525
basic_kernels = import_path(datadir / "basic_kernels.py")
2626
examples_dir = Path(__file__).parent.parent / "examples"
27-
examples_matmul = import_path(examples_dir / "matmul.py").matmul
27+
examples_matmul = import_path(examples_dir / "matmul.py").matmul_no_bias
2828

2929

3030
class TestAutotuner(TestCase):

test/test_examples.expected

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,6 +1103,45 @@ def matmul_layernorm(x: torch.Tensor, y: torch.Tensor, weight: torch.Tensor, bia
11031103
_launcher(_matmul_layernorm_kernel, (triton.cdiv(128, _BLOCK_SIZE_1),), x, y, weight, bias, out, out.stride(0), _BLOCK_SIZE_1, _RDIM_SIZE_0, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
11041104
return out
11051105

1106+
--- assertExpectedJournal(TestExamples.test_matmul_no_bias)
1107+
from __future__ import annotations
1108+
1109+
import torch
1110+
import triton
1111+
import triton.language as tl
1112+
from helion.runtime import default_launcher as _default_launcher
1113+
1114+
@triton.jit
1115+
def _matmul_no_bias_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
1116+
num_blocks_0 = tl.cdiv(128, _BLOCK_SIZE_0)
1117+
pid_0 = tl.program_id(0) % num_blocks_0
1118+
pid_1 = tl.program_id(0) // num_blocks_0
1119+
offset_0 = pid_0 * _BLOCK_SIZE_0
1120+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1121+
offset_1 = pid_1 * _BLOCK_SIZE_1
1122+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
1123+
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
1124+
for offset_2 in tl.range(0, 256, _BLOCK_SIZE_2):
1125+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
1126+
acc_copy = acc
1127+
acc_copy_0 = acc_copy
1128+
load = tl.load(x + (indices_0[:, None] * 256 + indices_2[None, :] * 1), None)
1129+
load_1 = tl.load(y + (indices_2[:, None] * 512 + indices_1[None, :] * 1), None)
1130+
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
1131+
v_0 = acc.to(tl.float16)
1132+
tl.store(out + (indices_0[:, None] * 512 + indices_1[None, :] * 1), v_0, None)
1133+
1134+
def matmul_no_bias(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
1135+
m, k = x.size()
1136+
k2, n = y.size()
1137+
assert k == k2, f'size mismatch {k} != {k2}'
1138+
out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
1139+
_BLOCK_SIZE_0 = 32
1140+
_BLOCK_SIZE_1 = 32
1141+
_BLOCK_SIZE_2 = 32
1142+
_launcher(_matmul_no_bias_kernel, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
1143+
return out
1144+
11061145
--- assertExpectedJournal(TestExamples.test_matmul_split_k)
11071146
from __future__ import annotations
11081147

@@ -1152,6 +1191,157 @@ def matmul_split_k(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launc
11521191
_launcher(_matmul_split_k_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1) * triton.cdiv(1024, _BLOCK_SIZE_2),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
11531192
return out
11541193

1194+
--- assertExpectedJournal(TestExamples.test_matmul_split_k_no_bias)
1195+
from __future__ import annotations
1196+
1197+
import torch
1198+
import helion
1199+
import triton
1200+
import triton.language as tl
1201+
from helion.runtime import default_launcher as _default_launcher
1202+
1203+
import helion._testing.matmul_split_k as _source_module
1204+
1205+
@triton.jit
1206+
def _matmul_split_k_no_bias_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
1207+
num_blocks_0 = tl.cdiv(64, _BLOCK_SIZE_0)
1208+
num_blocks_1 = tl.cdiv(64, _BLOCK_SIZE_1)
1209+
pid_0 = tl.program_id(0) % num_blocks_0
1210+
pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
1211+
pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1)
1212+
offset_0 = pid_0 * _BLOCK_SIZE_0
1213+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1214+
offset_1 = pid_1 * _BLOCK_SIZE_1
1215+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
1216+
offset_2 = pid_2 * _BLOCK_SIZE_2
1217+
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
1218+
tile_end = tl.minimum(offset_2 + _BLOCK_SIZE_2, 1024)
1219+
for offset_3 in tl.range(offset_2.to(tl.int32), tile_end.to(tl.int32), _BLOCK_SIZE_3):
1220+
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
1221+
mask_3 = indices_3 < tile_end
1222+
acc_copy = acc
1223+
acc_copy_0 = acc_copy
1224+
load = tl.load(x + (indices_0[:, None] * 1024 + indices_3[None, :] * 1), mask_3[None, :], other=0)
1225+
load_1 = tl.load(y + (indices_3[:, None] * 64 + indices_1[None, :] * 1), mask_3[:, None], other=0)
1226+
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
1227+
tl.atomic_add(out + (indices_0[:, None] * 64 + indices_1[None, :] * 1), acc, mask=None, sem='relaxed')
1228+
1229+
def matmul_split_k_no_bias(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
1230+
m, k = x.size()
1231+
k2, n = y.size()
1232+
assert k == k2, f'size mismatch {k} != {k2}'
1233+
out = torch.zeros([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
1234+
split_k = 8
1235+
k_block = helion.next_power_of_2(helion.cdiv(k, split_k))
1236+
_BLOCK_SIZE_0 = 16
1237+
_BLOCK_SIZE_1 = 16
1238+
_BLOCK_SIZE_2 = k_block
1239+
_BLOCK_SIZE_3 = 32
1240+
_launcher(_matmul_split_k_no_bias_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1) * triton.cdiv(1024, _BLOCK_SIZE_2),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
1241+
return out
1242+
1243+
--- assertExpectedJournal(TestExamples.test_matmul_split_k_with_bias)
1244+
from __future__ import annotations
1245+
1246+
import torch
1247+
import helion
1248+
import triton
1249+
import triton.language as tl
1250+
from helion.runtime import default_launcher as _default_launcher
1251+
1252+
import helion._testing.matmul_split_k as _source_module
1253+
1254+
@triton.jit
1255+
def _matmul_split_k_with_bias_kernel(x, y, bias, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
1256+
num_blocks_0 = tl.cdiv(64, _BLOCK_SIZE_0)
1257+
num_blocks_1 = tl.cdiv(64, _BLOCK_SIZE_1)
1258+
pid_0 = tl.program_id(0) % num_blocks_0
1259+
pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
1260+
pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1)
1261+
offset_0 = pid_0 * _BLOCK_SIZE_0
1262+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1263+
offset_1 = pid_1 * _BLOCK_SIZE_1
1264+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
1265+
offset_2 = pid_2 * _BLOCK_SIZE_2
1266+
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
1267+
tile_end = tl.minimum(offset_2 + _BLOCK_SIZE_2, 1024)
1268+
for offset_3 in tl.range(offset_2.to(tl.int32), tile_end.to(tl.int32), _BLOCK_SIZE_3):
1269+
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
1270+
mask_3 = indices_3 < tile_end
1271+
acc_copy = acc
1272+
acc_copy_0 = acc_copy
1273+
load = tl.load(x + (indices_0[:, None] * 1024 + indices_3[None, :] * 1), mask_3[None, :], other=0)
1274+
load_1 = tl.load(y + (indices_3[:, None] * 64 + indices_1[None, :] * 1), mask_3[:, None], other=0)
1275+
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
1276+
eq = offset_2 == 0
1277+
if eq:
1278+
acc_copy_1 = acc
1279+
acc_copy_1_0 = acc_copy_1
1280+
load_2 = tl.load(tl.make_block_ptr(bias, [64], [1], [offset_1], [_BLOCK_SIZE_1], [0]), boundary_check=[0], padding_option='zero')
1281+
v_0 = load_2[None, :]
1282+
v_1 = v_0.to(tl.float32)
1283+
acc = acc_copy_1_0 + v_1
1284+
tl.atomic_add(out + (indices_0[:, None] * 64 + indices_1[None, :] * 1), acc, mask=None, sem='relaxed')
1285+
1286+
def matmul_split_k_with_bias(x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor, *, _launcher=_default_launcher):
1287+
m, k = x.size()
1288+
k2, n = y.size()
1289+
assert k == k2, f'size mismatch {k} != {k2}'
1290+
bias_size = bias.size(0)
1291+
assert bias_size == n, f'bias size mismatch, expected {n}, got {bias_size}'
1292+
out = torch.zeros([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
1293+
split_k = 8
1294+
k_block = helion.next_power_of_2(helion.cdiv(k, split_k))
1295+
_BLOCK_SIZE_0 = 16
1296+
_BLOCK_SIZE_1 = 16
1297+
_BLOCK_SIZE_2 = k_block
1298+
_BLOCK_SIZE_3 = 32
1299+
_launcher(_matmul_split_k_with_bias_kernel, (triton.cdiv(64, _BLOCK_SIZE_0) * triton.cdiv(64, _BLOCK_SIZE_1) * triton.cdiv(1024, _BLOCK_SIZE_2),), x, y, bias, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
1300+
return out
1301+
1302+
--- assertExpectedJournal(TestExamples.test_matmul_with_bias)
1303+
from __future__ import annotations
1304+
1305+
import torch
1306+
import triton
1307+
import triton.language as tl
1308+
from helion.runtime import default_launcher as _default_launcher
1309+
1310+
@triton.jit
1311+
def _matmul_with_bias_kernel(x, y, bias, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
1312+
num_blocks_0 = tl.cdiv(128, _BLOCK_SIZE_0)
1313+
pid_0 = tl.program_id(0) % num_blocks_0
1314+
pid_1 = tl.program_id(0) // num_blocks_0
1315+
offset_0 = pid_0 * _BLOCK_SIZE_0
1316+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1317+
offset_1 = pid_1 * _BLOCK_SIZE_1
1318+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
1319+
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
1320+
for offset_2 in tl.range(0, 256, _BLOCK_SIZE_2):
1321+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
1322+
acc_copy = acc
1323+
acc_copy_0 = acc_copy
1324+
load = tl.load(x + (indices_0[:, None] * 256 + indices_2[None, :] * 1), None)
1325+
load_1 = tl.load(y + (indices_2[:, None] * 512 + indices_1[None, :] * 1), None)
1326+
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
1327+
load_2 = tl.load(bias + indices_1 * 1, None)
1328+
v_0 = load_2[None, :]
1329+
v_1 = v_0.to(tl.float32)
1330+
v_2 = acc + v_1
1331+
v_3 = v_2.to(tl.float16)
1332+
tl.store(out + (indices_0[:, None] * 512 + indices_1[None, :] * 1), v_3, None)
1333+
1334+
def matmul_with_bias(x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor, *, _launcher=_default_launcher):
1335+
m, k = x.size()
1336+
k2, n = y.size()
1337+
assert k == k2, f'size mismatch {k} != {k2}'
1338+
out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
1339+
_BLOCK_SIZE_0 = 32
1340+
_BLOCK_SIZE_1 = 32
1341+
_BLOCK_SIZE_2 = 32
1342+
_launcher(_matmul_with_bias_kernel, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(512, _BLOCK_SIZE_1),), x, y, bias, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
1343+
return out
1344+
11551345
--- assertExpectedJournal(TestExamples.test_moe_matmul_ogs)
11561346
from __future__ import annotations
11571347

0 commit comments

Comments
 (0)