Skip to content

Commit 3c0348a

Browse files
authored
Fix missing min block size for hl.dot (#522)
1 parent c221cab commit 3c0348a

File tree

3 files changed

+96
-19
lines changed

3 files changed

+96
-19
lines changed

helion/_compiler/inductor_lowering.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@
4141
from torch.fx.node import map_arg
4242

4343
from .. import exc
44-
from .._compat import min_dot_size
4544
from ..exc import InductorLoweringError
4645
from ..language._decorators import APIFunc
4746
from ..language._decorators import is_api_func
47+
from ..language.matmul_ops import enforce_dot_requirements
4848
from .ast_extension import ExtendedAST
4949
from .ast_extension import create
5050
from .ast_extension import expr_from_string
@@ -872,18 +872,8 @@ def apply_dot_requirements(
872872
lproxy, rproxy = map_arg(node.args[-2:], lambda arg: arg.meta["val"])
873873
assert isinstance(lproxy, torch.Tensor)
874874
assert isinstance(rproxy, torch.Tensor)
875-
lshape = lproxy.size()
876-
rshape = rproxy.size()
877-
# use last two dimensions for dot (supports 2D and batched 3D tensors)
878-
m, k = lshape[-2], lshape[-1]
879-
k2, n = rshape[-2], rshape[-1]
880-
assert k == k2, f"Mismatched k dimensions for dot: {k} vs {k2}"
881-
a, b, c = min_dot_size(lproxy.device, lproxy.dtype, rproxy.dtype)
882-
env = CompileEnvironment.current()
883-
for shape, min_size in [(m, a), (n, b), (k, c)]:
884-
block_idx = CompileEnvironment.current().get_block_id(shape)
885-
if block_idx is not None:
886-
env.block_sizes[block_idx].update_min_block(min_size, allow_flattened=True)
875+
# Update config spec min sizes for M, N, K
876+
enforce_dot_requirements(lproxy, rproxy)
887877
# inputs to the dot operation must be zero-masked
888878
*maybe_acc, lnode, rnode = node.args
889879
assert isinstance(lnode, torch.fx.Node)

helion/language/matmul_ops.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from torch._subclasses.fake_tensor import FakeTensor
99

1010
from .. import exc
11+
from .._compat import min_dot_size
12+
from .._compiler.compile_environment import CompileEnvironment
1113
from . import _decorators
1214

1315
if TYPE_CHECKING:
@@ -126,9 +128,34 @@ def _(
126128
f"hl.dot: acc shape {list(acc.shape)} incompatible with result shape {expected_shape}"
127129
)
128130

131+
# Apply min-dot-size constraints so autotuner won't pick invalid block_size
132+
enforce_dot_requirements(mat1, mat2)
133+
129134
return (mat1, mat2, acc)
130135

131136

137+
def enforce_dot_requirements(lhs: torch.Tensor, rhs: torch.Tensor) -> None:
138+
"""Update config-spec min sizes for M, N, K of a dot/matmul.
139+
140+
This ensures the autotuner does not select block sizes below the hardware
141+
minimums for the current device and dtypes.
142+
"""
143+
144+
# Last two dims are used for matmul
145+
lshape = lhs.size()
146+
rshape = rhs.size()
147+
m, k = lshape[-2], lshape[-1]
148+
k2, n = rshape[-2], rshape[-1]
149+
assert k == k2, f"Mismatched K dimensions for dot: {k} vs {k2}"
150+
151+
a, b, c = min_dot_size(lhs.device, lhs.dtype, rhs.dtype)
152+
env = CompileEnvironment.current()
153+
for shape, min_size in ((m, a), (n, b), (k, c)):
154+
block_idx = env.get_block_id(shape)
155+
if block_idx is not None:
156+
env.block_sizes[block_idx].update_min_block(min_size, allow_flattened=True)
157+
158+
132159
def _compute_out_dtype(
133160
mat1_dtype: torch.dtype,
134161
mat2_dtype: torch.dtype,
@@ -167,7 +194,6 @@ def _(
167194
def _(state: CodegenState) -> object:
168195
# Import here to avoid circular imports
169196
from .._compiler.ast_extension import expr_from_string
170-
from .._compiler.compile_environment import CompileEnvironment
171197

172198
# Get the AST representations of our arguments
173199
lhs_ast = state.ast_arg(0)
@@ -182,15 +208,11 @@ def _(state: CodegenState) -> object:
182208
acc_proxy = state.proxy_args[2] if len(state.proxy_args) > 2 else None
183209

184210
# Access dtype - proxy_args can be FakeTensor objects
185-
lhs_dtype = None
186-
rhs_dtype = None
187-
acc_dtype = None
188-
189-
# For FakeTensor objects, dtype is directly accessible
190211
lhs_dtype = lhs_proxy.dtype
191212
rhs_dtype = rhs_proxy.dtype
192213

193214
# Get accumulator dtype if available
215+
acc_dtype: torch.dtype | None = None
194216
if acc_proxy is not None:
195217
assert isinstance(acc_proxy, FakeTensor), "acc_proxy must be a FakeTensor"
196218
acc_dtype = acc_proxy.dtype

test/test_dot_requirements.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from __future__ import annotations
2+
3+
import unittest
4+
from unittest.mock import patch
5+
6+
import torch
7+
8+
import helion
9+
from helion import _compat
10+
from helion._testing import DEVICE
11+
from helion._testing import RefEagerTestDisabled
12+
from helion._testing import TestCase
13+
import helion.language as hl
14+
15+
16+
class TestDotRequirements(RefEagerTestDisabled, TestCase):
17+
@patch.object(_compat, "_min_dot_size", lambda *args: (2, 8, 16))
18+
def test_hl_dot_sets_min_size(self) -> None:
19+
@helion.kernel
20+
def k_small(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
21+
m, k = x.size()
22+
k2, n = y.size()
23+
assert k == k2
24+
out = torch.empty([m, n], dtype=torch.float32, device=x.device)
25+
for tile_m, tile_n in hl.tile([m, n]):
26+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
27+
for tile_k in hl.tile(k):
28+
acc += hl.dot(x[tile_m, tile_k], y[tile_k, tile_n])
29+
out[tile_m, tile_n] = acc
30+
return out
31+
32+
m, k, n = 32, 4, 16
33+
args = (
34+
torch.randn([m, k], device=DEVICE, dtype=torch.float16),
35+
torch.randn([k, n], device=DEVICE, dtype=torch.float16),
36+
)
37+
spec = k_small.bind(args).config_spec
38+
self.assertEqual([x.min_size for x in spec.block_sizes], [2, 8, 16])
39+
40+
@patch.object(_compat, "_min_dot_size", lambda *args: (2, 8, 16))
41+
def test_matmul_sets_min_size(self) -> None:
42+
@helion.kernel
43+
def k_small(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
44+
m, k = x.size()
45+
k2, n = y.size()
46+
assert k == k2
47+
out = torch.empty([m, n], dtype=torch.float32, device=x.device)
48+
for tile_m, tile_n in hl.tile([m, n]):
49+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
50+
for tile_k in hl.tile(k):
51+
acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n])
52+
out[tile_m, tile_n] = acc
53+
return out
54+
55+
m, k, n = 32, 4, 16
56+
args = (
57+
torch.randn([m, k], device=DEVICE, dtype=torch.float16),
58+
torch.randn([k, n], device=DEVICE, dtype=torch.float16),
59+
)
60+
spec = k_small.bind(args).config_spec
61+
self.assertEqual([x.min_size for x in spec.block_sizes], [2, 8, 16])
62+
63+
64+
if __name__ == "__main__":
65+
unittest.main()

0 commit comments

Comments
 (0)