diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index bd4150c87eabb..94c1c52a28462 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -11114,7 +11114,7 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N, } } -/// Given an integer binary operator, return the generic ISD::VECREDUCE_OP +/// Given a binary operator, return the *associative* generic ISD::VECREDUCE_OP /// which corresponds to it. static unsigned getVecReduceOpcode(unsigned Opc) { switch (Opc) { @@ -11136,6 +11136,9 @@ static unsigned getVecReduceOpcode(unsigned Opc) { return ISD::VECREDUCE_OR; case ISD::XOR: return ISD::VECREDUCE_XOR; + case ISD::FADD: + // Note: This is the associative form of the generic reduction opcode. + return ISD::VECREDUCE_FADD; } } @@ -11162,12 +11165,16 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG, const SDLoc DL(N); const EVT VT = N->getValueType(0); + const unsigned Opc = N->getOpcode(); - // TODO: Handle floating point here. - if (!VT.isInteger()) + // For FADD, we only handle the case with reassociation allowed. We + // could handle strict reduction order, but at the moment, there's no + // known reason to, and the complexity isn't worth it. + // TODO: Handle fminnum and fmaxnum here + if (!VT.isInteger() && + (Opc != ISD::FADD || !N->getFlags().hasAllowReassociation())) return SDValue(); - const unsigned Opc = N->getOpcode(); const unsigned ReduceOpc = getVecReduceOpcode(Opc); assert(Opc == ISD::getVecReduceBaseOpcode(ReduceOpc) && "Inconsistent mappings"); @@ -11200,7 +11207,7 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG, EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2); SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec, DAG.getVectorIdxConstant(0, DL)); - return DAG.getNode(ReduceOpc, DL, VT, Vec); + return DAG.getNode(ReduceOpc, DL, VT, Vec, N->getFlags()); } // Match (binop (reduce (extract_subvector V, 0), @@ -11222,7 +11229,9 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG, EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, Idx + 1); SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec, DAG.getVectorIdxConstant(0, DL)); - return DAG.getNode(ReduceOpc, DL, VT, Vec); + auto Flags = ReduceVec->getFlags(); + Flags.intersectWith(N->getFlags()); + return DAG.getNode(ReduceOpc, DL, VT, Vec, Flags); } } diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll index dd9a1118ab821..76df097a76971 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll @@ -764,6 +764,165 @@ define i32 @reduce_umin_16xi32_prefix5(ptr %p) { %umin3 = call i32 @llvm.umin.i32(i32 %umin2, i32 %e4) ret i32 %umin3 } + +define float @reduce_fadd_16xf32_prefix2(ptr %p) { +; CHECK-LABEL: reduce_fadd_16xf32_prefix2: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma +; CHECK-NEXT: vle32.v v8, (a0) +; CHECK-NEXT: vmv.s.x v9, zero +; CHECK-NEXT: vfredusum.vs v8, v8, v9 +; CHECK-NEXT: vfmv.f.s fa0, v8 +; CHECK-NEXT: ret + %v = load <16 x float>, ptr %p, align 256 + %e0 = extractelement <16 x float> %v, i32 0 + %e1 = extractelement <16 x float> %v, i32 1 + %fadd0 = fadd fast float %e0, %e1 + ret float %fadd0 +} + +define float @reduce_fadd_16xi32_prefix5(ptr %p) { +; CHECK-LABEL: reduce_fadd_16xi32_prefix5: +; CHECK: # %bb.0: +; CHECK-NEXT: lui a1, 524288 +; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma +; CHECK-NEXT: vle32.v v8, (a0) +; CHECK-NEXT: vmv.s.x v10, a1 +; CHECK-NEXT: vsetivli zero, 6, e32, m2, tu, ma +; CHECK-NEXT: vslideup.vi v8, v10, 5 +; CHECK-NEXT: vsetivli zero, 7, e32, m2, tu, ma +; CHECK-NEXT: vslideup.vi v8, v10, 6 +; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma +; CHECK-NEXT: vslideup.vi v8, v10, 7 +; CHECK-NEXT: vfredusum.vs v8, v8, v10 +; CHECK-NEXT: vfmv.f.s fa0, v8 +; CHECK-NEXT: ret + %v = load <16 x float>, ptr %p, align 256 + %e0 = extractelement <16 x float> %v, i32 0 + %e1 = extractelement <16 x float> %v, i32 1 + %e2 = extractelement <16 x float> %v, i32 2 + %e3 = extractelement <16 x float> %v, i32 3 + %e4 = extractelement <16 x float> %v, i32 4 + %fadd0 = fadd fast float %e0, %e1 + %fadd1 = fadd fast float %fadd0, %e2 + %fadd2 = fadd fast float %fadd1, %e3 + %fadd3 = fadd fast float %fadd2, %e4 + ret float %fadd3 +} + +;; Corner case tests for fadd associativity + +; Negative test, not associative. Would need strict opcode. +define float @reduce_fadd_2xf32_non_associative(ptr %p) { +; CHECK-LABEL: reduce_fadd_2xf32_non_associative: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma +; CHECK-NEXT: vle32.v v8, (a0) +; CHECK-NEXT: vfmv.f.s fa5, v8 +; CHECK-NEXT: vslidedown.vi v8, v8, 1 +; CHECK-NEXT: vfmv.f.s fa4, v8 +; CHECK-NEXT: fadd.s fa0, fa5, fa4 +; CHECK-NEXT: ret + %v = load <2 x float>, ptr %p, align 256 + %e0 = extractelement <2 x float> %v, i32 0 + %e1 = extractelement <2 x float> %v, i32 1 + %fadd0 = fadd float %e0, %e1 + ret float %fadd0 +} + +; Positive test - minimal set of fast math flags +define float @reduce_fadd_2xf32_reassoc_only(ptr %p) { +; CHECK-LABEL: reduce_fadd_2xf32_reassoc_only: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma +; CHECK-NEXT: vle32.v v8, (a0) +; CHECK-NEXT: lui a0, 524288 +; CHECK-NEXT: vmv.s.x v9, a0 +; CHECK-NEXT: vfredusum.vs v8, v8, v9 +; CHECK-NEXT: vfmv.f.s fa0, v8 +; CHECK-NEXT: ret + %v = load <2 x float>, ptr %p, align 256 + %e0 = extractelement <2 x float> %v, i32 0 + %e1 = extractelement <2 x float> %v, i32 1 + %fadd0 = fadd reassoc float %e0, %e1 + ret float %fadd0 +} + +; Negative test - wrong fast math flag. +define float @reduce_fadd_2xf32_ninf_only(ptr %p) { +; CHECK-LABEL: reduce_fadd_2xf32_ninf_only: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma +; CHECK-NEXT: vle32.v v8, (a0) +; CHECK-NEXT: vfmv.f.s fa5, v8 +; CHECK-NEXT: vslidedown.vi v8, v8, 1 +; CHECK-NEXT: vfmv.f.s fa4, v8 +; CHECK-NEXT: fadd.s fa0, fa5, fa4 +; CHECK-NEXT: ret + %v = load <2 x float>, ptr %p, align 256 + %e0 = extractelement <2 x float> %v, i32 0 + %e1 = extractelement <2 x float> %v, i32 1 + %fadd0 = fadd ninf float %e0, %e1 + ret float %fadd0 +} + + +; Negative test - last fadd is not associative +define float @reduce_fadd_4xi32_non_associative(ptr %p) { +; CHECK-LABEL: reduce_fadd_4xi32_non_associative: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, ma +; CHECK-NEXT: vle32.v v8, (a0) +; CHECK-NEXT: vslidedown.vi v9, v8, 3 +; CHECK-NEXT: vfmv.f.s fa5, v9 +; CHECK-NEXT: lui a0, 524288 +; CHECK-NEXT: vmv.s.x v9, a0 +; CHECK-NEXT: vslideup.vi v8, v9, 3 +; CHECK-NEXT: vfredusum.vs v8, v8, v9 +; CHECK-NEXT: vfmv.f.s fa4, v8 +; CHECK-NEXT: fadd.s fa0, fa4, fa5 +; CHECK-NEXT: ret + %v = load <4 x float>, ptr %p, align 256 + %e0 = extractelement <4 x float> %v, i32 0 + %e1 = extractelement <4 x float> %v, i32 1 + %e2 = extractelement <4 x float> %v, i32 2 + %e3 = extractelement <4 x float> %v, i32 3 + %fadd0 = fadd fast float %e0, %e1 + %fadd1 = fadd fast float %fadd0, %e2 + %fadd2 = fadd float %fadd1, %e3 + ret float %fadd2 +} + +; Negative test - first fadd is not associative +; We could form a reduce for elements 2 and 3. +define float @reduce_fadd_4xi32_non_associative2(ptr %p) { +; CHECK-LABEL: reduce_fadd_4xi32_non_associative2: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, ma +; CHECK-NEXT: vle32.v v8, (a0) +; CHECK-NEXT: vfmv.f.s fa5, v8 +; CHECK-NEXT: vslidedown.vi v9, v8, 1 +; CHECK-NEXT: vfmv.f.s fa4, v9 +; CHECK-NEXT: vslidedown.vi v9, v8, 2 +; CHECK-NEXT: vfmv.f.s fa3, v9 +; CHECK-NEXT: vslidedown.vi v8, v8, 3 +; CHECK-NEXT: vfmv.f.s fa2, v8 +; CHECK-NEXT: fadd.s fa5, fa5, fa4 +; CHECK-NEXT: fadd.s fa4, fa3, fa2 +; CHECK-NEXT: fadd.s fa0, fa5, fa4 +; CHECK-NEXT: ret + %v = load <4 x float>, ptr %p, align 256 + %e0 = extractelement <4 x float> %v, i32 0 + %e1 = extractelement <4 x float> %v, i32 1 + %e2 = extractelement <4 x float> %v, i32 2 + %e3 = extractelement <4 x float> %v, i32 3 + %fadd0 = fadd float %e0, %e1 + %fadd1 = fadd fast float %fadd0, %e2 + %fadd2 = fadd fast float %fadd1, %e3 + ret float %fadd2 +} + + ;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line: ; RV32: {{.*}} ; RV64: {{.*}}