Skip to content

[mlir][Tosa]: Add folder to ReciprocalOp of splat constant inputs #78137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,17 @@ def Tosa_ReciprocalOp : Tosa_ElementwiseOp<"reciprocal",
let results = (outs
Tosa_Tensor:$output
);

let extraClassDeclaration = [{
/// Return the reciprocal result on the operand.
static inline APFloat calcOneElement(const APFloat &operand) {
APFloat recip = APFloat(operand.getSemantics(), 1);
recip.divide(operand, APFloat::rmNearestTiesToEven);
return recip;
}
}];

let hasFolder = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
20 changes: 20 additions & 0 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Matchers.h"
Expand All @@ -25,6 +26,7 @@
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"

Expand Down Expand Up @@ -1036,3 +1038,21 @@ OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
getOperation()->setOperands(concatOperands);
return getResult();
}

OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
auto input = adaptor.getInput1();

auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
// Fold splat inputs only.
if (!inputAttr || !inputAttr.isSplat())
return {};

auto shapeType = llvm::cast<ShapedType>(getType());
if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
auto floatVal = inputAttr.getSplatValue<APFloat>();
return DenseElementsAttr::get(shapeType,
ReciprocalOp::calcOneElement(floatVal));
}

return {};
}
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"

Expand Down
20 changes: 4 additions & 16 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ using namespace mlir::tosa;

namespace {

/// Rounding mode to be used on floating point operations that require rounding.
static constexpr llvm::RoundingMode tosaRoundingMode =
llvm::APFloat::rmNearestTiesToEven;

/// Apply the given transformation \p toApply to every element of the tensor to
/// be transformed \p toTransform.
///
Expand All @@ -44,14 +40,14 @@ static constexpr llvm::RoundingMode tosaRoundingMode =
template <class SrcValType, class TargetValType, class TargetType>
DenseElementsAttr applyElementWise(
const DenseElementsAttr &toTransform,
const std::function<TargetValType(const SrcValType &, TargetType)> &toApply,
const std::function<TargetValType(const SrcValType &)> &toApply,
TargetType targetType) {
SmallVector<TargetValType> transformedValues;
// We already know the amount of values we will insert, reserve space for
// all of them to avoid dynamic resizing
transformedValues.reserve(toTransform.getNumElements());
for (auto val : toTransform.getValues<SrcValType>()) {
auto transformedVal = toApply(val, targetType);
auto transformedVal = toApply(val);
transformedValues.push_back(transformedVal);
}

Expand All @@ -64,7 +60,7 @@ DenseElementsAttr applyElementWise(

template DenseElementsAttr applyElementWise<APFloat, APFloat, FloatType>(
const DenseElementsAttr &toTransform,
const std::function<APFloat(const APFloat &, FloatType)> &toApply,
const std::function<APFloat(const APFloat &)> &toApply,
FloatType targetType);

/// Function that checks if the type contained in \p toCheck is float.
Expand Down Expand Up @@ -249,14 +245,6 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {

using OpRewritePattern::OpRewritePattern;

static APFloat computeReciprocal(const APFloat &floatVal, FloatType floatTy) {
auto recipAttr = FloatAttr::get(floatTy, 1.0);
APFloat recip = recipAttr.getValue();
recip.divide(floatVal, tosaRoundingMode);

return recip;
}

LogicalResult matchAndRewrite(ReciprocalOp recip,
PatternRewriter &rewriter) const override {
auto inputTensor = recip.getInput1();
Expand All @@ -281,7 +269,7 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {

// Create a new tensor with the updated values
auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
inputValues, &computeReciprocal,
inputValues, &ReciprocalOp::calcOneElement,
cast<FloatType>(inputValues.getElementType()));

// Replace the use of the reciprocal with the transformed tensor
Expand Down
24 changes: 24 additions & 0 deletions mlir/test/Dialect/Tosa/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -613,3 +613,27 @@ func.func nested @fold_tile_rank_zero() -> tensor<i32> {
%1 = tosa.tile %0 {multiples = array<i64>} : (tensor<i32>) -> tensor<i32>
return %1 : tensor<i32>
}

// -----

// CHECK-LABEL: @fold_reciprocal
func.func nested @fold_reciprocal() -> tensor<3x600x1200xf32> {
// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<8.620690e-03> : tensor<3x600x1200xf32>}> : () -> tensor<3x600x1200xf32>
// CHECK: return %[[VAL_0]] : tensor<3x600x1200xf32>
// CHECK: }
%0 = "tosa.const"(){ value = dense<116.0>: tensor<f32> }: () -> tensor<f32>
%1 = "tosa.cast"(%0) : (tensor<f32>) -> tensor<3x600x1200xf32>
%2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xf32>) -> tensor<3x600x1200xf32>
return %2 : tensor<3x600x1200xf32>
}

// -----

// CHECK-LABEL: @do_not_fold_reciprocal_int
func.func nested @do_not_fold_reciprocal_int() -> tensor<3x600x1200xi32> {
// CHECK: tosa.reciprocal
%0 = "tosa.const"(){ value = dense<11>: tensor<i32> }: () -> tensor<i32>
%1 = "tosa.cast"(%0) : (tensor<i32>) -> tensor<3x600x1200xi32>
%2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xi32>) -> tensor<3x600x1200xi32>
return %2 : tensor<3x600x1200xi32>
}