@@ -8149,15 +8149,15 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
8149
8149
// something that isn't another partial reduction. This is because the
8150
8150
// extends are intended to be lowered along with the reduction itself.
8151
8151
8152
- // Build up a set of partial reduction bin ops for efficient use checking.
8153
- SmallSet<User *, 4 > PartialReductionBinOps ;
8152
+ // Build up a set of partial reduction ops for efficient use checking.
8153
+ SmallSet<User *, 4 > PartialReductionOps ;
8154
8154
for (const auto &[PartialRdx, _] : PartialReductionChains)
8155
- PartialReductionBinOps .insert (PartialRdx.BinOp );
8155
+ PartialReductionOps .insert (PartialRdx.ExtendUser );
8156
8156
8157
8157
auto ExtendIsOnlyUsedByPartialReductions =
8158
- [&PartialReductionBinOps ](Instruction *Extend) {
8158
+ [&PartialReductionOps ](Instruction *Extend) {
8159
8159
return all_of (Extend->users (), [&](const User *U) {
8160
- return PartialReductionBinOps .contains (U);
8160
+ return PartialReductionOps .contains (U);
8161
8161
});
8162
8162
};
8163
8163
@@ -8166,15 +8166,14 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
8166
8166
for (auto Pair : PartialReductionChains) {
8167
8167
PartialReductionChain Chain = Pair.first ;
8168
8168
if (ExtendIsOnlyUsedByPartialReductions (Chain.ExtendA ) &&
8169
- ExtendIsOnlyUsedByPartialReductions (Chain.ExtendB ))
8169
+ (!Chain. ExtendB || ExtendIsOnlyUsedByPartialReductions (Chain.ExtendB ) ))
8170
8170
ScaledReductionMap.try_emplace (Chain.Reduction , Pair.second );
8171
8171
}
8172
8172
}
8173
8173
8174
8174
bool VPRecipeBuilder::getScaledReductions (
8175
8175
Instruction *PHI, Instruction *RdxExitInstr, VFRange &Range,
8176
8176
SmallVectorImpl<std::pair<PartialReductionChain, unsigned >> &Chains) {
8177
-
8178
8177
if (!CM.TheLoop ->contains (RdxExitInstr))
8179
8178
return false ;
8180
8179
@@ -8203,43 +8202,71 @@ bool VPRecipeBuilder::getScaledReductions(
8203
8202
if (PhiOp != PHI)
8204
8203
return false ;
8205
8204
8206
- auto *BinOp = dyn_cast<BinaryOperator>(Op);
8207
- if (!BinOp || !BinOp->hasOneUse ())
8208
- return false ;
8209
-
8210
8205
using namespace llvm ::PatternMatch;
8211
- // Use the side-effect of match to replace BinOp only if the pattern is
8212
- // matched, we don't care at this point whether it actually matched.
8213
- match (BinOp, m_Neg (m_BinOp (BinOp)));
8214
8206
8215
- Value *A, *B;
8216
- if (!match (BinOp->getOperand (0 ), m_ZExtOrSExt (m_Value (A))) ||
8217
- !match (BinOp->getOperand (1 ), m_ZExtOrSExt (m_Value (B))))
8218
- return false ;
8207
+ // If the update is a binary operator, check both of its operands to see if
8208
+ // they are extends. Otherwise, see if the update comes directly from an
8209
+ // extend.
8210
+ Instruction *Exts[2 ] = {nullptr };
8211
+ BinaryOperator *ExtendUser = dyn_cast<BinaryOperator>(Op);
8212
+ std::optional<unsigned > BinOpc;
8213
+ Type *ExtOpTypes[2 ] = {nullptr };
8214
+
8215
+ auto CollectExtInfo = [&Exts,
8216
+ &ExtOpTypes](SmallVectorImpl<Value *> &Ops) -> bool {
8217
+ unsigned I = 0 ;
8218
+ for (Value *OpI : Ops) {
8219
+ Value *ExtOp;
8220
+ if (!match (OpI, m_ZExtOrSExt (m_Value (ExtOp))))
8221
+ return false ;
8222
+ Exts[I] = cast<Instruction>(OpI);
8223
+ ExtOpTypes[I] = ExtOp->getType ();
8224
+ I++;
8225
+ }
8226
+ return true ;
8227
+ };
8228
+
8229
+ if (ExtendUser) {
8230
+ if (!ExtendUser->hasOneUse ())
8231
+ return false ;
8219
8232
8220
- Instruction *ExtA = cast<Instruction>(BinOp->getOperand (0 ));
8221
- Instruction *ExtB = cast<Instruction>(BinOp->getOperand (1 ));
8233
+ // Use the side-effect of match to replace BinOp only if the pattern is
8234
+ // matched, we don't care at this point whether it actually matched.
8235
+ match (ExtendUser, m_Neg (m_BinOp (ExtendUser)));
8236
+
8237
+ SmallVector<Value *> Ops (ExtendUser->operands ());
8238
+ if (!CollectExtInfo (Ops))
8239
+ return false ;
8240
+
8241
+ BinOpc = std::make_optional (ExtendUser->getOpcode ());
8242
+ } else if (match (Update, m_Add (m_Value (), m_Value ()))) {
8243
+ // We already know the operands for Update are Op and PhiOp.
8244
+ SmallVector<Value *> Ops ({Op});
8245
+ if (!CollectExtInfo (Ops))
8246
+ return false ;
8247
+
8248
+ ExtendUser = Update;
8249
+ BinOpc = std::nullopt;
8250
+ } else
8251
+ return false ;
8222
8252
8223
8253
TTI::PartialReductionExtendKind OpAExtend =
8224
- TargetTransformInfo ::getPartialReductionExtendKind (ExtA );
8254
+ TTI ::getPartialReductionExtendKind (Exts[ 0 ] );
8225
8255
TTI::PartialReductionExtendKind OpBExtend =
8226
- TargetTransformInfo::getPartialReductionExtendKind (ExtB);
8227
-
8228
- PartialReductionChain Chain (RdxExitInstr, ExtA, ExtB, BinOp);
8256
+ Exts[1 ] ? TTI::getPartialReductionExtendKind (Exts[1 ]) : TTI::PR_None;
8257
+ PartialReductionChain Chain (RdxExitInstr, Exts[0 ], Exts[1 ], ExtendUser);
8229
8258
8230
8259
TypeSize PHISize = PHI->getType ()->getPrimitiveSizeInBits ();
8231
- TypeSize ASize = A->getType ()->getPrimitiveSizeInBits ();
8232
-
8260
+ TypeSize ASize = ExtOpTypes[0 ]->getPrimitiveSizeInBits ();
8233
8261
if (!PHISize.hasKnownScalarFactor (ASize))
8234
8262
return false ;
8235
-
8236
8263
unsigned TargetScaleFactor = PHISize.getKnownScalarFactor (ASize);
8237
8264
8238
8265
if (LoopVectorizationPlanner::getDecisionAndClampRange (
8239
8266
[&](ElementCount VF) {
8240
8267
InstructionCost Cost = TTI->getPartialReductionCost (
8241
- Update->getOpcode (), A-> getType (), B-> getType (), PHI-> getType () ,
8242
- VF, OpAExtend, OpBExtend, BinOp-> getOpcode () , CM.CostKind );
8268
+ Update->getOpcode (), ExtOpTypes[ 0 ], ExtOpTypes[ 1 ] ,
8269
+ PHI-> getType (), VF, OpAExtend, OpBExtend, BinOpc , CM.CostKind );
8243
8270
return Cost.isValid ();
8244
8271
},
8245
8272
Range)) {
0 commit comments