diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 3257ecd9d91f1..0ee9e713724ea 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1114,6 +1114,17 @@ def Tosa_ReciprocalOp : Tosa_ElementwiseOp<"reciprocal", let results = (outs Tosa_Tensor:$output ); + + let extraClassDeclaration = [{ + /// Return the reciprocal result on the operand. + static inline APFloat calcOneElement(const APFloat &operand) { + APFloat recip = APFloat(operand.getSemantics(), 1); + recip.divide(operand, APFloat::rmNearestTiesToEven); + return recip; + } + }]; + + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 26c39ff352343..3f683f701e0fc 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Matchers.h" @@ -25,6 +26,7 @@ #include "mlir/Transforms/InliningUtils.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/TypeSwitch.h" @@ -1036,3 +1038,21 @@ OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) { getOperation()->setOperands(concatOperands); return getResult(); } + +OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) { + auto input = adaptor.getInput1(); + + auto inputAttr = llvm::dyn_cast_if_present(input); + // Fold splat inputs only. + if (!inputAttr || !inputAttr.isSplat()) + return {}; + + auto shapeType = llvm::cast(getType()); + if (auto floatType = llvm::dyn_cast(inputAttr.getElementType())) { + auto floatVal = inputAttr.getSplatValue(); + return DenseElementsAttr::get(shapeType, + ReciprocalOp::calcOneElement(floatVal)); + } + + return {}; +} diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 661126f4df997..729116da45e47 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -25,6 +25,7 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/APFloat.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/TypeSwitch.h" diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp index d35e911ebe63c..050f8ca3f32ae 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp @@ -30,10 +30,6 @@ using namespace mlir::tosa; namespace { -/// Rounding mode to be used on floating point operations that require rounding. -static constexpr llvm::RoundingMode tosaRoundingMode = - llvm::APFloat::rmNearestTiesToEven; - /// Apply the given transformation \p toApply to every element of the tensor to /// be transformed \p toTransform. /// @@ -44,14 +40,14 @@ static constexpr llvm::RoundingMode tosaRoundingMode = template DenseElementsAttr applyElementWise( const DenseElementsAttr &toTransform, - const std::function &toApply, + const std::function &toApply, TargetType targetType) { SmallVector transformedValues; // We already know the amount of values we will insert, reserve space for // all of them to avoid dynamic resizing transformedValues.reserve(toTransform.getNumElements()); for (auto val : toTransform.getValues()) { - auto transformedVal = toApply(val, targetType); + auto transformedVal = toApply(val); transformedValues.push_back(transformedVal); } @@ -64,7 +60,7 @@ DenseElementsAttr applyElementWise( template DenseElementsAttr applyElementWise( const DenseElementsAttr &toTransform, - const std::function &toApply, + const std::function &toApply, FloatType targetType); /// Function that checks if the type contained in \p toCheck is float. @@ -249,14 +245,6 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - static APFloat computeReciprocal(const APFloat &floatVal, FloatType floatTy) { - auto recipAttr = FloatAttr::get(floatTy, 1.0); - APFloat recip = recipAttr.getValue(); - recip.divide(floatVal, tosaRoundingMode); - - return recip; - } - LogicalResult matchAndRewrite(ReciprocalOp recip, PatternRewriter &rewriter) const override { auto inputTensor = recip.getInput1(); @@ -281,7 +269,7 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern { // Create a new tensor with the updated values auto newTensor = applyElementWise( - inputValues, &computeReciprocal, + inputValues, &ReciprocalOp::calcOneElement, cast(inputValues.getElementType())); // Replace the use of the reciprocal with the transformed tensor diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index fd51d287bca05..e7ede2e0ccef9 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -613,3 +613,27 @@ func.func nested @fold_tile_rank_zero() -> tensor { %1 = tosa.tile %0 {multiples = array} : (tensor) -> tensor return %1 : tensor } + +// ----- + +// CHECK-LABEL: @fold_reciprocal +func.func nested @fold_reciprocal() -> tensor<3x600x1200xf32> { + // CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<8.620690e-03> : tensor<3x600x1200xf32>}> : () -> tensor<3x600x1200xf32> + // CHECK: return %[[VAL_0]] : tensor<3x600x1200xf32> + // CHECK: } + %0 = "tosa.const"(){ value = dense<116.0>: tensor }: () -> tensor + %1 = "tosa.cast"(%0) : (tensor) -> tensor<3x600x1200xf32> + %2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xf32>) -> tensor<3x600x1200xf32> + return %2 : tensor<3x600x1200xf32> +} + +// ----- + +// CHECK-LABEL: @do_not_fold_reciprocal_int +func.func nested @do_not_fold_reciprocal_int() -> tensor<3x600x1200xi32> { + // CHECK: tosa.reciprocal + %0 = "tosa.const"(){ value = dense<11>: tensor }: () -> tensor + %1 = "tosa.cast"(%0) : (tensor) -> tensor<3x600x1200xi32> + %2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xi32>) -> tensor<3x600x1200xi32> + return %2 : tensor<3x600x1200xi32> +}