|
2 | 2 |
|
3 | 3 | import collections |
4 | 4 | from typing import TYPE_CHECKING |
| 5 | +from typing import cast |
5 | 6 |
|
6 | 7 | import torch |
7 | 8 |
|
|
15 | 16 |
|
16 | 17 | from .._compiler.inductor_lowering import CodegenState |
17 | 18 |
|
18 | | -__all__ = ["subscript"] |
| 19 | +__all__ = ["join", "split", "subscript"] |
19 | 20 |
|
20 | 21 |
|
21 | 22 | @_decorators.api(tiles_as_sizes=True) |
@@ -114,3 +115,93 @@ def _(node: torch.fx.Node) -> float | bool | None: |
114 | 115 | other = node.args[0] |
115 | 116 | assert isinstance(other, torch.fx.Node) |
116 | 117 | return cached_masked_value(other) |
| 118 | + |
| 119 | + |
| 120 | +@_decorators.api(is_device_only=True) |
| 121 | +def split(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| 122 | + """ |
| 123 | + Split the last dimension of a tensor with size two into two separate tensors. |
| 124 | +
|
| 125 | + Args: |
| 126 | + tensor: The input tensor whose last dimension has length two. |
| 127 | +
|
| 128 | + Returns: |
| 129 | + A tuple ``(lo, hi)`` where each tensor has the same shape as ``tensor`` |
| 130 | + without its last dimension. |
| 131 | +
|
| 132 | + See Also: |
| 133 | + - :func:`~helion.language.join` |
| 134 | + """ |
| 135 | + raise NotInsideKernel |
| 136 | + |
| 137 | + |
| 138 | +@_decorators.register_fake(split) |
| 139 | +def _(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| 140 | + out_shape = tensor.shape[:-1] |
| 141 | + return ( |
| 142 | + tensor.new_empty(out_shape), |
| 143 | + tensor.new_empty(out_shape), |
| 144 | + ) |
| 145 | + |
| 146 | + |
| 147 | +@_decorators.codegen(split) |
| 148 | +def _(state: CodegenState) -> list[ast.AST]: |
| 149 | + split_call = expr_from_string("tl.split({tensor})", tensor=state.ast_arg(0)) |
| 150 | + return [ |
| 151 | + expr_from_string("{value}[0]", value=split_call), |
| 152 | + expr_from_string("{value}[1]", value=split_call), |
| 153 | + ] |
| 154 | + |
| 155 | + |
| 156 | +@_decorators.ref(split) |
| 157 | +def _(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| 158 | + return cast("tuple[torch.Tensor, torch.Tensor]", torch.unbind(tensor, dim=-1)) |
| 159 | + |
| 160 | + |
| 161 | +@_decorators.api(is_device_only=True) |
| 162 | +def join( |
| 163 | + tensor0: torch.Tensor, |
| 164 | + tensor1: torch.Tensor, |
| 165 | +) -> torch.Tensor: |
| 166 | + """ |
| 167 | + Join two tensors along a new minor dimension. |
| 168 | +
|
| 169 | + Args: |
| 170 | + tensor0: First tensor to join. |
| 171 | + tensor1: Second tensor to join. Must be broadcast-compatible with |
| 172 | + ``tensor0``. |
| 173 | +
|
| 174 | + Returns: |
| 175 | + torch.Tensor: A tensor with shape ``broadcast_shape + (2,)`` where |
| 176 | + ``broadcast_shape`` is the broadcast of the input shapes. |
| 177 | +
|
| 178 | + See Also: |
| 179 | + - :func:`~helion.language.split` |
| 180 | + """ |
| 181 | + raise NotInsideKernel |
| 182 | + |
| 183 | + |
| 184 | +@_decorators.register_fake(join) |
| 185 | +def _(tensor0: torch.Tensor, tensor1: torch.Tensor) -> torch.Tensor: |
| 186 | + if tensor0.dtype != tensor1.dtype: |
| 187 | + raise TypeError("join() requires both tensors to have the same dtype") |
| 188 | + if tensor0.device != tensor1.device: |
| 189 | + raise ValueError("join() requires both tensors to be on the same device") |
| 190 | + |
| 191 | + broadcast_shape = torch.broadcast_shapes(tensor0.shape, tensor1.shape) |
| 192 | + return tensor0.new_empty([*broadcast_shape, 2]) |
| 193 | + |
| 194 | + |
| 195 | +@_decorators.codegen(join) |
| 196 | +def _(state: CodegenState) -> ast.AST: |
| 197 | + return expr_from_string( |
| 198 | + "tl.join({tensor0}, {tensor1})", |
| 199 | + tensor0=state.ast_arg(0), |
| 200 | + tensor1=state.ast_arg(1), |
| 201 | + ) |
| 202 | + |
| 203 | + |
| 204 | +@_decorators.ref(join) |
| 205 | +def _(tensor0: torch.Tensor, tensor1: torch.Tensor) -> torch.Tensor: |
| 206 | + left, right = torch.broadcast_tensors(tensor0, tensor1) |
| 207 | + return torch.stack((left, right), dim=-1) |
0 commit comments