Skip to content

Commit dad9de0

Browse files
committed
[mlir][vector] Improve lowering to LLVM for minf, maxf reductions
This patch improves the lowering by changing target LLVM intrinsics from `reduce.fmax` and `reduce.fmin`, which have different semantic for handling NaN, to `reduce.fmaximum` and `reduce.fminimum` ones. Fixes #63969 Depends on D155869 Reviewed By: dcaballe Differential Revision: https://reviews.llvm.org/D155877
1 parent 346c1f2 commit dad9de0

File tree

2 files changed

+31
-47
lines changed

2 files changed

+31
-47
lines changed

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 27 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -573,35 +573,31 @@ static Value createIntegerReductionComparisonOpLowering(
573573
return result;
574574
}
575575

576-
/// Create lowering of minf/maxf op. We cannot use llvm.maximum/llvm.minimum
577-
/// with vector types.
578-
static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
579-
Value rhs, bool isMin) {
580-
auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType()));
581-
Type i1Type = builder.getI1Type();
582-
if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
583-
i1Type = VectorType::get(vecType.getShape(), i1Type);
584-
Value cmp = builder.create<LLVM::FCmpOp>(
585-
loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
586-
lhs, rhs);
587-
Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
588-
Value isNan = builder.create<LLVM::FCmpOp>(
589-
loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs);
590-
Value nan = builder.create<LLVM::ConstantOp>(
591-
loc, lhs.getType(),
592-
builder.getFloatAttr(floatType,
593-
APFloat::getQNaN(floatType.getFloatSemantics())));
594-
return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel);
595-
}
576+
namespace {
577+
template <typename Source>
578+
struct VectorToScalarMapper;
579+
template <>
580+
struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> {
581+
using Type = LLVM::MaximumOp;
582+
};
583+
template <>
584+
struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> {
585+
using Type = LLVM::MinimumOp;
586+
};
587+
} // namespace
596588

597589
template <class LLVMRedIntrinOp>
598-
static Value createFPReductionComparisonOpLowering(
599-
ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
600-
Value vectorOperand, Value accumulator, bool isMin) {
590+
static Value
591+
createFPReductionComparisonOpLowering(ConversionPatternRewriter &rewriter,
592+
Location loc, Type llvmType,
593+
Value vectorOperand, Value accumulator) {
601594
Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
602595

603-
if (accumulator)
604-
result = createMinMaxF(rewriter, loc, result, accumulator, /*isMin=*/isMin);
596+
if (accumulator) {
597+
result =
598+
rewriter.create<typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>(
599+
loc, result, accumulator);
600+
}
605601

606602
return result;
607603
}
@@ -774,17 +770,13 @@ class VectorReductionOpConversion
774770
ReductionNeutralFPOne>(
775771
rewriter, loc, llvmType, operand, acc, reassociateFPReductions);
776772
} else if (kind == vector::CombiningKind::MINF) {
777-
// FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
778-
// NaNs/-0.0/+0.0 in the same way.
779-
result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
780-
rewriter, loc, llvmType, operand, acc,
781-
/*isMin=*/true);
773+
result =
774+
createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
775+
rewriter, loc, llvmType, operand, acc);
782776
} else if (kind == vector::CombiningKind::MAXF) {
783-
// FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle
784-
// NaNs/-0.0/+0.0 in the same way.
785-
result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
786-
rewriter, loc, llvmType, operand, acc,
787-
/*isMin=*/false);
777+
result =
778+
createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
779+
rewriter, loc, llvmType, operand, acc);
788780
} else
789781
return failure();
790782

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1374,12 +1374,8 @@ func.func @reduce_fmax_f32(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
13741374
}
13751375
// CHECK-LABEL: @reduce_fmax_f32(
13761376
// CHECK-SAME: %[[A:.*]]: vector<16xf32>, %[[B:.*]]: f32)
1377-
// CHECK: %[[V:.*]] = llvm.intr.vector.reduce.fmax(%[[A]]) : (vector<16xf32>) -> f32
1378-
// CHECK: %[[C0:.*]] = llvm.fcmp "ogt" %[[V]], %[[B]] : f32
1379-
// CHECK: %[[S0:.*]] = llvm.select %[[C0]], %[[V]], %[[B]] : i1, f32
1380-
// CHECK: %[[C1:.*]] = llvm.fcmp "uno" %[[V]], %[[B]] : f32
1381-
// CHECK: %[[NAN:.*]] = llvm.mlir.constant(0x7FC00000 : f32) : f32
1382-
// CHECK: %[[R:.*]] = llvm.select %[[C1]], %[[NAN]], %[[S0]] : i1, f32
1377+
// CHECK: %[[V:.*]] = llvm.intr.vector.reduce.fmaximum(%[[A]]) : (vector<16xf32>) -> f32
1378+
// CHECK: %[[R:.*]] = llvm.intr.maximum(%[[V]], %[[B]]) : (f32, f32) -> f32
13831379
// CHECK: return %[[R]] : f32
13841380

13851381
// -----
@@ -1390,12 +1386,8 @@ func.func @reduce_fmin_f32(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
13901386
}
13911387
// CHECK-LABEL: @reduce_fmin_f32(
13921388
// CHECK-SAME: %[[A:.*]]: vector<16xf32>, %[[B:.*]]: f32)
1393-
// CHECK: %[[V:.*]] = llvm.intr.vector.reduce.fmin(%[[A]]) : (vector<16xf32>) -> f32
1394-
// CHECK: %[[C0:.*]] = llvm.fcmp "olt" %[[V]], %[[B]] : f32
1395-
// CHECK: %[[S0:.*]] = llvm.select %[[C0]], %[[V]], %[[B]] : i1, f32
1396-
// CHECK: %[[C1:.*]] = llvm.fcmp "uno" %[[V]], %[[B]] : f32
1397-
// CHECK: %[[NAN:.*]] = llvm.mlir.constant(0x7FC00000 : f32) : f32
1398-
// CHECK: %[[R:.*]] = llvm.select %[[C1]], %[[NAN]], %[[S0]] : i1, f32
1389+
// CHECK: %[[V:.*]] = llvm.intr.vector.reduce.fminimum(%[[A]]) : (vector<16xf32>) -> f32
1390+
// CHECK: %[[R:.*]] = llvm.intr.minimum(%[[V]], %[[B]]) : (f32, f32) -> f32
13991391
// CHECK: return %[[R]] : f32
14001392

14011393
// -----

0 commit comments

Comments
 (0)