diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 7b70b3ab8afc9..607667fcc6945 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -32,10 +32,47 @@ #include "llvm/ADT/Sequence.h" #include +#include using namespace mlir; using namespace mlir::tosa; +// Helper function to materialize the semantically correct compare and select +// operations given a binary operation with a specific NaN propagation mode. +// +// In the case of "PROPAGATE" semantics no compare and selection is required and +// this function does nothing. +// +// In the case of "IGNORE" semantics this function materializes a comparison of +// the current operands to the op which will return true for any NaN +// argument and then selects between the non-NaN operation argument and the +// calculated result based on whether the lhs or rhs is NaN or not. In pseudo +// code: +// +// binary(lhs, rhs): +// result = op(lhs, rhs) +// if lhs == NaN return rhs +// if rhs == NaN return lhs +// return result +template +static Value +materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter, + Value lhs, Value rhs, Value result) { + auto nanMode = op.getNanMode(); + if (nanMode == "PROPAGATE") + return result; + + // Unordered comparison of NaN against itself will always return true. + Value lhsIsNaN = rewriter.create( + op.getLoc(), arith::CmpFPredicate::UNO, lhs, lhs); + Value rhsIsNaN = rewriter.create( + op.getLoc(), arith::CmpFPredicate::UNO, rhs, rhs); + Value rhsOrResult = + rewriter.create(op.getLoc(), lhsIsNaN, rhs, result); + return rewriter.create(op.getLoc(), rhsIsNaN, lhs, + rhsOrResult); +} + template static arith::ConstantOp createConstFromIntAttribute(Operation *op, const std::string &attrName, @@ -367,7 +404,9 @@ static Value createLinalgBodyCalculationForElementwiseOp( // tosa::MaximumOp if (isa(op) && isa(elementTy)) { - return rewriter.create(loc, args[0], args[1]); + auto max = rewriter.create(loc, args[0], args[1]); + return materializeBinaryNanCheckIfRequired(llvm::cast(op), + rewriter, args[0], args[1], max); } if (isa(op) && elementTy.isSignlessInteger()) { @@ -376,7 +415,9 @@ static Value createLinalgBodyCalculationForElementwiseOp( // tosa::MinimumOp if (isa(op) && isa(elementTy)) { - return rewriter.create(loc, args[0], args[1]); + auto min = rewriter.create(loc, args[0], args[1]); + return materializeBinaryNanCheckIfRequired(llvm::cast(op), + rewriter, args[0], args[1], min); } if (isa(op) && elementTy.isSignlessInteger()) { @@ -404,7 +445,31 @@ static Value createLinalgBodyCalculationForElementwiseOp( loc, elementTy, rewriter.getFloatAttr(elementTy, minApf)); auto max = rewriter.create( loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf)); - return clampFloatHelper(loc, args[0], min, max, rewriter); + auto result = clampFloatHelper(loc, args[0], min, max, rewriter); + + auto clampOp = llvm::cast(op); + const auto nanMode = clampOp.getNanMode(); + // In the case of "PROPAGATE" semantics no compare and selection is + // required. + if (nanMode == "PROPAGATE") + return result; + + // In the case of "IGNORE" semantics materialize a comparison + // of the current operand to the reduction which will return true for a NaN + // argument and then selects between the initial reduction value and the + // calculated result based on whether the argument is NaN or not. In pseudo + // code: + // + // reduce(x, init): + // result = op(init, x) + // return init if x == NaN else result + + // Unordered comparison of NaN against itself will always return true. + Value isNaN = rewriter.create( + op->getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]); + // TOSA specifies that in "ignore" NaN mode the result is "min" if the input + // is NaN. + return rewriter.create(op->getLoc(), isNaN, min, result); } if (isa(op) && isa(elementTy)) { @@ -1078,7 +1143,8 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op, // Performs the match and rewrite for reduction operations. This includes // declaring a correctly sized initial value, and the linalg.generic operation // that reduces across the specified axis. -static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, +template +static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, PatternRewriter &rewriter) { auto loc = op->getLoc(); auto inputTy = cast(op->getOperand(0).getType()); @@ -1096,6 +1162,9 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, } } + SmallVector inputs, outputs; + inputs.push_back(input); + // First fill the output buffer with the init value. auto emptyTensor = rewriter @@ -1113,26 +1182,127 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, .create(loc, ValueRange{fillValue}, ValueRange{emptyTensor}) .result(); + outputs.push_back(filledTensor); + + bool isNanIgnoreMode = false; + if constexpr (std::is_same_v || + std::is_same_v) { + if (op.getNanMode() == "IGNORE") { + isNanIgnoreMode = true; + // Because the TOSA spec requires the result be NaN iff all elements in + // the reduction are NaN we can't simply perform a compare and select. + // Additionally we have to keep track of whether we've seen any non-NaN + // values and then do a final select based on this predicate. + auto trueAttr = rewriter.getBoolAttr(true); + auto trueValue = rewriter.create(loc, trueAttr); + auto emptyBoolTensor = + rewriter + .create(loc, reduceShape, trueValue.getType(), + dynDims) + .getResult(); + auto allResultsNaNTensor = + rewriter + .create(loc, ValueRange{trueValue}, + ValueRange{emptyBoolTensor}) + .result(); + // Note that because the linalg::ReduceOp has two variadic arguments + // (inputs and outputs) and it has the SameVariadicOperandSize trait we + // need to have the same number of inputs and outputs. + // + // The second input isn't actually used anywhere since the value used to + // update the NaN flag is calculated inside the body of the reduction and + // then used to update an out value. + // In order to satisfy type constraints we just pass another copy of the + // input here. + inputs.push_back(input); + outputs.push_back(allResultsNaNTensor); + } + } bool didEncounterError = false; - auto linalgOp = rewriter.create( - loc, input, filledTensor, axis, + linalg::LinalgOp linalgOp = rewriter.create( + loc, inputs, outputs, axis, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { + std::array binaryArgs{ + blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]}; auto result = createLinalgBodyCalculationForReduceOp( - op, blockArgs, elementTy, rewriter); + op, binaryArgs, elementTy, rewriter); if (result) didEncounterError = true; - nestedBuilder.create(loc, result); + SmallVector resultsToYield; + if (isNanIgnoreMode) { + auto inputValue = blockArgs[0]; + auto initialValue = blockArgs[2]; + auto oldAllResultsNanFlagValue = blockArgs[3]; + + // Unordered comparison of NaN against itself will always return true. + Value isNaN = nestedBuilder.create( + op->getLoc(), arith::CmpFPredicate::UNO, inputValue, inputValue); + // If we've encountered a NaN, take the non-NaN value. + auto selectOp = nestedBuilder.create( + op->getLoc(), isNaN, initialValue, result); + // Update the flag which keeps track of whether we have seen a non-NaN + // value. + auto newAllResultsNanFlagValue = nestedBuilder.create( + op->getLoc(), oldAllResultsNanFlagValue, isNaN); + resultsToYield.push_back(selectOp); + resultsToYield.push_back(newAllResultsNanFlagValue); + } else { + resultsToYield.push_back(result); + } + nestedBuilder.create(loc, resultsToYield); }); if (!didEncounterError) return rewriter.notifyMatchFailure( op, "unable to create linalg.generic body for reduce op"); + if (isNanIgnoreMode) { + // Materialize a check to see whether we encountered any non-NaN values, if + // we didn't we need to select a tensor of NaNs since the result will just + // be the initial identity value propagated through all the compares and + // selects inside the reduction. + + // Create a tensor full of NaNs. + auto nanValueAttr = rewriter.getFloatAttr( + elementTy, + APFloat::getNaN(cast(elementTy).getFloatSemantics(), false)); + auto nanValue = rewriter.create(loc, nanValueAttr); + auto emptyNanTensor = + rewriter + .create(loc, reduceShape, + resultTy.getElementType(), dynDims) + .getResult(); + auto nanFilledTensor = + rewriter + .create(loc, ValueRange{nanValue}, + ValueRange{emptyNanTensor}) + .result(); + + // Create an empty tensor, non need to fill this since it will be + // overwritten by the select. + auto finalEmptyTensor = + rewriter + .create(loc, reduceShape, + resultTy.getElementType(), dynDims) + .getResult(); + + // Do a selection between the tensors akin to: + // result = NaN if "all results NaN" else result. + SmallVector ins, outs; + ins.push_back(linalgOp->getOpResult(1)); + ins.push_back(nanFilledTensor); + ins.push_back(linalgOp->getResult(0)); + outs.push_back(finalEmptyTensor); + auto linalgSelect = + rewriter.create(op->getLoc(), ins, outs); + linalgOp = linalgSelect; + } + SmallVector reassociationMap; uint64_t expandInputRank = - cast(linalgOp.getResults()[0].getType()).getRank(); + cast(linalgOp->getResults()[0].getType()).getRank(); reassociationMap.resize(expandInputRank); for (uint64_t i = 0; i < expandInputRank; i++) { @@ -1151,7 +1321,7 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, // not have access to such information. This matters when handling dynamically // sized tensors. rewriter.replaceOpWithNewOp( - op, resultTy, linalgOp.getResults()[0], reassociationMap); + op, resultTy, linalgOp->getResults()[0], reassociationMap); return success(); } @@ -2097,6 +2267,27 @@ class ArgMaxConverter : public OpRewritePattern { nestedLoc, predicate, newValue, oldValue); auto resultIndex = rewriter.create( nestedLoc, predicate, newIndex, oldIndex); + + // Check if we need to materialize compare and select for the given + // NaN propagation mode. + + // "PROPAGATE" matches the default NaN propagation mode of the arith + // dialect so no compare and select is required. + // + // In the case "IGNORE" we check if the current argument is NaN and + // select the old index and value otherwise take the updated index and + // value. + if (const auto nanMode = argmaxOp.getNanMode(); nanMode == "IGNORE") { + // Unordered comparison of NaN against itself will always return + // true. + Value isNaN = rewriter.create( + argmaxOp.getLoc(), arith::CmpFPredicate::UNO, newValue, + newValue); + resultMax = rewriter.create(nestedLoc, isNaN, + oldValue, resultMax); + resultIndex = rewriter.create( + nestedLoc, isNaN, oldIndex, resultIndex); + } nestedBuilder.create( nestedLoc, ValueRange({resultIndex, resultMax})); }); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index a8fd536dd2548..620b5f95825f6 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -724,11 +724,44 @@ class MaxPool2dConverter : public OpConversionPattern { rewriter.replaceOpWithNewOp( op, ArrayRef{resultTy}, ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr, dilationAttr); - } else { - rewriter.replaceOpWithNewOp( - op, ArrayRef{resultTy}, ValueRange{paddedInput, fakeWindowDims}, - filledEmptyTensor, strideAttr, dilationAttr); + return llvm::success(); } + + auto resultOp = rewriter.create( + op->getLoc(), ArrayRef{resultTy}, + ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr, + dilationAttr); + + rewriter.replaceOp(op, resultOp); + // "PROPAGATE" mode matches the behaviour of the LinAlg named op, so no + // compare and select materialization is required. + // + // In the case of "IGNORE" we need to insert a compare and select. Since + // we've already produced a named op we will just take its body and modify + // it to include the appropriate checks. If the current value is NaN the + // old value of pool will be taken otherwise we use the result. + if (const auto nanMode = op.getNanMode(); nanMode == "IGNORE") { + auto genericOp = rewriter.create( + op->getLoc(), resultOp.getType(0), resultOp.getInputs(), + resultOp.getOutputs(), resultOp.getIndexingMapsArray(), + resultOp.getIteratorTypesArray(), + [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) { + IRMapping map; + auto oldBlock = resultOp.getRegion().begin(); + auto oldArgs = oldBlock->getArguments(); + auto &oldMaxOp = *resultOp.getBlock()->begin(); + map.map(oldArgs, blockArgs); + auto *newOp = opBuilder.clone(oldMaxOp, map); + Value isNaN = opBuilder.create( + op->getLoc(), arith::CmpFPredicate::UNO, blockArgs.front(), + blockArgs.front()); + auto selectOp = opBuilder.create( + op->getLoc(), isNaN, blockArgs.back(), newOp->getResult(0)); + opBuilder.create(loc, selectOp.getResult()); + }); + rewriter.replaceOp(resultOp, genericOp); + } + return success(); } }; diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir index a524359b49759..980805ad94b7a 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -906,3 +906,27 @@ func.func @test_transpose_dyn_multiple_3d(%arg0: tensor) { %1 = "tosa.transpose"(%arg0, %0) : (tensor, tensor<3xi32>) -> tensor return } + +// ----- + +// CHECK-LABEL: @max_pool2d_nan_propagate +func.func @max_pool2d_nan_propagate(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x4x32x62xf32>) { + // CHECK: linalg.pooling_nhwc_max + // CHECK-NOT: linalg.generic + %0 = tosa.max_pool2d %arg0 {pad = array, kernel = array, stride = array, nan_mode = "PROPAGATE"} : (tensor<1x6x34x62xf32>) -> tensor<1x4x32x62xf32> + return %0 : tensor<1x4x32x62xf32> +} + +// ----- + +// CHECK-LABEL: @max_pool2d_nan_ignore +func.func @max_pool2d_nan_ignore(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x4x32x62xf32>) { + // CHECK-NOT: linalg.pooling_nhwc_max + // CHECK: linalg.generic + // CHECK: arith.maximumf + // CHECK: arith.cmpf uno + // CHECK: arith.select + // CHECK: linalg.yield + %0 = tosa.max_pool2d %arg0 {pad = array, kernel = array, stride = array, nan_mode = "IGNORE"} : (tensor<1x6x34x62xf32>) -> tensor<1x4x32x62xf32> + return %0: tensor<1x4x32x62xf32> +} diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 17add2d41afe7..86e6f9ed9264b 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1992,3 +1992,195 @@ func.func @test_cast_fp32_i64(%arg0: tensor<1xf32>) -> (tensor<1xi64>) { %0 = tosa.cast %arg0 : (tensor<1xf32>) -> tensor<1xi64> return %0: tensor<1xi64> } + +// ----- + +// CHECK-LABEL: @reduce_min_nan_propagate +func.func @reduce_min_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () { + // CHECK: linalg.reduce + // CHECK: arith.minimumf + // CHECK-NOT: arith.cmpf uno + // CHECK-NOT: arith.select + // CHECK: linalg.yield + // CHECK-NOT: arith.constant 0x7FC00000 + // CHECK-NOT: tensor.empty() + // CHECK-NOT: linalg.fill + // CHECK-NOT: tensor.empty() + // CHECK-NOT: select + // CHECK: return + %3 = tosa.reduce_min %arg0 {axis = 0 : i32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>) -> tensor<1x4xf32> + return +} + +// ----- + +// CHECK-LABEL: @reduce_max_nan_propagate +func.func @reduce_max_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () { + // CHECK: linalg.reduce + // CHECK: arith.maximumf + // CHECK-NOT: arith.cmpf uno + // CHECK-NOT: arith.select + // CHECK: linalg.yield + // CHECK-NOT: arith.constant 0x7FC00000 + // CHECK-NOT: tensor.empty() + // CHECK-NOT: linalg.fill + // CHECK-NOT: tensor.empty() + // CHECK-NOT: select + // CHECK: return + %4 = tosa.reduce_max %arg0 {axis = 0 : i32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>) -> tensor<1x4xf32> + return +} + +// ----- + +// CHECK-LABEL: @reduce_min_nan_ignore +func.func @reduce_min_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () { + // CHECK: linalg.reduce + // CHECK: arith.minimumf + // CHECK: arith.cmpf uno + // CHECK: arith.select + // CHECK: linalg.yield + // CHECK: arith.constant 0x7FC00000 + // CHECK: tensor.empty() + // CHECK: linalg.fill + // CHECK: tensor.empty() + // CHECK: select + %5 = tosa.reduce_min %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xf32>) -> tensor<1x4xf32> + return +} + +// ----- + +// CHECK-LABEL: @reduce_max_nan_ignore +func.func @reduce_max_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () { + // CHECK: linalg.reduce + // CHECK: arith.maximumf + // CHECK: arith.cmpf uno + // CHECK: arith.select + // CHECK: linalg.yield + // CHECK: arith.constant 0x7FC00000 + // CHECK: tensor.empty() + // CHECK: linalg.fill + // CHECK: tensor.empty() + // CHECK: select + %6 = tosa.reduce_max %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xf32>) -> tensor<1x4xf32> + return +} + +// ----- + +// CHECK-LABEL: @minimum_nan_propagate +func.func @minimum_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () { + // CHECK: linalg.generic + // CHECK: arith.minimumf + // CHECK-NOT: arith.cmpf uno + // CHECK-NOT: arith.select + // CHECK: linalg.yield + %7 = tosa.minimum %arg0, %arg1 {nan_mode = "PROPAGATE"} : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32> + return +} + +// ----- + +// CHECK-LABEL: @maximum_nan_propagate +func.func @maximum_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () { + // CHECK: linalg.generic + // CHECK: arith.maximumf + // CHECK-NOT: arith.cmpf uno + // CHECK-NOT: arith.select + // CHECK: linalg.yield + %8 = tosa.maximum %arg0, %arg1 {nan_mode = "PROPAGATE"} : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32> + return +} + +// ----- + +// CHECK-LABEL: @minimum_nan_ignore +func.func @minimum_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () { + // CHECK: linalg.generic + // CHECK: arith.minimumf + // CHECK: arith.cmpf uno + // CHECK: arith.cmpf uno + // CHECK: arith.select + // CHECK: arith.select + // CHECK: linalg.yield + %9 = tosa.minimum %arg0, %arg1 {nan_mode = "IGNORE"} : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32> + return +} + +// ----- + +// CHECK-LABEL: @maximum_nan_ignore +func.func @maximum_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () { + // CHECK: linalg.generic + // CHECK: arith.maximumf + // CHECK: arith.cmpf uno + // CHECK: arith.cmpf uno + // CHECK: arith.select + // CHECK: arith.select + // CHECK: linalg.yield + %10 = tosa.maximum %arg0, %arg1 {nan_mode = "IGNORE"} : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32> + return +} + +// ----- + +// CHECK-LABEL: @argmax_nan_propagate +func.func @argmax_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () { + // CHECK: linalg.generic + // CHECK: arith.cmpf ogt + // CHECK: arith.select + // CHECK: arith.select + // CHECK-NOT: arith.cmpf uno + // CHECK-NOT: arith.cmpf uno + // CHECK-NOT: arith.select + // CHECK-NOT: arith.select + // CHECK: linalg.yield + %11 = tosa.argmax %arg0 {axis = 0 : i32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>) -> tensor<4xi32> + return +} + +// ----- + +// CHECK-LABEL: @argmax_nan_ignore +func.func @argmax_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () { + // CHECK: linalg.generic + // CHECK: arith.cmpf ogt + // CHECK: arith.select + // CHECK: arith.select + // CHECK: arith.cmpf uno + // CHECK: arith.select + // CHECK: arith.select + // CHECK: linalg.yield + %12 = tosa.argmax %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xf32>) -> tensor<4xi32> + return +} + +// ----- + +// CHECK-LABEL: @clamp_nan_propagate +func.func @clamp_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () { + // CHECK: linalg.generic + // CHECK: arith.minimumf + // CHECK: arith.maximumf + // CHECK-NOT: arith.cmpf uno + // CHECK-NOT: arith.select + // CHECK: linalg.yield + %13 = tosa.clamp %arg0 {min_val = 1.0 : f32, max_val = 5.0 : f32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>) -> tensor<5x4xf32> + return +} + +// ----- + +// CHECK-LABEL: @clamp_nan_ignore +func.func @clamp_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () { + // CHECK: linalg.generic + // CHECK: arith.minimumf + // CHECK: arith.maximumf + // CHECK: arith.cmpf uno + // CHECK: arith.select + // CHECK: linalg.yield + %14 = tosa.clamp %arg0 {min_val = 1.0 : f32, max_val = 5.0 : f32, nan_mode = "IGNORE"} : (tensor<5x4xf32>) -> tensor<5x4xf32> + + return +}