@@ -495,6 +495,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
495
495
setTargetDAGCombine (ISD::INTRINSIC_VOID);
496
496
setTargetDAGCombine (ISD::INTRINSIC_W_CHAIN);
497
497
setTargetDAGCombine (ISD::INSERT_VECTOR_ELT);
498
+ setTargetDAGCombine (ISD::EXTRACT_VECTOR_ELT);
498
499
499
500
MaxStoresPerMemset = MaxStoresPerMemsetOptSize = 8 ;
500
501
MaxStoresPerMemcpy = MaxStoresPerMemcpyOptSize = 4 ;
@@ -8584,6 +8585,102 @@ static SDValue performPostLD1Combine(SDNode *N,
8584
8585
return SDValue ();
8585
8586
}
8586
8587
8588
+ // / Target-specific DAG combine for the across vector reduction.
8589
+ // / This function specifically handles the final clean-up step of a vector
8590
+ // / reduction produced by the LoopVectorizer. It is the log2-shuffle pattern,
8591
+ // / consisting of log2(NumVectorElements) steps and, in each step, 2^(s)
8592
+ // / elements are reduced, where s is an induction variable from 0
8593
+ // / to log2(NumVectorElements).
8594
+ // / For example,
8595
+ // / %1 = vector_shuffle %0, <2,3,u,u>
8596
+ // / %2 = add %0, %1
8597
+ // / %3 = vector_shuffle %2, <1,u,u,u>
8598
+ // / %4 = add %2, %3
8599
+ // / %5 = extract_vector_elt %4, 0
8600
+ // / becomes :
8601
+ // / %0 = uaddv %0
8602
+ // / %1 = extract_vector_elt %0, 0
8603
+ // /
8604
+ // / FIXME: Currently this function is implemented and tested specifically
8605
+ // / for the add reduction. We could also support other types of across lane
8606
+ // / reduction available in AArch64, including SMAXV, SMINV, UMAXV, UMINV,
8607
+ // / SADDLV, UADDLV, FMAXNMV, FMAXV, FMINNMV, FMINV.
8608
+ static SDValue
8609
+ performAcrossLaneReductionCombine (SDNode *N, SelectionDAG &DAG,
8610
+ const AArch64Subtarget *Subtarget) {
8611
+ if (!Subtarget->hasNEON ())
8612
+ return SDValue ();
8613
+ SDValue N0 = N->getOperand (0 );
8614
+ SDValue N1 = N->getOperand (1 );
8615
+
8616
+ // Check if the input vector is fed by the operator we want to handle.
8617
+ // We specifically check only ADD for now.
8618
+ if (N0->getOpcode () != ISD::ADD)
8619
+ return SDValue ();
8620
+
8621
+ // The vector extract idx must constant zero because we only expect the final
8622
+ // result of the reduction is placed in lane 0.
8623
+ if (!isa<ConstantSDNode>(N1) || cast<ConstantSDNode>(N1)->getZExtValue ())
8624
+ return SDValue ();
8625
+
8626
+ EVT EltTy = N0.getValueType ().getVectorElementType ();
8627
+ if (EltTy != MVT::i32 && EltTy != MVT::i16 && EltTy != MVT::i8)
8628
+ return SDValue ();
8629
+
8630
+ int NumVecElts = N0.getValueType ().getVectorNumElements ();
8631
+ if (NumVecElts != 4 && NumVecElts != 8 && NumVecElts != 16 )
8632
+ return SDValue ();
8633
+
8634
+ int NumExpectedSteps = APInt (8 , NumVecElts).logBase2 ();
8635
+ SDValue PreOp = N0;
8636
+ // Iterate over each step of the across vector reduction.
8637
+ for (int CurStep = 0 ; CurStep != NumExpectedSteps; ++CurStep) {
8638
+ // We specifically check ADD for now.
8639
+ if (PreOp.getOpcode () != ISD::ADD)
8640
+ return SDValue ();
8641
+ SDValue CurOp = PreOp.getOperand (0 );
8642
+ SDValue Shuffle = PreOp.getOperand (1 );
8643
+ if (Shuffle.getOpcode () != ISD::VECTOR_SHUFFLE) {
8644
+ // Try to swap the 1st and 2nd operand as add is commutative.
8645
+ CurOp = PreOp.getOperand (1 );
8646
+ Shuffle = PreOp.getOperand (0 );
8647
+ if (Shuffle.getOpcode () != ISD::VECTOR_SHUFFLE)
8648
+ return SDValue ();
8649
+ }
8650
+ // Check if it forms one step of the across vector reduction.
8651
+ // E.g.,
8652
+ // %cur = add %1, %0
8653
+ // %shuffle = vector_shuffle %cur, <2, 3, u, u>
8654
+ // %pre = add %cur, %shuffle
8655
+ if (Shuffle.getOperand (0 ) != CurOp)
8656
+ return SDValue ();
8657
+
8658
+ int NumMaskElts = 1 << CurStep;
8659
+ ArrayRef<int > Mask = cast<ShuffleVectorSDNode>(Shuffle)->getMask ();
8660
+ // Check mask values in each step.
8661
+ // We expect the shuffle mask in each step follows a specific pattern
8662
+ // denoted here by the <M, U> form, where M is a sequence of integers
8663
+ // starting from NumMaskElts, increasing by 1, and the number integers
8664
+ // in M should be NumMaskElts. U is a sequence of UNDEFs and the number
8665
+ // of undef in U should be NumVecElts - NumMaskElts.
8666
+ // E.g., for <8 x i16>, mask values in each step should be :
8667
+ // step 0 : <1,u,u,u,u,u,u,u>
8668
+ // step 1 : <2,3,u,u,u,u,u,u>
8669
+ // step 2 : <4,5,6,7,u,u,u,u>
8670
+ for (int i = 0 ; i < NumVecElts; ++i)
8671
+ if ((i < NumMaskElts && Mask[i] != (NumMaskElts + i)) ||
8672
+ (i >= NumMaskElts && !(Mask[i] < 0 )))
8673
+ return SDValue ();
8674
+
8675
+ PreOp = CurOp;
8676
+ }
8677
+ SDLoc DL (N);
8678
+ return DAG.getNode (
8679
+ ISD::EXTRACT_VECTOR_ELT, DL, N->getValueType (0 ),
8680
+ DAG.getNode (AArch64ISD::UADDV, DL, PreOp.getSimpleValueType (), PreOp),
8681
+ DAG.getConstant (0 , DL, MVT::i64));
8682
+ }
8683
+
8587
8684
// / Target-specific DAG combine function for NEON load/store intrinsics
8588
8685
// / to merge base address updates.
8589
8686
static SDValue performNEONPostLDSTCombine (SDNode *N,
@@ -9178,6 +9275,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
9178
9275
return performNVCASTCombine (N);
9179
9276
case ISD::INSERT_VECTOR_ELT:
9180
9277
return performPostLD1Combine (N, DCI, true );
9278
+ case ISD::EXTRACT_VECTOR_ELT:
9279
+ return performAcrossLaneReductionCombine (N, DAG, Subtarget);
9181
9280
case ISD::INTRINSIC_VOID:
9182
9281
case ISD::INTRINSIC_W_CHAIN:
9183
9282
switch (cast<ConstantSDNode>(N->getOperand (1 ))->getZExtValue ()) {
0 commit comments