Skip to content

Commit ca8cba7

Browse files
[mlir][arith] Introduce minnumf and maxnumf operations (#66429)
This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671. Here we introduce new operations for floating-point numbers: `minnum` and `maxnum`. These operations have different semantics than `minumumf` and `maximumf` ops. They follow the eponymous LLVM intrinsics semantics, which differs in the handling positive and negative zeros and NaNs. This patch addresses the 1.3 task from the RFC.
1 parent 72bbac4 commit ca8cba7

File tree

4 files changed

+142
-14
lines changed

4 files changed

+142
-14
lines changed

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,34 @@ def Arith_MaximumFOp : Arith_FloatBinaryOp<"maximumf", [Commutative]> {
857857
let hasFolder = 1;
858858
}
859859

860+
//===----------------------------------------------------------------------===//
861+
// MaxNumFOp
862+
//===----------------------------------------------------------------------===//
863+
864+
def Arith_MaxNumFOp : Arith_FloatBinaryOp<"maxnumf", [Commutative]> {
865+
let summary = "floating-point maximum operation";
866+
let description = [{
867+
Syntax:
868+
869+
```
870+
operation ::= ssa-id `=` `arith.maxnumf` ssa-use `,` ssa-use `:` type
871+
```
872+
873+
Returns the maximum of the two arguments.
874+
If the arguments are -0.0 and +0.0, then the result is either of them.
875+
If one of the arguments is NaN, then the result is the other argument.
876+
877+
Example:
878+
879+
```mlir
880+
// Scalar floating-point maximum.
881+
%a = arith.maxnumf %b, %c : f64
882+
```
883+
}];
884+
let hasFolder = 1;
885+
}
886+
887+
860888
//===----------------------------------------------------------------------===//
861889
// MaxSIOp
862890
//===----------------------------------------------------------------------===//
@@ -901,6 +929,33 @@ def Arith_MinimumFOp : Arith_FloatBinaryOp<"minimumf", [Commutative]> {
901929
let hasFolder = 1;
902930
}
903931

932+
//===----------------------------------------------------------------------===//
933+
// MinNumFOp
934+
//===----------------------------------------------------------------------===//
935+
936+
def Arith_MinNumFOp : Arith_FloatBinaryOp<"minnumf", [Commutative]> {
937+
let summary = "floating-point minimum operation";
938+
let description = [{
939+
Syntax:
940+
941+
```
942+
operation ::= ssa-id `=` `arith.minnumf` ssa-use `,` ssa-use `:` type
943+
```
944+
945+
Returns the minimum of the two arguments.
946+
If the arguments are -0.0 and +0.0, then the result is either of them.
947+
If one of the arguments is NaN, then the result is the other argument.
948+
949+
Example:
950+
951+
```mlir
952+
// Scalar floating-point minimum.
953+
%a = arith.minnumf %b, %c : f64
954+
```
955+
}];
956+
let hasFolder = 1;
957+
}
958+
904959
//===----------------------------------------------------------------------===//
905960
// MinSIOp
906961
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -927,11 +927,11 @@ OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
927927
//===----------------------------------------------------------------------===//
928928

929929
OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
930-
// maxf(x,x) -> x
930+
// maximumf(x,x) -> x
931931
if (getLhs() == getRhs())
932932
return getRhs();
933933

934-
// maxf(x, -inf) -> x
934+
// maximumf(x, -inf) -> x
935935
if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
936936
return getLhs();
937937

@@ -940,6 +940,25 @@ OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
940940
[](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
941941
}
942942

943+
//===----------------------------------------------------------------------===//
944+
// MaxNumFOp
945+
//===----------------------------------------------------------------------===//
946+
947+
OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
948+
// maxnumf(x,x) -> x
949+
if (getLhs() == getRhs())
950+
return getRhs();
951+
952+
// maxnumf(x, -inf) -> x
953+
if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
954+
return getLhs();
955+
956+
return constFoldBinaryOp<FloatAttr>(
957+
adaptor.getOperands(),
958+
[](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
959+
}
960+
961+
943962
//===----------------------------------------------------------------------===//
944963
// MaxSIOp
945964
//===----------------------------------------------------------------------===//
@@ -995,11 +1014,11 @@ OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
9951014
//===----------------------------------------------------------------------===//
9961015

9971016
OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
998-
// minf(x,x) -> x
1017+
// minimumf(x,x) -> x
9991018
if (getLhs() == getRhs())
10001019
return getRhs();
10011020

1002-
// minf(x, +inf) -> x
1021+
// minimumf(x, +inf) -> x
10031022
if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
10041023
return getLhs();
10051024

@@ -1008,6 +1027,24 @@ OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
10081027
[](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
10091028
}
10101029

1030+
//===----------------------------------------------------------------------===//
1031+
// MinNumFOp
1032+
//===----------------------------------------------------------------------===//
1033+
1034+
OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1035+
// minnumf(x,x) -> x
1036+
if (getLhs() == getRhs())
1037+
return getRhs();
1038+
1039+
// minnumf(x, +inf) -> x
1040+
if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1041+
return getLhs();
1042+
1043+
return constFoldBinaryOp<FloatAttr>(
1044+
adaptor.getOperands(),
1045+
[](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); });
1046+
}
1047+
10111048
//===----------------------------------------------------------------------===//
10121049
// MinSIOp
10131050
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1635,8 +1635,8 @@ func.func @test_minui2(%arg0 : i8) -> (i8, i8, i8, i8) {
16351635

16361636
// -----
16371637

1638-
// CHECK-LABEL: @test_minf(
1639-
func.func @test_minf(%arg0 : f32) -> (f32, f32, f32) {
1638+
// CHECK-LABEL: @test_minimumf(
1639+
func.func @test_minimumf(%arg0 : f32) -> (f32, f32, f32) {
16401640
// CHECK-DAG: %[[C0:.+]] = arith.constant 0.0
16411641
// CHECK-NEXT: %[[X:.+]] = arith.minimumf %arg0, %[[C0]]
16421642
// CHECK-NEXT: return %[[X]], %arg0, %arg0
@@ -1650,8 +1650,8 @@ func.func @test_minf(%arg0 : f32) -> (f32, f32, f32) {
16501650

16511651
// -----
16521652

1653-
// CHECK-LABEL: @test_maxf(
1654-
func.func @test_maxf(%arg0 : f32) -> (f32, f32, f32) {
1653+
// CHECK-LABEL: @test_maximumf(
1654+
func.func @test_maximumf(%arg0 : f32) -> (f32, f32, f32) {
16551655
// CHECK-DAG: %[[C0:.+]] = arith.constant
16561656
// CHECK-NEXT: %[[X:.+]] = arith.maximumf %arg0, %[[C0]]
16571657
// CHECK-NEXT: return %[[X]], %arg0, %arg0
@@ -1665,6 +1665,36 @@ func.func @test_maxf(%arg0 : f32) -> (f32, f32, f32) {
16651665

16661666
// -----
16671667

1668+
// CHECK-LABEL: @test_minnumf(
1669+
func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32) {
1670+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0.0
1671+
// CHECK-NEXT: %[[X:.+]] = arith.minnumf %arg0, %[[C0]]
1672+
// CHECK-NEXT: return %[[X]], %arg0, %arg0
1673+
%c0 = arith.constant 0.0 : f32
1674+
%inf = arith.constant 0x7F800000 : f32
1675+
%0 = arith.minnumf %c0, %arg0 : f32
1676+
%1 = arith.minnumf %arg0, %arg0 : f32
1677+
%2 = arith.minnumf %inf, %arg0 : f32
1678+
return %0, %1, %2 : f32, f32, f32
1679+
}
1680+
1681+
// -----
1682+
1683+
// CHECK-LABEL: @test_maxnumf(
1684+
func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32) {
1685+
// CHECK-DAG: %[[C0:.+]] = arith.constant
1686+
// CHECK-NEXT: %[[X:.+]] = arith.maxnumf %arg0, %[[C0]]
1687+
// CHECK-NEXT: return %[[X]], %arg0, %arg0
1688+
%c0 = arith.constant 0.0 : f32
1689+
%-inf = arith.constant 0xFF800000 : f32
1690+
%0 = arith.maxnumf %c0, %arg0 : f32
1691+
%1 = arith.maxnumf %arg0, %arg0 : f32
1692+
%2 = arith.maxnumf %-inf, %arg0 : f32
1693+
return %0, %1, %2 : f32, f32, f32
1694+
}
1695+
1696+
// -----
1697+
16681698
// CHECK-LABEL: @test_addf(
16691699
func.func @test_addf(%arg0 : f32) -> (f32, f32, f32, f32) {
16701700
// CHECK-DAG: %[[C2:.+]] = arith.constant 2.0

mlir/test/Dialect/Arith/ops.mlir

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,9 +1071,12 @@ func.func @maximum(%v1: vector<4xf32>, %v2: vector<4xf32>,
10711071
%sv1: vector<[4]xf32>, %sv2: vector<[4]xf32>,
10721072
%f1: f32, %f2: f32,
10731073
%i1: i32, %i2: i32) {
1074-
%max_vector = arith.maximumf %v1, %v2 : vector<4xf32>
1075-
%max_scalable_vector = arith.maximumf %sv1, %sv2 : vector<[4]xf32>
1076-
%max_float = arith.maximumf %f1, %f2 : f32
1074+
%maximum_vector = arith.maximumf %v1, %v2 : vector<4xf32>
1075+
%maximum_scalable_vector = arith.maximumf %sv1, %sv2 : vector<[4]xf32>
1076+
%maximum_float = arith.maximumf %f1, %f2 : f32
1077+
%maxnum_vector = arith.maxnumf %v1, %v2 : vector<4xf32>
1078+
%maxnum_scalable_vector = arith.maxnumf %sv1, %sv2 : vector<[4]xf32>
1079+
%maxnum_float = arith.maxnumf %f1, %f2 : f32
10771080
%max_signed = arith.maxsi %i1, %i2 : i32
10781081
%max_unsigned = arith.maxui %i1, %i2 : i32
10791082
return
@@ -1084,9 +1087,12 @@ func.func @minimum(%v1: vector<4xf32>, %v2: vector<4xf32>,
10841087
%sv1: vector<[4]xf32>, %sv2: vector<[4]xf32>,
10851088
%f1: f32, %f2: f32,
10861089
%i1: i32, %i2: i32) {
1087-
%min_vector = arith.minimumf %v1, %v2 : vector<4xf32>
1088-
%min_scalable_vector = arith.minimumf %sv1, %sv2 : vector<[4]xf32>
1089-
%min_float = arith.minimumf %f1, %f2 : f32
1090+
%minimum_vector = arith.minimumf %v1, %v2 : vector<4xf32>
1091+
%minimum_scalable_vector = arith.minimumf %sv1, %sv2 : vector<[4]xf32>
1092+
%minimum_float = arith.minimumf %f1, %f2 : f32
1093+
%minnum_vector = arith.minnumf %v1, %v2 : vector<4xf32>
1094+
%minnum_scalable_vector = arith.minnumf %sv1, %sv2 : vector<[4]xf32>
1095+
%minnum_float = arith.minnumf %f1, %f2 : f32
10901096
%min_signed = arith.minsi %i1, %i2 : i32
10911097
%min_unsigned = arith.minui %i1, %i2 : i32
10921098
return

0 commit comments

Comments
 (0)