Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 3 additions & 13 deletions helion/_compiler/inductor_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@
from torch.fx.node import map_arg

from .. import exc
from .._compat import min_dot_size
from ..exc import InductorLoweringError
from ..language._decorators import APIFunc
from ..language._decorators import is_api_func
from ..language.matmul_ops import enforce_dot_requirements
from .ast_extension import ExtendedAST
from .ast_extension import create
from .ast_extension import expr_from_string
Expand Down Expand Up @@ -870,18 +870,8 @@ def apply_dot_requirements(
lproxy, rproxy = map_arg(node.args[-2:], lambda arg: arg.meta["val"])
assert isinstance(lproxy, torch.Tensor)
assert isinstance(rproxy, torch.Tensor)
lshape = lproxy.size()
rshape = rproxy.size()
# use last two dimensions for dot (supports 2D and batched 3D tensors)
m, k = lshape[-2], lshape[-1]
k2, n = rshape[-2], rshape[-1]
assert k == k2, f"Mismatched k dimensions for dot: {k} vs {k2}"
a, b, c = min_dot_size(lproxy.device, lproxy.dtype, rproxy.dtype)
env = CompileEnvironment.current()
for shape, min_size in [(m, a), (n, b), (k, c)]:
block_idx = CompileEnvironment.current().get_block_id(shape)
if block_idx is not None:
env.block_sizes[block_idx].update_min_block(min_size, allow_flattened=True)
# Update config spec min sizes for M, N, K
enforce_dot_requirements(lproxy, rproxy)
# inputs to the dot operation must be zero-masked
*maybe_acc, lnode, rnode = node.args
assert isinstance(lnode, torch.fx.Node)
Expand Down
34 changes: 28 additions & 6 deletions helion/language/matmul_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from torch._subclasses.fake_tensor import FakeTensor

from .. import exc
from .._compat import min_dot_size
from .._compiler.compile_environment import CompileEnvironment
from . import _decorators

if TYPE_CHECKING:
Expand Down Expand Up @@ -126,9 +128,34 @@ def _(
f"hl.dot: acc shape {list(acc.shape)} incompatible with result shape {expected_shape}"
)

# Apply min-dot-size constraints so autotuner won't pick invalid block_size
enforce_dot_requirements(mat1, mat2)

return (mat1, mat2, acc)


def enforce_dot_requirements(lhs: torch.Tensor, rhs: torch.Tensor) -> None:
"""Update config-spec min sizes for M, N, K of a dot/matmul.

This ensures the autotuner does not select block sizes below the hardware
minimums for the current device and dtypes.
"""

# Last two dims are used for matmul
lshape = lhs.size()
rshape = rhs.size()
m, k = lshape[-2], lshape[-1]
k2, n = rshape[-2], rshape[-1]
assert k == k2, f"Mismatched K dimensions for dot: {k} vs {k2}"

a, b, c = min_dot_size(lhs.device, lhs.dtype, rhs.dtype)
env = CompileEnvironment.current()
for shape, min_size in ((m, a), (n, b), (k, c)):
block_idx = env.get_block_id(shape)
if block_idx is not None:
env.block_sizes[block_idx].update_min_block(min_size, allow_flattened=True)


def _compute_out_dtype(
mat1_dtype: torch.dtype,
mat2_dtype: torch.dtype,
Expand Down Expand Up @@ -167,7 +194,6 @@ def _(
def _(state: CodegenState) -> object:
# Import here to avoid circular imports
from .._compiler.ast_extension import expr_from_string
from .._compiler.compile_environment import CompileEnvironment

# Get the AST representations of our arguments
lhs_ast = state.ast_arg(0)
Expand All @@ -182,15 +208,11 @@ def _(state: CodegenState) -> object:
acc_proxy = state.proxy_args[2] if len(state.proxy_args) > 2 else None

# Access dtype - proxy_args can be FakeTensor objects
lhs_dtype = None
rhs_dtype = None
acc_dtype = None

# For FakeTensor objects, dtype is directly accessible
lhs_dtype = lhs_proxy.dtype
rhs_dtype = rhs_proxy.dtype

# Get accumulator dtype if available
acc_dtype: torch.dtype | None = None
if acc_proxy is not None:
assert isinstance(acc_proxy, FakeTensor), "acc_proxy must be a FakeTensor"
acc_dtype = acc_proxy.dtype
Expand Down
65 changes: 65 additions & 0 deletions test/test_dot_requirements.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from __future__ import annotations

import unittest
from unittest.mock import patch

import torch

import helion
from helion import _compat
from helion._testing import DEVICE
from helion._testing import RefEagerTestDisabled
from helion._testing import TestCase
import helion.language as hl


class TestDotRequirements(RefEagerTestDisabled, TestCase):
@patch.object(_compat, "_min_dot_size", lambda *args: (2, 8, 16))
def test_hl_dot_sets_min_size(self) -> None:
@helion.kernel
def k_small(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
m, k = x.size()
k2, n = y.size()
assert k == k2
out = torch.empty([m, n], dtype=torch.float32, device=x.device)
for tile_m, tile_n in hl.tile([m, n]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
for tile_k in hl.tile(k):
acc += hl.dot(x[tile_m, tile_k], y[tile_k, tile_n])
out[tile_m, tile_n] = acc
return out

m, k, n = 32, 4, 16
args = (
torch.randn([m, k], device=DEVICE, dtype=torch.float16),
torch.randn([k, n], device=DEVICE, dtype=torch.float16),
)
spec = k_small.bind(args).config_spec
self.assertEqual([x.min_size for x in spec.block_sizes], [2, 8, 16])

@patch.object(_compat, "_min_dot_size", lambda *args: (2, 8, 16))
def test_matmul_sets_min_size(self) -> None:
@helion.kernel
def k_small(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
m, k = x.size()
k2, n = y.size()
assert k == k2
out = torch.empty([m, n], dtype=torch.float32, device=x.device)
for tile_m, tile_n in hl.tile([m, n]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
for tile_k in hl.tile(k):
acc += torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n])
out[tile_m, tile_n] = acc
return out

m, k, n = 32, 4, 16
args = (
torch.randn([m, k], device=DEVICE, dtype=torch.float16),
torch.randn([k, n], device=DEVICE, dtype=torch.float16),
)
spec = k_small.bind(args).config_spec
self.assertEqual([x.min_size for x in spec.block_sizes], [2, 8, 16])


if __name__ == "__main__":
unittest.main()
Loading