-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][tosa][tosa-to-linalg] Add NaN Mode Lowering #125668
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tosa Author: Jack Frankland (FranklandJack) ChangesAdd support for NaN propagation lowering in the Add appropriate lit tests including negative tests which check the various comparisons and selects are materialized as appropriate. This affects the following TOSA operators:
Full diff: https://github.com/llvm/llvm-project/pull/125668.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 069073bc2d164ac..1cd0d6a63114244 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -141,6 +141,22 @@ namespace tosa {
bool isa_tosa_shape_type(mlir::Type t);
+// Helper function to materialize the semantically correct compare and select
+// operations a reduction operation with a specific NaN propagation mode.
+//
+// In the case of "PROPAGATE" semantics no compare and selection is required and
+// this function does nothing.
+//
+// In the case of "IGNORE" semantics this function materializes a comparison of
+// the current operand to the reduction which will return true for a NaN
+// argument and then selects between the initial reduction value and the
+// calculated result based on whether the argument is NaN or not. In pseudo
+// code:
+//
+// reduce<op>(x, init):
+// result = op(init, x)
+// return init if x == NaN else result
+
} // namespace tosa
} // namespace mlir
@@ -267,6 +283,31 @@ extractConvZpPair(TosaConvOp op, PatternRewriter &rewriter) {
return std::make_optional<ConvZpPair>(*maybeInputZp, *maybeWeightZp);
}
+
+// Helper to determine if an operation should have the "nan_mode" string
+// attribute.
+inline bool shouldHaveNanPropagation(Operation *op) {
+ return isa<tosa::ClampOp>(op) || isa<tosa::MaxPool2dOp>(op) ||
+ isa<tosa::ReduceMinOp>(op) || isa<tosa::ReduceMaxOp>(op) ||
+ isa<tosa::MaximumOp>(op) || isa<tosa::MinimumOp>(op);
+}
+
+// Helper function to extract the NaN propagation mode from an operation.
+// Note that the for operations which support NaN mode propagation the attribute
+// is optional and its default value is "PROPAGATE".
+//
+// If the function is called with an operator that doesn't support the NaN mode
+// attribute it will return a std::nullopt.
+inline std::optional<std::string> getNanMode(Operation *op,
+ PatternRewriter &rewriter) {
+ if (shouldHaveNanPropagation(op))
+ return op->hasAttr("nan_mode") ? op->getAttrOfType<StringAttr>(
+ rewriter.getStringAttr("nan_mode"))
+ .str()
+ : "PROPAGATE";
+ return std::nullopt;
+}
+
} // namespace tosa
} // namespace mlir
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index b0eb2d6cbc30b63..43e3f286c5fda18 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -36,6 +36,81 @@
using namespace mlir;
using namespace mlir::tosa;
+// Helper function to materialize the semantically correct compare and select
+// operations a reduction operation with a specific NaN propagation mode.
+//
+// In the case of "PROPAGATE" semantics no compare and selection is required and
+// this function does nothing.
+//
+// In the case of "IGNORE" semantics this function materializes a comparison of
+// the current operand to the reduction which will return true for a NaN
+// argument and then selects between the initial reduction value and the
+// calculated result based on whether the argument is NaN or not. In pseudo
+// code:
+//
+// reduce<op>(x, init):
+// result = op(init, x)
+// return init if x == NaN else result
+static Value materializeReductionNanCheckIfRequired(Operation *op,
+ PatternRewriter &rewriter,
+ Value in, Value init,
+ Value result) {
+ const auto nanMode = getNanMode(op, rewriter);
+ if (!nanMode)
+ return {};
+
+ if (*nanMode == "PROPAGATE")
+ return result;
+
+ assert(*nanMode == "IGNORE" && "Unhandled nan-propagation mode");
+
+ // Unordered comparison of NaN against itself will always return true.
+ Value isNaN = rewriter.create<arith::CmpFOp>(
+ op->getLoc(), arith::CmpFPredicate::UNO, in, in);
+ return rewriter.create<arith::SelectOp>(op->getLoc(), isNaN, init, result);
+}
+
+// Helper function to materialize the semantically correct compare and select
+// operations a binary operation with a specific NaN propagation mode.
+//
+// In the case of "PROPAGATE" semantics no compare and selection is required and
+// this function does nothing.
+//
+// In the case of "IGNORE" semantics this function materializes a comparison of
+// the current operands to the op which will return true for any NaN
+// argument and then selects between the non-NaN operation argument and the
+// calculated result based on whether the lhs or rhs is NaN or not. In pseudo
+// code:
+//
+// binary<op>(lhs, rhs):
+// result = op(lhs, rhs)
+// if lhs == NaN return rhs
+// if rhs == NaN return lhs
+// return result
+static Value materializeBinaryNanCheckIfRequired(Operation *op,
+ PatternRewriter &rewriter,
+ Value lhs, Value rhs,
+ Value result) {
+ const auto nanMode = getNanMode(op, rewriter);
+ if (!nanMode)
+ return {};
+
+ if (*nanMode == "PROPAGATE")
+ return result;
+
+ assert(*nanMode == "IGNORE" && "Unhandled nan-propagation mode");
+
+ // Unordered comparison of NaN against itself will always return true.
+ Value lhsIsNaN = rewriter.create<arith::CmpFOp>(
+ op->getLoc(), arith::CmpFPredicate::UNO, lhs, lhs);
+ Value rhsIsNaN = rewriter.create<arith::CmpFOp>(
+ op->getLoc(), arith::CmpFPredicate::UNO, rhs, rhs);
+ Value rhsOrResult =
+ rewriter.create<arith::SelectOp>(op->getLoc(), lhsIsNaN, rhs, result);
+ return rewriter.create<arith::SelectOp>(op->getLoc(), rhsIsNaN, lhs,
+ rhsOrResult);
+}
+
template <typename T>
static arith::ConstantOp
createConstFromIntAttribute(Operation *op, const std::string &attrName,
@@ -358,7 +433,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// tosa::MaximumOp
if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
- return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
+ auto max = rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
+ return materializeBinaryNanCheckIfRequired(op, rewriter, args[0], args[1],
+ max);
}
if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
@@ -367,7 +444,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// tosa::MinimumOp
if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
- return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
+ auto min = rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
+ return materializeBinaryNanCheckIfRequired(op, rewriter, args[0], args[1],
+ min);
}
if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
@@ -395,7 +474,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
auto max = rewriter.create<arith::ConstantOp>(
loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
- return clampFloatHelper(loc, args[0], min, max, rewriter);
+ auto result = clampFloatHelper(loc, args[0], min, max, rewriter);
+ // TOSA specifies that in "ignore" NaN mode the result is "min" if the input
+ // is NaN.
+ return materializeReductionNanCheckIfRequired(op, rewriter, args[0], min,
+ result);
}
if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
@@ -1042,7 +1125,9 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
}
if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
- return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
+ auto min = rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
+ return materializeReductionNanCheckIfRequired(op, rewriter, args[0],
+ args[1], min);
}
if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
@@ -1050,7 +1135,9 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
}
if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
- return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
+ auto max = rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
+ return materializeReductionNanCheckIfRequired(op, rewriter, args[0],
+ args[1], max);
}
if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
@@ -2078,6 +2165,32 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
nestedLoc, predicate, newValue, oldValue);
auto resultIndex = rewriter.create<arith::SelectOp>(
nestedLoc, predicate, newIndex, oldIndex);
+
+ // Check if we need to materialize compare and select for the given
+ // NaN propagation mode.
+ const auto nanMode = getNanMode(argmaxOp, rewriter);
+ if (!nanMode) {
+ didEncounterError = true;
+ return;
+ }
+
+ // "PROPAGATE" matches the default NaN propagation mode of the arith
+ // dialect so no compare and select is required.
+ //
+ // In the case "IGNORE" we check if the current argument is NaN and
+ // select the old index and value otherwise take the updated index and
+ // value.
+ if (*nanMode == "IGNORE") {
+ // Unordered comparison of NaN against itself will always return
+ // true.
+ Value isNaN = rewriter.create<arith::CmpFOp>(
+ argmaxOp.getLoc(), arith::CmpFPredicate::UNO, newValue,
+ newValue);
+ resultMax = rewriter.create<arith::SelectOp>(nestedLoc, isNaN,
+ oldValue, resultMax);
+ resultIndex = rewriter.create<arith::SelectOp>(
+ nestedLoc, isNaN, oldIndex, resultIndex);
+ }
nestedBuilder.create<linalg::YieldOp>(
nestedLoc, ValueRange({resultIndex, resultMax}));
});
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index cf9852e05cf7c9f..67d8f38c18ae949 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -807,11 +807,47 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(
op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
filledEmptyTensor, strideAttr, dilationAttr);
- } else {
- rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
- op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
- filledEmptyTensor, strideAttr, dilationAttr);
+ return llvm::success();
}
+
+ auto resultOp = rewriter.create<linalg::PoolingNhwcMaxOp>(
+ op->getLoc(), ArrayRef<Type>{resultTy},
+ ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr,
+ dilationAttr);
+
+ // Check the NaN propgation mode is present.
+ const auto nanMode = getNanMode(op, rewriter);
+ if (!nanMode)
+ return failure();
+
+ // "PROPAGATE" mode matches the behaviour of the LinAlg named op, so no
+ // compare and select materialization is required.
+ //
+ // In the case of "IGNORE" we need to insert a compare and select. Since
+ // we've already produced a named op we will just take its body and modify
+ // it to include the appropriate checks. If the current value is NaN the
+ // old value of pool will be taken otherwise we use the result.
+ if (nanMode == "IGNORE") {
+ auto *block = resultOp.getBlock();
+ rewriter.setInsertionPointToEnd(block);
+
+ auto in = block->getArgument(0);
+ auto out = block->getArgument(1);
+
+ auto *oldYieldOp = &*block->rbegin();
+ auto result = oldYieldOp->getOperand(0);
+
+ Value isNaN = rewriter.create<arith::CmpFOp>(
+ op->getLoc(), arith::CmpFPredicate::UNO, in, in);
+
+ auto selectOp =
+ rewriter.create<arith::SelectOp>(op->getLoc(), isNaN, out, result);
+ auto newYieldOp = rewriter.create<linalg::YieldOp>(oldYieldOp->getLoc(),
+ selectOp.getResult());
+ rewriter.replaceOp(oldYieldOp, newYieldOp);
+ }
+
+ rewriter.replaceOp(op, resultOp);
return success();
}
};
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 116cd045aa0d3af..785819dd68499bc 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics -o -| FileCheck %s
// RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named{prefer-conv2d-kernel-layout-hwcf=true}))" %s -verify-diagnostics -o -| FileCheck --check-prefix="HWCF" %s
// RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,cse))" %s -verify-diagnostics -o -| FileCheck --check-prefix="CHECK-CSE" %s
+// RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,linalg-generalize-named-ops))" %s -verify-diagnostics -o -| FileCheck %s --check-prefix="CHECK-NAN"
// CHECK-LABEL: @matmul
func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>) {
@@ -977,3 +978,24 @@ func.func @test_transpose_dyn_multiple_3d(%arg0: tensor<?x?x?xf32>) {
%1 = "tosa.transpose"(%arg0, %0) : (tensor<?x?x?xf32>, tensor<3xi32>) -> tensor<?x?x?xf32>
return
}
+
+// -----
+
+// CHECK-NAN-LABEL: @nan_propagation_modes
+func.func @nan_propagation_modes(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x4x32x62xf32>, tensor<1x4x32x62xf32>) {
+ // CHECK-NAN: linalg.generic
+ // CHECK-NAN-NOT: arith.maximumf
+ // CHECK-NAN-NOT: arith.cmpf uno
+ // CHECK-NAN-NOT: arith.select
+ // CHECK-NAN: linalg.yield
+ %0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>, nan_mode = "PROPAGATE"} : (tensor<1x6x34x62xf32>) -> tensor<1x4x32x62xf32>
+
+ // CHECK-NAN: linalg.generic
+ // CHECK-NAN: arith.maximumf
+ // CHECK-NAN: arith.cmpf uno
+ // CHECK-NAN: arith.select
+ // CHECK-NAN: linalg.yield
+ %1 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>, nan_mode = "IGNORE"} : (tensor<1x6x34x62xf32>) -> tensor<1x4x32x62xf32>
+
+ return %0, %1 : tensor<1x4x32x62xf32>, tensor<1x4x32x62xf32>
+}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index f9bdcefa35317aa..f5ced07561d278b 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1991,3 +1991,107 @@ func.func @test_cast_fp32_i64(%arg0: tensor<1xf32>) -> (tensor<1xi64>) {
%0 = tosa.cast %arg0 : (tensor<1xf32>) -> tensor<1xi64>
return %0: tensor<1xi64>
}
+
+// -----
+
+// CHECK-LABEL: @nan_propagation_modes
+func.func @nan_propagation_modes(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
+ // CHECK: linalg.reduce
+ // CHECK: arith.minimumf
+ // CHECK-NOT: arith.cmpf uno
+ // CHECK-NOT: arith.select
+ // CHECK: linalg.yield
+ %3 = tosa.reduce_min %arg0 {axis = 0 : i32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>) -> tensor<1x4xf32>
+
+ // CHECK: linalg.reduce
+ // CHECK: arith.maximumf
+ // CHECK-NOT: arith.cmpf uno
+ // CHECK-NOT: arith.select
+ // CHECK: linalg.yield
+ %4 = tosa.reduce_max %arg0 {axis = 0 : i32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>) -> tensor<1x4xf32>
+
+ // CHECK: linalg.reduce
+ // CHECK: arith.minimumf
+ // CHECK: arith.cmpf uno
+ // CHECK: arith.select
+ // CHECK: linalg.yield
+ %5 = tosa.reduce_min %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xf32>) -> tensor<1x4xf32>
+
+ // CHECK: linalg.reduce
+ // CHECK: arith.maximumf
+ // CHECK: arith.cmpf uno
+ // CHECK: arith.select
+ // CHECK: linalg.yield
+ %6 = tosa.reduce_max %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xf32>) -> tensor<1x4xf32>
+
+ // CHECK: linalg.generic
+ // CHECK: arith.minimumf
+ // CHECK-NOT: arith.cmpf uno
+ // CHECK-NOT: arith.select
+ // CHECK: linalg.yield
+ %7 = tosa.minimum %arg0, %arg1 {nan_mode = "PROPAGATE"} : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32>
+
+ // CHECK: linalg.generic
+ // CHECK: arith.maximumf
+ // CHECK-NOT: arith.cmpf uno
+ // CHECK-NOT: arith.select
+ // CHECK: linalg.yield
+ %8 = tosa.maximum %arg0, %arg1 {nan_mode = "PROPAGATE"} : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32>
+
+ // CHECK: linalg.generic
+ // CHECK: arith.minimumf
+ // CHECK: arith.cmpf uno
+ // CHECK: arith.cmpf uno
+ // CHECK: arith.select
+ // CHECK: arith.select
+ // CHECK: linalg.yield
+ %9 = tosa.minimum %arg0, %arg1 {nan_mode = "IGNORE"} : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32>
+
+ // CHECK: linalg.generic
+ // CHECK: arith.maximumf
+ // CHECK: arith.cmpf uno
+ // CHECK: arith.cmpf uno
+ // CHECK: arith.select
+ // CHECK: arith.select
+ // CHECK: linalg.yield
+ %10 = tosa.maximum %arg0, %arg1 {nan_mode = "IGNORE"} : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32>
+
+ // CHECK: linalg.generic
+ // CHECK: arith.cmpf ogt
+ // CHECK: arith.select
+ // CHECK: arith.select
+ // CHECK-NOT: arith.cmpf uno
+ // CHECK-NOT: arith.cmpf uno
+ // CHECK-NOT: arith.select
+ // CHECK-NOT: arith.select
+ // CHECK: linalg.yield
+ %11 = tosa.argmax %arg0 {axis = 0 : i32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>) -> tensor<4xi32>
+
+ // CHECK: linalg.generic
+ // CHECK: arith.cmpf ogt
+ // CHECK: arith.select
+ // CHECK: arith.select
+ // CHECK: arith.cmpf uno
+ // CHECK: arith.select
+ // CHECK: arith.select
+ // CHECK: linalg.yield
+ %12 = tosa.argmax %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xf32>) -> tensor<4xi32>
+
+ // CHECK: linalg.generic
+ // CHECK: arith.minimumf
+ // CHECK: arith.maximumf
+ // CHECK-NOT: arith.cmpf uno
+ // CHECK-NOT: arith.select
+ // CHECK: linalg.yield
+ %13 = tosa.clamp %arg0 {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>) -> tensor<5x4xf32>
+
+ // CHECK: linalg.generic
+ // CHECK: arith.minimumf
+ // CHECK: arith.maximumf
+ // CHECK: arith.cmpf uno
+ // CHECK: arith.select
+ // CHECK: linalg.yield
+ %14 = tosa.clamp %arg0 {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32, nan_mode = "IGNORE"} : (tensor<5x4xf32>) -> tensor<5x4xf32>
+
+ return
+}
|
199ae05
to
0c944b6
Compare
Please do not merge this yet! I've detected a bug in the case that all input values are NaN in which case this implementation doesn't match the behvaiour as defined in the TOSA specification. |
0c944b6
to
63d2502
Compare
This has now been resolved. |
63d2502
to
96ae7bc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lhutton1 can you have a look as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eyeballing the numerical behaviour it LGTM, though I'm not very familiar with linalg so I won't explicitly approve
96ae7bc
to
7c747c0
Compare
7a6bd0d
to
70919f0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comments
Add support for NaN propagation lowering in the `tosa-to-linalg` and `tosa-to-linalg-named` conversions by conditionally checking for NaN in the case of ignore semantics and materializing the appropriate select operations. Note that the default behviour of "propagate" matches that of the arith dialect and so in that case we can avoid creating the checks altogether. Add appropriate lit tests including negative tests which check the various comparisons and selects are materialized as appropriate. This affects the following TOSA operators: * arg_max * max_pool_2d * clamp * reduce_max * reduce_min * maximum * minimum Signed-off-by: Jack Frankland <[email protected]>
70919f0
to
13ab98c
Compare
Add support for NaN propagation lowering in the
tosa-to-linalg
andtosa-to-linalg-named
conversions by conditionally checking for NaN in the case of ignore semantics and materializing the appropriate select operations. Note that the default behviour of "propagate" matches that of the arith dialect and so in that case we can avoid creating the checks altogether.Add appropriate lit tests including negative tests which check the various comparisons and selects are materialized as appropriate.
This affects the following TOSA operators: