Skip to content

[mlir][Tosa]: Add folder to ReciprocalOp of splat constant inputs #78137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 17, 2024

Conversation

AviadCo
Copy link
Contributor

@AviadCo AviadCo commented Jan 15, 2024

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Jan 15, 2024

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Aviad Cohen (AviadCo)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/78137.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+7)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+19)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+9)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp (+2-13)
  • (modified) mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp (+1)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+13)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 3257ecd9d91f11..d8fc960563bf29 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1114,6 +1114,13 @@ def Tosa_ReciprocalOp : Tosa_ElementwiseOp<"reciprocal",
   let results = (outs
     Tosa_Tensor:$output
   );
+
+  let extraClassDeclaration = [{
+    /// Computes reciprocal on a float element (input must be from float type).
+    static llvm::APFloat computeFloatElemOne(const llvm::APFloat &floatVal, FloatType floatTy);
+  }];
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 26c39ff3523434..fb3cd378f2c84b 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/Matchers.h"
@@ -25,6 +26,7 @@
 #include "mlir/Transforms/InliningUtils.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/APInt.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/TypeSwitch.h"
 
@@ -1036,3 +1038,20 @@ OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
   getOperation()->setOperands(concatOperands);
   return getResult();
 }
+
+OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
+  auto input = adaptor.getInput1();
+
+  auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
+  // Fold splat inputs only.
+  if (!inputAttr || !inputAttr.isSplat())
+    return {};
+
+  auto shapeType = llvm::cast<ShapedType>(getType());
+  if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
+    auto floatVal = inputAttr.getSplatValue<APFloat>();
+    return DenseElementsAttr::get(shapeType, computeFloatElemOne(floatVal, floatType));
+  }
+
+  return {};
+}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 661126f4df9976..a2af9ef0c069f2 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -25,6 +25,7 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/TypeSwitch.h"
 
@@ -1778,6 +1779,14 @@ std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
   return std::nullopt;
 }
 
+APFloat tosa::ReciprocalOp::computeFloatElemOne(const APFloat &floatVal, FloatType floatTy) {
+  auto recipAttr = FloatAttr::get(floatTy, 1.0);
+  APFloat recip = recipAttr.getValue();
+  recip.divide(floatVal, llvm::APFloat::rmNearestTiesToEven);
+
+  return recip;
+}
+
 // parse and print of IfOp refer to the implementation of SCF dialect.
 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
   // Create the regions for 'then'.
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
index d35e911ebe63c4..6208b38900ebad 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
@@ -15,6 +15,7 @@
 
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -30,10 +31,6 @@ using namespace mlir::tosa;
 
 namespace {
 
-/// Rounding mode to be used on floating point operations that require rounding.
-static constexpr llvm::RoundingMode tosaRoundingMode =
-    llvm::APFloat::rmNearestTiesToEven;
-
 /// Apply the given transformation \p toApply to every element of the tensor to
 /// be transformed \p toTransform.
 ///
@@ -249,14 +246,6 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
 
   using OpRewritePattern::OpRewritePattern;
 
-  static APFloat computeReciprocal(const APFloat &floatVal, FloatType floatTy) {
-    auto recipAttr = FloatAttr::get(floatTy, 1.0);
-    APFloat recip = recipAttr.getValue();
-    recip.divide(floatVal, tosaRoundingMode);
-
-    return recip;
-  }
-
   LogicalResult matchAndRewrite(ReciprocalOp recip,
                                 PatternRewriter &rewriter) const override {
     auto inputTensor = recip.getInput1();
@@ -281,7 +270,7 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
 
     // Create a new tensor with the updated values
     auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
-        inputValues, &computeReciprocal,
+        inputValues, &ReciprocalOp::computeFloatElemOne,
         cast<FloatType>(inputValues.getElementType()));
 
     // Replace the use of the reciprocal with the transformed tensor
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index ee428b201d0073..9fc864463d95bf 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/BuiltinAttributes.h"
 
 using namespace mlir;
 using namespace mlir::tosa;
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index fd51d287bca058..de9d13b1453232 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -613,3 +613,16 @@ func.func nested @fold_tile_rank_zero() -> tensor<i32> {
   %1 = tosa.tile %0 {multiples = array<i64>} : (tensor<i32>) -> tensor<i32>
   return %1 : tensor<i32>
 }
+
+// -----
+
+// CHECK-LABEL: @fold_reciprocal
+func.func nested @fold_reciprocal() -> tensor<3x600x1200xf32> {
+  // CHECK:           %[[VAL_0:.*]] = "tosa.const"() <{value = dense<8.620690e-03> : tensor<3x600x1200xf32>}> : () -> tensor<3x600x1200xf32>
+  // CHECK:           return %[[VAL_0]] : tensor<3x600x1200xf32>
+  // CHECK:         }
+  %0 = "tosa.const"(){ value = dense<116.0>: tensor<f32> }: () -> tensor<f32>
+  %1 = "tosa.cast"(%0) : (tensor<f32>) -> tensor<3x600x1200xf32>
+  %2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xf32>) -> tensor<3x600x1200xf32>
+  return %2 : tensor<3x600x1200xf32>
+}

@AviadCo AviadCo requested review from amrami and amirBish January 15, 2024 09:54
Copy link

github-actions bot commented Jan 15, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@AviadCo AviadCo requested review from Lewuathe and jpienaar January 15, 2024 09:56
@AviadCo AviadCo force-pushed the tosa/reciprocal-fold branch from 488ef72 to a2ed9b7 Compare January 15, 2024 10:00
@AviadCo AviadCo force-pushed the tosa/reciprocal-fold branch from a2ed9b7 to 6a20e5a Compare January 16, 2024 08:10
@AviadCo AviadCo force-pushed the tosa/reciprocal-fold branch from 6a20e5a to 215f7f8 Compare January 16, 2024 08:14
Copy link
Contributor Author

@AviadCo AviadCo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review, answered your comments.

Copy link
Contributor

@eric-k256 eric-k256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me.

@AviadCo AviadCo merged commit d89a0a6 into llvm:main Jan 17, 2024
@AviadCo AviadCo deleted the tosa/reciprocal-fold branch January 17, 2024 07:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants