Skip to content

Commit d21a7e7

Browse files
wconstabpytorchmergebot
authored andcommitted
Assert TensorBox produced by lowering and add [Note: Inductor IR] (pytorch#94361)
Pull Request resolved: pytorch#94361 Approved by: https://github.com/jansel
1 parent 01de5dd commit d21a7e7

File tree

2 files changed

+62
-1
lines changed

2 files changed

+62
-1
lines changed

torch/_inductor/ir.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,63 @@
4646
indent = functools.partial(textwrap.indent, prefix=" ")
4747
aten = torch.ops.aten
4848

49+
""" [Note: Inductor IR]
50+
51+
Inductor's IR is produced by executing 'lowering' code (see lowering.py). Each
52+
lowering is registered to a particular aten operator, and expects inputs that
53+
correspond to the aten schema. However, in place of torch Tensor inputs, lowerings
54+
expect Inductor TensorBox inputs.
55+
56+
TensorBox IR represents torch tensors. Tensors are sometimes single objects owning
57+
storage, and sometimes views of another Tensor's storage. Mutating tensor operations
58+
(such as add_()) affect the underlying storage and any associated views. Other operations
59+
(such as .t_()) update metadata about the current view but don't modify the underlying storage.
60+
61+
To model this in Inductor, the IR distinguishes between TensorBox, View, StorageBox and Buffer.
62+
63+
TensorBox is the top level IR construct that any lowering should produce and maps to a torch.Tensor
64+
output from an operation. But just as torch.Tensors take different forms, TensorBox IR can
65+
reference View IR or directly reference StorageBox IRs.
66+
67+
Some Inductor lowerings produce new sets of 'Box'es, while others (such as .t() or other view ops)
68+
may take an existing TensorBox and point it to a new underlying View IR.
69+
70+
Tensors that directly own storage are represented as a chain of:
71+
TensorBox -> StorageBox -> Buffer
72+
where Buffer is a simple (1D) allocation, and StorageBox introduces the concept of a Layout.
73+
74+
If you mutate the data of such a tensor, we swing the StorageBox pointer to point to a new buffer
75+
(leaving the old buffer unmodified and functionalizing the operation).
76+
77+
Tensors backed by views add one more indirection to the IR.
78+
TensorBox -> View -> StorageBox -> Buffer
79+
In these cases, the underlying StorageBox/Buffer will be shared with the pre-view TensorBox.
80+
81+
For metadata mutation (e.g. as_strided_) we swing the TensorBox pointer.
82+
"""
83+
84+
85+
def validate_ir(node_or_nodes):
86+
def _check_tensorbox(node):
87+
# Could expand this to check deeper properties
88+
# (e.g. TensorBox points to View or StorageBox)
89+
assert isinstance(
90+
node,
91+
(
92+
TensorBox,
93+
RandSeedBuffer,
94+
torch.fx.experimental.symbolic_shapes.Symbol,
95+
sympy.core.numbers.Expr,
96+
),
97+
), f"Found {type(node)}, which is not a supported top level IR node. See [Note: Inductor IR]"
98+
99+
# Be picky about the accepted data structure (don't use pytree here)
100+
if isinstance(node_or_nodes, (List, Tuple)):
101+
for node in node_or_nodes:
102+
_check_tensorbox(node)
103+
else:
104+
_check_tensorbox(node_or_nodes)
105+
49106

50107
def inverse_reorder(order):
51108
inv_order = dict(zip(order, range(len(order))))

torch/_inductor/lowering.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
Reduction,
3434
SqueezeView,
3535
TensorBox,
36+
validate_ir,
3637
View,
3738
)
3839
from .utils import ceildiv, sympy_product
@@ -221,7 +222,10 @@ def wrapped(*args, **kwargs):
221222
args[i], list(args[indices[0]].get_size())
222223
)
223224

224-
return decomp_fn(*args, **kwargs)
225+
out = decomp_fn(*args, **kwargs)
226+
validate_ir(out)
227+
228+
return out
225229

226230
if not isinstance(aten_fn, (list, tuple)):
227231
aten_fn = [aten_fn]

0 commit comments

Comments
 (0)