32
32
#include " llvm/ADT/Sequence.h"
33
33
34
34
#include < numeric>
35
+ #include < type_traits>
35
36
36
37
using namespace mlir ;
37
38
using namespace mlir ::tosa;
38
39
40
+ // Helper function to materialize the semantically correct compare and select
41
+ // operations given a binary operation with a specific NaN propagation mode.
42
+ //
43
+ // In the case of "PROPAGATE" semantics no compare and selection is required and
44
+ // this function does nothing.
45
+ //
46
+ // In the case of "IGNORE" semantics this function materializes a comparison of
47
+ // the current operands to the op which will return true for any NaN
48
+ // argument and then selects between the non-NaN operation argument and the
49
+ // calculated result based on whether the lhs or rhs is NaN or not. In pseudo
50
+ // code:
51
+ //
52
+ // binary<op>(lhs, rhs):
53
+ // result = op(lhs, rhs)
54
+ // if lhs == NaN return rhs
55
+ // if rhs == NaN return lhs
56
+ // return result
57
+ template <typename OpTy>
58
+ static Value
59
+ materializeBinaryNanCheckIfRequired (OpTy op, PatternRewriter &rewriter,
60
+ Value lhs, Value rhs, Value result) {
61
+ auto nanMode = op.getNanMode ();
62
+ if (nanMode == " PROPAGATE" )
63
+ return result;
64
+
65
+ // Unordered comparison of NaN against itself will always return true.
66
+ Value lhsIsNaN = rewriter.create <arith::CmpFOp>(
67
+ op.getLoc (), arith::CmpFPredicate::UNO, lhs, lhs);
68
+ Value rhsIsNaN = rewriter.create <arith::CmpFOp>(
69
+ op.getLoc (), arith::CmpFPredicate::UNO, rhs, rhs);
70
+ Value rhsOrResult =
71
+ rewriter.create <arith::SelectOp>(op.getLoc (), lhsIsNaN, rhs, result);
72
+ return rewriter.create <arith::SelectOp>(op.getLoc (), rhsIsNaN, lhs,
73
+ rhsOrResult);
74
+ }
75
+
39
76
template <typename T>
40
77
static arith::ConstantOp
41
78
createConstFromIntAttribute (Operation *op, const std::string &attrName,
@@ -367,7 +404,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
367
404
368
405
// tosa::MaximumOp
369
406
if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
370
- return rewriter.create <arith::MaximumFOp>(loc, args[0 ], args[1 ]);
407
+ auto max = rewriter.create <arith::MaximumFOp>(loc, args[0 ], args[1 ]);
408
+ return materializeBinaryNanCheckIfRequired (llvm::cast<tosa::MaximumOp>(op),
409
+ rewriter, args[0 ], args[1 ], max);
371
410
}
372
411
373
412
if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger ()) {
@@ -376,7 +415,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
376
415
377
416
// tosa::MinimumOp
378
417
if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
379
- return rewriter.create <arith::MinimumFOp>(loc, args[0 ], args[1 ]);
418
+ auto min = rewriter.create <arith::MinimumFOp>(loc, args[0 ], args[1 ]);
419
+ return materializeBinaryNanCheckIfRequired (llvm::cast<tosa::MinimumOp>(op),
420
+ rewriter, args[0 ], args[1 ], min);
380
421
}
381
422
382
423
if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger ()) {
@@ -404,7 +445,31 @@ static Value createLinalgBodyCalculationForElementwiseOp(
404
445
loc, elementTy, rewriter.getFloatAttr (elementTy, minApf));
405
446
auto max = rewriter.create <arith::ConstantOp>(
406
447
loc, elementTy, rewriter.getFloatAttr (elementTy, maxApf));
407
- return clampFloatHelper (loc, args[0 ], min, max, rewriter);
448
+ auto result = clampFloatHelper (loc, args[0 ], min, max, rewriter);
449
+
450
+ auto clampOp = llvm::cast<tosa::ClampOp>(op);
451
+ const auto nanMode = clampOp.getNanMode ();
452
+ // In the case of "PROPAGATE" semantics no compare and selection is
453
+ // required.
454
+ if (nanMode == " PROPAGATE" )
455
+ return result;
456
+
457
+ // In the case of "IGNORE" semantics materialize a comparison
458
+ // of the current operand to the reduction which will return true for a NaN
459
+ // argument and then selects between the initial reduction value and the
460
+ // calculated result based on whether the argument is NaN or not. In pseudo
461
+ // code:
462
+ //
463
+ // reduce<op>(x, init):
464
+ // result = op(init, x)
465
+ // return init if x == NaN else result
466
+
467
+ // Unordered comparison of NaN against itself will always return true.
468
+ Value isNaN = rewriter.create <arith::CmpFOp>(
469
+ op->getLoc (), arith::CmpFPredicate::UNO, args[0 ], args[0 ]);
470
+ // TOSA specifies that in "ignore" NaN mode the result is "min" if the input
471
+ // is NaN.
472
+ return rewriter.create <arith::SelectOp>(op->getLoc (), isNaN, min, result);
408
473
}
409
474
410
475
if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
@@ -1078,7 +1143,8 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
1078
1143
// Performs the match and rewrite for reduction operations. This includes
1079
1144
// declaring a correctly sized initial value, and the linalg.generic operation
1080
1145
// that reduces across the specified axis.
1081
- static LogicalResult reduceMatchAndRewriteHelper (Operation *op, uint64_t axis,
1146
+ template <typename OpTy>
1147
+ static LogicalResult reduceMatchAndRewriteHelper (OpTy op, uint64_t axis,
1082
1148
PatternRewriter &rewriter) {
1083
1149
auto loc = op->getLoc ();
1084
1150
auto inputTy = cast<ShapedType>(op->getOperand (0 ).getType ());
@@ -1096,6 +1162,9 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
1096
1162
}
1097
1163
}
1098
1164
1165
+ SmallVector<Value> inputs, outputs;
1166
+ inputs.push_back (input);
1167
+
1099
1168
// First fill the output buffer with the init value.
1100
1169
auto emptyTensor =
1101
1170
rewriter
@@ -1113,26 +1182,127 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
1113
1182
.create <linalg::FillOp>(loc, ValueRange{fillValue},
1114
1183
ValueRange{emptyTensor})
1115
1184
.result ();
1185
+ outputs.push_back (filledTensor);
1186
+
1187
+ bool isNanIgnoreMode = false ;
1188
+ if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
1189
+ std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
1190
+ if (op.getNanMode () == " IGNORE" ) {
1191
+ isNanIgnoreMode = true ;
1192
+ // Because the TOSA spec requires the result be NaN iff all elements in
1193
+ // the reduction are NaN we can't simply perform a compare and select.
1194
+ // Additionally we have to keep track of whether we've seen any non-NaN
1195
+ // values and then do a final select based on this predicate.
1196
+ auto trueAttr = rewriter.getBoolAttr (true );
1197
+ auto trueValue = rewriter.create <arith::ConstantOp>(loc, trueAttr);
1198
+ auto emptyBoolTensor =
1199
+ rewriter
1200
+ .create <tensor::EmptyOp>(loc, reduceShape, trueValue.getType (),
1201
+ dynDims)
1202
+ .getResult ();
1203
+ auto allResultsNaNTensor =
1204
+ rewriter
1205
+ .create <linalg::FillOp>(loc, ValueRange{trueValue},
1206
+ ValueRange{emptyBoolTensor})
1207
+ .result ();
1208
+ // Note that because the linalg::ReduceOp has two variadic arguments
1209
+ // (inputs and outputs) and it has the SameVariadicOperandSize trait we
1210
+ // need to have the same number of inputs and outputs.
1211
+ //
1212
+ // The second input isn't actually used anywhere since the value used to
1213
+ // update the NaN flag is calculated inside the body of the reduction and
1214
+ // then used to update an out value.
1215
+ // In order to satisfy type constraints we just pass another copy of the
1216
+ // input here.
1217
+ inputs.push_back (input);
1218
+ outputs.push_back (allResultsNaNTensor);
1219
+ }
1220
+ }
1116
1221
1117
1222
bool didEncounterError = false ;
1118
- auto linalgOp = rewriter.create <linalg::ReduceOp>(
1119
- loc, input, filledTensor , axis,
1223
+ linalg::LinalgOp linalgOp = rewriter.create <linalg::ReduceOp>(
1224
+ loc, inputs, outputs , axis,
1120
1225
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
1226
+ std::array<Value, 2 > binaryArgs{
1227
+ blockArgs[0 ], isNanIgnoreMode ? blockArgs[2 ] : blockArgs[1 ]};
1121
1228
auto result = createLinalgBodyCalculationForReduceOp (
1122
- op, blockArgs , elementTy, rewriter);
1229
+ op, binaryArgs , elementTy, rewriter);
1123
1230
if (result)
1124
1231
didEncounterError = true ;
1125
1232
1126
- nestedBuilder.create <linalg::YieldOp>(loc, result);
1233
+ SmallVector<Value> resultsToYield;
1234
+ if (isNanIgnoreMode) {
1235
+ auto inputValue = blockArgs[0 ];
1236
+ auto initialValue = blockArgs[2 ];
1237
+ auto oldAllResultsNanFlagValue = blockArgs[3 ];
1238
+
1239
+ // Unordered comparison of NaN against itself will always return true.
1240
+ Value isNaN = nestedBuilder.create <arith::CmpFOp>(
1241
+ op->getLoc (), arith::CmpFPredicate::UNO, inputValue, inputValue);
1242
+ // If we've encountered a NaN, take the non-NaN value.
1243
+ auto selectOp = nestedBuilder.create <arith::SelectOp>(
1244
+ op->getLoc (), isNaN, initialValue, result);
1245
+ // Update the flag which keeps track of whether we have seen a non-NaN
1246
+ // value.
1247
+ auto newAllResultsNanFlagValue = nestedBuilder.create <arith::AndIOp>(
1248
+ op->getLoc (), oldAllResultsNanFlagValue, isNaN);
1249
+ resultsToYield.push_back (selectOp);
1250
+ resultsToYield.push_back (newAllResultsNanFlagValue);
1251
+ } else {
1252
+ resultsToYield.push_back (result);
1253
+ }
1254
+ nestedBuilder.create <linalg::YieldOp>(loc, resultsToYield);
1127
1255
});
1128
1256
1129
1257
if (!didEncounterError)
1130
1258
return rewriter.notifyMatchFailure (
1131
1259
op, " unable to create linalg.generic body for reduce op" );
1132
1260
1261
+ if (isNanIgnoreMode) {
1262
+ // Materialize a check to see whether we encountered any non-NaN values, if
1263
+ // we didn't we need to select a tensor of NaNs since the result will just
1264
+ // be the initial identity value propagated through all the compares and
1265
+ // selects inside the reduction.
1266
+
1267
+ // Create a tensor full of NaNs.
1268
+ auto nanValueAttr = rewriter.getFloatAttr (
1269
+ elementTy,
1270
+ APFloat::getNaN (cast<FloatType>(elementTy).getFloatSemantics (), false ));
1271
+ auto nanValue = rewriter.create <arith::ConstantOp>(loc, nanValueAttr);
1272
+ auto emptyNanTensor =
1273
+ rewriter
1274
+ .create <tensor::EmptyOp>(loc, reduceShape,
1275
+ resultTy.getElementType (), dynDims)
1276
+ .getResult ();
1277
+ auto nanFilledTensor =
1278
+ rewriter
1279
+ .create <linalg::FillOp>(loc, ValueRange{nanValue},
1280
+ ValueRange{emptyNanTensor})
1281
+ .result ();
1282
+
1283
+ // Create an empty tensor, non need to fill this since it will be
1284
+ // overwritten by the select.
1285
+ auto finalEmptyTensor =
1286
+ rewriter
1287
+ .create <tensor::EmptyOp>(loc, reduceShape,
1288
+ resultTy.getElementType (), dynDims)
1289
+ .getResult ();
1290
+
1291
+ // Do a selection between the tensors akin to:
1292
+ // result = NaN if "all results NaN" else result.
1293
+ SmallVector<Value> ins, outs;
1294
+ ins.push_back (linalgOp->getOpResult (1 ));
1295
+ ins.push_back (nanFilledTensor);
1296
+ ins.push_back (linalgOp->getResult (0 ));
1297
+ outs.push_back (finalEmptyTensor);
1298
+ auto linalgSelect =
1299
+ rewriter.create <linalg::SelectOp>(op->getLoc (), ins, outs);
1300
+ linalgOp = linalgSelect;
1301
+ }
1302
+
1133
1303
SmallVector<ReassociationExprs, 4 > reassociationMap;
1134
1304
uint64_t expandInputRank =
1135
- cast<ShapedType>(linalgOp. getResults ()[0 ].getType ()).getRank ();
1305
+ cast<ShapedType>(linalgOp-> getResults ()[0 ].getType ()).getRank ();
1136
1306
reassociationMap.resize (expandInputRank);
1137
1307
1138
1308
for (uint64_t i = 0 ; i < expandInputRank; i++) {
@@ -1151,7 +1321,7 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
1151
1321
// not have access to such information. This matters when handling dynamically
1152
1322
// sized tensors.
1153
1323
rewriter.replaceOpWithNewOp <tensor::ExpandShapeOp>(
1154
- op, resultTy, linalgOp. getResults ()[0 ], reassociationMap);
1324
+ op, resultTy, linalgOp-> getResults ()[0 ], reassociationMap);
1155
1325
return success ();
1156
1326
}
1157
1327
@@ -2097,6 +2267,27 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
2097
2267
nestedLoc, predicate, newValue, oldValue);
2098
2268
auto resultIndex = rewriter.create <arith::SelectOp>(
2099
2269
nestedLoc, predicate, newIndex, oldIndex);
2270
+
2271
+ // Check if we need to materialize compare and select for the given
2272
+ // NaN propagation mode.
2273
+
2274
+ // "PROPAGATE" matches the default NaN propagation mode of the arith
2275
+ // dialect so no compare and select is required.
2276
+ //
2277
+ // In the case "IGNORE" we check if the current argument is NaN and
2278
+ // select the old index and value otherwise take the updated index and
2279
+ // value.
2280
+ if (const auto nanMode = argmaxOp.getNanMode (); nanMode == " IGNORE" ) {
2281
+ // Unordered comparison of NaN against itself will always return
2282
+ // true.
2283
+ Value isNaN = rewriter.create <arith::CmpFOp>(
2284
+ argmaxOp.getLoc (), arith::CmpFPredicate::UNO, newValue,
2285
+ newValue);
2286
+ resultMax = rewriter.create <arith::SelectOp>(nestedLoc, isNaN,
2287
+ oldValue, resultMax);
2288
+ resultIndex = rewriter.create <arith::SelectOp>(
2289
+ nestedLoc, isNaN, oldIndex, resultIndex);
2290
+ }
2100
2291
nestedBuilder.create <linalg::YieldOp>(
2101
2292
nestedLoc, ValueRange ({resultIndex, resultMax}));
2102
2293
});
0 commit comments