-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[VPlan] Implement VPExtendedReduction, VPMulAccumulateReductionRecipe and corresponding vplan transformations. #137746
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
ElvisWang123
merged 10 commits into
llvm:main
from
ElvisWang123:add-mulacc-transform-NFC
May 16, 2025
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
3769e1f
[VPlan] Implement transformation for widen-cast/widen-mul + reduction…
ElvisWang123 a4077bc
Fixup, Address comments.
ElvisWang123 ce95f18
!fixup, Remove `computeCost()` for new recipes.
ElvisWang123 d267411
!fixup, address comments.
ElvisWang123 06ef087
!fixup, fix assertion of getResultType().
ElvisWang123 34a6f3b
!fixup getResultType() in VPMulAccumulateReductionRecipe.
ElvisWang123 bfc5fc2
Fixup! Remove IterT and always add result type in VPMulAccumulateRedu…
ElvisWang123 a0515c3
!fixup update VPMulAccumulateReduction::clone().
ElvisWang123 fca5a28
Address comments.
ElvisWang123 a1bf71c
Merge branch 'main' into add-mulacc-transform-NFC
ElvisWang123 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -517,6 +517,8 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue { | |
case VPRecipeBase::VPInstructionSC: | ||
case VPRecipeBase::VPReductionEVLSC: | ||
case VPRecipeBase::VPReductionSC: | ||
case VPRecipeBase::VPMulAccumulateReductionSC: | ||
case VPRecipeBase::VPExtendedReductionSC: | ||
case VPRecipeBase::VPReplicateSC: | ||
case VPRecipeBase::VPScalarIVStepsSC: | ||
case VPRecipeBase::VPVectorPointerSC: | ||
|
@@ -601,13 +603,15 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe { | |
DisjointFlagsTy(bool IsDisjoint) : IsDisjoint(IsDisjoint) {} | ||
}; | ||
|
||
struct NonNegFlagsTy { | ||
char NonNeg : 1; | ||
NonNegFlagsTy(bool IsNonNeg) : NonNeg(IsNonNeg) {} | ||
}; | ||
|
||
private: | ||
struct ExactFlagsTy { | ||
char IsExact : 1; | ||
}; | ||
struct NonNegFlagsTy { | ||
char NonNeg : 1; | ||
}; | ||
struct FastMathFlagsTy { | ||
char AllowReassoc : 1; | ||
char NoNaNs : 1; | ||
|
@@ -697,6 +701,12 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe { | |
: VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::DisjointOp), | ||
DisjointFlags(DisjointFlags) {} | ||
|
||
template <typename IterT> | ||
VPRecipeWithIRFlags(const unsigned char SC, IterT Operands, | ||
NonNegFlagsTy NonNegFlags, DebugLoc DL = {}) | ||
: VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::NonNegOp), | ||
NonNegFlags(NonNegFlags) {} | ||
|
||
protected: | ||
VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands, | ||
GEPNoWrapFlags GEPFlags, DebugLoc DL = {}) | ||
|
@@ -715,7 +725,9 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe { | |
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC || | ||
R->getVPDefID() == VPRecipeBase::VPReplicateSC || | ||
R->getVPDefID() == VPRecipeBase::VPVectorEndPointerSC || | ||
R->getVPDefID() == VPRecipeBase::VPVectorPointerSC; | ||
R->getVPDefID() == VPRecipeBase::VPVectorPointerSC || | ||
R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC || | ||
R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC; | ||
} | ||
|
||
static inline bool classof(const VPUser *U) { | ||
|
@@ -812,6 +824,15 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe { | |
|
||
FastMathFlags getFastMathFlags() const; | ||
|
||
/// Returns true if the recipe has non-negative flag. | ||
bool hasNonNegFlag() const { return OpType == OperationType::NonNegOp; } | ||
|
||
bool isNonNeg() const { | ||
assert(OpType == OperationType::NonNegOp && | ||
"recipe doesn't have a NNEG flag"); | ||
return NonNegFlags.NonNeg; | ||
} | ||
|
||
bool hasNoUnsignedWrap() const { | ||
assert(OpType == OperationType::OverflowingBinOp && | ||
"recipe doesn't have a NUW flag"); | ||
|
@@ -1294,10 +1315,19 @@ class VPWidenRecipe : public VPRecipeWithIRFlags, public VPIRMetadata { | |
: VPRecipeWithIRFlags(VPDefOpcode, Operands, I), VPIRMetadata(I), | ||
Opcode(I.getOpcode()) {} | ||
|
||
VPWidenRecipe(unsigned VPDefOpcode, unsigned Opcode, | ||
ArrayRef<VPValue *> Operands, bool NUW, bool NSW, DebugLoc DL) | ||
: VPRecipeWithIRFlags(VPDefOpcode, Operands, WrapFlagsTy(NUW, NSW), DL), | ||
Opcode(Opcode) {} | ||
|
||
public: | ||
VPWidenRecipe(Instruction &I, ArrayRef<VPValue *> Operands) | ||
: VPWidenRecipe(VPDef::VPWidenSC, I, Operands) {} | ||
|
||
VPWidenRecipe(unsigned Opcode, ArrayRef<VPValue *> Operands, bool NUW, | ||
bool NSW, DebugLoc DL) | ||
: VPWidenRecipe(VPDef::VPWidenSC, Opcode, Operands, NUW, NSW, DL) {} | ||
|
||
~VPWidenRecipe() override = default; | ||
|
||
VPWidenRecipe *clone() override { | ||
|
@@ -1342,8 +1372,15 @@ class VPWidenCastRecipe : public VPRecipeWithIRFlags, public VPIRMetadata { | |
"opcode of underlying cast doesn't match"); | ||
} | ||
|
||
VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy) | ||
: VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op), VPIRMetadata(), | ||
VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy, | ||
DebugLoc DL = {}) | ||
: VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, DL), VPIRMetadata(), | ||
Opcode(Opcode), ResultTy(ResultTy) {} | ||
|
||
VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy, | ||
bool IsNonNeg, DebugLoc DL = {}) | ||
: VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, NonNegFlagsTy(IsNonNeg), | ||
DL), | ||
Opcode(Opcode), ResultTy(ResultTy) {} | ||
|
||
~VPWidenCastRecipe() override = default; | ||
|
@@ -2394,6 +2431,28 @@ class VPReductionRecipe : public VPRecipeWithIRFlags { | |
setUnderlyingValue(I); | ||
} | ||
|
||
/// For VPExtendedReductionRecipe. | ||
/// Note that the debug location is from the extend. | ||
VPReductionRecipe(const unsigned char SC, const RecurKind RdxKind, | ||
ArrayRef<VPValue *> Operands, VPValue *CondOp, | ||
bool IsOrdered, DebugLoc DL) | ||
: VPRecipeWithIRFlags(SC, Operands, DL), RdxKind(RdxKind), | ||
IsOrdered(IsOrdered), IsConditional(CondOp) { | ||
if (CondOp) | ||
addOperand(CondOp); | ||
} | ||
|
||
/// For VPMulAccumulateReductionRecipe. | ||
/// Note that the NUW/NSW flags and the debug location are from the Mul. | ||
VPReductionRecipe(const unsigned char SC, const RecurKind RdxKind, | ||
ArrayRef<VPValue *> Operands, VPValue *CondOp, | ||
bool IsOrdered, WrapFlagsTy WrapFlags, DebugLoc DL) | ||
: VPRecipeWithIRFlags(SC, Operands, WrapFlags, DL), RdxKind(RdxKind), | ||
IsOrdered(IsOrdered), IsConditional(CondOp) { | ||
if (CondOp) | ||
addOperand(CondOp); | ||
} | ||
|
||
public: | ||
VPReductionRecipe(RecurKind RdxKind, FastMathFlags FMFs, Instruction *I, | ||
VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp, | ||
|
@@ -2402,6 +2461,13 @@ class VPReductionRecipe : public VPRecipeWithIRFlags { | |
ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp, | ||
IsOrdered, DL) {} | ||
|
||
VPReductionRecipe(const RecurKind RdxKind, FastMathFlags FMFs, | ||
VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp, | ||
bool IsOrdered, DebugLoc DL = {}) | ||
: VPReductionRecipe(VPDef::VPReductionSC, RdxKind, FMFs, nullptr, | ||
ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp, | ||
IsOrdered, DL) {} | ||
|
||
~VPReductionRecipe() override = default; | ||
|
||
VPReductionRecipe *clone() override { | ||
|
@@ -2412,7 +2478,9 @@ class VPReductionRecipe : public VPRecipeWithIRFlags { | |
|
||
static inline bool classof(const VPRecipeBase *R) { | ||
return R->getVPDefID() == VPRecipeBase::VPReductionSC || | ||
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC; | ||
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC || | ||
R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC || | ||
R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC; | ||
} | ||
|
||
static inline bool classof(const VPUser *U) { | ||
|
@@ -2551,6 +2619,182 @@ class VPReductionEVLRecipe : public VPReductionRecipe { | |
} | ||
}; | ||
|
||
/// A recipe to represent inloop extended reduction operations, performing a | ||
/// reduction on a extended vector operand into a scalar value, and adding the | ||
/// result to a chain. This recipe is abstract and needs to be lowered to | ||
/// concrete recipes before codegen. The operands are {ChainOp, VecOp, | ||
/// [Condition]}. | ||
class VPExtendedReductionRecipe : public VPReductionRecipe { | ||
/// Opcode of the extend for VecOp. | ||
Instruction::CastOps ExtOp; | ||
|
||
/// The scalar type after extending. | ||
Type *ResultTy; | ||
|
||
/// For cloning VPExtendedReductionRecipe. | ||
VPExtendedReductionRecipe(VPExtendedReductionRecipe *ExtRed) | ||
: VPReductionRecipe( | ||
VPDef::VPExtendedReductionSC, ExtRed->getRecurrenceKind(), | ||
{ExtRed->getChainOp(), ExtRed->getVecOp()}, ExtRed->getCondOp(), | ||
ExtRed->isOrdered(), ExtRed->getDebugLoc()), | ||
ExtOp(ExtRed->getExtOpcode()), ResultTy(ExtRed->getResultType()) { | ||
ElvisWang123 marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. note: underlying value not set here, but that should be fine as cost is computed before cloning at the moment. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added, thanks! |
||
transferFlags(*ExtRed); | ||
setUnderlyingValue(ExtRed->getUnderlyingValue()); | ||
} | ||
|
||
public: | ||
VPExtendedReductionRecipe(VPReductionRecipe *R, VPWidenCastRecipe *Ext) | ||
: VPReductionRecipe(VPDef::VPExtendedReductionSC, R->getRecurrenceKind(), | ||
{R->getChainOp(), Ext->getOperand(0)}, R->getCondOp(), | ||
R->isOrdered(), Ext->getDebugLoc()), | ||
ExtOp(Ext->getOpcode()), ResultTy(Ext->getResultType()) { | ||
assert((ExtOp == Instruction::CastOps::ZExt || | ||
ExtOp == Instruction::CastOps::SExt) && | ||
"VPExtendedReductionRecipe only supports zext and sext."); | ||
|
||
transferFlags(*Ext); | ||
setUnderlyingValue(R->getUnderlyingValue()); | ||
} | ||
|
||
~VPExtendedReductionRecipe() override = default; | ||
|
||
VPExtendedReductionRecipe *clone() override { | ||
return new VPExtendedReductionRecipe(this); | ||
} | ||
|
||
VP_CLASSOF_IMPL(VPDef::VPExtendedReductionSC); | ||
|
||
void execute(VPTransformState &State) override { | ||
llvm_unreachable("VPExtendedReductionRecipe should be transform to " | ||
"VPExtendedRecipe + VPReductionRecipe before execution."); | ||
}; | ||
|
||
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) | ||
/// Print the recipe. | ||
void print(raw_ostream &O, const Twine &Indent, | ||
VPSlotTracker &SlotTracker) const override; | ||
#endif | ||
|
||
/// The scalar type after extending. | ||
Type *getResultType() const { return ResultTy; } | ||
|
||
/// Is the extend ZExt? | ||
bool isZExt() const { return getExtOpcode() == Instruction::ZExt; } | ||
|
||
/// Get the opcode of the extend for VecOp. | ||
Instruction::CastOps getExtOpcode() const { return ExtOp; } | ||
}; | ||
|
||
/// A recipe to represent inloop MulAccumulateReduction operations, multiplying | ||
/// the vector operands (which may be extended), performing a reduction.add on | ||
/// the result, and adding the scalar result to a chain. This recipe is abstract | ||
/// and needs to be lowered to concrete recipes before codegen. The operands are | ||
/// {ChainOp, VecOp1, VecOp2, [Condition]}. | ||
class VPMulAccumulateReductionRecipe : public VPReductionRecipe { | ||
/// Opcode of the extend for VecOp1 and VecOp2. | ||
Instruction::CastOps ExtOp; | ||
|
||
/// Non-neg flag of the extend recipe. | ||
bool IsNonNeg = false; | ||
|
||
/// The scalar type after extending. | ||
Type *ResultTy = nullptr; | ||
|
||
/// For cloning VPMulAccumulateReductionRecipe. | ||
VPMulAccumulateReductionRecipe(VPMulAccumulateReductionRecipe *MulAcc) | ||
: VPReductionRecipe( | ||
VPDef::VPMulAccumulateReductionSC, MulAcc->getRecurrenceKind(), | ||
{MulAcc->getChainOp(), MulAcc->getVecOp0(), MulAcc->getVecOp1()}, | ||
MulAcc->getCondOp(), MulAcc->isOrdered(), | ||
WrapFlagsTy(MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap()), | ||
MulAcc->getDebugLoc()), | ||
ExtOp(MulAcc->getExtOpcode()), IsNonNeg(MulAcc->isNonNeg()), | ||
ResultTy(MulAcc->getResultType()) { | ||
transferFlags(*MulAcc); | ||
setUnderlyingValue(MulAcc->getUnderlyingValue()); | ||
} | ||
|
||
public: | ||
VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul, | ||
VPWidenCastRecipe *Ext0, | ||
VPWidenCastRecipe *Ext1, Type *ResultTy) | ||
: VPReductionRecipe( | ||
VPDef::VPMulAccumulateReductionSC, R->getRecurrenceKind(), | ||
{R->getChainOp(), Ext0->getOperand(0), Ext1->getOperand(0)}, | ||
R->getCondOp(), R->isOrdered(), | ||
WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()), | ||
R->getDebugLoc()), | ||
ExtOp(Ext0->getOpcode()), ResultTy(ResultTy) { | ||
assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) == | ||
Instruction::Add && | ||
"The reduction instruction in MulAccumulateteReductionRecipe must " | ||
"be Add"); | ||
assert((ExtOp == Instruction::CastOps::ZExt || | ||
ExtOp == Instruction::CastOps::SExt) && | ||
"VPMulAccumulateReductionRecipe only supports zext and sext."); | ||
setUnderlyingValue(R->getUnderlyingValue()); | ||
// Only set the non-negative flag if the original recipe contains. | ||
if (Ext0->hasNonNegFlag()) | ||
IsNonNeg = Ext0->isNonNeg(); | ||
} | ||
|
||
VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul, | ||
Type *ResultTy) | ||
: VPReductionRecipe( | ||
VPDef::VPMulAccumulateReductionSC, R->getRecurrenceKind(), | ||
{R->getChainOp(), Mul->getOperand(0), Mul->getOperand(1)}, | ||
R->getCondOp(), R->isOrdered(), | ||
WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()), | ||
R->getDebugLoc()), | ||
ExtOp(Instruction::CastOps::CastOpsEnd), ResultTy(ResultTy) { | ||
assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) == | ||
Instruction::Add && | ||
"The reduction instruction in MulAccumulateReductionRecipe must be " | ||
"Add"); | ||
setUnderlyingValue(R->getUnderlyingValue()); | ||
} | ||
|
||
~VPMulAccumulateReductionRecipe() override = default; | ||
|
||
VPMulAccumulateReductionRecipe *clone() override { | ||
return new VPMulAccumulateReductionRecipe(this); | ||
} | ||
|
||
VP_CLASSOF_IMPL(VPDef::VPMulAccumulateReductionSC); | ||
|
||
void execute(VPTransformState &State) override { | ||
llvm_unreachable("VPMulAccumulateReductionRecipe should transform to " | ||
"VPWidenCastRecipe + " | ||
"VPWidenRecipe + VPReductionRecipe before execution"); | ||
} | ||
|
||
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) | ||
/// Print the recipe. | ||
void print(raw_ostream &O, const Twine &Indent, | ||
VPSlotTracker &SlotTracker) const override; | ||
#endif | ||
|
||
Type *getResultType() const { return ResultTy; } | ||
|
||
/// The first vector value to be extended and reduced. | ||
VPValue *getVecOp0() const { return getOperand(1); } | ||
|
||
/// The second vector value to be extended and reduced. | ||
VPValue *getVecOp1() const { return getOperand(2); } | ||
|
||
/// Return true if this recipe contains extended operands. | ||
bool isExtended() const { return ExtOp != Instruction::CastOps::CastOpsEnd; } | ||
|
||
/// Return the opcode of the extends for the operands. | ||
Instruction::CastOps getExtOpcode() const { return ExtOp; } | ||
|
||
/// Return if the operands are zero-extended. | ||
bool isZExt() const { return ExtOp == Instruction::CastOps::ZExt; } | ||
|
||
/// Return true if the operand extends have the non-negative flag. | ||
bool isNonNeg() const { return IsNonNeg; } | ||
}; | ||
|
||
/// VPReplicateRecipe replicates a given instruction producing multiple scalar | ||
/// copies of the original scalar type, one per lane, instead of producing a | ||
/// single copy of widened type for all lanes. If the instruction is known to be | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.