Skip to content

[mlir][Transforms][NFC] Move ReconcileUnrealizedCasts implementation #104671

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,29 @@ struct ConversionConfig {
RewriterBase::Listener *listener = nullptr;
};

//===----------------------------------------------------------------------===//
// Reconcile Unrealized Casts
//===----------------------------------------------------------------------===//

/// Try to reconcile all given UnrealizedConversionCastOps and store the
/// left-over ops in `remainingCastOps` (if provided).
///
/// This function processes cast ops in a worklist-driven fashion. For each
/// cast op, if the chain of input casts eventually reaches a cast op where the
/// input types match the output types of the matched op, replace the matched
/// op with the inputs.
///
/// Example:
/// %1 = unrealized_conversion_cast %0 : !A to !B
/// %2 = unrealized_conversion_cast %1 : !B to !C
/// %3 = unrealized_conversion_cast %2 : !C to !A
///
/// In the above example, %0 can be used instead of %3 and all cast ops are
/// folded away.
void reconcileUnrealizedCasts(
ArrayRef<UnrealizedConversionCastOp> castOps,
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps = nullptr);

//===----------------------------------------------------------------------===//
// Op Conversion Entry Points
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
#define GEN_PASS_DEF_RECONCILEUNREALIZEDCASTS
Expand Down Expand Up @@ -39,63 +40,10 @@ struct ReconcileUnrealizedCasts
ReconcileUnrealizedCasts() = default;

void runOnOperation() override {
// Gather all unrealized_conversion_cast ops.
SetVector<UnrealizedConversionCastOp> worklist;
SmallVector<UnrealizedConversionCastOp> ops;
getOperation()->walk(
[&](UnrealizedConversionCastOp castOp) { worklist.insert(castOp); });

// Helper function that adds all operands to the worklist that are an
// unrealized_conversion_cast op result.
auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
for (Value v : castOp.getInputs())
if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
worklist.insert(inputCastOp);
};

// Helper function that return the unrealized_conversion_cast op that
// defines all inputs of the given op (in the same order). Return "nullptr"
// if there is no such op.
auto getInputCast =
[](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
if (castOp.getInputs().empty())
return {};
auto inputCastOp = castOp.getInputs()
.front()
.getDefiningOp<UnrealizedConversionCastOp>();
if (!inputCastOp)
return {};
if (inputCastOp.getOutputs() != castOp.getInputs())
return {};
return inputCastOp;
};

// Process ops in the worklist bottom-to-top.
while (!worklist.empty()) {
UnrealizedConversionCastOp castOp = worklist.pop_back_val();
if (castOp->use_empty()) {
// DCE: If the op has no users, erase it. Add the operands to the
// worklist to find additional DCE opportunities.
enqueueOperands(castOp);
castOp->erase();
continue;
}

// Traverse the chain of input cast ops to see if an op with the same
// input types can be found.
UnrealizedConversionCastOp nextCast = castOp;
while (nextCast) {
if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
// Found a cast where the input types match the output types of the
// matched op. We can directly use those inputs and the matched op can
// be removed.
enqueueOperands(castOp);
castOp.replaceAllUsesWith(nextCast.getInputs());
castOp->erase();
break;
}
nextCast = getInputCast(nextCast);
}
}
[&](UnrealizedConversionCastOp castOp) { ops.push_back(castOp); });
reconcileUnrealizedCasts(ops);
}
};

Expand Down
74 changes: 74 additions & 0 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2869,6 +2869,80 @@ LogicalResult OperationConverter::legalizeErasedResult(
return success();
}

//===----------------------------------------------------------------------===//
// Reconcile Unrealized Casts
//===----------------------------------------------------------------------===//

void mlir::reconcileUnrealizedCasts(
ArrayRef<UnrealizedConversionCastOp> castOps,
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
SetVector<UnrealizedConversionCastOp> worklist(castOps.begin(),
castOps.end());
// This set is maintained only if `remainingCastOps` is provided.
DenseSet<Operation *> erasedOps;

// Helper function that adds all operands to the worklist that are an
// unrealized_conversion_cast op result.
auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
for (Value v : castOp.getInputs())
if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
worklist.insert(inputCastOp);
};

// Helper function that return the unrealized_conversion_cast op that
// defines all inputs of the given op (in the same order). Return "nullptr"
// if there is no such op.
auto getInputCast =
[](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
if (castOp.getInputs().empty())
return {};
auto inputCastOp =
castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
if (!inputCastOp)
return {};
if (inputCastOp.getOutputs() != castOp.getInputs())
return {};
return inputCastOp;
};

// Process ops in the worklist bottom-to-top.
while (!worklist.empty()) {
UnrealizedConversionCastOp castOp = worklist.pop_back_val();
if (castOp->use_empty()) {
// DCE: If the op has no users, erase it. Add the operands to the
// worklist to find additional DCE opportunities.
enqueueOperands(castOp);
if (remainingCastOps)
erasedOps.insert(castOp.getOperation());
castOp->erase();
continue;
}

// Traverse the chain of input cast ops to see if an op with the same
// input types can be found.
UnrealizedConversionCastOp nextCast = castOp;
while (nextCast) {
if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
// Found a cast where the input types match the output types of the
// matched op. We can directly use those inputs and the matched op can
// be removed.
enqueueOperands(castOp);
castOp.replaceAllUsesWith(nextCast.getInputs());
if (remainingCastOps)
erasedOps.insert(castOp.getOperation());
castOp->erase();
break;
}
nextCast = getInputCast(nextCast);
}
}

if (remainingCastOps)
for (UnrealizedConversionCastOp op : castOps)
if (!erasedOps.contains(op.getOperation()))
remainingCastOps->push_back(op);
}

//===----------------------------------------------------------------------===//
// Type Conversion
//===----------------------------------------------------------------------===//
Expand Down
Loading