Skip to content

Commit ca832fa

Browse files
authored
Add hl.split and hl.join (#791)
1 parent e5320e8 commit ca832fa

File tree

6 files changed

+173
-2
lines changed

6 files changed

+173
-2
lines changed

docs/api/index.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ runtime
8888
full
8989
arange
9090
subscript
91+
split
92+
join
9193
reduce
9294
associative_scan
9395
cumsum

docs/api/language.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,18 @@ The `Tile` class represents a portion of an iteration space with the following k
194194
.. autofunction:: subscript
195195
```
196196

197+
### split()
198+
199+
```{eval-rst}
200+
.. autofunction:: split
201+
```
202+
203+
### join()
204+
205+
```{eval-rst}
206+
.. autofunction:: join
207+
```
208+
197209
## StackTensor
198210
### StackTensor class
199211
```{eval-rst}

helion/_compiler/roll_reduction.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from ..language.matmul_ops import dot as hl_dot
2323
from ..language.memory_ops import store
2424
from ..language.reduce_ops import _reduce
25+
from ..language.view_ops import join as hl_join
26+
from ..language.view_ops import split as hl_split
2527
from .compile_environment import CompileEnvironment
2628
from .inductor_lowering import APIFuncLowering
2729
from .inductor_lowering import ReductionLowering
@@ -119,6 +121,28 @@ def should_go_in_inner_graph(self, node: torch.fx.Node) -> bool:
119121
return self.should_go_in_inner_graph(arg)
120122
return False
121123

124+
if node.target is hl_split:
125+
base = node.args[0]
126+
if isinstance(base, torch.fx.Node):
127+
return self.should_go_in_inner_graph(base)
128+
return False
129+
130+
if node.target is operator.getitem:
131+
base = node.args[0]
132+
if isinstance(base, torch.fx.Node) and base.target is hl_split:
133+
return self.should_go_in_inner_graph(base)
134+
135+
if node.target is hl_join:
136+
left = node.args[0]
137+
right = node.args[1]
138+
left_inner = isinstance(
139+
left, torch.fx.Node
140+
) and self.should_go_in_inner_graph(left)
141+
right_inner = isinstance(
142+
right, torch.fx.Node
143+
) and self.should_go_in_inner_graph(right)
144+
return left_inner or right_inner
145+
122146
if self.is_reduction(node):
123147
return True
124148

@@ -178,8 +202,13 @@ def start_new_graph(self) -> None:
178202

179203
inner_nodes: dict[torch.fx.Node, torch.fx.Node] = self.inner_nodes
180204
outputs = {}
205+
inner_node_set = set(inner_nodes)
181206
for orig_node, inner_node in inner_nodes.items():
182-
if self.is_reduction(orig_node) and orig_node not in self.outer_nodes:
207+
needs_output = orig_node not in self.outer_nodes and (
208+
self.is_reduction(orig_node)
209+
or any(user not in inner_node_set for user in orig_node.users)
210+
)
211+
if needs_output:
183212
outputs[orig_node] = inner_node
184213
self.available.add(orig_node)
185214
graph = self.inner_graph

helion/language/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
from .tunable_ops import register_block_size as register_block_size
3939
from .tunable_ops import register_reduction_dim as register_reduction_dim
4040
from .tunable_ops import register_tunable as register_tunable
41+
from .view_ops import join as join
42+
from .view_ops import split as split
4143
from .view_ops import subscript as subscript
4244

4345
_MEMORY_OPS = (

helion/language/view_ops.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import collections
44
from typing import TYPE_CHECKING
5+
from typing import cast
56

67
import torch
78

@@ -15,7 +16,7 @@
1516

1617
from .._compiler.inductor_lowering import CodegenState
1718

18-
__all__ = ["subscript"]
19+
__all__ = ["join", "split", "subscript"]
1920

2021

2122
@_decorators.api(tiles_as_sizes=True)
@@ -114,3 +115,93 @@ def _(node: torch.fx.Node) -> float | bool | None:
114115
other = node.args[0]
115116
assert isinstance(other, torch.fx.Node)
116117
return cached_masked_value(other)
118+
119+
120+
@_decorators.api(is_device_only=True)
121+
def split(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
122+
"""
123+
Split the last dimension of a tensor with size two into two separate tensors.
124+
125+
Args:
126+
tensor: The input tensor whose last dimension has length two.
127+
128+
Returns:
129+
A tuple ``(lo, hi)`` where each tensor has the same shape as ``tensor``
130+
without its last dimension.
131+
132+
See Also:
133+
- :func:`~helion.language.join`
134+
"""
135+
raise NotInsideKernel
136+
137+
138+
@_decorators.register_fake(split)
139+
def _(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
140+
out_shape = tensor.shape[:-1]
141+
return (
142+
tensor.new_empty(out_shape),
143+
tensor.new_empty(out_shape),
144+
)
145+
146+
147+
@_decorators.codegen(split)
148+
def _(state: CodegenState) -> list[ast.AST]:
149+
split_call = expr_from_string("tl.split({tensor})", tensor=state.ast_arg(0))
150+
return [
151+
expr_from_string("{value}[0]", value=split_call),
152+
expr_from_string("{value}[1]", value=split_call),
153+
]
154+
155+
156+
@_decorators.ref(split)
157+
def _(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
158+
return cast("tuple[torch.Tensor, torch.Tensor]", torch.unbind(tensor, dim=-1))
159+
160+
161+
@_decorators.api(is_device_only=True)
162+
def join(
163+
tensor0: torch.Tensor,
164+
tensor1: torch.Tensor,
165+
) -> torch.Tensor:
166+
"""
167+
Join two tensors along a new minor dimension.
168+
169+
Args:
170+
tensor0: First tensor to join.
171+
tensor1: Second tensor to join. Must be broadcast-compatible with
172+
``tensor0``.
173+
174+
Returns:
175+
torch.Tensor: A tensor with shape ``broadcast_shape + (2,)`` where
176+
``broadcast_shape`` is the broadcast of the input shapes.
177+
178+
See Also:
179+
- :func:`~helion.language.split`
180+
"""
181+
raise NotInsideKernel
182+
183+
184+
@_decorators.register_fake(join)
185+
def _(tensor0: torch.Tensor, tensor1: torch.Tensor) -> torch.Tensor:
186+
if tensor0.dtype != tensor1.dtype:
187+
raise TypeError("join() requires both tensors to have the same dtype")
188+
if tensor0.device != tensor1.device:
189+
raise ValueError("join() requires both tensors to be on the same device")
190+
191+
broadcast_shape = torch.broadcast_shapes(tensor0.shape, tensor1.shape)
192+
return tensor0.new_empty([*broadcast_shape, 2])
193+
194+
195+
@_decorators.codegen(join)
196+
def _(state: CodegenState) -> ast.AST:
197+
return expr_from_string(
198+
"tl.join({tensor0}, {tensor1})",
199+
tensor0=state.ast_arg(0),
200+
tensor1=state.ast_arg(1),
201+
)
202+
203+
204+
@_decorators.ref(join)
205+
def _(tensor0: torch.Tensor, tensor1: torch.Tensor) -> torch.Tensor:
206+
left, right = torch.broadcast_tensors(tensor0, tensor1)
207+
return torch.stack((left, right), dim=-1)

test/test_views.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,41 @@ def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
156156
_code, result = code_and_output(fn, args)
157157
torch.testing.assert_close(result, args[0] + args[1])
158158

159+
def test_split_join_roundtrip(self):
160+
@helion.kernel(config={"block_size": 64})
161+
def fn(x: torch.Tensor) -> torch.Tensor:
162+
n = x.size(0)
163+
out = torch.empty_like(x)
164+
for tile in hl.tile(n):
165+
lo, hi = hl.split(x[tile, :])
166+
out[tile, :] = hl.join(hi, lo)
167+
return out
168+
169+
x = torch.randn([256, 2], device=DEVICE)
170+
code, result = code_and_output(fn, (x,))
171+
expected = torch.stack((x[:, 1], x[:, 0]), dim=-1)
172+
torch.testing.assert_close(result, expected)
173+
self.assertIn("tl.split", code)
174+
self.assertIn("tl.join", code)
175+
176+
def test_join_broadcast_scalar(self):
177+
@helion.kernel(config={"block_size": 64})
178+
def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
179+
n = x.size(0)
180+
out = torch.empty([n, 2], dtype=x.dtype, device=x.device)
181+
for tile in hl.tile(n):
182+
scalar = hl.load(y, [0])
183+
out[tile, :] = hl.join(x[tile], scalar)
184+
return out
185+
186+
x = torch.randn([128], device=DEVICE)
187+
y = torch.randn([1], device=DEVICE)
188+
code, result = code_and_output(fn, (x, y))
189+
broadcast_y = torch.broadcast_to(y, x.shape)
190+
expected = torch.stack((x, broadcast_y), dim=-1)
191+
torch.testing.assert_close(result, expected)
192+
self.assertIn("tl.join", code)
193+
159194
def test_reshape_input_types(self):
160195
@helion.kernel(static_shapes=True)
161196
def reshape_reduction_dim(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)