Skip to content

Commit cf3b036

Browse files
[mlir][tosa][tosa-to-linalg] Add NaN Mode Lowering (#125668)
Add support for NaN propagation lowering in the `tosa-to-linalg` and `tosa-to-linalg-named` conversions by conditionally checking for NaN in the case of ignore semantics and materializing the appropriate select operations. Note that the default behviour of "propagate" matches that of the arith dialect and so in that case we can avoid creating the checks altogether. Add appropriate lit tests including negative tests which check the various comparisons and selects are materialized as appropriate. This affects the following TOSA operators: * arg_max * max_pool_2d * clamp * reduce_max * reduce_min * maximum * minimum Signed-off-by: Jack Frankland <[email protected]>
1 parent 83c6b1a commit cf3b036

File tree

4 files changed

+454
-14
lines changed

4 files changed

+454
-14
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 201 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,47 @@
3232
#include "llvm/ADT/Sequence.h"
3333

3434
#include <numeric>
35+
#include <type_traits>
3536

3637
using namespace mlir;
3738
using namespace mlir::tosa;
3839

40+
// Helper function to materialize the semantically correct compare and select
41+
// operations given a binary operation with a specific NaN propagation mode.
42+
//
43+
// In the case of "PROPAGATE" semantics no compare and selection is required and
44+
// this function does nothing.
45+
//
46+
// In the case of "IGNORE" semantics this function materializes a comparison of
47+
// the current operands to the op which will return true for any NaN
48+
// argument and then selects between the non-NaN operation argument and the
49+
// calculated result based on whether the lhs or rhs is NaN or not. In pseudo
50+
// code:
51+
//
52+
// binary<op>(lhs, rhs):
53+
// result = op(lhs, rhs)
54+
// if lhs == NaN return rhs
55+
// if rhs == NaN return lhs
56+
// return result
57+
template <typename OpTy>
58+
static Value
59+
materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
60+
Value lhs, Value rhs, Value result) {
61+
auto nanMode = op.getNanMode();
62+
if (nanMode == "PROPAGATE")
63+
return result;
64+
65+
// Unordered comparison of NaN against itself will always return true.
66+
Value lhsIsNaN = rewriter.create<arith::CmpFOp>(
67+
op.getLoc(), arith::CmpFPredicate::UNO, lhs, lhs);
68+
Value rhsIsNaN = rewriter.create<arith::CmpFOp>(
69+
op.getLoc(), arith::CmpFPredicate::UNO, rhs, rhs);
70+
Value rhsOrResult =
71+
rewriter.create<arith::SelectOp>(op.getLoc(), lhsIsNaN, rhs, result);
72+
return rewriter.create<arith::SelectOp>(op.getLoc(), rhsIsNaN, lhs,
73+
rhsOrResult);
74+
}
75+
3976
template <typename T>
4077
static arith::ConstantOp
4178
createConstFromIntAttribute(Operation *op, const std::string &attrName,
@@ -367,7 +404,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
367404

368405
// tosa::MaximumOp
369406
if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
370-
return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
407+
auto max = rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
408+
return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MaximumOp>(op),
409+
rewriter, args[0], args[1], max);
371410
}
372411

373412
if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
@@ -376,7 +415,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
376415

377416
// tosa::MinimumOp
378417
if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
379-
return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
418+
auto min = rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
419+
return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MinimumOp>(op),
420+
rewriter, args[0], args[1], min);
380421
}
381422

382423
if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
@@ -404,7 +445,31 @@ static Value createLinalgBodyCalculationForElementwiseOp(
404445
loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
405446
auto max = rewriter.create<arith::ConstantOp>(
406447
loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
407-
return clampFloatHelper(loc, args[0], min, max, rewriter);
448+
auto result = clampFloatHelper(loc, args[0], min, max, rewriter);
449+
450+
auto clampOp = llvm::cast<tosa::ClampOp>(op);
451+
const auto nanMode = clampOp.getNanMode();
452+
// In the case of "PROPAGATE" semantics no compare and selection is
453+
// required.
454+
if (nanMode == "PROPAGATE")
455+
return result;
456+
457+
// In the case of "IGNORE" semantics materialize a comparison
458+
// of the current operand to the reduction which will return true for a NaN
459+
// argument and then selects between the initial reduction value and the
460+
// calculated result based on whether the argument is NaN or not. In pseudo
461+
// code:
462+
//
463+
// reduce<op>(x, init):
464+
// result = op(init, x)
465+
// return init if x == NaN else result
466+
467+
// Unordered comparison of NaN against itself will always return true.
468+
Value isNaN = rewriter.create<arith::CmpFOp>(
469+
op->getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
470+
// TOSA specifies that in "ignore" NaN mode the result is "min" if the input
471+
// is NaN.
472+
return rewriter.create<arith::SelectOp>(op->getLoc(), isNaN, min, result);
408473
}
409474

410475
if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
@@ -1078,7 +1143,8 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
10781143
// Performs the match and rewrite for reduction operations. This includes
10791144
// declaring a correctly sized initial value, and the linalg.generic operation
10801145
// that reduces across the specified axis.
1081-
static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
1146+
template <typename OpTy>
1147+
static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
10821148
PatternRewriter &rewriter) {
10831149
auto loc = op->getLoc();
10841150
auto inputTy = cast<ShapedType>(op->getOperand(0).getType());
@@ -1096,6 +1162,9 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
10961162
}
10971163
}
10981164

1165+
SmallVector<Value> inputs, outputs;
1166+
inputs.push_back(input);
1167+
10991168
// First fill the output buffer with the init value.
11001169
auto emptyTensor =
11011170
rewriter
@@ -1113,26 +1182,127 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
11131182
.create<linalg::FillOp>(loc, ValueRange{fillValue},
11141183
ValueRange{emptyTensor})
11151184
.result();
1185+
outputs.push_back(filledTensor);
1186+
1187+
bool isNanIgnoreMode = false;
1188+
if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
1189+
std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
1190+
if (op.getNanMode() == "IGNORE") {
1191+
isNanIgnoreMode = true;
1192+
// Because the TOSA spec requires the result be NaN iff all elements in
1193+
// the reduction are NaN we can't simply perform a compare and select.
1194+
// Additionally we have to keep track of whether we've seen any non-NaN
1195+
// values and then do a final select based on this predicate.
1196+
auto trueAttr = rewriter.getBoolAttr(true);
1197+
auto trueValue = rewriter.create<arith::ConstantOp>(loc, trueAttr);
1198+
auto emptyBoolTensor =
1199+
rewriter
1200+
.create<tensor::EmptyOp>(loc, reduceShape, trueValue.getType(),
1201+
dynDims)
1202+
.getResult();
1203+
auto allResultsNaNTensor =
1204+
rewriter
1205+
.create<linalg::FillOp>(loc, ValueRange{trueValue},
1206+
ValueRange{emptyBoolTensor})
1207+
.result();
1208+
// Note that because the linalg::ReduceOp has two variadic arguments
1209+
// (inputs and outputs) and it has the SameVariadicOperandSize trait we
1210+
// need to have the same number of inputs and outputs.
1211+
//
1212+
// The second input isn't actually used anywhere since the value used to
1213+
// update the NaN flag is calculated inside the body of the reduction and
1214+
// then used to update an out value.
1215+
// In order to satisfy type constraints we just pass another copy of the
1216+
// input here.
1217+
inputs.push_back(input);
1218+
outputs.push_back(allResultsNaNTensor);
1219+
}
1220+
}
11161221

11171222
bool didEncounterError = false;
1118-
auto linalgOp = rewriter.create<linalg::ReduceOp>(
1119-
loc, input, filledTensor, axis,
1223+
linalg::LinalgOp linalgOp = rewriter.create<linalg::ReduceOp>(
1224+
loc, inputs, outputs, axis,
11201225
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
1226+
std::array<Value, 2> binaryArgs{
1227+
blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
11211228
auto result = createLinalgBodyCalculationForReduceOp(
1122-
op, blockArgs, elementTy, rewriter);
1229+
op, binaryArgs, elementTy, rewriter);
11231230
if (result)
11241231
didEncounterError = true;
11251232

1126-
nestedBuilder.create<linalg::YieldOp>(loc, result);
1233+
SmallVector<Value> resultsToYield;
1234+
if (isNanIgnoreMode) {
1235+
auto inputValue = blockArgs[0];
1236+
auto initialValue = blockArgs[2];
1237+
auto oldAllResultsNanFlagValue = blockArgs[3];
1238+
1239+
// Unordered comparison of NaN against itself will always return true.
1240+
Value isNaN = nestedBuilder.create<arith::CmpFOp>(
1241+
op->getLoc(), arith::CmpFPredicate::UNO, inputValue, inputValue);
1242+
// If we've encountered a NaN, take the non-NaN value.
1243+
auto selectOp = nestedBuilder.create<arith::SelectOp>(
1244+
op->getLoc(), isNaN, initialValue, result);
1245+
// Update the flag which keeps track of whether we have seen a non-NaN
1246+
// value.
1247+
auto newAllResultsNanFlagValue = nestedBuilder.create<arith::AndIOp>(
1248+
op->getLoc(), oldAllResultsNanFlagValue, isNaN);
1249+
resultsToYield.push_back(selectOp);
1250+
resultsToYield.push_back(newAllResultsNanFlagValue);
1251+
} else {
1252+
resultsToYield.push_back(result);
1253+
}
1254+
nestedBuilder.create<linalg::YieldOp>(loc, resultsToYield);
11271255
});
11281256

11291257
if (!didEncounterError)
11301258
return rewriter.notifyMatchFailure(
11311259
op, "unable to create linalg.generic body for reduce op");
11321260

1261+
if (isNanIgnoreMode) {
1262+
// Materialize a check to see whether we encountered any non-NaN values, if
1263+
// we didn't we need to select a tensor of NaNs since the result will just
1264+
// be the initial identity value propagated through all the compares and
1265+
// selects inside the reduction.
1266+
1267+
// Create a tensor full of NaNs.
1268+
auto nanValueAttr = rewriter.getFloatAttr(
1269+
elementTy,
1270+
APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false));
1271+
auto nanValue = rewriter.create<arith::ConstantOp>(loc, nanValueAttr);
1272+
auto emptyNanTensor =
1273+
rewriter
1274+
.create<tensor::EmptyOp>(loc, reduceShape,
1275+
resultTy.getElementType(), dynDims)
1276+
.getResult();
1277+
auto nanFilledTensor =
1278+
rewriter
1279+
.create<linalg::FillOp>(loc, ValueRange{nanValue},
1280+
ValueRange{emptyNanTensor})
1281+
.result();
1282+
1283+
// Create an empty tensor, non need to fill this since it will be
1284+
// overwritten by the select.
1285+
auto finalEmptyTensor =
1286+
rewriter
1287+
.create<tensor::EmptyOp>(loc, reduceShape,
1288+
resultTy.getElementType(), dynDims)
1289+
.getResult();
1290+
1291+
// Do a selection between the tensors akin to:
1292+
// result = NaN if "all results NaN" else result.
1293+
SmallVector<Value> ins, outs;
1294+
ins.push_back(linalgOp->getOpResult(1));
1295+
ins.push_back(nanFilledTensor);
1296+
ins.push_back(linalgOp->getResult(0));
1297+
outs.push_back(finalEmptyTensor);
1298+
auto linalgSelect =
1299+
rewriter.create<linalg::SelectOp>(op->getLoc(), ins, outs);
1300+
linalgOp = linalgSelect;
1301+
}
1302+
11331303
SmallVector<ReassociationExprs, 4> reassociationMap;
11341304
uint64_t expandInputRank =
1135-
cast<ShapedType>(linalgOp.getResults()[0].getType()).getRank();
1305+
cast<ShapedType>(linalgOp->getResults()[0].getType()).getRank();
11361306
reassociationMap.resize(expandInputRank);
11371307

11381308
for (uint64_t i = 0; i < expandInputRank; i++) {
@@ -1151,7 +1321,7 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
11511321
// not have access to such information. This matters when handling dynamically
11521322
// sized tensors.
11531323
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1154-
op, resultTy, linalgOp.getResults()[0], reassociationMap);
1324+
op, resultTy, linalgOp->getResults()[0], reassociationMap);
11551325
return success();
11561326
}
11571327

@@ -2097,6 +2267,27 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
20972267
nestedLoc, predicate, newValue, oldValue);
20982268
auto resultIndex = rewriter.create<arith::SelectOp>(
20992269
nestedLoc, predicate, newIndex, oldIndex);
2270+
2271+
// Check if we need to materialize compare and select for the given
2272+
// NaN propagation mode.
2273+
2274+
// "PROPAGATE" matches the default NaN propagation mode of the arith
2275+
// dialect so no compare and select is required.
2276+
//
2277+
// In the case "IGNORE" we check if the current argument is NaN and
2278+
// select the old index and value otherwise take the updated index and
2279+
// value.
2280+
if (const auto nanMode = argmaxOp.getNanMode(); nanMode == "IGNORE") {
2281+
// Unordered comparison of NaN against itself will always return
2282+
// true.
2283+
Value isNaN = rewriter.create<arith::CmpFOp>(
2284+
argmaxOp.getLoc(), arith::CmpFPredicate::UNO, newValue,
2285+
newValue);
2286+
resultMax = rewriter.create<arith::SelectOp>(nestedLoc, isNaN,
2287+
oldValue, resultMax);
2288+
resultIndex = rewriter.create<arith::SelectOp>(
2289+
nestedLoc, isNaN, oldIndex, resultIndex);
2290+
}
21002291
nestedBuilder.create<linalg::YieldOp>(
21012292
nestedLoc, ValueRange({resultIndex, resultMax}));
21022293
});

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -724,11 +724,44 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
724724
rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(
725725
op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
726726
filledEmptyTensor, strideAttr, dilationAttr);
727-
} else {
728-
rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
729-
op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
730-
filledEmptyTensor, strideAttr, dilationAttr);
727+
return llvm::success();
731728
}
729+
730+
auto resultOp = rewriter.create<linalg::PoolingNhwcMaxOp>(
731+
op->getLoc(), ArrayRef<Type>{resultTy},
732+
ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr,
733+
dilationAttr);
734+
735+
rewriter.replaceOp(op, resultOp);
736+
// "PROPAGATE" mode matches the behaviour of the LinAlg named op, so no
737+
// compare and select materialization is required.
738+
//
739+
// In the case of "IGNORE" we need to insert a compare and select. Since
740+
// we've already produced a named op we will just take its body and modify
741+
// it to include the appropriate checks. If the current value is NaN the
742+
// old value of pool will be taken otherwise we use the result.
743+
if (const auto nanMode = op.getNanMode(); nanMode == "IGNORE") {
744+
auto genericOp = rewriter.create<linalg::GenericOp>(
745+
op->getLoc(), resultOp.getType(0), resultOp.getInputs(),
746+
resultOp.getOutputs(), resultOp.getIndexingMapsArray(),
747+
resultOp.getIteratorTypesArray(),
748+
[&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
749+
IRMapping map;
750+
auto oldBlock = resultOp.getRegion().begin();
751+
auto oldArgs = oldBlock->getArguments();
752+
auto &oldMaxOp = *resultOp.getBlock()->begin();
753+
map.map(oldArgs, blockArgs);
754+
auto *newOp = opBuilder.clone(oldMaxOp, map);
755+
Value isNaN = opBuilder.create<arith::CmpFOp>(
756+
op->getLoc(), arith::CmpFPredicate::UNO, blockArgs.front(),
757+
blockArgs.front());
758+
auto selectOp = opBuilder.create<arith::SelectOp>(
759+
op->getLoc(), isNaN, blockArgs.back(), newOp->getResult(0));
760+
opBuilder.create<linalg::YieldOp>(loc, selectOp.getResult());
761+
});
762+
rewriter.replaceOp(resultOp, genericOp);
763+
}
764+
732765
return success();
733766
}
734767
};

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -906,3 +906,27 @@ func.func @test_transpose_dyn_multiple_3d(%arg0: tensor<?x?x?xf32>) {
906906
%1 = "tosa.transpose"(%arg0, %0) : (tensor<?x?x?xf32>, tensor<3xi32>) -> tensor<?x?x?xf32>
907907
return
908908
}
909+
910+
// -----
911+
912+
// CHECK-LABEL: @max_pool2d_nan_propagate
913+
func.func @max_pool2d_nan_propagate(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x4x32x62xf32>) {
914+
// CHECK: linalg.pooling_nhwc_max
915+
// CHECK-NOT: linalg.generic
916+
%0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>, nan_mode = "PROPAGATE"} : (tensor<1x6x34x62xf32>) -> tensor<1x4x32x62xf32>
917+
return %0 : tensor<1x4x32x62xf32>
918+
}
919+
920+
// -----
921+
922+
// CHECK-LABEL: @max_pool2d_nan_ignore
923+
func.func @max_pool2d_nan_ignore(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x4x32x62xf32>) {
924+
// CHECK-NOT: linalg.pooling_nhwc_max
925+
// CHECK: linalg.generic
926+
// CHECK: arith.maximumf
927+
// CHECK: arith.cmpf uno
928+
// CHECK: arith.select
929+
// CHECK: linalg.yield
930+
%0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>, nan_mode = "IGNORE"} : (tensor<1x6x34x62xf32>) -> tensor<1x4x32x62xf32>
931+
return %0: tensor<1x4x32x62xf32>
932+
}

0 commit comments

Comments
 (0)