-
Notifications
You must be signed in to change notification settings - Fork 15.5k
[mlir][Transforms] Legalize nested operations #172158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][Transforms] Legalize nested operations #172158
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesThis commit align the implementation of This function now legalizes the entire region, including nested ops. The implementation follows the same logic as the "main" traversal: pre-order, forward-dominance. Full diff: https://github.com/llvm/llvm-project/pull/172158.diff 2 Files Affected:
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 09ad42364baaf..d4b1c8c7f0a74 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2257,37 +2257,6 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
return success();
}
-LogicalResult ConversionPatternRewriter::legalize(Region *r) {
- // Fast path: If the region is empty, there is nothing to legalize.
- if (r->empty())
- return success();
-
- // Gather a list of all operations to legalize. This is done before
- // converting the entry block signature because unrealized_conversion_cast
- // ops should not be included.
- SmallVector<Operation *> ops;
- for (Block &b : *r)
- for (Operation &op : b)
- ops.push_back(&op);
-
- // If the current pattern runs with a type converter, convert the entry block
- // signature.
- if (const TypeConverter *converter = impl->currentTypeConverter) {
- std::optional<TypeConverter::SignatureConversion> conversion =
- converter->convertBlockSignature(&r->front());
- if (!conversion)
- return failure();
- applySignatureConversion(&r->front(), *conversion, converter);
- }
-
- // Legalize all operations in the region.
- for (Operation *op : ops)
- if (failed(legalize(op)))
- return failure();
-
- return success();
-}
-
void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
Block::iterator before,
ValueRange argValues) {
@@ -3287,8 +3256,20 @@ struct OperationConverter {
: rewriter(ctx, config, *this), opLegalizer(rewriter, target, patterns),
mode(mode) {}
- /// Converts the given operations to the conversion target.
- LogicalResult convertOperations(ArrayRef<Operation *> ops);
+ /// Applies the conversion to the given operations (and their nested
+ /// operations).
+ LogicalResult applyConversion(ArrayRef<Operation *> ops);
+
+ /// Legalizes the given operations (and their nested operations) to the
+ /// conversion target.
+ template <typename Fn>
+ LogicalResult legalizeOperations(ArrayRef<Operation *> ops, Fn onFailure,
+ bool isRecursiveLegalization = false);
+ LogicalResult legalizeOperations(ArrayRef<Operation *> ops,
+ bool isRecursiveLegalization = false) {
+ return legalizeOperations(
+ ops, /*onFailure=*/[&]() {}, isRecursiveLegalization);
+ }
/// Converts a single operation. If `isRecursiveLegalization` is "true", the
/// conversion is a recursive legalization request, triggered from within a
@@ -3297,6 +3278,8 @@ struct OperationConverter {
/// legalization mechanism).
LogicalResult convert(Operation *op, bool isRecursiveLegalization = false);
+ const ConversionTarget &getTarget() { return opLegalizer.getTarget(); }
+
private:
/// The rewriter to use when converting operations.
ConversionPatternRewriter rewriter;
@@ -3309,10 +3292,6 @@ struct OperationConverter {
};
} // namespace mlir
-LogicalResult ConversionPatternRewriter::legalize(Operation *op) {
- return impl->opConverter.convert(op, /*isRecursiveLegalization=*/true);
-}
-
LogicalResult OperationConverter::convert(Operation *op,
bool isRecursiveLegalization) {
const ConversionConfig &config = rewriter.getConfig();
@@ -3398,12 +3377,15 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
return failure();
}
-LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
+template <typename Fn>
+LogicalResult
+OperationConverter::legalizeOperations(ArrayRef<Operation *> ops, Fn onFailure,
+ bool isRecursiveLegalization) {
const ConversionTarget &target = opLegalizer.getTarget();
// Compute the set of operations and blocks to convert.
SmallVector<Operation *> toConvert;
- for (auto *op : ops) {
+ for (Operation *op : ops) {
op->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
[&](Operation *op) {
toConvert.push_back(op);
@@ -3415,25 +3397,67 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
return WalkResult::advance();
});
}
+ for (Operation *op : toConvert) {
+ if (failed(convert(op, isRecursiveLegalization))) {
+ // Failed to convert an operation.
+ onFailure();
+ return failure();
+ }
+ }
+ return success();
+}
- // Convert each operation and discard rewrites on failure.
- ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
+LogicalResult ConversionPatternRewriter::legalize(Operation *op) {
+ return impl->opConverter.legalizeOperations(op,
+ /*isRecursiveLegalization=*/true);
+}
- for (auto *op : toConvert) {
- if (failed(convert(op))) {
- // Dialect conversion failed.
- if (rewriterImpl.config.allowPatternRollback) {
- // Rollback is allowed: restore the original IR.
- rewriterImpl.undoRewrites();
- } else {
- // Rollback is not allowed: apply all modifications that have been
- // performed so far.
- rewriterImpl.applyRewrites();
- }
+LogicalResult ConversionPatternRewriter::legalize(Region *r) {
+ // Fast path: If the region is empty, there is nothing to legalize.
+ if (r->empty())
+ return success();
+
+ // Gather a list of all operations to legalize. This is done before
+ // converting the entry block signature because unrealized_conversion_cast
+ // ops should not be included.
+ SmallVector<Operation *> ops;
+ for (Block &b : *r)
+ for (Operation &op : b)
+ ops.push_back(&op);
+
+ // If the current pattern runs with a type converter, convert the entry block
+ // signature.
+ if (const TypeConverter *converter = impl->currentTypeConverter) {
+ std::optional<TypeConverter::SignatureConversion> conversion =
+ converter->convertBlockSignature(&r->front());
+ if (!conversion)
return failure();
- }
+ applySignatureConversion(&r->front(), *conversion, converter);
}
+ // Legalize all operations in the region. This includes all nested
+ // operations.
+ return impl->opConverter.legalizeOperations(ops,
+ /*isRecursiveLegalization=*/true);
+}
+
+LogicalResult OperationConverter::applyConversion(ArrayRef<Operation *> ops) {
+ // Convert each operation and discard rewrites on failure.
+ ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
+ LogicalResult status = legalizeOperations(ops, /*onFailure=*/[&]() {
+ // Dialect conversion failed.
+ if (rewriterImpl.config.allowPatternRollback) {
+ // Rollback is allowed: restore the original IR.
+ rewriterImpl.undoRewrites();
+ } else {
+ // Rollback is not allowed: apply all modifications that have been
+ // performed so far.
+ rewriterImpl.applyRewrites();
+ }
+ });
+ if (failed(status))
+ return failure();
+
// After a successful conversion, apply rewrites.
rewriterImpl.applyRewrites();
@@ -4143,7 +4167,7 @@ static LogicalResult applyConversion(ArrayRef<Operation *> ops,
[&] {
OperationConverter opConverter(ops.front()->getContext(), target,
patterns, config, mode);
- status = opConverter.convertOperations(ops);
+ status = opConverter.applyConversion(ops);
},
irUnits);
return status;
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 88a71cc26ab0c..8d854aff1992f 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -454,11 +454,16 @@ func.func @test_working_1to1_pattern(%arg0: f16) {
// The region of "test.post_order_legalization" is converted before the op.
// CHECK: notifyBlockInserted into test.post_order_legalization: was unlinked
+// CHECK: notifyOperationInserted: test.remaining_consumer
+// CHECK: notifyOperationInserted: test.legal_op
// CHECK: notifyOperationInserted: test.invalid
// CHECK: notifyBlockErased
// CHECK: notifyOperationInserted: test.valid, was unlinked
// CHECK: notifyOperationReplaced: test.invalid
// CHECK: notifyOperationErased: test.invalid
+// CHECK: notifyOperationInserted: test.valid, was unlinked
+// CHECK: notifyOperationReplaced: test.invalid
+// CHECK: notifyOperationErased: test.invalid
// CHECK: notifyOperationModified: test.post_order_legalization
// CHECK-LABEL: func @test_preorder_legalization
@@ -475,6 +480,9 @@ func.func @test_preorder_legalization() {
^bb0(%arg0: i64):
// expected-remark @+1 {{'test.remaining_consumer' is not legalizable}}
"test.remaining_consumer"(%arg0) : (i64) -> ()
+ "test.legal_op"() ({
+ "test.invalid"(%arg0) : (i64) -> ()
+ }) : () -> ()
"test.invalid"(%arg0) : (i64) -> ()
}) : () -> ()
// expected-remark @+1 {{'func.return' is not legalizable}}
|
This commit align the implementation of
ConversionPatternRewriter::legalizewith its documentation:This function now legalizes the entire region, including nested ops. The implementation follows the same logic as the "main" traversal: pre-order, forward-dominance.