@@ -688,9 +688,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
688688 UnresolvedMaterializationRewrite (
689689 ConversionPatternRewriterImpl &rewriterImpl,
690690 UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr ,
691- MaterializationKind kind = MaterializationKind::Target)
692- : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
693- converterAndKind (converter, kind) {}
691+ MaterializationKind kind = MaterializationKind::Target);
694692
695693 static bool classof (const IRRewrite *rewrite) {
696694 return rewrite->getKind () == Kind::UnresolvedMaterialization;
@@ -730,26 +728,6 @@ static bool hasRewrite(R &&rewrites, Operation *op) {
730728 });
731729}
732730
733- // / Find the single rewrite object of the specified type and block among the
734- // / given rewrites. In debug mode, asserts that there is mo more than one such
735- // / object. Return "nullptr" if no object was found.
736- template <typename RewriteTy, typename R>
737- static RewriteTy *findSingleRewrite (R &&rewrites, Block *block) {
738- RewriteTy *result = nullptr ;
739- for (auto &rewrite : rewrites) {
740- auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get ());
741- if (rewriteTy && rewriteTy->getBlock () == block) {
742- #ifndef NDEBUG
743- assert (!result && " expected single matching rewrite" );
744- result = rewriteTy;
745- #else
746- return rewriteTy;
747- #endif // NDEBUG
748- }
749- }
750- return result;
751- }
752-
753731// ===----------------------------------------------------------------------===//
754732// ConversionPatternRewriterImpl
755733// ===----------------------------------------------------------------------===//
@@ -892,10 +870,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
892870
893871 bool wasErased (void *ptr) const { return erased.contains (ptr); }
894872
895- bool wasErased (OperationRewrite *rewrite) const {
896- return wasErased (rewrite->getOperation ());
897- }
898-
899873 void notifyOperationErased (Operation *op) override { erased.insert (op); }
900874
901875 void notifyBlockErased (Block *block) override { erased.insert (block); }
@@ -935,8 +909,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
935909 // / to modify/access them is invalid rewriter API usage.
936910 SetVector<Operation *> replacedOps;
937911
938- // / A set of all unresolved materializations.
939- DenseSet<Operation *> unresolvedMaterializations;
912+ // / A mapping of all unresolved materializations (UnrealizedConversionCastOp)
913+ // / to the corresponding rewrite objects.
914+ DenseMap<Operation *, UnresolvedMaterializationRewrite *>
915+ unresolvedMaterializations;
940916
941917 // / The current type converter, or nullptr if no type converter is currently
942918 // / active.
@@ -1058,6 +1034,14 @@ void CreateOperationRewrite::rollback() {
10581034 op->erase ();
10591035}
10601036
1037+ UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite (
1038+ ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
1039+ const TypeConverter *converter, MaterializationKind kind)
1040+ : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
1041+ converterAndKind(converter, kind) {
1042+ rewriterImpl.unresolvedMaterializations [op] = this ;
1043+ }
1044+
10611045void UnresolvedMaterializationRewrite::rollback () {
10621046 if (getMaterializationKind () == MaterializationKind::Target) {
10631047 for (Value input : op->getOperands ())
@@ -1345,7 +1329,6 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
13451329 builder.setInsertionPoint (ip.getBlock (), ip.getPoint ());
13461330 auto convertOp =
13471331 builder.create <UnrealizedConversionCastOp>(loc, outputType, inputs);
1348- unresolvedMaterializations.insert (convertOp);
13491332 appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
13501333 return convertOp.getResult (0 );
13511334}
@@ -2499,15 +2482,12 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
24992482
25002483 // Gather all unresolved materializations.
25012484 SmallVector<UnrealizedConversionCastOp> allCastOps;
2502- DenseMap<Operation *, UnresolvedMaterializationRewrite *> rewriteMap;
2503- for (std::unique_ptr<IRRewrite> &rewrite : rewriterImpl.rewrites ) {
2504- auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get ());
2505- if (!mat)
2506- continue ;
2507- if (rewriterImpl.eraseRewriter .wasErased (mat))
2485+ const DenseMap<Operation *, UnresolvedMaterializationRewrite *>
2486+ &materializations = rewriterImpl.unresolvedMaterializations ;
2487+ for (auto it : materializations) {
2488+ if (rewriterImpl.eraseRewriter .wasErased (it.first ))
25082489 continue ;
2509- allCastOps.push_back (mat->getOperation ());
2510- rewriteMap[mat->getOperation ()] = mat;
2490+ allCastOps.push_back (cast<UnrealizedConversionCastOp>(it.first ));
25112491 }
25122492
25132493 // Reconcile all UnrealizedConversionCastOps that were inserted by the
@@ -2520,8 +2500,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
25202500 if (config.buildMaterializations ) {
25212501 IRRewriter rewriter (rewriterImpl.context , config.listener );
25222502 for (UnrealizedConversionCastOp castOp : remainingCastOps) {
2523- auto it = rewriteMap .find (castOp.getOperation ());
2524- assert (it != rewriteMap .end () && " inconsistent state" );
2503+ auto it = materializations .find (castOp.getOperation ());
2504+ assert (it != materializations .end () && " inconsistent state" );
25252505 if (failed (legalizeUnresolvedMaterialization (rewriter, it->second )))
25262506 return failure ();
25272507 }
0 commit comments