@@ -573,35 +573,31 @@ static Value createIntegerReductionComparisonOpLowering(
573
573
return result;
574
574
}
575
575
576
- // / Create lowering of minf/maxf op. We cannot use llvm.maximum/llvm.minimum
577
- // / with vector types.
578
- static Value createMinMaxF (OpBuilder &builder, Location loc, Value lhs,
579
- Value rhs, bool isMin) {
580
- auto floatType = cast<FloatType>(getElementTypeOrSelf (lhs.getType ()));
581
- Type i1Type = builder.getI1Type ();
582
- if (auto vecType = dyn_cast<VectorType>(lhs.getType ()))
583
- i1Type = VectorType::get (vecType.getShape (), i1Type);
584
- Value cmp = builder.create <LLVM::FCmpOp>(
585
- loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
586
- lhs, rhs);
587
- Value sel = builder.create <LLVM::SelectOp>(loc, cmp, lhs, rhs);
588
- Value isNan = builder.create <LLVM::FCmpOp>(
589
- loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs);
590
- Value nan = builder.create <LLVM::ConstantOp>(
591
- loc, lhs.getType (),
592
- builder.getFloatAttr (floatType,
593
- APFloat::getQNaN (floatType.getFloatSemantics ())));
594
- return builder.create <LLVM::SelectOp>(loc, isNan, nan , sel);
595
- }
576
+ namespace {
577
+ template <typename Source>
578
+ struct VectorToScalarMapper ;
579
+ template <>
580
+ struct VectorToScalarMapper <LLVM::vector_reduce_fmaximum> {
581
+ using Type = LLVM::MaximumOp;
582
+ };
583
+ template <>
584
+ struct VectorToScalarMapper <LLVM::vector_reduce_fminimum> {
585
+ using Type = LLVM::MinimumOp;
586
+ };
587
+ } // namespace
596
588
597
589
template <class LLVMRedIntrinOp >
598
- static Value createFPReductionComparisonOpLowering (
599
- ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
600
- Value vectorOperand, Value accumulator, bool isMin) {
590
+ static Value
591
+ createFPReductionComparisonOpLowering (ConversionPatternRewriter &rewriter,
592
+ Location loc, Type llvmType,
593
+ Value vectorOperand, Value accumulator) {
601
594
Value result = rewriter.create <LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
602
595
603
- if (accumulator)
604
- result = createMinMaxF (rewriter, loc, result, accumulator, /* isMin=*/ isMin);
596
+ if (accumulator) {
597
+ result =
598
+ rewriter.create <typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>(
599
+ loc, result, accumulator);
600
+ }
605
601
606
602
return result;
607
603
}
@@ -774,17 +770,13 @@ class VectorReductionOpConversion
774
770
ReductionNeutralFPOne>(
775
771
rewriter, loc, llvmType, operand, acc, reassociateFPReductions);
776
772
} else if (kind == vector::CombiningKind::MINF) {
777
- // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
778
- // NaNs/-0.0/+0.0 in the same way.
779
- result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
780
- rewriter, loc, llvmType, operand, acc,
781
- /* isMin=*/ true );
773
+ result =
774
+ createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
775
+ rewriter, loc, llvmType, operand, acc);
782
776
} else if (kind == vector::CombiningKind::MAXF) {
783
- // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle
784
- // NaNs/-0.0/+0.0 in the same way.
785
- result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
786
- rewriter, loc, llvmType, operand, acc,
787
- /* isMin=*/ false );
777
+ result =
778
+ createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
779
+ rewriter, loc, llvmType, operand, acc);
788
780
} else
789
781
return failure ();
790
782
0 commit comments