Skip to content

Commit 8763343

Browse files
[mlir][bufferization] Update empty_tensor_elimination transform op (#68497)
The empty tensor elimination pass semantics have changed recently: when applied to a module, the One-Shot Module Analysis is run. Otherwise, the regular One-Shot Analysis is run. The latter one is slightly different because it ignores function boundaries and treats function block arguments as "read-only". This commit updates the transform dialect op to behave in the same way.
1 parent 32f7197 commit 8763343

File tree

3 files changed

+21
-21
lines changed

3 files changed

+21
-21
lines changed

mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ struct OneShotBufferizationOptions;
3232
/// In the above example, the subset op is "tensor.insert_slice". When tracing
3333
/// back the reverse use-def chain of a the source, we end up at a
3434
/// "tensor.empty" op.
35+
LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op);
36+
37+
/// Try to eliminate "tensor.empty" ops inside `op`.
38+
///
39+
/// This function overload accepts an existing `OneShotAnalysisState`, which
40+
/// contains in-place bufferization decisions. This overload is useful if an
41+
/// existing analysis should be reused for empty tensor elimination.
3542
LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op,
3643
OneShotAnalysisState &state);
3744

mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -117,17 +117,10 @@ void transform::EliminateEmptyTensorsOp::getEffects(
117117
DiagnosedSilenceableFailure transform::EliminateEmptyTensorsOp::apply(
118118
transform::TransformRewriter &rewriter, TransformResults &transformResults,
119119
TransformState &state) {
120-
OneShotBufferizationOptions options;
121-
options.allowReturnAllocsFromLoops = true;
122-
123120
for (Operation *target : state.getPayloadOps(getTarget())) {
124-
OneShotAnalysisState state(target, options);
125-
if (failed(analyzeOp(target, state)))
126-
return mlir::emitSilenceableFailure(target->getLoc())
127-
<< "failed to analyze op";
128-
if (failed(bufferization::eliminateEmptyTensors(rewriter, target, state)))
121+
if (failed(bufferization::eliminateEmptyTensors(rewriter, target)))
129122
return mlir::emitSilenceableFailure(target->getLoc())
130-
<< "failed to eliminate insert_slice anchored tensor.empty ops";
123+
<< "empty tensor elimination failed";
131124
}
132125
return DiagnosedSilenceableFailure::success();
133126
}

mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,8 @@ struct EmptyTensorElimination
183183
};
184184
} // namespace
185185

186-
void EmptyTensorElimination::runOnOperation() {
187-
Operation *op = getOperation();
186+
LogicalResult mlir::bufferization::eliminateEmptyTensors(RewriterBase &rewriter,
187+
Operation *op) {
188188
auto moduleOp = dyn_cast<ModuleOp>(op);
189189
OneShotBufferizationOptions options;
190190
options.allowReturnAllocsFromLoops = true;
@@ -193,21 +193,21 @@ void EmptyTensorElimination::runOnOperation() {
193193
OneShotAnalysisState state(op, options);
194194
if (moduleOp) {
195195
// Module analysis takes into account function boundaries.
196-
if (failed(analyzeModuleOp(moduleOp, state))) {
197-
signalPassFailure();
198-
return;
199-
}
196+
if (failed(analyzeModuleOp(moduleOp, state)))
197+
return failure();
200198
} else {
201199
// Regular One-Shot Bufferize ignores func.func block arguments, func.call,
202200
// func.return.
203-
if (failed(analyzeOp(op, state))) {
204-
signalPassFailure();
205-
return;
206-
}
201+
if (failed(analyzeOp(op, state)))
202+
return failure();
207203
}
208204

209-
IRRewriter rewriter(op->getContext());
210-
if (failed(bufferization::eliminateEmptyTensors(rewriter, op, state)))
205+
return bufferization::eliminateEmptyTensors(rewriter, op, state);
206+
}
207+
208+
void EmptyTensorElimination::runOnOperation() {
209+
IRRewriter rewriter(getOperation()->getContext());
210+
if (failed(bufferization::eliminateEmptyTensors(rewriter, getOperation())))
211211
signalPassFailure();
212212
}
213213

0 commit comments

Comments
 (0)