Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 78 additions & 54 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2257,37 +2257,6 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
return success();
}

LogicalResult ConversionPatternRewriter::legalize(Region *r) {
// Fast path: If the region is empty, there is nothing to legalize.
if (r->empty())
return success();

// Gather a list of all operations to legalize. This is done before
// converting the entry block signature because unrealized_conversion_cast
// ops should not be included.
SmallVector<Operation *> ops;
for (Block &b : *r)
for (Operation &op : b)
ops.push_back(&op);

// If the current pattern runs with a type converter, convert the entry block
// signature.
if (const TypeConverter *converter = impl->currentTypeConverter) {
std::optional<TypeConverter::SignatureConversion> conversion =
converter->convertBlockSignature(&r->front());
if (!conversion)
return failure();
applySignatureConversion(&r->front(), *conversion, converter);
}

// Legalize all operations in the region.
for (Operation *op : ops)
if (failed(legalize(op)))
return failure();

return success();
}

void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
Block::iterator before,
ValueRange argValues) {
Expand Down Expand Up @@ -3287,8 +3256,20 @@ struct OperationConverter {
: rewriter(ctx, config, *this), opLegalizer(rewriter, target, patterns),
mode(mode) {}

/// Converts the given operations to the conversion target.
LogicalResult convertOperations(ArrayRef<Operation *> ops);
/// Applies the conversion to the given operations (and their nested
/// operations).
LogicalResult applyConversion(ArrayRef<Operation *> ops);

/// Legalizes the given operations (and their nested operations) to the
/// conversion target.
template <typename Fn>
LogicalResult legalizeOperations(ArrayRef<Operation *> ops, Fn onFailure,
bool isRecursiveLegalization = false);
LogicalResult legalizeOperations(ArrayRef<Operation *> ops,
bool isRecursiveLegalization = false) {
return legalizeOperations(
ops, /*onFailure=*/[&]() {}, isRecursiveLegalization);
}

/// Converts a single operation. If `isRecursiveLegalization` is "true", the
/// conversion is a recursive legalization request, triggered from within a
Expand All @@ -3297,6 +3278,8 @@ struct OperationConverter {
/// legalization mechanism).
LogicalResult convert(Operation *op, bool isRecursiveLegalization = false);

const ConversionTarget &getTarget() { return opLegalizer.getTarget(); }

private:
/// The rewriter to use when converting operations.
ConversionPatternRewriter rewriter;
Expand All @@ -3309,10 +3292,6 @@ struct OperationConverter {
};
} // namespace mlir

LogicalResult ConversionPatternRewriter::legalize(Operation *op) {
return impl->opConverter.convert(op, /*isRecursiveLegalization=*/true);
}

LogicalResult OperationConverter::convert(Operation *op,
bool isRecursiveLegalization) {
const ConversionConfig &config = rewriter.getConfig();
Expand Down Expand Up @@ -3398,12 +3377,15 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
return failure();
}

LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
template <typename Fn>
LogicalResult
OperationConverter::legalizeOperations(ArrayRef<Operation *> ops, Fn onFailure,
bool isRecursiveLegalization) {
const ConversionTarget &target = opLegalizer.getTarget();

// Compute the set of operations and blocks to convert.
SmallVector<Operation *> toConvert;
for (auto *op : ops) {
for (Operation *op : ops) {
op->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
[&](Operation *op) {
toConvert.push_back(op);
Expand All @@ -3415,25 +3397,67 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
return WalkResult::advance();
});
}
for (Operation *op : toConvert) {
if (failed(convert(op, isRecursiveLegalization))) {
// Failed to convert an operation.
onFailure();
return failure();
}
}
return success();
}

// Convert each operation and discard rewrites on failure.
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
LogicalResult ConversionPatternRewriter::legalize(Operation *op) {
return impl->opConverter.legalizeOperations(op,
/*isRecursiveLegalization=*/true);
}

for (auto *op : toConvert) {
if (failed(convert(op))) {
// Dialect conversion failed.
if (rewriterImpl.config.allowPatternRollback) {
// Rollback is allowed: restore the original IR.
rewriterImpl.undoRewrites();
} else {
// Rollback is not allowed: apply all modifications that have been
// performed so far.
rewriterImpl.applyRewrites();
}
LogicalResult ConversionPatternRewriter::legalize(Region *r) {
// Fast path: If the region is empty, there is nothing to legalize.
if (r->empty())
return success();

// Gather a list of all operations to legalize. This is done before
// converting the entry block signature because unrealized_conversion_cast
// ops should not be included.
SmallVector<Operation *> ops;
for (Block &b : *r)
for (Operation &op : b)
ops.push_back(&op);

// If the current pattern runs with a type converter, convert the entry block
// signature.
if (const TypeConverter *converter = impl->currentTypeConverter) {
std::optional<TypeConverter::SignatureConversion> conversion =
converter->convertBlockSignature(&r->front());
if (!conversion)
return failure();
}
applySignatureConversion(&r->front(), *conversion, converter);
}

// Legalize all operations in the region. This includes all nested
// operations.
return impl->opConverter.legalizeOperations(ops,
/*isRecursiveLegalization=*/true);
}

LogicalResult OperationConverter::applyConversion(ArrayRef<Operation *> ops) {
// Convert each operation and discard rewrites on failure.
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
LogicalResult status = legalizeOperations(ops, /*onFailure=*/[&]() {
// Dialect conversion failed.
if (rewriterImpl.config.allowPatternRollback) {
// Rollback is allowed: restore the original IR.
rewriterImpl.undoRewrites();
} else {
// Rollback is not allowed: apply all modifications that have been
// performed so far.
rewriterImpl.applyRewrites();
}
});
if (failed(status))
return failure();

// After a successful conversion, apply rewrites.
rewriterImpl.applyRewrites();

Expand Down Expand Up @@ -4143,7 +4167,7 @@ static LogicalResult applyConversion(ArrayRef<Operation *> ops,
[&] {
OperationConverter opConverter(ops.front()->getContext(), target,
patterns, config, mode);
status = opConverter.convertOperations(ops);
status = opConverter.applyConversion(ops);
},
irUnits);
return status;
Expand Down
8 changes: 8 additions & 0 deletions mlir/test/Transforms/test-legalizer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -454,11 +454,16 @@ func.func @test_working_1to1_pattern(%arg0: f16) {
// The region of "test.post_order_legalization" is converted before the op.

// CHECK: notifyBlockInserted into test.post_order_legalization: was unlinked
// CHECK: notifyOperationInserted: test.remaining_consumer
// CHECK: notifyOperationInserted: test.legal_op
// CHECK: notifyOperationInserted: test.invalid
// CHECK: notifyBlockErased
// CHECK: notifyOperationInserted: test.valid, was unlinked
// CHECK: notifyOperationReplaced: test.invalid
// CHECK: notifyOperationErased: test.invalid
// CHECK: notifyOperationInserted: test.valid, was unlinked
// CHECK: notifyOperationReplaced: test.invalid
// CHECK: notifyOperationErased: test.invalid
// CHECK: notifyOperationModified: test.post_order_legalization

// CHECK-LABEL: func @test_preorder_legalization
Expand All @@ -475,6 +480,9 @@ func.func @test_preorder_legalization() {
^bb0(%arg0: i64):
// expected-remark @+1 {{'test.remaining_consumer' is not legalizable}}
"test.remaining_consumer"(%arg0) : (i64) -> ()
"test.legal_op"() ({
"test.invalid"(%arg0) : (i64) -> ()
}) : () -> ()
"test.invalid"(%arg0) : (i64) -> ()
}) : () -> ()
// expected-remark @+1 {{'func.return' is not legalizable}}
Expand Down