diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index f2ca33c581433..ba24143e0b5b6 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -711,6 +711,8 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe { R->getVPDefID() == VPRecipeBase::VPWidenGEPSC || R->getVPDefID() == VPRecipeBase::VPWidenCastSC || R->getVPDefID() == VPRecipeBase::VPWidenIntrinsicSC || + R->getVPDefID() == VPRecipeBase::VPReductionSC || + R->getVPDefID() == VPRecipeBase::VPReductionEVLSC || R->getVPDefID() == VPRecipeBase::VPReplicateSC || R->getVPDefID() == VPRecipeBase::VPReverseVectorPointerSC || R->getVPDefID() == VPRecipeBase::VPVectorPointerSC; @@ -2236,7 +2238,7 @@ class VPInterleaveRecipe : public VPRecipeBase { /// A recipe to represent inloop reduction operations, performing a reduction on /// a vector operand into a scalar value, and adding the result to a chain. /// The Operands are {ChainOp, VecOp, [Condition]}. -class VPReductionRecipe : public VPSingleDefRecipe { +class VPReductionRecipe : public VPRecipeWithIRFlags { /// The recurrence decriptor for the reduction in question. const RecurrenceDescriptor &RdxDesc; bool IsOrdered; @@ -2247,12 +2249,17 @@ class VPReductionRecipe : public VPSingleDefRecipe { VPReductionRecipe(const unsigned char SC, const RecurrenceDescriptor &R, Instruction *I, ArrayRef Operands, VPValue *CondOp, bool IsOrdered, DebugLoc DL) - : VPSingleDefRecipe(SC, Operands, I, DL), RdxDesc(R), - IsOrdered(IsOrdered) { + : VPRecipeWithIRFlags(SC, Operands, + isa_and_nonnull(I) + ? R.getFastMathFlags() + : FastMathFlags(), + DL), + RdxDesc(R), IsOrdered(IsOrdered) { if (CondOp) { IsConditional = true; addOperand(CondOp); } + setUnderlyingValue(I); } public: @@ -2318,12 +2325,13 @@ class VPReductionRecipe : public VPSingleDefRecipe { /// The Operands are {ChainOp, VecOp, EVL, [Condition]}. class VPReductionEVLRecipe : public VPReductionRecipe { public: - VPReductionEVLRecipe(VPReductionRecipe &R, VPValue &EVL, VPValue *CondOp) + VPReductionEVLRecipe(VPReductionRecipe &R, VPValue &EVL, VPValue *CondOp, + DebugLoc DL = {}) : VPReductionRecipe( VPDef::VPReductionEVLSC, R.getRecurrenceDescriptor(), cast_or_null(R.getUnderlyingValue()), ArrayRef({R.getChainOp(), R.getVecOp(), &EVL}), CondOp, - R.isOrdered(), R.getDebugLoc()) {} + R.isOrdered(), DL) {} ~VPReductionEVLRecipe() override = default; diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 1b6894376f73b..d315dbe9b4170 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -2290,7 +2290,7 @@ void VPReductionRecipe::execute(VPTransformState &State) { "In-loop AnyOf reductions aren't currently supported"); // Propagate the fast-math flags carried by the underlying instruction. IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder); - State.Builder.setFastMathFlags(RdxDesc.getFastMathFlags()); + State.Builder.setFastMathFlags(getFastMathFlags()); State.setDebugLocFrom(getDebugLoc()); Value *NewVecOp = State.get(getVecOp()); if (VPValue *Cond = getCondOp()) { @@ -2337,7 +2337,7 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) { // Propagate the fast-math flags carried by the underlying instruction. IRBuilderBase::FastMathFlagGuard FMFGuard(Builder); const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor(); - Builder.setFastMathFlags(RdxDesc.getFastMathFlags()); + Builder.setFastMathFlags(getFastMathFlags()); RecurKind Kind = RdxDesc.getRecurrenceKind(); Value *Prev = State.get(getChainOp(), /*IsScalar*/ true); @@ -2374,6 +2374,7 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF, Type *ElementTy = Ctx.Types.inferScalarType(this); auto *VectorTy = cast(toVectorTy(ElementTy, VF)); unsigned Opcode = RdxDesc.getOpcode(); + FastMathFlags FMFs = getFastMathFlags(); // TODO: Support any-of and in-loop reductions. assert( @@ -2393,12 +2394,12 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF, Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, Ctx.CostKind); if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) { Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind); - return Cost + Ctx.TTI.getMinMaxReductionCost( - Id, VectorTy, RdxDesc.getFastMathFlags(), Ctx.CostKind); + return Cost + + Ctx.TTI.getMinMaxReductionCost(Id, VectorTy, FMFs, Ctx.CostKind); } - return Cost + Ctx.TTI.getArithmeticReductionCost( - Opcode, VectorTy, RdxDesc.getFastMathFlags(), Ctx.CostKind); + return Cost + Ctx.TTI.getArithmeticReductionCost(Opcode, VectorTy, FMFs, + Ctx.CostKind); } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -2409,8 +2410,7 @@ void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent, O << " = "; getChainOp()->printAsOperand(O, SlotTracker); O << " +"; - if (isa(getUnderlyingInstr())) - O << getUnderlyingInstr()->getFastMathFlags(); + printFlags(O); O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " ("; getVecOp()->printAsOperand(O, SlotTracker); if (isConditional()) { @@ -2431,8 +2431,7 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent, O << " = "; getChainOp()->printAsOperand(O, SlotTracker); O << " +"; - if (isa(getUnderlyingInstr())) - O << getUnderlyingInstr()->getFastMathFlags(); + printFlags(O); O << " vp.reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " ("; getVecOp()->printAsOperand(O, SlotTracker); O << ", "; diff --git a/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp b/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp index f9a85869e3142..ca1e48290f25b 100644 --- a/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp @@ -1165,22 +1165,27 @@ TEST_F(VPRecipeTest, MayHaveSideEffectsAndMayReadWriteMemory) { } { + auto *Add = BinaryOperator::CreateAdd(PoisonValue::get(Int32), + PoisonValue::get(Int32)); VPValue *ChainOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 1)); VPValue *VecOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 2)); VPValue *CondOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 3)); - VPReductionRecipe Recipe(RecurrenceDescriptor(), nullptr, ChainOp, CondOp, + VPReductionRecipe Recipe(RecurrenceDescriptor(), Add, ChainOp, CondOp, VecOp, false); EXPECT_FALSE(Recipe.mayHaveSideEffects()); EXPECT_FALSE(Recipe.mayReadFromMemory()); EXPECT_FALSE(Recipe.mayWriteToMemory()); EXPECT_FALSE(Recipe.mayReadOrWriteMemory()); + delete Add; } { + auto *Add = BinaryOperator::CreateAdd(PoisonValue::get(Int32), + PoisonValue::get(Int32)); VPValue *ChainOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 1)); VPValue *VecOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 2)); VPValue *CondOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 3)); - VPReductionRecipe Recipe(RecurrenceDescriptor(), nullptr, ChainOp, CondOp, + VPReductionRecipe Recipe(RecurrenceDescriptor(), Add, ChainOp, CondOp, VecOp, false); VPValue *EVL = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 4)); VPReductionEVLRecipe EVLRecipe(Recipe, *EVL, CondOp); @@ -1188,6 +1193,7 @@ TEST_F(VPRecipeTest, MayHaveSideEffectsAndMayReadWriteMemory) { EXPECT_FALSE(EVLRecipe.mayReadFromMemory()); EXPECT_FALSE(EVLRecipe.mayWriteToMemory()); EXPECT_FALSE(EVLRecipe.mayReadOrWriteMemory()); + delete Add; } { @@ -1529,28 +1535,34 @@ TEST_F(VPRecipeTest, dumpRecipeUnnamedVPValuesNotInPlanOrBlock) { TEST_F(VPRecipeTest, CastVPReductionRecipeToVPUser) { IntegerType *Int32 = IntegerType::get(C, 32); + auto *Add = BinaryOperator::CreateAdd(PoisonValue::get(Int32), + PoisonValue::get(Int32)); VPValue *ChainOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 1)); VPValue *VecOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 2)); VPValue *CondOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 3)); - VPReductionRecipe Recipe(RecurrenceDescriptor(), nullptr, ChainOp, CondOp, - VecOp, false); + VPReductionRecipe Recipe(RecurrenceDescriptor(), Add, ChainOp, CondOp, VecOp, + false); EXPECT_TRUE(isa(&Recipe)); VPRecipeBase *BaseR = &Recipe; EXPECT_TRUE(isa(BaseR)); + delete Add; } TEST_F(VPRecipeTest, CastVPReductionEVLRecipeToVPUser) { IntegerType *Int32 = IntegerType::get(C, 32); + auto *Add = BinaryOperator::CreateAdd(PoisonValue::get(Int32), + PoisonValue::get(Int32)); VPValue *ChainOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 1)); VPValue *VecOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 2)); VPValue *CondOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 3)); - VPReductionRecipe Recipe(RecurrenceDescriptor(), nullptr, ChainOp, CondOp, - VecOp, false); + VPReductionRecipe Recipe(RecurrenceDescriptor(), Add, ChainOp, CondOp, VecOp, + false); VPValue *EVL = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 0)); VPReductionEVLRecipe EVLRecipe(Recipe, *EVL, CondOp); EXPECT_TRUE(isa(&EVLRecipe)); VPRecipeBase *BaseR = &EVLRecipe; EXPECT_TRUE(isa(BaseR)); + delete Add; } } // namespace