Skip to content

Commit 96ae7bc

Browse files
committed
[mlir][tosa][tosa-to-linalg] Add NaN Mode Lowering
Add support for NaN propagation lowering in the `tosa-to-linalg` and `tosa-to-linalg-named` conversions by conditionally checking for NaN in the case of ignore semantics and materializing the appropriate select operations. Note that the default behviour of "propagate" matches that of the arith dialect and so in that case we can avoid creating the checks altogether. Add appropriate lit tests including negative tests which check the various comparisons and selects are materialized as appropriate. This affects the following TOSA operators: * arg_max * max_pool_2d * clamp * reduce_max * reduce_min * maximum * minimum Signed-off-by: Jack Frankland <[email protected]>
1 parent 499d6da commit 96ae7bc

File tree

5 files changed

+482
-13
lines changed

5 files changed

+482
-13
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,26 @@ extractConvZpPair(TosaConvOp op, PatternRewriter &rewriter) {
267267

268268
return std::make_optional<ConvZpPair>(*maybeInputZp, *maybeWeightZp);
269269
}
270+
271+
// Helper function to extract the NaN propagation mode from an operation.
272+
// Note that the for operations which support NaN mode propagation the attribute
273+
// is optional and its default value is "PROPAGATE".
274+
//
275+
// If the function is called with an operator that doesn't support the NaN mode
276+
// attribute it will return a std::nullopt.
277+
inline std::optional<std::string> getNanMode(Operation *op,
278+
PatternRewriter &rewriter) {
279+
if (isa<tosa::ClampOp>(op) || isa<tosa::MaxPool2dOp>(op) ||
280+
isa<tosa::ReduceMinOp>(op) || isa<tosa::ReduceMaxOp>(op) ||
281+
isa<tosa::MaximumOp>(op) || isa<tosa::MinimumOp>(op) ||
282+
isa<tosa::ArgMaxOp>(op))
283+
return op->hasAttr("nan_mode") ? op->getAttrOfType<StringAttr>(
284+
rewriter.getStringAttr("nan_mode"))
285+
.str()
286+
: "PROPAGATE";
287+
return std::nullopt;
288+
}
289+
270290
} // namespace tosa
271291
} // namespace mlir
272292

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 208 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,47 @@
3636
using namespace mlir;
3737
using namespace mlir::tosa;
3838

39+
// Helper function to materialize the semantically correct compare and select
40+
// operations a binary operation with a specific NaN propagation mode.
41+
//
42+
// In the case of "PROPAGATE" semantics no compare and selection is required and
43+
// this function does nothing.
44+
//
45+
// In the case of "IGNORE" semantics this function materializes a comparison of
46+
// the current operands to the op which will return true for any NaN
47+
// argument and then selects between the non-NaN operation argument and the
48+
// calculated result based on whether the lhs or rhs is NaN or not. In pseudo
49+
// code:
50+
//
51+
// binary<op>(lhs, rhs):
52+
// result = op(lhs, rhs)
53+
// if lhs == NaN return rhs
54+
// if rhs == NaN return lhs
55+
// return result
56+
static Value materializeBinaryNanCheckIfRequired(Operation *op,
57+
PatternRewriter &rewriter,
58+
Value lhs, Value rhs,
59+
Value result) {
60+
const auto nanMode = getNanMode(op, rewriter);
61+
if (!nanMode)
62+
return {};
63+
64+
if (*nanMode == "PROPAGATE")
65+
return result;
66+
67+
assert(*nanMode == "IGNORE" && "Unhandled nan-propagation mode");
68+
69+
// Unordered comparison of NaN against itself will always return true.
70+
Value lhsIsNaN = rewriter.create<arith::CmpFOp>(
71+
op->getLoc(), arith::CmpFPredicate::UNO, lhs, lhs);
72+
Value rhsIsNaN = rewriter.create<arith::CmpFOp>(
73+
op->getLoc(), arith::CmpFPredicate::UNO, rhs, rhs);
74+
Value rhsOrResult =
75+
rewriter.create<arith::SelectOp>(op->getLoc(), lhsIsNaN, rhs, result);
76+
return rewriter.create<arith::SelectOp>(op->getLoc(), rhsIsNaN, lhs,
77+
rhsOrResult);
78+
}
79+
3980
template <typename T>
4081
static arith::ConstantOp
4182
createConstFromIntAttribute(Operation *op, const std::string &attrName,
@@ -367,7 +408,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
367408

368409
// tosa::MaximumOp
369410
if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
370-
return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
411+
auto max = rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
412+
return materializeBinaryNanCheckIfRequired(op, rewriter, args[0], args[1],
413+
max);
371414
}
372415

373416
if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
@@ -376,7 +419,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
376419

377420
// tosa::MinimumOp
378421
if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
379-
return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
422+
auto min = rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
423+
return materializeBinaryNanCheckIfRequired(op, rewriter, args[0], args[1],
424+
min);
380425
}
381426

382427
if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
@@ -404,7 +449,34 @@ static Value createLinalgBodyCalculationForElementwiseOp(
404449
loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
405450
auto max = rewriter.create<arith::ConstantOp>(
406451
loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
407-
return clampFloatHelper(loc, args[0], min, max, rewriter);
452+
auto result = clampFloatHelper(loc, args[0], min, max, rewriter);
453+
454+
const auto nanMode = getNanMode(op, rewriter);
455+
if (!nanMode)
456+
return {};
457+
458+
// In the case of "PROPAGATE" semantics no compare and selection is
459+
// required.
460+
if (*nanMode == "PROPAGATE")
461+
return result;
462+
463+
// In the case of "IGNORE" semantics materialize a comparison
464+
// of the current operand to the reduction which will return true for a NaN
465+
// argument and then selects between the initial reduction value and the
466+
// calculated result based on whether the argument is NaN or not. In pseudo
467+
// code:
468+
//
469+
// reduce<op>(x, init):
470+
// result = op(init, x)
471+
// return init if x == NaN else result
472+
assert(*nanMode == "IGNORE" && "Unhandled nan-propagation mode");
473+
474+
// Unordered comparison of NaN against itself will always return true.
475+
Value isNaN = rewriter.create<arith::CmpFOp>(
476+
op->getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
477+
// TOSA specifies that in "ignore" NaN mode the result is "min" if the input
478+
// is NaN.
479+
return rewriter.create<arith::SelectOp>(op->getLoc(), isNaN, min, result);
408480
}
409481

410482
if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
@@ -1096,6 +1168,9 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
10961168
}
10971169
}
10981170

1171+
SmallVector<Value> inputs, outputs;
1172+
inputs.push_back(input);
1173+
10991174
// First fill the output buffer with the init value.
11001175
auto emptyTensor =
11011176
rewriter
@@ -1113,26 +1188,124 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
11131188
.create<linalg::FillOp>(loc, ValueRange{fillValue},
11141189
ValueRange{emptyTensor})
11151190
.result();
1191+
outputs.push_back(filledTensor);
1192+
1193+
const auto nanMode = getNanMode(op, rewriter);
1194+
const bool isNanIgnoreMode = (nanMode && *nanMode == "IGNORE");
1195+
if (isNanIgnoreMode) {
1196+
// Because the TOSA spec requires the result be NaN iff all elements in the
1197+
// reduction are NaN we can't simply perform a compare and select.
1198+
// Additionally we have to keep track of whether we've seen any non-NaN
1199+
// values and then do a final select based on this predicate.
1200+
auto trueAttr = rewriter.getBoolAttr(true);
1201+
auto trueValue = rewriter.create<arith::ConstantOp>(loc, trueAttr);
1202+
auto emptyBoolTensor =
1203+
rewriter
1204+
.create<tensor::EmptyOp>(loc, reduceShape, trueValue.getType(),
1205+
dynDims)
1206+
.getResult();
1207+
auto allResultsNaNTensor =
1208+
rewriter
1209+
.create<linalg::FillOp>(loc, ValueRange{trueValue},
1210+
ValueRange{emptyBoolTensor})
1211+
.result();
1212+
// Note that because the linalg::ReduceOp has two variadic arguments (inputs
1213+
// and outputs) and it has the SameVariadicOperandSize trait we need to have
1214+
// the same number of inputs and outputs.
1215+
//
1216+
// The second input isn't actully used anywhere since the value used to
1217+
// update the NaN flag is calculated inside the body of the reduction and
1218+
// then used to update an out value.
1219+
// In order to satisfy type constraints we just pass another copy of the
1220+
// input here.
1221+
inputs.push_back(input);
1222+
outputs.push_back(allResultsNaNTensor);
1223+
}
11161224

11171225
bool didEncounterError = false;
1118-
auto linalgOp = rewriter.create<linalg::ReduceOp>(
1119-
loc, input, filledTensor, axis,
1226+
linalg::LinalgOp linalgOp = rewriter.create<linalg::ReduceOp>(
1227+
loc, inputs, outputs, axis,
11201228
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
1229+
std::array<Value, 2> binaryArgs{
1230+
blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
11211231
auto result = createLinalgBodyCalculationForReduceOp(
1122-
op, blockArgs, elementTy, rewriter);
1232+
op, binaryArgs, elementTy, rewriter);
11231233
if (result)
11241234
didEncounterError = true;
11251235

1126-
nestedBuilder.create<linalg::YieldOp>(loc, result);
1236+
SmallVector<Value> resultsToYield;
1237+
if (isNanIgnoreMode) {
1238+
auto inputValue = blockArgs[0];
1239+
auto initialValue = blockArgs[2];
1240+
auto oldAllResultsNanFlagValue = blockArgs[3];
1241+
1242+
// Unordered comparison of NaN against itself will always return true.
1243+
Value isNaN = nestedBuilder.create<arith::CmpFOp>(
1244+
op->getLoc(), arith::CmpFPredicate::UNO, inputValue, inputValue);
1245+
// If we've encountered a NaN, take the non-NaN value.
1246+
auto selectOp = nestedBuilder.create<arith::SelectOp>(
1247+
op->getLoc(), isNaN, initialValue, result);
1248+
// Update the flag which keeps track of whether we have seen a non-NaN
1249+
// value.
1250+
auto newAllResultsNanFlagValue = nestedBuilder.create<arith::AndIOp>(
1251+
op->getLoc(), oldAllResultsNanFlagValue, isNaN);
1252+
resultsToYield.push_back(selectOp);
1253+
resultsToYield.push_back(newAllResultsNanFlagValue);
1254+
} else {
1255+
resultsToYield.push_back(result);
1256+
}
1257+
nestedBuilder.create<linalg::YieldOp>(loc, resultsToYield);
11271258
});
11281259

11291260
if (!didEncounterError)
11301261
return rewriter.notifyMatchFailure(
11311262
op, "unable to create linalg.generic body for reduce op");
11321263

1264+
if (isNanIgnoreMode) {
1265+
// Materialize a check to see whether we encountered any non-NaN values, if
1266+
// we didn't we need to select a tensor of NaNs since the result will just
1267+
// be the initial identity value propagated through all the compares and
1268+
// selects inside the reduction.
1269+
1270+
// Create a tensor full of NaNs.
1271+
auto nanValueAttr = rewriter.getFloatAttr(
1272+
elementTy,
1273+
APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false));
1274+
auto nanValue = rewriter.create<arith::ConstantOp>(loc, nanValueAttr);
1275+
auto emptyNanTensor =
1276+
rewriter
1277+
.create<tensor::EmptyOp>(loc, reduceShape,
1278+
resultTy.getElementType(), dynDims)
1279+
.getResult();
1280+
auto nanFilledTensor =
1281+
rewriter
1282+
.create<linalg::FillOp>(loc, ValueRange{nanValue},
1283+
ValueRange{emptyNanTensor})
1284+
.result();
1285+
1286+
// Create an empty tensor, non need to fill this since it will be
1287+
// overwritten by the select.
1288+
auto finalEmptyTensor =
1289+
rewriter
1290+
.create<tensor::EmptyOp>(loc, reduceShape,
1291+
resultTy.getElementType(), dynDims)
1292+
.getResult();
1293+
1294+
// Do a selection between the tensors akin to:
1295+
// result = NaN if "all results NaN" else result.
1296+
SmallVector<Value> ins, outs;
1297+
ins.push_back(linalgOp->getOpResult(1));
1298+
ins.push_back(nanFilledTensor);
1299+
ins.push_back(linalgOp->getResult(0));
1300+
outs.push_back(finalEmptyTensor);
1301+
auto linalgSelect =
1302+
rewriter.create<linalg::SelectOp>(op->getLoc(), ins, outs);
1303+
linalgOp = linalgSelect;
1304+
}
1305+
11331306
SmallVector<ReassociationExprs, 4> reassociationMap;
11341307
uint64_t expandInputRank =
1135-
cast<ShapedType>(linalgOp.getResults()[0].getType()).getRank();
1308+
cast<ShapedType>(linalgOp->getResults()[0].getType()).getRank();
11361309
reassociationMap.resize(expandInputRank);
11371310

11381311
for (uint64_t i = 0; i < expandInputRank; i++) {
@@ -1151,7 +1324,7 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
11511324
// not have access to such information. This matters when handling dynamically
11521325
// sized tensors.
11531326
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1154-
op, resultTy, linalgOp.getResults()[0], reassociationMap);
1327+
op, resultTy, linalgOp->getResults()[0], reassociationMap);
11551328
return success();
11561329
}
11571330

@@ -2088,6 +2261,32 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
20882261
nestedLoc, predicate, newValue, oldValue);
20892262
auto resultIndex = rewriter.create<arith::SelectOp>(
20902263
nestedLoc, predicate, newIndex, oldIndex);
2264+
2265+
// Check if we need to materialize compare and select for the given
2266+
// NaN propagation mode.
2267+
const auto nanMode = getNanMode(argmaxOp, rewriter);
2268+
if (!nanMode) {
2269+
didEncounterError = true;
2270+
return;
2271+
}
2272+
2273+
// "PROPAGATE" matches the default NaN propagation mode of the arith
2274+
// dialect so no compare and select is required.
2275+
//
2276+
// In the case "IGNORE" we check if the current argument is NaN and
2277+
// select the old index and value otherwise take the updated index and
2278+
// value.
2279+
if (*nanMode == "IGNORE") {
2280+
// Unordered comparison of NaN against itself will always return
2281+
// true.
2282+
Value isNaN = rewriter.create<arith::CmpFOp>(
2283+
argmaxOp.getLoc(), arith::CmpFPredicate::UNO, newValue,
2284+
newValue);
2285+
resultMax = rewriter.create<arith::SelectOp>(nestedLoc, isNaN,
2286+
oldValue, resultMax);
2287+
resultIndex = rewriter.create<arith::SelectOp>(
2288+
nestedLoc, isNaN, oldIndex, resultIndex);
2289+
}
20912290
nestedBuilder.create<linalg::YieldOp>(
20922291
nestedLoc, ValueRange({resultIndex, resultMax}));
20932292
});

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -802,11 +802,47 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
802802
rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(
803803
op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
804804
filledEmptyTensor, strideAttr, dilationAttr);
805-
} else {
806-
rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
807-
op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
808-
filledEmptyTensor, strideAttr, dilationAttr);
805+
return llvm::success();
809806
}
807+
808+
auto resultOp = rewriter.create<linalg::PoolingNhwcMaxOp>(
809+
op->getLoc(), ArrayRef<Type>{resultTy},
810+
ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr,
811+
dilationAttr);
812+
813+
// Check the NaN propgation mode is present.
814+
const auto nanMode = getNanMode(op, rewriter);
815+
if (!nanMode)
816+
return failure();
817+
818+
// "PROPAGATE" mode matches the behaviour of the LinAlg named op, so no
819+
// compare and select materialization is required.
820+
//
821+
// In the case of "IGNORE" we need to insert a compare and select. Since
822+
// we've already produced a named op we will just take its body and modify
823+
// it to include the appropriate checks. If the current value is NaN the
824+
// old value of pool will be taken otherwise we use the result.
825+
if (nanMode == "IGNORE") {
826+
auto *block = resultOp.getBlock();
827+
rewriter.setInsertionPointToEnd(block);
828+
829+
auto in = block->getArgument(0);
830+
auto out = block->getArgument(2);
831+
832+
auto *oldYieldOp = &*block->rbegin();
833+
auto result = oldYieldOp->getOperand(0);
834+
835+
Value isNaN = rewriter.create<arith::CmpFOp>(
836+
op->getLoc(), arith::CmpFPredicate::UNO, in, in);
837+
838+
auto selectOp =
839+
rewriter.create<arith::SelectOp>(op->getLoc(), isNaN, out, result);
840+
auto newYieldOp = rewriter.create<linalg::YieldOp>(oldYieldOp->getLoc(),
841+
selectOp.getResult());
842+
rewriter.replaceOp(oldYieldOp, newYieldOp);
843+
}
844+
845+
rewriter.replaceOp(op, resultOp);
810846
return success();
811847
}
812848
};

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics -o -| FileCheck %s
22
// RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named{prefer-conv2d-kernel-layout-hwcf=true}))" %s -verify-diagnostics -o -| FileCheck --check-prefix="HWCF" %s
33
// RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,cse))" %s -verify-diagnostics -o -| FileCheck --check-prefix="CHECK-CSE" %s
4+
// RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,linalg-generalize-named-ops))" %s -verify-diagnostics -o -| FileCheck %s --check-prefix="CHECK-NAN"
45

56
// CHECK-LABEL: @matmul
67
func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>) {
@@ -977,3 +978,24 @@ func.func @test_transpose_dyn_multiple_3d(%arg0: tensor<?x?x?xf32>) {
977978
%1 = "tosa.transpose"(%arg0, %0) : (tensor<?x?x?xf32>, tensor<3xi32>) -> tensor<?x?x?xf32>
978979
return
979980
}
981+
982+
// -----
983+
984+
// CHECK-NAN-LABEL: @nan_propagation_modes
985+
func.func @nan_propagation_modes(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x4x32x62xf32>, tensor<1x4x32x62xf32>) {
986+
// CHECK-NAN: linalg.generic
987+
// CHECK-NAN-NOT: arith.maximumf
988+
// CHECK-NAN-NOT: arith.cmpf uno
989+
// CHECK-NAN-NOT: arith.select
990+
// CHECK-NAN: linalg.yield
991+
%0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>, nan_mode = "PROPAGATE"} : (tensor<1x6x34x62xf32>) -> tensor<1x4x32x62xf32>
992+
993+
// CHECK-NAN: linalg.generic
994+
// CHECK-NAN: arith.maximumf
995+
// CHECK-NAN: arith.cmpf uno
996+
// CHECK-NAN: arith.select
997+
// CHECK-NAN: linalg.yield
998+
%1 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>, nan_mode = "IGNORE"} : (tensor<1x6x34x62xf32>) -> tensor<1x4x32x62xf32>
999+
1000+
return %0, %1 : tensor<1x4x32x62xf32>, tensor<1x4x32x62xf32>
1001+
}

0 commit comments

Comments
 (0)