Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions helion/_compiler/inductor_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,16 @@ def convert_arg(arg: Node) -> TensorBox:
if len(result) > 1 and nodes:
last_node = nodes[-1] # The last node is the main node
output_nodes = {}
extra_deps = []
for n in nodes:
if "output_index" in n.meta:
output_nodes[n.meta["output_index"]] = n.name
if n is not last_node and n not in last_node._input_nodes:
extra_deps.append(n)
last_node.meta["output_nodes"] = output_nodes
if extra_deps:
# Need to ensure that the last node depends on all output nodes to prevent DCE issues
last_node.kwargs = {**last_node.kwargs, "_extra_deps": extra_deps}


def strip_unused_inputs(
Expand Down Expand Up @@ -371,7 +377,8 @@ def visit(n: torch.fx.Node) -> None:
device_function: DeviceFunction = ctx.cg.device_function
ndim: int = max([x.ndim for x in self.input_fake_tensors(node)] or (0,))
input_asts: list[ast.AST] = []
map_arg((node.args, node.kwargs), visit)
# _extra_deps should not be included in the inductor node inputs
map_arg((node.args, {**node.kwargs, "_extra_deps": None}), visit)
assert len(input_asts) == len(self.input_names)
return input_asts

Expand Down Expand Up @@ -411,9 +418,7 @@ def install_inductor_kernel_handlers(
"split_reductions": False,
}
),
V.set_graph_handler(
GraphLowering(dummy_gm(), shape_env=CompileEnvironment.current().shape_env)
),
V.set_graph_handler(FakeGraphLowering()),
V.set_ops_handler(
GenerateASTFromInductor(
cg,
Expand All @@ -432,6 +437,14 @@ def dummy_gm() -> torch.fx.GraphModule:
return torch.fx.symbolic_trace(lambda: None)


class FakeGraphLowering(GraphLowering):
def __init__(self) -> None:
env = CompileEnvironment.current()
super().__init__(dummy_gm(), shape_env=env.shape_env)
# Set the device directly on the graph_lowering to ensure get_current_device_or_throw() works
self._current_device = env.device


class PointwiseLowering(InductorLowering):
def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object:
with self.install_kernel_handlers(ctx, node):
Expand Down
30 changes: 16 additions & 14 deletions helion/_compiler/reduction_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch._inductor.ir import get_reduction_combine_fn
from torch._inductor.runtime.runtime_utils import next_power_of_2
from torch._inductor.utils import triton_type
from torch._prims_common import get_computation_dtype

from ..autotuner.config_fragment import integer_power_of_two
from .ast_extension import create
Expand Down Expand Up @@ -292,22 +293,23 @@ def codegen_reduction(
fake_input: torch.Tensor,
fake_output: torch.Tensor,
) -> ast.AST:
device_loop = state.codegen.active_device_loops[self.block_index][-1]
assert isinstance(device_loop, DeviceLoopState)
shape = self.fn.tile_strategy.shape_str([*fake_input.size()])
default = ir.Reduction.default_accumulator(reduction_type, fake_input.dtype)
assert isinstance(default, (float, int, bool))
assert state.fx_node is not None
acc = self.fn.new_var(f"{state.fx_node.name}_acc", dce=True)
device_loop.outer_prefix.append(
statement_from_string(
f"{acc} = tl.full({shape}, {constant_repr(default)}, {triton_acc_type(fake_input.dtype)})"
)
)
result = self.fn.new_var(state.fx_node.name, dce=True)
with install_inductor_kernel_handlers(state.codegen, {}):
device_loop = state.codegen.active_device_loops[self.block_index][-1]
assert isinstance(device_loop, DeviceLoopState)
shape = self.fn.tile_strategy.shape_str([*fake_input.size()])
acc_dtype = get_computation_dtype(fake_input.dtype) # promote fp16 to fp32
default = ir.Reduction.default_accumulator(reduction_type, acc_dtype)
assert isinstance(default, (float, int, bool))
assert state.fx_node is not None
acc = self.fn.new_var(f"{state.fx_node.name}_acc", dce=True)
device_loop.outer_prefix.append(
statement_from_string(
f"{acc} = tl.full({shape}, {constant_repr(default)}, {triton_acc_type(acc_dtype)})"
)
)
result = self.fn.new_var(state.fx_node.name, dce=True)
if reduction_type not in {"argmin", "argmax"}:
combine_fn = get_reduction_combine_fn(reduction_type, fake_input.dtype)
combine_fn = get_reduction_combine_fn(reduction_type, acc_dtype)
state.add_statement(f"{acc} = {combine_fn(acc, input_name)}")
expr = self.call_reduction_function(
acc, reduction_type, dim, fake_input, fake_output
Expand Down
122 changes: 122 additions & 0 deletions test/test_reductions.expected
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,128 @@ def reduce_kernel(x: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor], o
_launcher(_reduce_kernel_kernel, (n,), x, out, out.size(0), x.size(0), x.size(1), out.stride(0), x.stride(0), x.stride(1), _m, _REDUCTION_BLOCK_1, num_warps=4, num_stages=3)
return out

--- assertExpectedJournal(TestReductions.test_fp16_var_mean)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from torch._inductor.runtime.triton_compat import libdevice
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _layer_norm_fwd_repro_kernel(x, weight, bias, out, eps, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
x_part = tl.load(x + (indices_0[:, None] * 64 + indices_1[None, :] * 1), None)
v_0 = x_part.to(tl.float32)
var_mean_extra = tl.reshape(tl.sum(v_0, 1), [_BLOCK_SIZE_0, 1])
v_1 = 64
v_2 = var_mean_extra / v_1.to(tl.float32)
v_3 = x_part.to(tl.float32)
v_4 = v_3 - v_2
v_5 = v_4 * v_4
var_mean_extra_2 = tl.reshape(tl.sum(v_5, 1), [_BLOCK_SIZE_0, 1])
v_6 = 64
v_7 = var_mean_extra_2 / v_6.to(tl.float32)
v_8 = v_7.to(tl.float16)
v_9 = v_2.to(tl.float16)
v_10 = x_part - v_9
v_11 = v_8.to(tl.float32)
v_12 = v_11 + eps
v_13 = libdevice.rsqrt(v_12)
v_14 = v_10.to(tl.float32)
v_15 = v_14 * v_13
load_1 = tl.load(weight + indices_1 * 1, None)
v_16 = load_1.to(tl.float32)
v_17 = v_16[None, :]
v_18 = v_15 * v_17
load_2 = tl.load(bias + indices_1 * 1, None)
v_19 = load_2.to(tl.float32)
v_20 = v_19[None, :]
v_21 = v_18 + v_20
v_22 = v_21.to(tl.float16)
tl.store(out + (indices_0[:, None] * 64 + indices_1[None, :] * 1), v_22, None)

def layer_norm_fwd_repro(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher):
m, n = x.size()
out = torch.empty([m, n], dtype=torch.float16, device=x.device)
_BLOCK_SIZE_0 = 32
_RDIM_SIZE_1 = 64
_launcher(_layer_norm_fwd_repro_kernel, (triton.cdiv(32, _BLOCK_SIZE_0),), x, weight, bias, out, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
return out

--- assertExpectedJournal(TestReductions.test_fp16_var_mean)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from torch._inductor.runtime.triton_compat import libdevice
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _layer_norm_fwd_repro_kernel(x, weight, bias, out, eps, _BLOCK_SIZE_0: tl.constexpr, _REDUCTION_BLOCK_1: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
var_mean_extra_acc = tl.full([_BLOCK_SIZE_0, _REDUCTION_BLOCK_1], 0, tl.float32)
for roffset_1 in tl.range(0, 64, _REDUCTION_BLOCK_1):
rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32)
x_part = tl.load(x + (indices_0[:, None] * 64 + rindex_1[None, :] * 1), None)
v_0 = x_part.to(tl.float32)
v_1 = var_mean_extra_acc + v_0
var_mean_extra_acc = v_1
var_mean_extra = tl.reshape(tl.sum(var_mean_extra_acc, 1), [_BLOCK_SIZE_0, 1])
v_2 = 64
v_3 = var_mean_extra / v_2.to(tl.float32)
var_mean_extra_2_acc = tl.full([_BLOCK_SIZE_0, _REDUCTION_BLOCK_1], 0, tl.float32)
for roffset_1 in tl.range(0, 64, _REDUCTION_BLOCK_1):
rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32)
v_3_copy = v_3
x_part_1 = tl.load(x + (indices_0[:, None] * 64 + rindex_1[None, :] * 1), None)
v_4 = x_part_1.to(tl.float32)
v_5 = v_4 - v_3_copy
v_6 = v_5 * v_5
v_7 = var_mean_extra_2_acc + v_6
var_mean_extra_2_acc = v_7
var_mean_extra_2 = tl.reshape(tl.sum(var_mean_extra_2_acc, 1), [_BLOCK_SIZE_0, 1])
v_8 = 64
v_9 = var_mean_extra_2 / v_8.to(tl.float32)
v_10 = v_9.to(tl.float16)
v_11 = v_3.to(tl.float16)
v_12 = v_10.to(tl.float32)
v_13 = v_12 + eps
v_14 = libdevice.rsqrt(v_13)
for roffset_1 in tl.range(0, 64, _REDUCTION_BLOCK_1):
rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32)
v_11_copy = v_11
v_14_copy = v_14
x_part_2 = tl.load(x + (indices_0[:, None] * 64 + rindex_1[None, :] * 1), None)
v_15 = x_part_2 - v_11_copy
v_16 = v_15.to(tl.float32)
v_17 = v_16 * v_14_copy
load_1 = tl.load(weight + rindex_1 * 1, None)
v_18 = load_1.to(tl.float32)
v_19 = v_18[None, :]
v_20 = v_17 * v_19
load_2 = tl.load(bias + rindex_1 * 1, None)
v_21 = load_2.to(tl.float32)
v_22 = v_21[None, :]
v_23 = v_20 + v_22
v_24 = v_23.to(tl.float16)
tl.store(out + (indices_0[:, None] * 64 + rindex_1[None, :] * 1), v_24, None)

def layer_norm_fwd_repro(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher):
m, n = x.size()
out = torch.empty([m, n], dtype=torch.float16, device=x.device)
_BLOCK_SIZE_0 = 32
_REDUCTION_BLOCK_1 = 8
_launcher(_layer_norm_fwd_repro_kernel, (triton.cdiv(32, _BLOCK_SIZE_0),), x, weight, bias, out, eps, _BLOCK_SIZE_0, _REDUCTION_BLOCK_1, num_warps=4, num_stages=3)
return out

--- assertExpectedJournal(TestReductions.test_mean)
def reduce_kernel(x: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor], out_dtype=torch.float32):
# Call: SequenceType((SymIntType(s77), SymIntType(s27))) SourceOrigin(location=<SourceLocation test_reductions.py:48>)
Expand Down
42 changes: 42 additions & 0 deletions test/test_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,48 @@ def layer_norm_reduction(
)
self.assertExpectedJournal(code)

def test_fp16_var_mean(self):
@helion.kernel(static_shapes=True)
def layer_norm_fwd_repro(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float = 1e-5,
) -> torch.Tensor:
m, n = x.size()
out = torch.empty([m, n], dtype=torch.float16, device=x.device)
for tile_m in hl.tile(m):
x_part = x[tile_m, :]
var, mean = torch.var_mean(x_part, dim=-1, keepdim=True, correction=0)
normalized = (x_part - mean) * torch.rsqrt(var.to(torch.float32) + eps)
out[tile_m, :] = normalized * (weight[:].to(torch.float32)) + (
bias[:].to(torch.float32)
)
return out

batch_size = 32
dim = 64
x = torch.randn([batch_size, dim], device=DEVICE, dtype=torch.float16)
weight = torch.randn([dim], device=DEVICE, dtype=torch.float16)
bias = torch.randn([dim], device=DEVICE, dtype=torch.float16)
eps = 1e-4
code1, result1 = code_and_output(
layer_norm_fwd_repro,
(x, weight, bias, eps),
block_sizes=[32],
reduction_loops=[None],
)
self.assertExpectedJournal(code1)

code2, result2 = code_and_output(
layer_norm_fwd_repro,
(x, weight, bias, eps),
block_sizes=[32],
reduction_loops=[8],
)
self.assertExpectedJournal(code2)
torch.testing.assert_close(result1, result2, rtol=1e-3, atol=1e-3)


if __name__ == "__main__":
unittest.main()
Loading