Skip to content

Commit 641124a

Browse files
[mlir][spirv] Add conversions for Arith's maxnumf and minnumf (#66696)
This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671. In this commit, we add conversion patterns for the newly introduced operations `arith.minnumf` and `arith.maxnumf`. When converting to `spirv.CL`, there is no need to insert additional guards to propagate non-NaN values when one of the arguments is NaN because `CL` ops do exactly the same. However, `GL` ops have undefined behavior when one of the arguments is NaN, so we should insert additional guards to enforce the semantics of Arith's ops. This patch addresses the 1.5 task of the mentioned RFC.
1 parent 80c01dd commit 641124a

File tree

3 files changed

+135
-12
lines changed

3 files changed

+135
-12
lines changed

mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/IR/BuiltinTypes.h"
2020
#include "llvm/ADT/APInt.h"
2121
#include "llvm/ADT/ArrayRef.h"
22+
#include "llvm/ADT/STLExtras.h"
2223
#include "llvm/Support/Debug.h"
2324
#include "llvm/Support/MathExtras.h"
2425
#include <cassert>
@@ -1086,6 +1087,61 @@ class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
10861087
}
10871088
};
10881089

1090+
//===----------------------------------------------------------------------===//
1091+
// MinNumFOp, MaxNumFOp
1092+
//===----------------------------------------------------------------------===//
1093+
1094+
/// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or
1095+
/// spirv.CL.fmax/fmin.
1096+
template <typename Op, typename SPIRVOp>
1097+
class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
1098+
template <typename TargetOp>
1099+
constexpr bool shouldInsertNanGuards() const {
1100+
return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
1101+
}
1102+
1103+
public:
1104+
using OpConversionPattern<Op>::OpConversionPattern;
1105+
LogicalResult
1106+
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1107+
ConversionPatternRewriter &rewriter) const override {
1108+
auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1109+
Type dstType = converter->convertType(op.getType());
1110+
if (!dstType)
1111+
return getTypeConversionFailure(rewriter, op);
1112+
1113+
// arith.maxnumf/minnumf:
1114+
// "If one of the arguments is NaN, then the result is the other
1115+
// argument."
1116+
// spirv.GL.FMax/FMin
1117+
// "which operand is the result is undefined if one of the operands
1118+
// is a NaN."
1119+
// spirv.CL.fmax/fmin:
1120+
// "If one argument is a NaN, Fmin returns the other argument."
1121+
1122+
Location loc = op.getLoc();
1123+
Value spirvOp =
1124+
rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
1125+
1126+
if (!shouldInsertNanGuards<SPIRVOp>() ||
1127+
converter->getOptions().enableFastMathMode) {
1128+
rewriter.replaceOp(op, spirvOp);
1129+
return success();
1130+
}
1131+
1132+
Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
1133+
Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
1134+
1135+
Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
1136+
adaptor.getRhs(), spirvOp);
1137+
Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
1138+
adaptor.getLhs(), select1);
1139+
1140+
rewriter.replaceOp(op, select2);
1141+
return success();
1142+
}
1143+
};
1144+
10891145
} // namespace
10901146

10911147
//===----------------------------------------------------------------------===//
@@ -1138,13 +1194,17 @@ void mlir::arith::populateArithToSPIRVPatterns(
11381194

11391195
MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
11401196
MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
1197+
MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
1198+
MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
11411199
spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSMaxOp>,
11421200
spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLUMaxOp>,
11431201
spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSMinOp>,
11441202
spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLUMinOp>,
11451203

11461204
MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
11471205
MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
1206+
MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
1207+
MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
11481208
spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::CLSMaxOp>,
11491209
spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::CLUMaxOp>,
11501210
spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::CLSMinOp>,

mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1124,9 +1124,9 @@ func.func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
11241124
return
11251125
}
11261126

1127-
// CHECK-LABEL: @float32_minf_scalar
1127+
// CHECK-LABEL: @float32_minimumf_scalar
11281128
// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
1129-
func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
1129+
func.func @float32_minimumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
11301130
// CHECK: %[[MIN:.+]] = spirv.CL.fmin %arg0, %arg1 : f32
11311131
// CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : f32
11321132
// CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : f32
@@ -1137,9 +1137,18 @@ func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
11371137
return %0: f32
11381138
}
11391139

1140-
// CHECK-LABEL: @float32_maxf_scalar
1140+
// CHECK-LABEL: @float32_minnumf_scalar
1141+
// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
1142+
func.func @float32_minnumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
1143+
// CHECK: %[[MIN:.+]] = spirv.CL.fmin %arg0, %arg1 : f32
1144+
%0 = arith.minnumf %arg0, %arg1 : f32
1145+
// CHECK: return %[[MIN]]
1146+
return %0: f32
1147+
}
1148+
1149+
// CHECK-LABEL: @float32_maximumf_scalar
11411150
// CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
1142-
func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
1151+
func.func @float32_maximumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
11431152
// CHECK: %[[MAX:.+]] = spirv.CL.fmax %arg0, %arg1 : vector<2xf32>
11441153
// CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : vector<2xf32>
11451154
// CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : vector<2xf32>
@@ -1150,6 +1159,16 @@ func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) ->
11501159
return %0: vector<2xf32>
11511160
}
11521161

1162+
// CHECK-LABEL: @float32_maxnumf_scalar
1163+
// CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
1164+
func.func @float32_maxnumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
1165+
// CHECK: %[[MAX:.+]] = spirv.CL.fmax %arg0, %arg1 : vector<2xf32>
1166+
%0 = arith.maxnumf %arg0, %arg1 : vector<2xf32>
1167+
// CHECK: return %[[MAX]]
1168+
return %0: vector<2xf32>
1169+
}
1170+
1171+
11531172
// CHECK-LABEL: @scalar_srem
11541173
// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
11551174
func.func @scalar_srem(%lhs: i32, %rhs: i32) {
@@ -1270,9 +1289,9 @@ func.func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
12701289
return
12711290
}
12721291

1273-
// CHECK-LABEL: @float32_minf_scalar
1292+
// CHECK-LABEL: @float32_minimumf_scalar
12741293
// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
1275-
func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
1294+
func.func @float32_minimumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
12761295
// CHECK: %[[MIN:.+]] = spirv.GL.FMin %arg0, %arg1 : f32
12771296
// CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : f32
12781297
// CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : f32
@@ -1283,9 +1302,22 @@ func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
12831302
return %0: f32
12841303
}
12851304

1286-
// CHECK-LABEL: @float32_maxf_scalar
1305+
// CHECK-LABEL: @float32_minnumf_scalar
1306+
// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
1307+
func.func @float32_minnumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
1308+
// CHECK: %[[MIN:.+]] = spirv.GL.FMin %arg0, %arg1 : f32
1309+
// CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : f32
1310+
// CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : f32
1311+
// CHECK: %[[SELECT1:.+]] = spirv.Select %[[LHS_NAN]], %[[RHS]], %[[MIN]]
1312+
// CHECK: %[[SELECT2:.+]] = spirv.Select %[[RHS_NAN]], %[[LHS]], %[[SELECT1]]
1313+
%0 = arith.minnumf %arg0, %arg1 : f32
1314+
// CHECK: return %[[SELECT2]]
1315+
return %0: f32
1316+
}
1317+
1318+
// CHECK-LABEL: @float32_maximumf_scalar
12871319
// CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
1288-
func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
1320+
func.func @float32_maximumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
12891321
// CHECK: %[[MAX:.+]] = spirv.GL.FMax %arg0, %arg1 : vector<2xf32>
12901322
// CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : vector<2xf32>
12911323
// CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : vector<2xf32>
@@ -1296,6 +1328,19 @@ func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) ->
12961328
return %0: vector<2xf32>
12971329
}
12981330

1331+
// CHECK-LABEL: @float32_maxnumf_scalar
1332+
// CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
1333+
func.func @float32_maxnumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
1334+
// CHECK: %[[MAX:.+]] = spirv.GL.FMax %arg0, %arg1 : vector<2xf32>
1335+
// CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : vector<2xf32>
1336+
// CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : vector<2xf32>
1337+
// CHECK: %[[SELECT1:.+]] = spirv.Select %[[LHS_NAN]], %[[RHS]], %[[MAX]]
1338+
// CHECK: %[[SELECT2:.+]] = spirv.Select %[[RHS_NAN]], %[[LHS]], %[[SELECT1]]
1339+
%0 = arith.maxnumf %arg0, %arg1 : vector<2xf32>
1340+
// CHECK: return %[[SELECT2]]
1341+
return %0: vector<2xf32>
1342+
}
1343+
12991344
// Check int vector types.
13001345
// CHECK-LABEL: @int_vector234
13011346
func.func @int_vector234(%arg0: vector<2xi8>, %arg1: vector<4xi64>) {

mlir/test/Conversion/ArithToSPIRV/fast-math.mlir

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,40 @@ module attributes {
3030
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], []>, #spirv.resource_limits<>>
3131
} {
3232

33-
// CHECK-LABEL: @minf
33+
// CHECK-LABEL: @minimumf
3434
// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
35-
func.func @minf(%arg0 : f32, %arg1 : f32) -> f32 {
35+
func.func @minimumf(%arg0 : f32, %arg1 : f32) -> f32 {
3636
// CHECK: %[[F:.+]] = spirv.GL.FMin %[[LHS]], %[[RHS]]
3737
%0 = arith.minimumf %arg0, %arg1 : f32
3838
// CHECK: return %[[F]]
3939
return %0: f32
4040
}
4141

42-
// CHECK-LABEL: @maxf
42+
// CHECK-LABEL: @maximumf
4343
// CHECK-SAME: %[[LHS:.+]]: vector<4xf32>, %[[RHS:.+]]: vector<4xf32>
44-
func.func @maxf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
44+
func.func @maximumf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
4545
// CHECK: %[[F:.+]] = spirv.GL.FMax %[[LHS]], %[[RHS]]
4646
%0 = arith.maximumf %arg0, %arg1 : vector<4xf32>
4747
// CHECK: return %[[F]]
4848
return %0: vector<4xf32>
4949
}
5050

51+
// CHECK-LABEL: @minnumf
52+
// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
53+
func.func @minnumf(%arg0 : f32, %arg1 : f32) -> f32 {
54+
// CHECK: %[[F:.+]] = spirv.GL.FMin %[[LHS]], %[[RHS]]
55+
%0 = arith.minnumf %arg0, %arg1 : f32
56+
// CHECK: return %[[F]]
57+
return %0: f32
58+
}
59+
60+
// CHECK-LABEL: @maxnumf
61+
// CHECK-SAME: %[[LHS:.+]]: vector<4xf32>, %[[RHS:.+]]: vector<4xf32>
62+
func.func @maxnumf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
63+
// CHECK: %[[F:.+]] = spirv.GL.FMax %[[LHS]], %[[RHS]]
64+
%0 = arith.maxnumf %arg0, %arg1 : vector<4xf32>
65+
// CHECK: return %[[F]]
66+
return %0: vector<4xf32>
67+
}
68+
5169
} // end module

0 commit comments

Comments
 (0)