Skip to content

Commit 0a94d35

Browse files
sabaumaSpenser Bauman
and
Spenser Bauman
authored
[mlir][tosa] Fix tosa-infer-shapes crash (#87234)
The tosa-infer-shapes pass inserts tensor.cast operations to mediate refined result types with consumers whose types cannot be refined. This process interferes with how types are refined in tosa.while_loop body regions, where types are propagated speculatively (to determine the types of the tosa.yield terminator) and then reverted. The new tosa.cast operations result in a crash due to not having types associated to them for the reversion process. This change modifies the shape propagation behavior so that the introduction to tensor.cast operations behaves better with this type reversion process. The new behavior is to only introduce tensor.cast operations once we wish to commit the newly computed types to the IR. This is an example causing the crash: ```mlir func.func @while_dont_crash(%arg0 : tensor<i32>) -> (tensor<*xi32>) { %0 = tosa.add %arg0, %arg0 : (tensor<i32>, tensor<i32>) -> tensor<*xi32> %1 = tosa.while_loop (%arg1 = %0) : (tensor<*xi32>) -> tensor<*xi32> { %2 = "tosa.const"() <{value = dense<3> : tensor<i32>}> : () -> tensor<i32> %3 = tosa.greater_equal %2, %arg1 : (tensor<i32>, tensor<*xi32>) -> tensor<*xi1> tosa.yield %3 : tensor<*xi1> } do { ^bb0(%arg1: tensor<*xi32>): // Inferrable operation whose type will refine to tensor<i32> %3 = tosa.add %arg1, %arg1 : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> // Non-inferrable use site, will require the cast: // tensor.cast %3 : tensor<i32> to tensor<*xi32> // // The new cast operation will result in accessing undefined memory through // originalTypeMap in the C++ code. "use"(%3) : (tensor<*xi32>) -> () tosa.yield %3 : tensor<*xi32> } return %1 : tensor<*xi32> } ``` The `tensor.cast` operation inserted in the loop body causes a failure in the code which resets the types after propagation through the loop body: ```c++ // The types inferred in the block assume the operand types specified for // this iteration. We need to restore the original types to ensure that // future iterations only use the already specified types, not possible // types from previous iterations. for (auto &block : bodyRegion) { for (auto arg : block.getArguments()) arg.setType(originalTypeMap[arg]); for (auto &op : block) for (auto result : op.getResults()) result.setType(originalTypeMap[result]); // problematic access } ``` --------- Co-authored-by: Spenser Bauman <sabauma@fastmail>
1 parent e61d6b7 commit 0a94d35

File tree

2 files changed

+195
-96
lines changed

2 files changed

+195
-96
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp

Lines changed: 103 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,9 @@
1818
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
1919
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
2020
#include "mlir/IR/Builders.h"
21-
#include "mlir/IR/BuiltinOps.h"
22-
#include "mlir/IR/IRMapping.h"
23-
#include "mlir/IR/Matchers.h"
2421
#include "mlir/Interfaces/InferTypeOpInterface.h"
2522
#include "mlir/Pass/Pass.h"
2623
#include "mlir/Transforms/DialectConversion.h"
27-
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28-
#include "llvm/Support/FormatVariadic.h"
2924

3025
namespace mlir {
3126
namespace tosa {
@@ -39,9 +34,87 @@ using namespace mlir::tosa;
3934

4035
namespace {
4136

42-
void propagateShapesInRegion(Region &region);
37+
// Check whether this use case is replaceable. We define an op as
38+
// being replaceable if it is used by a TosaOp, or an op with a
39+
// type-inference related interface.
40+
// When a non-replaceable use is encountered, the value is wrapped in a
41+
// cast back to the original type after inference.
42+
bool isReplaceableUser(Operation *user) {
43+
// Handle unregistered dialects.
44+
if (!user->getDialect())
45+
return false;
46+
47+
return user->getDialect()->getNamespace() ==
48+
TosaDialect::getDialectNamespace() ||
49+
isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
50+
}
51+
52+
// During type propagation, the types of values in the operator graph are
53+
// updated. For the tosa.while_loop operation, types are speculatively updated
54+
// within the body region to determine the output type of the while_loop. This
55+
// process is performed until a fixed point is reached, then the types are
56+
// reverted.
57+
//
58+
// This class encapsulates the state information needed to perform the reversion
59+
// process or to commit to the final changes.
60+
class TypeModificationState {
61+
public:
62+
TypeModificationState() = default;
63+
64+
~TypeModificationState() {
65+
// Ensure the recorded modifications are either committed or reverted.
66+
assert(oldTypes.empty() && "unhandled type modifications");
67+
}
68+
69+
// Update the state of the value and record the old type.
70+
void setType(Value value, Type type) {
71+
if (value.getType() != type) {
72+
oldTypes.emplace_back(value, value.getType());
73+
value.setType(type);
74+
}
75+
}
4376

44-
void propagateShapesToTosaIf(Operation &op) {
77+
// Revert changes made to the types in the IR by setting all the affected
78+
// values to their old types.
79+
void revert() {
80+
// Otherwise revert the changes.
81+
for (auto [value, type] : oldTypes)
82+
value.setType(type);
83+
84+
oldTypes.clear();
85+
}
86+
87+
// Commit the changes to the types in the IR.
88+
// This requires inserting tensor.cast operations to mediate the newly
89+
// inferred result types with users that do not support type inference.
90+
void commit() {
91+
// For each use whose type changed, cast the value with the new type back to
92+
// the old type.
93+
for (auto [value, oldType] : oldTypes) {
94+
for (auto &use : value.getUses()) {
95+
if (isReplaceableUser(use.getOwner()))
96+
continue;
97+
98+
OpBuilder builder(value.getContext());
99+
builder.setInsertionPoint(use.getOwner());
100+
101+
Location loc = value.getLoc();
102+
use.set(builder.create<tensor::CastOp>(loc, oldType, value));
103+
}
104+
}
105+
106+
oldTypes.clear();
107+
}
108+
109+
private:
110+
// A record of each value whose type was updated along with that value's
111+
// previous type.
112+
llvm::SmallVector<std::pair<Value, Type>> oldTypes;
113+
};
114+
115+
void propagateShapesInRegion(Region &region, TypeModificationState &state);
116+
117+
void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
45118
IfOp ifOp = dyn_cast<IfOp>(op);
46119
if (!ifOp)
47120
return;
@@ -58,7 +131,7 @@ void propagateShapesToTosaIf(Operation &op) {
58131

59132
if (inferredTy.hasRank()) {
60133
Type newType = oldType.clone(inferredTy.getShape());
61-
blockArg.setType(newType);
134+
state.setType(blockArg, newType);
62135
}
63136
}
64137

@@ -71,64 +144,44 @@ void propagateShapesToTosaIf(Operation &op) {
71144
ValueKnowledge::join(operandKnowledge, blockKnowledge);
72145
if (!joinedKnowledge)
73146
continue;
74-
frontBlock.getArgument(i).setType(joinedKnowledge.getType());
147+
state.setType(frontBlock.getArgument(i), joinedKnowledge.getType());
75148
}
76149

77-
propagateShapesInRegion(region);
150+
propagateShapesInRegion(region, state);
78151
}
79152
}
80153

81-
void propagateShapesToTosaWhile(Operation &op) {
154+
void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
82155
WhileOp whileOp = dyn_cast<WhileOp>(op);
83156
if (!whileOp)
84157
return;
85158

86159
// Determine what the expected argument types are to the cond/body blocks.
87160
// The expected arguments should be compatible with ever iteration of the
88161
// loop body / condition for tosa.while.
89-
llvm::SmallVector<Type> argTypes;
90-
for (auto operand : op.getOperands()) {
91-
auto operandTy = cast<ShapedType>(operand.getType());
92-
if (operandTy.hasRank()) {
93-
auto newTy = operandTy.clone(operandTy.getShape());
94-
argTypes.push_back(newTy);
95-
} else {
96-
argTypes.push_back(operand.getType());
97-
}
98-
}
99-
100-
// Save out the type information so we can restore at the end.
101-
llvm::DenseMap<Value, Type> originalTypeMap;
102-
for (auto &block : op.getRegion(1)) {
103-
for (auto arg : block.getArguments())
104-
originalTypeMap[arg] = arg.getType();
105-
for (auto &op : block)
106-
for (auto result : op.getResults())
107-
originalTypeMap[result] = result.getType();
108-
}
162+
SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes());
109163

110164
bool hasNewTypes = true;
111165
while (hasNewTypes) {
166+
TypeModificationState localState;
112167

113168
// Set types on the block args.
114169
Region &bodyRegion = op.getRegion(1);
115170
Block &block = bodyRegion.front();
116171
for (int i = 0, s = argTypes.size(); i < s; i++) {
117-
block.getArgument(i).setType(argTypes[i]);
172+
localState.setType(block.getArgument(i), argTypes[i]);
118173
}
119174

120175
// Propagate to the end.
121-
propagateShapesInRegion(bodyRegion);
176+
propagateShapesInRegion(bodyRegion, localState);
122177

123-
// Find all the tosa yield types and verify there is atleast one.
178+
// Find all the tosa yield types and verify there is a single one.
124179
llvm::SmallVector<YieldOp> yieldOps;
125180
for (auto &block : bodyRegion)
126181
if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
127182
yieldOps.push_back(yieldOp);
128183

129-
if (yieldOps.empty())
130-
return;
131-
184+
assert(yieldOps.size() == 1 && "missing or non-unique yield op");
132185
// Using the new tosa.yield operand types, infer the new subtypes.
133186
llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
134187
for (auto ty : argTypes) {
@@ -158,59 +211,31 @@ void propagateShapesToTosaWhile(Operation &op) {
158211
argTypes[i] = newType;
159212
}
160213

161-
// The types inferred in the block assume the operand types specified for
162-
// this iteration. We need to restore the original types to ensure that
163-
// future iterations only use the already specified types, not possible
164-
// types from previous iterations.
165-
for (auto &block : bodyRegion) {
166-
for (auto arg : block.getArguments())
167-
arg.setType(originalTypeMap[arg]);
168-
for (auto &op : block)
169-
for (auto result : op.getResults())
170-
result.setType(originalTypeMap[result]);
171-
}
214+
// Revert all changes made during the speculative part of the algorithm.
215+
localState.revert();
172216
}
173217

174218
// We now set the block arguments according to the most recent shape
175219
// inference results. This gives us the block arg types for the next
176220
// iteration.
177221
for (auto &region : op.getRegions()) {
178222
for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
179-
region.front().getArgument(i).setType(argTypes[i]);
223+
state.setType(region.front().getArgument(i), argTypes[i]);
180224
}
181225

182-
propagateShapesInRegion(region);
226+
propagateShapesInRegion(region, state);
183227
}
184228
}
185229

186-
// Track the old type for each operand whose type was updated
187-
// during inference. This information is used to introduce casts
188-
// back to the type expected by the operand after inference.
189-
struct TypeRewriteInfo {
190-
OpOperand *operand;
191-
Type oldType;
192-
};
193-
194-
void propagateShapesInRegion(Region &region) {
195-
// Check whether this use case is replaceable. We define an op as
196-
// being replaceable if it is used by a TosaOp, or an op with a
197-
// type-inference related interface.
198-
// When a non-replaceable use is encountered, the value is wrapped in a
199-
// cast back to the original type after inference.
200-
auto isReplaceableUser = [](Operation *user) -> bool {
201-
return user->getDialect()->getNamespace() ==
202-
TosaDialect::getDialectNamespace() ||
203-
isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
204-
};
205-
206-
llvm::SmallVector<TypeRewriteInfo> requiresUpdate;
230+
void propagateShapesInRegion(Region &region, TypeModificationState &state) {
207231
for (auto &block : region) {
208232
for (Operation &op : block) {
209-
if (op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
233+
if (!op.getDialect() ||
234+
op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
210235
continue;
211236

212-
propagateShapesToTosaIf(op);
213-
propagateShapesToTosaWhile(op);
237+
propagateShapesToTosaIf(op, state);
238+
propagateShapesToTosaWhile(op, state);
214239

215240
InferShapedTypeOpInterface shapeInterface =
216241
dyn_cast<InferShapedTypeOpInterface>(op);
@@ -252,30 +277,11 @@ void propagateShapesInRegion(Region &region) {
252277
continue;
253278

254279
// Set new type
255-
result.setType(newKnowledge.getType());
256-
257-
// Collect all uses of the operation which require update.
258-
for (auto &user : result.getUses()) {
259-
if (!isReplaceableUser(user.getOwner()))
260-
requiresUpdate.push_back({&user, resultTy});
261-
}
280+
state.setType(result, newKnowledge.getType());
262281
}
263282
}
264283
}
265284
}
266-
267-
// For each use whose type changed, cast the value with the new type back to
268-
// the old type.
269-
IRRewriter rewriter(region.getContext());
270-
for (auto [operand, oldType] : requiresUpdate) {
271-
rewriter.setInsertionPoint(operand->getOwner());
272-
273-
auto oldValue = operand->get();
274-
275-
auto loc = oldValue.getLoc();
276-
auto castOp = rewriter.create<tensor::CastOp>(loc, oldType, oldValue);
277-
operand->set(castOp);
278-
}
279285
}
280286

281287
/// Pass that performs shape propagation across TOSA operations. This includes
@@ -285,7 +291,9 @@ struct TosaInferShapes
285291
public:
286292
void runOnOperation() override {
287293
func::FuncOp func = getOperation();
288-
propagateShapesInRegion(func.getBody());
294+
TypeModificationState state;
295+
propagateShapesInRegion(func.getBody(), state);
296+
state.commit();
289297
}
290298
};
291299
} // namespace

0 commit comments

Comments
 (0)