1818#include < llvm/ADT/FloatingPointMode.h>
1919#include < llvm/ADT/SmallVector.h>
2020#include < mlir/IR/BuiltinAttributes.h>
21+ #include < mlir/Support/LogicalResult.h>
2122
2223using namespace mlir ;
2324using namespace mlir ::tosa;
@@ -38,10 +39,9 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
3839 return recip;
3940 }
4041
41- ConstOp replaceTensorWithReciprocal (ConstOp tensorToReplace,
42- const DenseElementsAttr &inputValues,
43- PatternRewriter &rewriter) const {
44-
42+ DenseElementsAttr
43+ replaceTensorWithReciprocal (ConstOp tensorToReplace,
44+ const DenseElementsAttr &inputValues) const {
4545 // TODO it would be nicer to do this in-place
4646
4747 // Compute the reciprocal for each tensor element
@@ -57,9 +57,7 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
5757 // Replace the current tensor with one containing the computed reciprocals
5858 auto newTensor =
5959 DenseElementsAttr::get (inputValues.getType (), transformedValues);
60- auto newOp = rewriter.replaceOpWithNewOp <ConstOp>(
61- tensorToReplace, newTensor.getType (), newTensor);
62- return newOp;
60+ return newTensor;
6361 }
6462
6563 LogicalResult matchAndRewrite (ReciprocalOp recip,
@@ -116,14 +114,10 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
116114 }
117115
118116 // Create a new tensor with the updated values
119- auto newOp =
120- replaceTensorWithReciprocal (definingConstOp, inputValues, rewriter);
117+ auto newTensor = replaceTensorWithReciprocal (definingConstOp, inputValues);
121118
122119 // Replace the use of the reciprocal with the transformed tensor
123- auto updateUse = [&recip, &newOp]() { recip->replaceAllUsesWith (newOp); };
124- rewriter.updateRootInPlace (*(recip->getUsers ().begin ()), updateUse);
125- // Remove the reciprocal operation
126- rewriter.eraseOp (recip);
120+ rewriter.replaceOpWithNewOp <ConstOp>(recip, newTensor.getType (), newTensor);
127121 return success ();
128122 }
129123};
0 commit comments