36
36
using namespace mlir ;
37
37
using namespace mlir ::tosa;
38
38
39
+ // Helper function to materialize the semantically correct compare and select
40
+ // operations a binary operation with a specific NaN propagation mode.
41
+ //
42
+ // In the case of "PROPAGATE" semantics no compare and selection is required and
43
+ // this function does nothing.
44
+ //
45
+ // In the case of "IGNORE" semantics this function materializes a comparison of
46
+ // the current operands to the op which will return true for any NaN
47
+ // argument and then selects between the non-NaN operation argument and the
48
+ // calculated result based on whether the lhs or rhs is NaN or not. In pseudo
49
+ // code:
50
+ //
51
+ // binary<op>(lhs, rhs):
52
+ // result = op(lhs, rhs)
53
+ // if lhs == NaN return rhs
54
+ // if rhs == NaN return lhs
55
+ // return result
56
+ static Value materializeBinaryNanCheckIfRequired (Operation *op,
57
+ PatternRewriter &rewriter,
58
+ Value lhs, Value rhs,
59
+ Value result) {
60
+ const auto nanMode = getNanMode (op, rewriter);
61
+ if (!nanMode)
62
+ return {};
63
+
64
+ if (*nanMode == " PROPAGATE" )
65
+ return result;
66
+
67
+ assert (*nanMode == " IGNORE" && " Unhandled nan-propagation mode" );
68
+
69
+ // Unordered comparison of NaN against itself will always return true.
70
+ Value lhsIsNaN = rewriter.create <arith::CmpFOp>(
71
+ op->getLoc (), arith::CmpFPredicate::UNO, lhs, lhs);
72
+ Value rhsIsNaN = rewriter.create <arith::CmpFOp>(
73
+ op->getLoc (), arith::CmpFPredicate::UNO, rhs, rhs);
74
+ Value rhsOrResult =
75
+ rewriter.create <arith::SelectOp>(op->getLoc (), lhsIsNaN, rhs, result);
76
+ return rewriter.create <arith::SelectOp>(op->getLoc (), rhsIsNaN, lhs,
77
+ rhsOrResult);
78
+ }
79
+
39
80
template <typename T>
40
81
static arith::ConstantOp
41
82
createConstFromIntAttribute (Operation *op, const std::string &attrName,
@@ -367,7 +408,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
367
408
368
409
// tosa::MaximumOp
369
410
if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
370
- return rewriter.create <arith::MaximumFOp>(loc, args[0 ], args[1 ]);
411
+ auto max = rewriter.create <arith::MaximumFOp>(loc, args[0 ], args[1 ]);
412
+ return materializeBinaryNanCheckIfRequired (op, rewriter, args[0 ], args[1 ],
413
+ max);
371
414
}
372
415
373
416
if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger ()) {
@@ -376,7 +419,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
376
419
377
420
// tosa::MinimumOp
378
421
if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
379
- return rewriter.create <arith::MinimumFOp>(loc, args[0 ], args[1 ]);
422
+ auto min = rewriter.create <arith::MinimumFOp>(loc, args[0 ], args[1 ]);
423
+ return materializeBinaryNanCheckIfRequired (op, rewriter, args[0 ], args[1 ],
424
+ min);
380
425
}
381
426
382
427
if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger ()) {
@@ -404,7 +449,34 @@ static Value createLinalgBodyCalculationForElementwiseOp(
404
449
loc, elementTy, rewriter.getFloatAttr (elementTy, minApf));
405
450
auto max = rewriter.create <arith::ConstantOp>(
406
451
loc, elementTy, rewriter.getFloatAttr (elementTy, maxApf));
407
- return clampFloatHelper (loc, args[0 ], min, max, rewriter);
452
+ auto result = clampFloatHelper (loc, args[0 ], min, max, rewriter);
453
+
454
+ const auto nanMode = getNanMode (op, rewriter);
455
+ if (!nanMode)
456
+ return {};
457
+
458
+ // In the case of "PROPAGATE" semantics no compare and selection is
459
+ // required.
460
+ if (*nanMode == " PROPAGATE" )
461
+ return result;
462
+
463
+ // In the case of "IGNORE" semantics materialize a comparison
464
+ // of the current operand to the reduction which will return true for a NaN
465
+ // argument and then selects between the initial reduction value and the
466
+ // calculated result based on whether the argument is NaN or not. In pseudo
467
+ // code:
468
+ //
469
+ // reduce<op>(x, init):
470
+ // result = op(init, x)
471
+ // return init if x == NaN else result
472
+ assert (*nanMode == " IGNORE" && " Unhandled nan-propagation mode" );
473
+
474
+ // Unordered comparison of NaN against itself will always return true.
475
+ Value isNaN = rewriter.create <arith::CmpFOp>(
476
+ op->getLoc (), arith::CmpFPredicate::UNO, args[0 ], args[0 ]);
477
+ // TOSA specifies that in "ignore" NaN mode the result is "min" if the input
478
+ // is NaN.
479
+ return rewriter.create <arith::SelectOp>(op->getLoc (), isNaN, min, result);
408
480
}
409
481
410
482
if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
@@ -1096,6 +1168,9 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
1096
1168
}
1097
1169
}
1098
1170
1171
+ SmallVector<Value> inputs, outputs;
1172
+ inputs.push_back (input);
1173
+
1099
1174
// First fill the output buffer with the init value.
1100
1175
auto emptyTensor =
1101
1176
rewriter
@@ -1113,26 +1188,124 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
1113
1188
.create <linalg::FillOp>(loc, ValueRange{fillValue},
1114
1189
ValueRange{emptyTensor})
1115
1190
.result ();
1191
+ outputs.push_back (filledTensor);
1192
+
1193
+ const auto nanMode = getNanMode (op, rewriter);
1194
+ const bool isNanIgnoreMode = (nanMode && *nanMode == " IGNORE" );
1195
+ if (isNanIgnoreMode) {
1196
+ // Because the TOSA spec requires the result be NaN iff all elements in the
1197
+ // reduction are NaN we can't simply perform a compare and select.
1198
+ // Additionally we have to keep track of whether we've seen any non-NaN
1199
+ // values and then do a final select based on this predicate.
1200
+ auto trueAttr = rewriter.getBoolAttr (true );
1201
+ auto trueValue = rewriter.create <arith::ConstantOp>(loc, trueAttr);
1202
+ auto emptyBoolTensor =
1203
+ rewriter
1204
+ .create <tensor::EmptyOp>(loc, reduceShape, trueValue.getType (),
1205
+ dynDims)
1206
+ .getResult ();
1207
+ auto allResultsNaNTensor =
1208
+ rewriter
1209
+ .create <linalg::FillOp>(loc, ValueRange{trueValue},
1210
+ ValueRange{emptyBoolTensor})
1211
+ .result ();
1212
+ // Note that because the linalg::ReduceOp has two variadic arguments (inputs
1213
+ // and outputs) and it has the SameVariadicOperandSize trait we need to have
1214
+ // the same number of inputs and outputs.
1215
+ //
1216
+ // The second input isn't actully used anywhere since the value used to
1217
+ // update the NaN flag is calculated inside the body of the reduction and
1218
+ // then used to update an out value.
1219
+ // In order to satisfy type constraints we just pass another copy of the
1220
+ // input here.
1221
+ inputs.push_back (input);
1222
+ outputs.push_back (allResultsNaNTensor);
1223
+ }
1116
1224
1117
1225
bool didEncounterError = false ;
1118
- auto linalgOp = rewriter.create <linalg::ReduceOp>(
1119
- loc, input, filledTensor , axis,
1226
+ linalg::LinalgOp linalgOp = rewriter.create <linalg::ReduceOp>(
1227
+ loc, inputs, outputs , axis,
1120
1228
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
1229
+ std::array<Value, 2 > binaryArgs{
1230
+ blockArgs[0 ], isNanIgnoreMode ? blockArgs[2 ] : blockArgs[1 ]};
1121
1231
auto result = createLinalgBodyCalculationForReduceOp (
1122
- op, blockArgs , elementTy, rewriter);
1232
+ op, binaryArgs , elementTy, rewriter);
1123
1233
if (result)
1124
1234
didEncounterError = true ;
1125
1235
1126
- nestedBuilder.create <linalg::YieldOp>(loc, result);
1236
+ SmallVector<Value> resultsToYield;
1237
+ if (isNanIgnoreMode) {
1238
+ auto inputValue = blockArgs[0 ];
1239
+ auto initialValue = blockArgs[2 ];
1240
+ auto oldAllResultsNanFlagValue = blockArgs[3 ];
1241
+
1242
+ // Unordered comparison of NaN against itself will always return true.
1243
+ Value isNaN = nestedBuilder.create <arith::CmpFOp>(
1244
+ op->getLoc (), arith::CmpFPredicate::UNO, inputValue, inputValue);
1245
+ // If we've encountered a NaN, take the non-NaN value.
1246
+ auto selectOp = nestedBuilder.create <arith::SelectOp>(
1247
+ op->getLoc (), isNaN, initialValue, result);
1248
+ // Update the flag which keeps track of whether we have seen a non-NaN
1249
+ // value.
1250
+ auto newAllResultsNanFlagValue = nestedBuilder.create <arith::AndIOp>(
1251
+ op->getLoc (), oldAllResultsNanFlagValue, isNaN);
1252
+ resultsToYield.push_back (selectOp);
1253
+ resultsToYield.push_back (newAllResultsNanFlagValue);
1254
+ } else {
1255
+ resultsToYield.push_back (result);
1256
+ }
1257
+ nestedBuilder.create <linalg::YieldOp>(loc, resultsToYield);
1127
1258
});
1128
1259
1129
1260
if (!didEncounterError)
1130
1261
return rewriter.notifyMatchFailure (
1131
1262
op, " unable to create linalg.generic body for reduce op" );
1132
1263
1264
+ if (isNanIgnoreMode) {
1265
+ // Materialize a check to see whether we encountered any non-NaN values, if
1266
+ // we didn't we need to select a tensor of NaNs since the result will just
1267
+ // be the initial identity value propagated through all the compares and
1268
+ // selects inside the reduction.
1269
+
1270
+ // Create a tensor full of NaNs.
1271
+ auto nanValueAttr = rewriter.getFloatAttr (
1272
+ elementTy,
1273
+ APFloat::getNaN (cast<FloatType>(elementTy).getFloatSemantics (), false ));
1274
+ auto nanValue = rewriter.create <arith::ConstantOp>(loc, nanValueAttr);
1275
+ auto emptyNanTensor =
1276
+ rewriter
1277
+ .create <tensor::EmptyOp>(loc, reduceShape,
1278
+ resultTy.getElementType (), dynDims)
1279
+ .getResult ();
1280
+ auto nanFilledTensor =
1281
+ rewriter
1282
+ .create <linalg::FillOp>(loc, ValueRange{nanValue},
1283
+ ValueRange{emptyNanTensor})
1284
+ .result ();
1285
+
1286
+ // Create an empty tensor, non need to fill this since it will be
1287
+ // overwritten by the select.
1288
+ auto finalEmptyTensor =
1289
+ rewriter
1290
+ .create <tensor::EmptyOp>(loc, reduceShape,
1291
+ resultTy.getElementType (), dynDims)
1292
+ .getResult ();
1293
+
1294
+ // Do a selection between the tensors akin to:
1295
+ // result = NaN if "all results NaN" else result.
1296
+ SmallVector<Value> ins, outs;
1297
+ ins.push_back (linalgOp->getOpResult (1 ));
1298
+ ins.push_back (nanFilledTensor);
1299
+ ins.push_back (linalgOp->getResult (0 ));
1300
+ outs.push_back (finalEmptyTensor);
1301
+ auto linalgSelect =
1302
+ rewriter.create <linalg::SelectOp>(op->getLoc (), ins, outs);
1303
+ linalgOp = linalgSelect;
1304
+ }
1305
+
1133
1306
SmallVector<ReassociationExprs, 4 > reassociationMap;
1134
1307
uint64_t expandInputRank =
1135
- cast<ShapedType>(linalgOp. getResults ()[0 ].getType ()).getRank ();
1308
+ cast<ShapedType>(linalgOp-> getResults ()[0 ].getType ()).getRank ();
1136
1309
reassociationMap.resize (expandInputRank);
1137
1310
1138
1311
for (uint64_t i = 0 ; i < expandInputRank; i++) {
@@ -1151,7 +1324,7 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
1151
1324
// not have access to such information. This matters when handling dynamically
1152
1325
// sized tensors.
1153
1326
rewriter.replaceOpWithNewOp <tensor::ExpandShapeOp>(
1154
- op, resultTy, linalgOp. getResults ()[0 ], reassociationMap);
1327
+ op, resultTy, linalgOp-> getResults ()[0 ], reassociationMap);
1155
1328
return success ();
1156
1329
}
1157
1330
@@ -2088,6 +2261,32 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
2088
2261
nestedLoc, predicate, newValue, oldValue);
2089
2262
auto resultIndex = rewriter.create <arith::SelectOp>(
2090
2263
nestedLoc, predicate, newIndex, oldIndex);
2264
+
2265
+ // Check if we need to materialize compare and select for the given
2266
+ // NaN propagation mode.
2267
+ const auto nanMode = getNanMode (argmaxOp, rewriter);
2268
+ if (!nanMode) {
2269
+ didEncounterError = true ;
2270
+ return ;
2271
+ }
2272
+
2273
+ // "PROPAGATE" matches the default NaN propagation mode of the arith
2274
+ // dialect so no compare and select is required.
2275
+ //
2276
+ // In the case "IGNORE" we check if the current argument is NaN and
2277
+ // select the old index and value otherwise take the updated index and
2278
+ // value.
2279
+ if (*nanMode == " IGNORE" ) {
2280
+ // Unordered comparison of NaN against itself will always return
2281
+ // true.
2282
+ Value isNaN = rewriter.create <arith::CmpFOp>(
2283
+ argmaxOp.getLoc (), arith::CmpFPredicate::UNO, newValue,
2284
+ newValue);
2285
+ resultMax = rewriter.create <arith::SelectOp>(nestedLoc, isNaN,
2286
+ oldValue, resultMax);
2287
+ resultIndex = rewriter.create <arith::SelectOp>(
2288
+ nestedLoc, isNaN, oldIndex, resultIndex);
2289
+ }
2091
2290
nestedBuilder.create <linalg::YieldOp>(
2092
2291
nestedLoc, ValueRange ({resultIndex, resultMax}));
2093
2292
});
0 commit comments