Skip to content

[mlir][tosa][tosa-to-linalg] Add NaN Mode Lowering #125668

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 25, 2025

Conversation

FranklandJack
Copy link
Contributor

Add support for NaN propagation lowering in the tosa-to-linalg and tosa-to-linalg-named conversions by conditionally checking for NaN in the case of ignore semantics and materializing the appropriate select operations. Note that the default behviour of "propagate" matches that of the arith dialect and so in that case we can avoid creating the checks altogether.

Add appropriate lit tests including negative tests which check the various comparisons and selects are materialized as appropriate.

This affects the following TOSA operators:

  • arg_max
  • max_pool_2d
  • clamp
  • reduce_max
  • reduce_min
  • maximum
  • minimum

@llvmbot
Copy link
Member

llvmbot commented Feb 4, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir-tosa

Author: Jack Frankland (FranklandJack)

Changes

Add support for NaN propagation lowering in the tosa-to-linalg and tosa-to-linalg-named conversions by conditionally checking for NaN in the case of ignore semantics and materializing the appropriate select operations. Note that the default behviour of "propagate" matches that of the arith dialect and so in that case we can avoid creating the checks altogether.

Add appropriate lit tests including negative tests which check the various comparisons and selects are materialized as appropriate.

This affects the following TOSA operators:

  • arg_max
  • max_pool_2d
  • clamp
  • reduce_max
  • reduce_min
  • maximum
  • minimum

Full diff: https://github.com/llvm/llvm-project/pull/125668.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h (+41)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+118-5)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+40-4)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+22)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+104)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 069073bc2d164ac..1cd0d6a63114244 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -141,6 +141,22 @@ namespace tosa {
 
 bool isa_tosa_shape_type(mlir::Type t);
 
+// Helper function to materialize the semantically correct compare and select
+// operations a reduction operation with a specific NaN propagation mode.
+//
+// In the case of "PROPAGATE" semantics no compare and selection is required and
+// this function does nothing.
+//
+// In the case of "IGNORE" semantics this function materializes a comparison of
+// the current operand to the reduction which will return true for a NaN
+// argument and then selects between the initial reduction value and the
+// calculated result based on whether the argument is NaN or not. In pseudo
+// code:
+//
+// reduce<op>(x, init):
+//   result = op(init, x)
+//   return init if x == NaN else result
+
 } // namespace tosa
 
 } // namespace mlir
@@ -267,6 +283,31 @@ extractConvZpPair(TosaConvOp op, PatternRewriter &rewriter) {
 
   return std::make_optional<ConvZpPair>(*maybeInputZp, *maybeWeightZp);
 }
+
+// Helper to determine if an operation should have the "nan_mode" string
+// attribute.
+inline bool shouldHaveNanPropagation(Operation *op) {
+  return isa<tosa::ClampOp>(op) || isa<tosa::MaxPool2dOp>(op) ||
+         isa<tosa::ReduceMinOp>(op) || isa<tosa::ReduceMaxOp>(op) ||
+         isa<tosa::MaximumOp>(op) || isa<tosa::MinimumOp>(op);
+}
+
+// Helper function to extract the NaN propagation mode from an operation.
+// Note that the for operations which support NaN mode propagation the attribute
+// is optional and its default value is "PROPAGATE".
+//
+// If the function is called with an operator that doesn't support the NaN mode
+// attribute it will return a std::nullopt.
+inline std::optional<std::string> getNanMode(Operation *op,
+                                             PatternRewriter &rewriter) {
+  if (shouldHaveNanPropagation(op))
+    return op->hasAttr("nan_mode") ? op->getAttrOfType<StringAttr>(
+                                           rewriter.getStringAttr("nan_mode"))
+                                         .str()
+                                   : "PROPAGATE";
+  return std::nullopt;
+}
+
 } // namespace tosa
 } // namespace mlir
 
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index b0eb2d6cbc30b63..43e3f286c5fda18 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -36,6 +36,81 @@
 using namespace mlir;
 using namespace mlir::tosa;
 
+// Helper function to materialize the semantically correct compare and select
+// operations a reduction operation with a specific NaN propagation mode.
+//
+// In the case of "PROPAGATE" semantics no compare and selection is required and
+// this function does nothing.
+//
+// In the case of "IGNORE" semantics this function materializes a comparison of
+// the current operand to the reduction which will return true for a NaN
+// argument and then selects between the initial reduction value and the
+// calculated result based on whether the argument is NaN or not. In pseudo
+// code:
+//
+// reduce<op>(x, init):
+//   result = op(init, x)
+//   return init if x == NaN else result
+static Value materializeReductionNanCheckIfRequired(Operation *op,
+                                                    PatternRewriter &rewriter,
+                                                    Value in, Value init,
+                                                    Value result) {
+  const auto nanMode = getNanMode(op, rewriter);
+  if (!nanMode)
+    return {};
+
+  if (*nanMode == "PROPAGATE")
+    return result;
+
+  assert(*nanMode == "IGNORE" && "Unhandled nan-propagation mode");
+
+  // Unordered comparison of NaN against itself will always return true.
+  Value isNaN = rewriter.create<arith::CmpFOp>(
+      op->getLoc(), arith::CmpFPredicate::UNO, in, in);
+  return rewriter.create<arith::SelectOp>(op->getLoc(), isNaN, init, result);
+}
+
+// Helper function to materialize the semantically correct compare and select
+// operations a binary operation with a specific NaN propagation mode.
+//
+// In the case of "PROPAGATE" semantics no compare and selection is required and
+// this function does nothing.
+//
+// In the case of "IGNORE" semantics this function materializes a comparison of
+// the current operands to the op which will return true for any NaN
+// argument and then selects between the non-NaN operation argument and the
+// calculated result based on whether the lhs or rhs is NaN or not. In pseudo
+// code:
+//
+// binary<op>(lhs, rhs):
+//   result = op(lhs, rhs)
+//   if lhs == NaN return rhs
+//   if rhs == NaN return lhs
+//   return result
+static Value materializeBinaryNanCheckIfRequired(Operation *op,
+                                                 PatternRewriter &rewriter,
+                                                 Value lhs, Value rhs,
+                                                 Value result) {
+  const auto nanMode = getNanMode(op, rewriter);
+  if (!nanMode)
+    return {};
+
+  if (*nanMode == "PROPAGATE")
+    return result;
+
+  assert(*nanMode == "IGNORE" && "Unhandled nan-propagation mode");
+
+  // Unordered comparison of NaN against itself will always return true.
+  Value lhsIsNaN = rewriter.create<arith::CmpFOp>(
+      op->getLoc(), arith::CmpFPredicate::UNO, lhs, lhs);
+  Value rhsIsNaN = rewriter.create<arith::CmpFOp>(
+      op->getLoc(), arith::CmpFPredicate::UNO, rhs, rhs);
+  Value rhsOrResult =
+      rewriter.create<arith::SelectOp>(op->getLoc(), lhsIsNaN, rhs, result);
+  return rewriter.create<arith::SelectOp>(op->getLoc(), rhsIsNaN, lhs,
+                                          rhsOrResult);
+}
+
 template <typename T>
 static arith::ConstantOp
 createConstFromIntAttribute(Operation *op, const std::string &attrName,
@@ -358,7 +433,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
 
   // tosa::MaximumOp
   if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
-    return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
+    auto max = rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
+    return materializeBinaryNanCheckIfRequired(op, rewriter, args[0], args[1],
+                                               max);
   }
 
   if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
@@ -367,7 +444,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
 
   // tosa::MinimumOp
   if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
-    return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
+    auto min = rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
+    return materializeBinaryNanCheckIfRequired(op, rewriter, args[0], args[1],
+                                               min);
   }
 
   if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
@@ -395,7 +474,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
         loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
     auto max = rewriter.create<arith::ConstantOp>(
         loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
-    return clampFloatHelper(loc, args[0], min, max, rewriter);
+    auto result = clampFloatHelper(loc, args[0], min, max, rewriter);
+    // TOSA specifies that in "ignore" NaN mode the result is "min" if the input
+    // is NaN.
+    return materializeReductionNanCheckIfRequired(op, rewriter, args[0], min,
+                                                  result);
   }
 
   if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
@@ -1042,7 +1125,9 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
   }
 
   if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
-    return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
+    auto min = rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
+    return materializeReductionNanCheckIfRequired(op, rewriter, args[0],
+                                                  args[1], min);
   }
 
   if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
@@ -1050,7 +1135,9 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
   }
 
   if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
-    return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
+    auto max = rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
+    return materializeReductionNanCheckIfRequired(op, rewriter, args[0],
+                                                  args[1], max);
   }
 
   if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
@@ -2078,6 +2165,32 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
               nestedLoc, predicate, newValue, oldValue);
           auto resultIndex = rewriter.create<arith::SelectOp>(
               nestedLoc, predicate, newIndex, oldIndex);
+
+          // Check if we need to materialize compare and select for the given
+          // NaN propagation mode.
+          const auto nanMode = getNanMode(argmaxOp, rewriter);
+          if (!nanMode) {
+            didEncounterError = true;
+            return;
+          }
+
+          // "PROPAGATE" matches the default NaN propagation mode of the arith
+          // dialect so no compare and select is required.
+          //
+          // In the case "IGNORE" we check if the current argument is NaN and
+          // select the old index and value otherwise take the updated index and
+          // value.
+          if (*nanMode == "IGNORE") {
+            // Unordered comparison of NaN against itself will always return
+            // true.
+            Value isNaN = rewriter.create<arith::CmpFOp>(
+                argmaxOp.getLoc(), arith::CmpFPredicate::UNO, newValue,
+                newValue);
+            resultMax = rewriter.create<arith::SelectOp>(nestedLoc, isNaN,
+                                                         oldValue, resultMax);
+            resultIndex = rewriter.create<arith::SelectOp>(
+                nestedLoc, isNaN, oldIndex, resultIndex);
+          }
           nestedBuilder.create<linalg::YieldOp>(
               nestedLoc, ValueRange({resultIndex, resultMax}));
         });
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index cf9852e05cf7c9f..67d8f38c18ae949 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -807,11 +807,47 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
       rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(
           op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
           filledEmptyTensor, strideAttr, dilationAttr);
-    } else {
-      rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
-          op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
-          filledEmptyTensor, strideAttr, dilationAttr);
+      return llvm::success();
     }
+
+    auto resultOp = rewriter.create<linalg::PoolingNhwcMaxOp>(
+        op->getLoc(), ArrayRef<Type>{resultTy},
+        ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr,
+        dilationAttr);
+
+    // Check the NaN propgation mode is present.
+    const auto nanMode = getNanMode(op, rewriter);
+    if (!nanMode)
+      return failure();
+
+    // "PROPAGATE" mode matches the behaviour of the LinAlg named op, so no
+    // compare and select materialization is required.
+    //
+    // In the case of "IGNORE" we need to insert a compare and select. Since
+    // we've already produced a named op we will just take its body and modify
+    // it to include the appropriate checks. If the current value is NaN the
+    // old value of pool will be taken otherwise we use the result.
+    if (nanMode == "IGNORE") {
+      auto *block = resultOp.getBlock();
+      rewriter.setInsertionPointToEnd(block);
+
+      auto in = block->getArgument(0);
+      auto out = block->getArgument(1);
+
+      auto *oldYieldOp = &*block->rbegin();
+      auto result = oldYieldOp->getOperand(0);
+
+      Value isNaN = rewriter.create<arith::CmpFOp>(
+          op->getLoc(), arith::CmpFPredicate::UNO, in, in);
+
+      auto selectOp =
+          rewriter.create<arith::SelectOp>(op->getLoc(), isNaN, out, result);
+      auto newYieldOp = rewriter.create<linalg::YieldOp>(oldYieldOp->getLoc(),
+                                                         selectOp.getResult());
+      rewriter.replaceOp(oldYieldOp, newYieldOp);
+    }
+
+    rewriter.replaceOp(op, resultOp);
     return success();
   }
 };
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 116cd045aa0d3af..785819dd68499bc 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -1,6 +1,7 @@
 // RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics -o -| FileCheck %s
 // RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named{prefer-conv2d-kernel-layout-hwcf=true}))" %s -verify-diagnostics -o -| FileCheck --check-prefix="HWCF" %s
 // RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,cse))" %s -verify-diagnostics -o -| FileCheck --check-prefix="CHECK-CSE" %s
+// RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,linalg-generalize-named-ops))" %s -verify-diagnostics -o -| FileCheck %s --check-prefix="CHECK-NAN"
 
 // CHECK-LABEL: @matmul
 func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>) {
@@ -977,3 +978,24 @@ func.func @test_transpose_dyn_multiple_3d(%arg0: tensor<?x?x?xf32>) {
   %1 = "tosa.transpose"(%arg0, %0) : (tensor<?x?x?xf32>, tensor<3xi32>) -> tensor<?x?x?xf32>
   return
 }
+
+// -----
+
+// CHECK-NAN-LABEL: @nan_propagation_modes
+func.func @nan_propagation_modes(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x4x32x62xf32>, tensor<1x4x32x62xf32>) {
+  // CHECK-NAN: linalg.generic
+  // CHECK-NAN-NOT: arith.maximumf
+  // CHECK-NAN-NOT: arith.cmpf uno
+  // CHECK-NAN-NOT: arith.select
+  // CHECK-NAN: linalg.yield
+  %0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>, nan_mode = "PROPAGATE"} : (tensor<1x6x34x62xf32>) -> tensor<1x4x32x62xf32>
+
+  // CHECK-NAN: linalg.generic
+  // CHECK-NAN: arith.maximumf
+  // CHECK-NAN: arith.cmpf uno
+  // CHECK-NAN: arith.select
+  // CHECK-NAN: linalg.yield
+  %1 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>, nan_mode = "IGNORE"} : (tensor<1x6x34x62xf32>) -> tensor<1x4x32x62xf32>
+
+  return %0, %1 : tensor<1x4x32x62xf32>, tensor<1x4x32x62xf32>
+}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index f9bdcefa35317aa..f5ced07561d278b 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1991,3 +1991,107 @@ func.func @test_cast_fp32_i64(%arg0: tensor<1xf32>) -> (tensor<1xi64>) {
   %0 = tosa.cast %arg0 : (tensor<1xf32>) -> tensor<1xi64>
   return %0: tensor<1xi64>
 }
+
+// -----
+
+// CHECK-LABEL: @nan_propagation_modes
+func.func @nan_propagation_modes(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
+  // CHECK: linalg.reduce
+  // CHECK: arith.minimumf
+  // CHECK-NOT: arith.cmpf uno
+  // CHECK-NOT: arith.select
+  // CHECK: linalg.yield
+  %3 = tosa.reduce_min %arg0 {axis = 0 : i32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>) -> tensor<1x4xf32>
+
+  // CHECK: linalg.reduce
+  // CHECK: arith.maximumf
+  // CHECK-NOT: arith.cmpf uno
+  // CHECK-NOT: arith.select
+  // CHECK: linalg.yield
+  %4 = tosa.reduce_max %arg0 {axis = 0 : i32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>) -> tensor<1x4xf32>
+
+  // CHECK: linalg.reduce
+  // CHECK: arith.minimumf
+  // CHECK: arith.cmpf uno
+  // CHECK: arith.select
+  // CHECK: linalg.yield
+  %5 = tosa.reduce_min %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xf32>) -> tensor<1x4xf32>
+
+  // CHECK: linalg.reduce
+  // CHECK: arith.maximumf
+  // CHECK: arith.cmpf uno
+  // CHECK: arith.select
+  // CHECK: linalg.yield
+  %6 = tosa.reduce_max %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xf32>) -> tensor<1x4xf32>
+
+  // CHECK: linalg.generic
+  // CHECK: arith.minimumf
+  // CHECK-NOT: arith.cmpf uno
+  // CHECK-NOT: arith.select
+  // CHECK: linalg.yield
+  %7 = tosa.minimum %arg0, %arg1 {nan_mode = "PROPAGATE"} : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32>
+
+  // CHECK: linalg.generic
+  // CHECK: arith.maximumf
+  // CHECK-NOT: arith.cmpf uno
+  // CHECK-NOT: arith.select
+  // CHECK: linalg.yield
+  %8 = tosa.maximum %arg0, %arg1 {nan_mode = "PROPAGATE"} : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32>
+
+  // CHECK: linalg.generic
+  // CHECK: arith.minimumf
+  // CHECK: arith.cmpf uno
+  // CHECK: arith.cmpf uno
+  // CHECK: arith.select
+  // CHECK: arith.select
+  // CHECK: linalg.yield
+  %9 = tosa.minimum %arg0, %arg1 {nan_mode = "IGNORE"} : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32>
+
+  // CHECK: linalg.generic
+  // CHECK: arith.maximumf
+  // CHECK: arith.cmpf uno
+  // CHECK: arith.cmpf uno
+  // CHECK: arith.select
+  // CHECK: arith.select
+  // CHECK: linalg.yield
+  %10 = tosa.maximum %arg0, %arg1 {nan_mode = "IGNORE"} : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32>
+
+  // CHECK: linalg.generic
+  // CHECK: arith.cmpf ogt
+  // CHECK: arith.select
+  // CHECK: arith.select
+  // CHECK-NOT: arith.cmpf uno
+  // CHECK-NOT: arith.cmpf uno
+  // CHECK-NOT: arith.select
+  // CHECK-NOT: arith.select
+  // CHECK: linalg.yield
+  %11 = tosa.argmax %arg0 {axis = 0 : i32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>)  -> tensor<4xi32>
+
+  // CHECK: linalg.generic
+  // CHECK: arith.cmpf ogt
+  // CHECK: arith.select
+  // CHECK: arith.select
+  // CHECK: arith.cmpf uno
+  // CHECK: arith.select
+  // CHECK: arith.select
+  // CHECK: linalg.yield
+  %12 = tosa.argmax %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xf32>)  -> tensor<4xi32>
+
+  // CHECK: linalg.generic
+  // CHECK: arith.minimumf
+  // CHECK: arith.maximumf
+  // CHECK-NOT: arith.cmpf uno
+  // CHECK-NOT: arith.select
+  // CHECK: linalg.yield
+  %13 = tosa.clamp %arg0 {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>) -> tensor<5x4xf32>
+
+  // CHECK: linalg.generic
+  // CHECK: arith.minimumf
+  // CHECK: arith.maximumf
+  // CHECK: arith.cmpf uno
+  // CHECK: arith.select
+  // CHECK: linalg.yield
+  %14 = tosa.clamp %arg0 {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32, nan_mode = "IGNORE"} : (tensor<5x4xf32>) -> tensor<5x4xf32>
+
+  return
+}

@FranklandJack FranklandJack force-pushed the tosa_nan_mode_lowering branch from 199ae05 to 0c944b6 Compare February 4, 2025 14:27
@FranklandJack
Copy link
Contributor Author

Please do not merge this yet!

I've detected a bug in the case that all input values are NaN in which case this implementation doesn't match the behvaiour as defined in the TOSA specification.

@FranklandJack
Copy link
Contributor Author

Please do not merge this yet!

I've detected a bug in the case that all input values are NaN in which case this implementation doesn't match the behvaiour as defined in the TOSA specification.

This has now been resolved.

Copy link
Contributor

@GeorgeARM GeorgeARM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lhutton1 can you have a look as well?

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eyeballing the numerical behaviour it LGTM, though I'm not very familiar with linalg so I won't explicitly approve

@FranklandJack FranklandJack force-pushed the tosa_nan_mode_lowering branch 2 times, most recently from 7a6bd0d to 70919f0 Compare February 24, 2025 15:49
Copy link
Contributor

@GeorgeARM GeorgeARM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comments

Add support for NaN propagation lowering in the `tosa-to-linalg` and
`tosa-to-linalg-named` conversions by conditionally checking for NaN in
the case of ignore semantics and materializing the appropriate select
operations. Note that the default behviour of "propagate" matches
that of the arith dialect and so in that case we can avoid creating the
checks altogether.

Add appropriate lit tests including negative tests which check the
various comparisons and selects are materialized as appropriate.

This affects the following TOSA operators:
* arg_max
* max_pool_2d
* clamp
* reduce_max
* reduce_min
* maximum
* minimum

Signed-off-by: Jack Frankland <[email protected]>
@FranklandJack FranklandJack merged commit cf3b036 into llvm:main Feb 25, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants