diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index c2fef4993f6ec..812bb26f201a0 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -13316,12 +13316,15 @@ namespace { // apply a combine. struct CombineResult; +enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 }; /// Helper class for folding sign/zero extensions. /// In particular, this class is used for the following combines: /// add | add_vl -> vwadd(u) | vwadd(u)_w /// sub | sub_vl -> vwsub(u) | vwsub(u)_w /// mul | mul_vl -> vwmul(u) | vwmul_su -/// +/// fadd -> vfwadd | vfwadd_w +/// fsub -> vfwsub | vfwsub_w +/// fmul -> vfwmul /// An object of this class represents an operand of the operation we want to /// combine. /// E.g., when trying to combine `mul_vl a, b`, we will have one instance of @@ -13335,7 +13338,8 @@ struct CombineResult; /// - VWADDU_W == add(op0, zext(op1)) /// - VWSUB_W == sub(op0, sext(op1)) /// - VWSUBU_W == sub(op0, zext(op1)) -/// +/// - VFWADD_W == fadd(op0, fpext(op1)) +/// - VFWSUB_W == fsub(op0, fpext(op1)) /// And VMV_V_X_VL, depending on the value, is conceptually equivalent to /// zext|sext(smaller_value). struct NodeExtensionHelper { @@ -13346,6 +13350,8 @@ struct NodeExtensionHelper { /// instance, a splat constant (e.g., 3), would support being both sign and /// zero extended. bool SupportsSExt; + /// Records if this operand is like being floating-Point extended. + bool SupportsFPExt; /// This boolean captures whether we care if this operand would still be /// around after the folding happens. bool EnforceOneUse; @@ -13369,6 +13375,7 @@ struct NodeExtensionHelper { case ISD::SIGN_EXTEND: case RISCVISD::VSEXT_VL: case RISCVISD::VZEXT_VL: + case RISCVISD::FP_EXTEND_VL: return OrigOperand.getOperand(0); default: return OrigOperand; @@ -13380,22 +13387,34 @@ struct NodeExtensionHelper { return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL; } + /// Get the extended opcode. + unsigned getExtOpc(ExtKind SupportsExt) const { + switch (SupportsExt) { + case ExtKind::SExt: + return RISCVISD::VSEXT_VL; + case ExtKind::ZExt: + return RISCVISD::VZEXT_VL; + case ExtKind::FPExt: + return RISCVISD::FP_EXTEND_VL; + } + } + /// Get or create a value that can feed \p Root with the given extension \p - /// SExt. If \p SExt is std::nullopt, this returns the source of this operand. - /// \see ::getSource(). + /// SupportsExt. If \p SExt is std::nullopt, this returns the source of this + /// operand. \see ::getSource(). SDValue getOrCreateExtendedOp(SDNode *Root, SelectionDAG &DAG, const RISCVSubtarget &Subtarget, - std::optional SExt) const { - if (!SExt.has_value()) + std::optional SupportsExt) const { + if (!SupportsExt.has_value()) return OrigOperand; - MVT NarrowVT = getNarrowType(Root); + MVT NarrowVT = getNarrowType(Root, *SupportsExt); SDValue Source = getSource(); if (Source.getValueType() == NarrowVT) return Source; - unsigned ExtOpc = *SExt ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL; + unsigned ExtOpc = getExtOpc(*SupportsExt); // If we need an extension, we should be changing the type. SDLoc DL(OrigOperand); @@ -13405,6 +13424,7 @@ struct NodeExtensionHelper { case ISD::SIGN_EXTEND: case RISCVISD::VSEXT_VL: case RISCVISD::VZEXT_VL: + case RISCVISD::FP_EXTEND_VL: return DAG.getNode(ExtOpc, DL, NarrowVT, Source, Mask, VL); case RISCVISD::VMV_V_X_VL: return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT, @@ -13420,41 +13440,79 @@ struct NodeExtensionHelper { /// Helper function to get the narrow type for \p Root. /// The narrow type is the type of \p Root where we divided the size of each /// element by 2. E.g., if Root's type <2xi16> -> narrow type <2xi8>. - /// \pre The size of the type of the elements of Root must be a multiple of 2 - /// and be greater than 16. - static MVT getNarrowType(const SDNode *Root) { + /// \pre Both the narrow type and the original type should be legal. + static MVT getNarrowType(const SDNode *Root, ExtKind SupportsExt) { MVT VT = Root->getSimpleValueType(0); // Determine the narrow size. unsigned NarrowSize = VT.getScalarSizeInBits() / 2; - assert(NarrowSize >= 8 && "Trying to extend something we can't represent"); - MVT NarrowVT = MVT::getVectorVT(MVT::getIntegerVT(NarrowSize), - VT.getVectorElementCount()); + + MVT EltVT = SupportsExt == ExtKind::FPExt + ? MVT::getFloatingPointVT(NarrowSize) + : MVT::getIntegerVT(NarrowSize); + + assert(NarrowSize >= (SupportsExt == ExtKind::FPExt ? 16 : 8) && + "Trying to extend something we can't represent"); + MVT NarrowVT = MVT::getVectorVT(EltVT, VT.getVectorElementCount()); return NarrowVT; } - /// Return the opcode required to materialize the folding of the sign - /// extensions (\p IsSExt == true) or zero extensions (IsSExt == false) for - /// both operands for \p Opcode. - /// Put differently, get the opcode to materialize: - /// - ISExt == true: \p Opcode(sext(a), sext(b)) -> newOpcode(a, b) - /// - ISExt == false: \p Opcode(zext(a), zext(b)) -> newOpcode(a, b) - /// \pre \p Opcode represents a supported root (\see ::isSupportedRoot()). - static unsigned getSameExtensionOpcode(unsigned Opcode, bool IsSExt) { + /// Get the opcode to materialize: + /// Opcode(sext(a), sext(b)) -> newOpcode(a, b) + static unsigned getSExtOpcode(unsigned Opcode) { switch (Opcode) { case ISD::ADD: case RISCVISD::ADD_VL: case RISCVISD::VWADD_W_VL: case RISCVISD::VWADDU_W_VL: - return IsSExt ? RISCVISD::VWADD_VL : RISCVISD::VWADDU_VL; + return RISCVISD::VWADD_VL; + case ISD::SUB: + case RISCVISD::SUB_VL: + case RISCVISD::VWSUB_W_VL: + case RISCVISD::VWSUBU_W_VL: + return RISCVISD::VWSUB_VL; case ISD::MUL: case RISCVISD::MUL_VL: - return IsSExt ? RISCVISD::VWMUL_VL : RISCVISD::VWMULU_VL; + return RISCVISD::VWMUL_VL; + default: + llvm_unreachable("Unexpected opcode"); + } + } + + /// Get the opcode to materialize: + /// Opcode(zext(a), zext(b)) -> newOpcode(a, b) + static unsigned getZExtOpcode(unsigned Opcode) { + switch (Opcode) { + case ISD::ADD: + case RISCVISD::ADD_VL: + case RISCVISD::VWADD_W_VL: + case RISCVISD::VWADDU_W_VL: + return RISCVISD::VWADDU_VL; case ISD::SUB: case RISCVISD::SUB_VL: case RISCVISD::VWSUB_W_VL: case RISCVISD::VWSUBU_W_VL: - return IsSExt ? RISCVISD::VWSUB_VL : RISCVISD::VWSUBU_VL; + return RISCVISD::VWSUBU_VL; + case ISD::MUL: + case RISCVISD::MUL_VL: + return RISCVISD::VWMULU_VL; + default: + llvm_unreachable("Unexpected opcode"); + } + } + + /// Get the opcode to materialize: + /// Opcode(fpext(a), fpext(b)) -> newOpcode(a, b) + static unsigned getFPExtOpcode(unsigned Opcode) { + switch (Opcode) { + case RISCVISD::FADD_VL: + case RISCVISD::VFWADD_W_VL: + return RISCVISD::VFWADD_VL; + case RISCVISD::FSUB_VL: + case RISCVISD::VFWSUB_W_VL: + return RISCVISD::VFWSUB_VL; + case RISCVISD::FMUL_VL: + return RISCVISD::VFWMUL_VL; default: llvm_unreachable("Unexpected opcode"); } @@ -13468,16 +13526,22 @@ struct NodeExtensionHelper { return RISCVISD::VWMULSU_VL; } - /// Get the opcode to materialize \p Opcode(a, s|zext(b)) -> - /// newOpcode(a, b). - static unsigned getWOpcode(unsigned Opcode, bool IsSExt) { + /// Get the opcode to materialize + /// \p Opcode(a, s|z|fpext(b)) -> newOpcode(a, b). + static unsigned getWOpcode(unsigned Opcode, ExtKind SupportsExt) { switch (Opcode) { case ISD::ADD: case RISCVISD::ADD_VL: - return IsSExt ? RISCVISD::VWADD_W_VL : RISCVISD::VWADDU_W_VL; + return SupportsExt == ExtKind::SExt ? RISCVISD::VWADD_W_VL + : RISCVISD::VWADDU_W_VL; case ISD::SUB: case RISCVISD::SUB_VL: - return IsSExt ? RISCVISD::VWSUB_W_VL : RISCVISD::VWSUBU_W_VL; + return SupportsExt == ExtKind::SExt ? RISCVISD::VWSUB_W_VL + : RISCVISD::VWSUBU_W_VL; + case RISCVISD::FADD_VL: + return RISCVISD::VFWADD_W_VL; + case RISCVISD::FSUB_VL: + return RISCVISD::VFWSUB_W_VL; default: llvm_unreachable("Unexpected opcode"); } @@ -13497,6 +13561,7 @@ struct NodeExtensionHelper { const RISCVSubtarget &Subtarget) { SupportsZExt = false; SupportsSExt = false; + SupportsFPExt = false; EnforceOneUse = true; CheckMask = true; unsigned Opc = OrigOperand.getOpcode(); @@ -13538,6 +13603,11 @@ struct NodeExtensionHelper { Mask = OrigOperand.getOperand(1); VL = OrigOperand.getOperand(2); break; + case RISCVISD::FP_EXTEND_VL: + SupportsFPExt = true; + Mask = OrigOperand.getOperand(1); + VL = OrigOperand.getOperand(2); + break; case RISCVISD::VMV_V_X_VL: { // Historically, we didn't care about splat values not disappearing during // combines. @@ -13584,15 +13654,16 @@ struct NodeExtensionHelper { /// Check if \p Root supports any extension folding combines. static bool isSupportedRoot(const SDNode *Root, const SelectionDAG &DAG) { + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); switch (Root->getOpcode()) { case ISD::ADD: case ISD::SUB: case ISD::MUL: { - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); if (!TLI.isTypeLegal(Root->getValueType(0))) return false; return Root->getValueType(0).isScalableVector(); } + // Vector Widening Integer Add/Sub/Mul Instructions case RISCVISD::ADD_VL: case RISCVISD::MUL_VL: case RISCVISD::VWADD_W_VL: @@ -13600,7 +13671,13 @@ struct NodeExtensionHelper { case RISCVISD::SUB_VL: case RISCVISD::VWSUB_W_VL: case RISCVISD::VWSUBU_W_VL: - return true; + // Vector Widening Floating-Point Add/Sub/Mul Instructions + case RISCVISD::FADD_VL: + case RISCVISD::FSUB_VL: + case RISCVISD::FMUL_VL: + case RISCVISD::VFWADD_W_VL: + case RISCVISD::VFWSUB_W_VL: + return TLI.isTypeLegal(Root->getValueType(0)); default: return false; } @@ -13616,16 +13693,23 @@ struct NodeExtensionHelper { unsigned Opc = Root->getOpcode(); switch (Opc) { - // We consider VW(U)_W(LHS, RHS) as if they were - // (LHS, S|ZEXT(RHS)) + // We consider + // VW_W(LHS, RHS) -> (LHS, SEXT(RHS)) + // VWU_W(LHS, RHS) -> (LHS, ZEXT(RHS)) + // VFW_W(LHS, RHS) -> F(LHS, FPEXT(RHS)) case RISCVISD::VWADD_W_VL: case RISCVISD::VWADDU_W_VL: case RISCVISD::VWSUB_W_VL: case RISCVISD::VWSUBU_W_VL: + case RISCVISD::VFWADD_W_VL: + case RISCVISD::VFWSUB_W_VL: if (OperandIdx == 1) { SupportsZExt = Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL; - SupportsSExt = !SupportsZExt; + SupportsSExt = + Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWSUB_W_VL; + SupportsFPExt = + Opc == RISCVISD::VFWADD_W_VL || Opc == RISCVISD::VFWSUB_W_VL; std::tie(Mask, VL) = getMaskAndVL(Root, DAG, Subtarget); CheckMask = true; // There's no existing extension here, so we don't have to worry about @@ -13685,11 +13769,16 @@ struct NodeExtensionHelper { case RISCVISD::MUL_VL: case RISCVISD::VWADD_W_VL: case RISCVISD::VWADDU_W_VL: + case RISCVISD::FADD_VL: + case RISCVISD::FMUL_VL: + case RISCVISD::VFWADD_W_VL: return true; case ISD::SUB: case RISCVISD::SUB_VL: case RISCVISD::VWSUB_W_VL: case RISCVISD::VWSUBU_W_VL: + case RISCVISD::FSUB_VL: + case RISCVISD::VFWSUB_W_VL: return false; default: llvm_unreachable("Unexpected opcode"); @@ -13711,10 +13800,9 @@ struct NodeExtensionHelper { struct CombineResult { /// Opcode to be generated when materializing the combine. unsigned TargetOpcode; - // No value means no extension is needed. If extension is needed, the value - // indicates if it needs to be sign extended. - std::optional SExtLHS; - std::optional SExtRHS; + // No value means no extension is needed. + std::optional LHSExt; + std::optional RHSExt; /// Root of the combine. SDNode *Root; /// LHS of the TargetOpcode. @@ -13723,10 +13811,10 @@ struct CombineResult { NodeExtensionHelper RHS; CombineResult(unsigned TargetOpcode, SDNode *Root, - const NodeExtensionHelper &LHS, std::optional SExtLHS, - const NodeExtensionHelper &RHS, std::optional SExtRHS) - : TargetOpcode(TargetOpcode), SExtLHS(SExtLHS), SExtRHS(SExtRHS), - Root(Root), LHS(LHS), RHS(RHS) {} + const NodeExtensionHelper &LHS, std::optional LHSExt, + const NodeExtensionHelper &RHS, std::optional RHSExt) + : TargetOpcode(TargetOpcode), LHSExt(LHSExt), RHSExt(RHSExt), Root(Root), + LHS(LHS), RHS(RHS) {} /// Return a value that uses TargetOpcode and that can be used to replace /// Root. @@ -13747,8 +13835,8 @@ struct CombineResult { break; } return DAG.getNode(TargetOpcode, SDLoc(Root), Root->getValueType(0), - LHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtLHS), - RHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtRHS), + LHS.getOrCreateExtendedOp(Root, DAG, Subtarget, LHSExt), + RHS.getOrCreateExtendedOp(Root, DAG, Subtarget, RHSExt), Merge, Mask, VL); } }; @@ -13756,7 +13844,7 @@ struct CombineResult { /// Check if \p Root follows a pattern Root(ext(LHS), ext(RHS)) /// where `ext` is the same for both LHS and RHS (i.e., both are sext or both /// are zext) and LHS and RHS can be folded into Root. -/// AllowSExt and AllozZExt define which form `ext` can take in this pattern. +/// AllowExtMask define which form `ext` can take in this pattern. /// /// \note If the pattern can match with both zext and sext, the returned /// CombineResult will feature the zext result. @@ -13765,22 +13853,24 @@ struct CombineResult { /// can be used to apply the pattern. static std::optional canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS, - const NodeExtensionHelper &RHS, bool AllowSExt, - bool AllowZExt, SelectionDAG &DAG, + const NodeExtensionHelper &RHS, + uint8_t AllowExtMask, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { - assert((AllowSExt || AllowZExt) && "Forgot to set what you want?"); if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) || !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget)) return std::nullopt; - if (AllowZExt && LHS.SupportsZExt && RHS.SupportsZExt) - return CombineResult(NodeExtensionHelper::getSameExtensionOpcode( - Root->getOpcode(), /*IsSExt=*/false), - Root, LHS, /*SExtLHS=*/false, RHS, /*SExtRHS=*/false); - if (AllowSExt && LHS.SupportsSExt && RHS.SupportsSExt) - return CombineResult(NodeExtensionHelper::getSameExtensionOpcode( - Root->getOpcode(), /*IsSExt=*/true), - Root, LHS, /*SExtLHS=*/true, RHS, - /*SExtRHS=*/true); + if ((AllowExtMask & ExtKind::ZExt) && LHS.SupportsZExt && RHS.SupportsZExt) + return CombineResult(NodeExtensionHelper::getZExtOpcode(Root->getOpcode()), + Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS, + /*RHSExt=*/{ExtKind::ZExt}); + if ((AllowExtMask & ExtKind::SExt) && LHS.SupportsSExt && RHS.SupportsSExt) + return CombineResult(NodeExtensionHelper::getSExtOpcode(Root->getOpcode()), + Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS, + /*RHSExt=*/{ExtKind::SExt}); + if ((AllowExtMask & ExtKind::FPExt) && RHS.SupportsFPExt) + return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()), + Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS, + /*RHSExt=*/{ExtKind::FPExt}); return std::nullopt; } @@ -13794,8 +13884,9 @@ static std::optional canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS, const NodeExtensionHelper &RHS, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { - return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true, - /*AllowZExt=*/true, DAG, Subtarget); + return canFoldToVWWithSameExtensionImpl( + Root, LHS, RHS, ExtKind::ZExt | ExtKind::SExt | ExtKind::FPExt, DAG, + Subtarget); } /// Check if \p Root follows a pattern Root(LHS, ext(RHS)) @@ -13809,18 +13900,23 @@ canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS, if (!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget)) return std::nullopt; + if (RHS.SupportsFPExt) + return CombineResult( + NodeExtensionHelper::getWOpcode(Root->getOpcode(), ExtKind::FPExt), + Root, LHS, /*LHSExt=*/std::nullopt, RHS, /*RHSExt=*/{ExtKind::FPExt}); + // FIXME: Is it useful to form a vwadd.wx or vwsub.wx if it removes a scalar // sext/zext? // Control this behavior behind an option (AllowSplatInVW_W) for testing // purposes. if (RHS.SupportsZExt && (!RHS.isSplat() || AllowSplatInVW_W)) return CombineResult( - NodeExtensionHelper::getWOpcode(Root->getOpcode(), /*IsSExt=*/false), - Root, LHS, /*SExtLHS=*/std::nullopt, RHS, /*SExtRHS=*/false); + NodeExtensionHelper::getWOpcode(Root->getOpcode(), ExtKind::ZExt), Root, + LHS, /*LHSExt=*/std::nullopt, RHS, /*RHSExt=*/{ExtKind::ZExt}); if (RHS.SupportsSExt && (!RHS.isSplat() || AllowSplatInVW_W)) return CombineResult( - NodeExtensionHelper::getWOpcode(Root->getOpcode(), /*IsSExt=*/true), - Root, LHS, /*SExtLHS=*/std::nullopt, RHS, /*SExtRHS=*/true); + NodeExtensionHelper::getWOpcode(Root->getOpcode(), ExtKind::SExt), Root, + LHS, /*LHSExt=*/std::nullopt, RHS, /*RHSExt=*/{ExtKind::SExt}); return std::nullopt; } @@ -13832,8 +13928,8 @@ static std::optional canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS, const NodeExtensionHelper &RHS, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { - return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true, - /*AllowZExt=*/false, DAG, Subtarget); + return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::SExt, DAG, + Subtarget); } /// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS)) @@ -13844,8 +13940,20 @@ static std::optional canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS, const NodeExtensionHelper &RHS, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { - return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false, - /*AllowZExt=*/true, DAG, Subtarget); + return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG, + Subtarget); +} + +/// Check if \p Root follows a pattern Root(fpext(LHS), fpext(RHS)) +/// +/// \returns std::nullopt if the pattern doesn't match or a CombineResult that +/// can be used to apply the pattern. +static std::optional +canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS, + const NodeExtensionHelper &RHS, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::FPExt, DAG, + Subtarget); } /// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS)) @@ -13863,7 +13971,8 @@ canFoldToVW_SU(SDNode *Root, const NodeExtensionHelper &LHS, !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget)) return std::nullopt; return CombineResult(NodeExtensionHelper::getSUOpcode(Root->getOpcode()), - Root, LHS, /*SExtLHS=*/true, RHS, /*SExtRHS=*/false); + Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS, + /*RHSExt=*/{ExtKind::ZExt}); } SmallVector @@ -13874,11 +13983,16 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) { case ISD::SUB: case RISCVISD::ADD_VL: case RISCVISD::SUB_VL: - // add|sub -> vwadd(u)|vwsub(u) + case RISCVISD::FADD_VL: + case RISCVISD::FSUB_VL: + // add|sub|fadd|fsub-> vwadd(u)|vwsub(u)|vfwadd|vfwsub Strategies.push_back(canFoldToVWWithSameExtension); - // add|sub -> vwadd(u)_w|vwsub(u)_w + // add|sub|fadd|fsub -> vwadd(u)_w|vwsub(u)_w}|vfwadd_w|vfwsub_w Strategies.push_back(canFoldToVW_W); break; + case RISCVISD::FMUL_VL: + Strategies.push_back(canFoldToVWWithSameExtension); + break; case ISD::MUL: case RISCVISD::MUL_VL: // mul -> vwmul(u) @@ -13896,6 +14010,11 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) { // vwaddu_w|vwsubu_w -> vwaddu|vwsubu Strategies.push_back(canFoldToVWWithZEXT); break; + case RISCVISD::VFWADD_W_VL: + case RISCVISD::VFWSUB_W_VL: + // vfwadd_w|vfwsub_w -> vfwadd|vfwsub + Strategies.push_back(canFoldToVWWithFPEXT); + break; default: llvm_unreachable("Unexpected opcode"); } @@ -13908,8 +14027,13 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) { /// add_vl -> vwadd(u) | vwadd(u)_w /// sub_vl -> vwsub(u) | vwsub(u)_w /// mul_vl -> vwmul(u) | vwmul_su +/// fadd_vl -> vfwadd | vfwadd_w +/// fsub_vl -> vfwsub | vfwsub_w +/// fmul_vl -> vfwmul /// vwadd_w(u) -> vwadd(u) -/// vwub_w(u) -> vwadd(u) +/// vwsub_w(u) -> vwsub(u) +/// vfwadd_w -> vfwadd +/// vfwsub_w -> vfwsub static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const RISCVSubtarget &Subtarget) { @@ -13965,9 +14089,9 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N, // All the inputs that are extended need to be folded, otherwise // we would be leaving the old input (since it is may still be used), // and the new one. - if (Res->SExtLHS.has_value()) + if (Res->LHSExt.has_value()) AppendUsersIfNeeded(LHS); - if (Res->SExtRHS.has_value()) + if (Res->RHSExt.has_value()) AppendUsersIfNeeded(RHS); break; } @@ -14532,107 +14656,6 @@ static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG, N->getOperand(2), Mask, VL); } -static SDValue performVFMUL_VLCombine(SDNode *N, SelectionDAG &DAG, - const RISCVSubtarget &Subtarget) { - if (N->getValueType(0).isScalableVector() && - N->getValueType(0).getVectorElementType() == MVT::f32 && - (Subtarget.hasVInstructionsF16Minimal() && - !Subtarget.hasVInstructionsF16())) { - return SDValue(); - } - - // FIXME: Ignore strict opcodes for now. - assert(!N->isTargetStrictFPOpcode() && "Unexpected opcode"); - - // Try to form widening multiply. - SDValue Op0 = N->getOperand(0); - SDValue Op1 = N->getOperand(1); - SDValue Merge = N->getOperand(2); - SDValue Mask = N->getOperand(3); - SDValue VL = N->getOperand(4); - - if (Op0.getOpcode() != RISCVISD::FP_EXTEND_VL || - Op1.getOpcode() != RISCVISD::FP_EXTEND_VL) - return SDValue(); - - // TODO: Refactor to handle more complex cases similar to - // combineBinOp_VLToVWBinOp_VL. - if ((!Op0.hasOneUse() || !Op1.hasOneUse()) && - (Op0 != Op1 || !Op0->hasNUsesOfValue(2, 0))) - return SDValue(); - - // Check the mask and VL are the same. - if (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL || - Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL) - return SDValue(); - - Op0 = Op0.getOperand(0); - Op1 = Op1.getOperand(0); - - return DAG.getNode(RISCVISD::VFWMUL_VL, SDLoc(N), N->getValueType(0), Op0, - Op1, Merge, Mask, VL); -} - -static SDValue performFADDSUB_VLCombine(SDNode *N, SelectionDAG &DAG, - const RISCVSubtarget &Subtarget) { - if (N->getValueType(0).isScalableVector() && - N->getValueType(0).getVectorElementType() == MVT::f32 && - (Subtarget.hasVInstructionsF16Minimal() && - !Subtarget.hasVInstructionsF16())) { - return SDValue(); - } - - SDValue Op0 = N->getOperand(0); - SDValue Op1 = N->getOperand(1); - SDValue Merge = N->getOperand(2); - SDValue Mask = N->getOperand(3); - SDValue VL = N->getOperand(4); - - bool IsAdd = N->getOpcode() == RISCVISD::FADD_VL; - - // Look for foldable FP_EXTENDS. - bool Op0IsExtend = - Op0.getOpcode() == RISCVISD::FP_EXTEND_VL && - (Op0.hasOneUse() || (Op0 == Op1 && Op0->hasNUsesOfValue(2, 0))); - bool Op1IsExtend = - (Op0 == Op1 && Op0IsExtend) || - (Op1.getOpcode() == RISCVISD::FP_EXTEND_VL && Op1.hasOneUse()); - - // Check the mask and VL. - if (Op0IsExtend && (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL)) - Op0IsExtend = false; - if (Op1IsExtend && (Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL)) - Op1IsExtend = false; - - // Canonicalize. - if (!Op1IsExtend) { - // Sub requires at least operand 1 to be an extend. - if (!IsAdd) - return SDValue(); - - // Add is commutable, if the other operand is foldable, swap them. - if (!Op0IsExtend) - return SDValue(); - - std::swap(Op0, Op1); - std::swap(Op0IsExtend, Op1IsExtend); - } - - // Op1 is a foldable extend. Op0 might be foldable. - Op1 = Op1.getOperand(0); - if (Op0IsExtend) - Op0 = Op0.getOperand(0); - - unsigned Opc; - if (IsAdd) - Opc = Op0IsExtend ? RISCVISD::VFWADD_VL : RISCVISD::VFWADD_W_VL; - else - Opc = Op0IsExtend ? RISCVISD::VFWSUB_VL : RISCVISD::VFWSUB_W_VL; - - return DAG.getNode(Opc, SDLoc(N), N->getValueType(0), Op0, Op1, Merge, Mask, - VL); -} - static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { assert(N->getOpcode() == ISD::SRA && "Unexpected opcode"); @@ -16165,11 +16188,18 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, case RISCVISD::STRICT_VFMSUB_VL: case RISCVISD::STRICT_VFNMSUB_VL: return performVFMADD_VLCombine(N, DAG, Subtarget); - case RISCVISD::FMUL_VL: - return performVFMUL_VLCombine(N, DAG, Subtarget); case RISCVISD::FADD_VL: case RISCVISD::FSUB_VL: - return performFADDSUB_VLCombine(N, DAG, Subtarget); + case RISCVISD::FMUL_VL: + case RISCVISD::VFWADD_W_VL: + case RISCVISD::VFWSUB_W_VL: { + if (N->getValueType(0).isScalableVector() && + N->getValueType(0).getVectorElementType() == MVT::f32 && + (Subtarget.hasVInstructionsF16Minimal() && + !Subtarget.hasVInstructionsF16())) + return SDValue(); + return combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget); + } case ISD::LOAD: case ISD::STORE: { if (DCI.isAfterLegalizeDAG()) diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfw-web-simplification.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfw-web-simplification.ll new file mode 100644 index 0000000000000..26f77225dbb0e --- /dev/null +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfw-web-simplification.ll @@ -0,0 +1,88 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=1 | FileCheck %s --check-prefixes=NO_FOLDING +; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=2 | FileCheck %s --check-prefixes=NO_FOLDING +; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=3 | FileCheck %s --check-prefixes=FOLDING +; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfhmin,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=3 | FileCheck %s --check-prefixes=FOLDING,ZVFHMIN +; Check that the default value enables the web folding and +; that it is bigger than 3. +; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - | FileCheck %s --check-prefixes=FOLDING + +define void @vfwmul_v2f116_multiple_users(ptr %x, ptr %y, ptr %z, <2 x half> %a, <2 x half> %b, <2 x half> %b2) { +; NO_FOLDING-LABEL: vfwmul_v2f116_multiple_users: +; NO_FOLDING: # %bb.0: +; NO_FOLDING-NEXT: vsetivli zero, 2, e16, mf4, ta, ma +; NO_FOLDING-NEXT: vfwcvt.f.f.v v11, v8 +; NO_FOLDING-NEXT: vfwcvt.f.f.v v8, v9 +; NO_FOLDING-NEXT: vfwcvt.f.f.v v9, v10 +; NO_FOLDING-NEXT: vsetvli zero, zero, e32, mf2, ta, ma +; NO_FOLDING-NEXT: vfmul.vv v10, v11, v8 +; NO_FOLDING-NEXT: vfadd.vv v11, v11, v9 +; NO_FOLDING-NEXT: vfsub.vv v8, v8, v9 +; NO_FOLDING-NEXT: vse32.v v10, (a0) +; NO_FOLDING-NEXT: vse32.v v11, (a1) +; NO_FOLDING-NEXT: vse32.v v8, (a2) +; NO_FOLDING-NEXT: ret +; +; ZVFHMIN-LABEL: vfwmul_v2f116_multiple_users: +; ZVFHMIN: # %bb.0: +; ZVFHMIN-NEXT: vsetivli zero, 2, e16, mf4, ta, ma +; ZVFHMIN-NEXT: vfwcvt.f.f.v v11, v8 +; ZVFHMIN-NEXT: vfwcvt.f.f.v v8, v9 +; ZVFHMIN-NEXT: vfwcvt.f.f.v v9, v10 +; ZVFHMIN-NEXT: vsetvli zero, zero, e32, mf2, ta, ma +; ZVFHMIN-NEXT: vfmul.vv v10, v11, v8 +; ZVFHMIN-NEXT: vfadd.vv v11, v11, v9 +; ZVFHMIN-NEXT: vfsub.vv v8, v8, v9 +; ZVFHMIN-NEXT: vse32.v v10, (a0) +; ZVFHMIN-NEXT: vse32.v v11, (a1) +; ZVFHMIN-NEXT: vse32.v v8, (a2) +; ZVFHMIN-NEXT: ret + %c = fpext <2 x half> %a to <2 x float> + %d = fpext <2 x half> %b to <2 x float> + %d2 = fpext <2 x half> %b2 to <2 x float> + %e = fmul <2 x float> %c, %d + %f = fadd <2 x float> %c, %d2 + %g = fsub <2 x float> %d, %d2 + store <2 x float> %e, ptr %x + store <2 x float> %f, ptr %y + store <2 x float> %g, ptr %z + ret void +} + +define void @vfwmul_v2f32_multiple_users(ptr %x, ptr %y, ptr %z, <2 x float> %a, <2 x float> %b, <2 x float> %b2) { +; NO_FOLDING-LABEL: vfwmul_v2f32_multiple_users: +; NO_FOLDING: # %bb.0: +; NO_FOLDING-NEXT: vsetivli zero, 2, e32, mf2, ta, ma +; NO_FOLDING-NEXT: vfwcvt.f.f.v v11, v8 +; NO_FOLDING-NEXT: vfwcvt.f.f.v v8, v9 +; NO_FOLDING-NEXT: vfwcvt.f.f.v v9, v10 +; NO_FOLDING-NEXT: vsetvli zero, zero, e64, m1, ta, ma +; NO_FOLDING-NEXT: vfmul.vv v10, v11, v8 +; NO_FOLDING-NEXT: vfadd.vv v11, v11, v9 +; NO_FOLDING-NEXT: vfsub.vv v8, v8, v9 +; NO_FOLDING-NEXT: vse64.v v10, (a0) +; NO_FOLDING-NEXT: vse64.v v11, (a1) +; NO_FOLDING-NEXT: vse64.v v8, (a2) +; NO_FOLDING-NEXT: ret +; +; FOLDING-LABEL: vfwmul_v2f32_multiple_users: +; FOLDING: # %bb.0: +; FOLDING-NEXT: vsetivli zero, 2, e32, mf2, ta, ma +; FOLDING-NEXT: vfwmul.vv v11, v8, v9 +; FOLDING-NEXT: vfwadd.vv v12, v8, v10 +; FOLDING-NEXT: vfwsub.vv v8, v9, v10 +; FOLDING-NEXT: vse64.v v11, (a0) +; FOLDING-NEXT: vse64.v v12, (a1) +; FOLDING-NEXT: vse64.v v8, (a2) +; FOLDING-NEXT: ret + %c = fpext <2 x float> %a to <2 x double> + %d = fpext <2 x float> %b to <2 x double> + %d2 = fpext <2 x float> %b2 to <2 x double> + %e = fmul <2 x double> %c, %d + %f = fadd <2 x double> %c, %d2 + %g = fsub <2 x double> %d, %d2 + store <2 x double> %e, ptr %x + store <2 x double> %f, ptr %y + store <2 x double> %g, ptr %z + ret void +} diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwadd.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwadd.ll index c9dc75e18774f..dd3a50cfd7737 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwadd.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwadd.ll @@ -396,12 +396,10 @@ define <32 x double> @vfwadd_vf_v32f32(ptr %x, float %y) { ; CHECK-NEXT: vsetvli zero, a1, e32, m8, ta, ma ; CHECK-NEXT: vle32.v v24, (a0) ; CHECK-NEXT: vsetivli zero, 16, e32, m8, ta, ma -; CHECK-NEXT: vslidedown.vi v0, v24, 16 +; CHECK-NEXT: vslidedown.vi v8, v24, 16 ; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma -; CHECK-NEXT: vfmv.v.f v16, fa0 -; CHECK-NEXT: vfwcvt.f.f.v v8, v16 -; CHECK-NEXT: vfwadd.wv v16, v8, v0 -; CHECK-NEXT: vfwadd.wv v8, v8, v24 +; CHECK-NEXT: vfwadd.vf v16, v8, fa0 +; CHECK-NEXT: vfwadd.vf v8, v24, fa0 ; CHECK-NEXT: ret %a = load <32 x float>, ptr %x %b = insertelement <32 x float> poison, float %y, i32 0 diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmul.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmul.ll index 8ad858d4c7659..7eaa1856ce221 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmul.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmul.ll @@ -394,18 +394,12 @@ define <32 x double> @vfwmul_vf_v32f32(ptr %x, float %y) { ; CHECK: # %bb.0: ; CHECK-NEXT: li a1, 32 ; CHECK-NEXT: vsetvli zero, a1, e32, m8, ta, ma -; CHECK-NEXT: vle32.v v16, (a0) -; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma -; CHECK-NEXT: vfwcvt.f.f.v v8, v16 +; CHECK-NEXT: vle32.v v24, (a0) ; CHECK-NEXT: vsetivli zero, 16, e32, m8, ta, ma -; CHECK-NEXT: vslidedown.vi v16, v16, 16 +; CHECK-NEXT: vslidedown.vi v8, v24, 16 ; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma -; CHECK-NEXT: vfwcvt.f.f.v v24, v16 -; CHECK-NEXT: vfmv.v.f v16, fa0 -; CHECK-NEXT: vfwcvt.f.f.v v0, v16 -; CHECK-NEXT: vsetvli zero, zero, e64, m8, ta, ma -; CHECK-NEXT: vfmul.vv v16, v24, v0 -; CHECK-NEXT: vfmul.vv v8, v8, v0 +; CHECK-NEXT: vfwmul.vf v16, v8, fa0 +; CHECK-NEXT: vfwmul.vf v8, v24, fa0 ; CHECK-NEXT: ret %a = load <32 x float>, ptr %x %b = insertelement <32 x float> poison, float %y, i32 0 diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwsub.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwsub.ll index d22781d6a97ac..8cf7c5f175865 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwsub.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwsub.ll @@ -394,18 +394,12 @@ define <32 x double> @vfwsub_vf_v32f32(ptr %x, float %y) { ; CHECK: # %bb.0: ; CHECK-NEXT: li a1, 32 ; CHECK-NEXT: vsetvli zero, a1, e32, m8, ta, ma -; CHECK-NEXT: vle32.v v16, (a0) -; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma -; CHECK-NEXT: vfwcvt.f.f.v v8, v16 +; CHECK-NEXT: vle32.v v24, (a0) ; CHECK-NEXT: vsetivli zero, 16, e32, m8, ta, ma -; CHECK-NEXT: vslidedown.vi v16, v16, 16 +; CHECK-NEXT: vslidedown.vi v8, v24, 16 ; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma -; CHECK-NEXT: vfwcvt.f.f.v v24, v16 -; CHECK-NEXT: vfmv.v.f v16, fa0 -; CHECK-NEXT: vfwcvt.f.f.v v0, v16 -; CHECK-NEXT: vsetvli zero, zero, e64, m8, ta, ma -; CHECK-NEXT: vfsub.vv v16, v24, v0 -; CHECK-NEXT: vfsub.vv v8, v8, v0 +; CHECK-NEXT: vfwsub.vf v16, v8, fa0 +; CHECK-NEXT: vfwsub.vf v8, v24, fa0 ; CHECK-NEXT: ret %a = load <32 x float>, ptr %x %b = insertelement <32 x float> poison, float %y, i32 0