Skip to content

Commit a9f6224

Browse files
[mlir][Transforms][NFC] Move ReconcileUnrealizedCasts implementation (#104671)
Move the implementation of `ReconcileUnrealizedCasts` to `DialectConversion.cpp`, so that it can be called from there in a future commit. This commit is in preparation of decoupling argument/source/target materializations from the dialect conversion framework. The existing logic around unresolved materializations that predicts IR changes to decide if a cast op can be folded/erased will become obsolete, as `ReconcileUnrealizedCasts` will perform these kind of foldings on fully materialized IR. --------- Co-authored-by: Markus Böck <[email protected]>
1 parent 5a25854 commit a9f6224

File tree

3 files changed

+101
-56
lines changed

3 files changed

+101
-56
lines changed

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,6 +1126,29 @@ struct ConversionConfig {
11261126
RewriterBase::Listener *listener = nullptr;
11271127
};
11281128

1129+
//===----------------------------------------------------------------------===//
1130+
// Reconcile Unrealized Casts
1131+
//===----------------------------------------------------------------------===//
1132+
1133+
/// Try to reconcile all given UnrealizedConversionCastOps and store the
1134+
/// left-over ops in `remainingCastOps` (if provided).
1135+
///
1136+
/// This function processes cast ops in a worklist-driven fashion. For each
1137+
/// cast op, if the chain of input casts eventually reaches a cast op where the
1138+
/// input types match the output types of the matched op, replace the matched
1139+
/// op with the inputs.
1140+
///
1141+
/// Example:
1142+
/// %1 = unrealized_conversion_cast %0 : !A to !B
1143+
/// %2 = unrealized_conversion_cast %1 : !B to !C
1144+
/// %3 = unrealized_conversion_cast %2 : !C to !A
1145+
///
1146+
/// In the above example, %0 can be used instead of %3 and all cast ops are
1147+
/// folded away.
1148+
void reconcileUnrealizedCasts(
1149+
ArrayRef<UnrealizedConversionCastOp> castOps,
1150+
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps = nullptr);
1151+
11291152
//===----------------------------------------------------------------------===//
11301153
// Op Conversion Entry Points
11311154
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp

Lines changed: 4 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "mlir/IR/BuiltinOps.h"
1212
#include "mlir/Pass/Pass.h"
13+
#include "mlir/Transforms/DialectConversion.h"
1314

1415
namespace mlir {
1516
#define GEN_PASS_DEF_RECONCILEUNREALIZEDCASTS
@@ -39,63 +40,10 @@ struct ReconcileUnrealizedCasts
3940
ReconcileUnrealizedCasts() = default;
4041

4142
void runOnOperation() override {
42-
// Gather all unrealized_conversion_cast ops.
43-
SetVector<UnrealizedConversionCastOp> worklist;
43+
SmallVector<UnrealizedConversionCastOp> ops;
4444
getOperation()->walk(
45-
[&](UnrealizedConversionCastOp castOp) { worklist.insert(castOp); });
46-
47-
// Helper function that adds all operands to the worklist that are an
48-
// unrealized_conversion_cast op result.
49-
auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
50-
for (Value v : castOp.getInputs())
51-
if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
52-
worklist.insert(inputCastOp);
53-
};
54-
55-
// Helper function that return the unrealized_conversion_cast op that
56-
// defines all inputs of the given op (in the same order). Return "nullptr"
57-
// if there is no such op.
58-
auto getInputCast =
59-
[](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
60-
if (castOp.getInputs().empty())
61-
return {};
62-
auto inputCastOp = castOp.getInputs()
63-
.front()
64-
.getDefiningOp<UnrealizedConversionCastOp>();
65-
if (!inputCastOp)
66-
return {};
67-
if (inputCastOp.getOutputs() != castOp.getInputs())
68-
return {};
69-
return inputCastOp;
70-
};
71-
72-
// Process ops in the worklist bottom-to-top.
73-
while (!worklist.empty()) {
74-
UnrealizedConversionCastOp castOp = worklist.pop_back_val();
75-
if (castOp->use_empty()) {
76-
// DCE: If the op has no users, erase it. Add the operands to the
77-
// worklist to find additional DCE opportunities.
78-
enqueueOperands(castOp);
79-
castOp->erase();
80-
continue;
81-
}
82-
83-
// Traverse the chain of input cast ops to see if an op with the same
84-
// input types can be found.
85-
UnrealizedConversionCastOp nextCast = castOp;
86-
while (nextCast) {
87-
if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
88-
// Found a cast where the input types match the output types of the
89-
// matched op. We can directly use those inputs and the matched op can
90-
// be removed.
91-
enqueueOperands(castOp);
92-
castOp.replaceAllUsesWith(nextCast.getInputs());
93-
castOp->erase();
94-
break;
95-
}
96-
nextCast = getInputCast(nextCast);
97-
}
98-
}
45+
[&](UnrealizedConversionCastOp castOp) { ops.push_back(castOp); });
46+
reconcileUnrealizedCasts(ops);
9947
}
10048
};
10149

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2870,6 +2870,80 @@ LogicalResult OperationConverter::legalizeErasedResult(
28702870
return success();
28712871
}
28722872

2873+
//===----------------------------------------------------------------------===//
2874+
// Reconcile Unrealized Casts
2875+
//===----------------------------------------------------------------------===//
2876+
2877+
void mlir::reconcileUnrealizedCasts(
2878+
ArrayRef<UnrealizedConversionCastOp> castOps,
2879+
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
2880+
SetVector<UnrealizedConversionCastOp> worklist(castOps.begin(),
2881+
castOps.end());
2882+
// This set is maintained only if `remainingCastOps` is provided.
2883+
DenseSet<Operation *> erasedOps;
2884+
2885+
// Helper function that adds all operands to the worklist that are an
2886+
// unrealized_conversion_cast op result.
2887+
auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
2888+
for (Value v : castOp.getInputs())
2889+
if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
2890+
worklist.insert(inputCastOp);
2891+
};
2892+
2893+
// Helper function that return the unrealized_conversion_cast op that
2894+
// defines all inputs of the given op (in the same order). Return "nullptr"
2895+
// if there is no such op.
2896+
auto getInputCast =
2897+
[](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
2898+
if (castOp.getInputs().empty())
2899+
return {};
2900+
auto inputCastOp =
2901+
castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
2902+
if (!inputCastOp)
2903+
return {};
2904+
if (inputCastOp.getOutputs() != castOp.getInputs())
2905+
return {};
2906+
return inputCastOp;
2907+
};
2908+
2909+
// Process ops in the worklist bottom-to-top.
2910+
while (!worklist.empty()) {
2911+
UnrealizedConversionCastOp castOp = worklist.pop_back_val();
2912+
if (castOp->use_empty()) {
2913+
// DCE: If the op has no users, erase it. Add the operands to the
2914+
// worklist to find additional DCE opportunities.
2915+
enqueueOperands(castOp);
2916+
if (remainingCastOps)
2917+
erasedOps.insert(castOp.getOperation());
2918+
castOp->erase();
2919+
continue;
2920+
}
2921+
2922+
// Traverse the chain of input cast ops to see if an op with the same
2923+
// input types can be found.
2924+
UnrealizedConversionCastOp nextCast = castOp;
2925+
while (nextCast) {
2926+
if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
2927+
// Found a cast where the input types match the output types of the
2928+
// matched op. We can directly use those inputs and the matched op can
2929+
// be removed.
2930+
enqueueOperands(castOp);
2931+
castOp.replaceAllUsesWith(nextCast.getInputs());
2932+
if (remainingCastOps)
2933+
erasedOps.insert(castOp.getOperation());
2934+
castOp->erase();
2935+
break;
2936+
}
2937+
nextCast = getInputCast(nextCast);
2938+
}
2939+
}
2940+
2941+
if (remainingCastOps)
2942+
for (UnrealizedConversionCastOp op : castOps)
2943+
if (!erasedOps.contains(op.getOperation()))
2944+
remainingCastOps->push_back(op);
2945+
}
2946+
28732947
//===----------------------------------------------------------------------===//
28742948
// Type Conversion
28752949
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)