@@ -21952,74 +21952,72 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
21952
21952
21953
21953
SDLoc DL(N);
21954
21954
21955
- // The narrower of the two operands. Used as the accumulator
21956
- auto NarrowOp = N->getOperand(1);
21957
- auto MulOp = N ->getOperand(2);
21958
- if (MulOp-> getOpcode() != ISD::MUL )
21955
+ SDValue Op2 = N->getOperand(2);
21956
+ if (Op2->getOpcode() != ISD::MUL ||
21957
+ !ISD::isExtOpcode(Op2 ->getOperand(0)->getOpcode()) ||
21958
+ !ISD::isExtOpcode(Op2->getOperand(1)-> getOpcode()) )
21959
21959
return SDValue();
21960
21960
21961
- auto ExtA = MulOp->getOperand(0);
21962
- auto ExtB = MulOp->getOperand(1);
21963
-
21964
- if (!ISD::isExtOpcode(ExtA->getOpcode()) ||
21965
- !ISD::isExtOpcode(ExtB->getOpcode()))
21966
- return SDValue();
21967
- bool AIsSigned = ExtA->getOpcode() == ISD::SIGN_EXTEND;
21968
- bool BIsSigned = ExtB->getOpcode() == ISD::SIGN_EXTEND;
21961
+ SDValue Acc = N->getOperand(1);
21962
+ SDValue Mul = N->getOperand(2);
21963
+ SDValue ExtMulOpLHS = Mul->getOperand(0);
21964
+ SDValue ExtMulOpRHS = Mul->getOperand(1);
21969
21965
21970
- auto A = ExtA ->getOperand(0);
21971
- auto B = ExtB ->getOperand(0);
21972
- if (A .getValueType() != B .getValueType())
21966
+ SDValue MulOpLHS = ExtMulOpLHS ->getOperand(0);
21967
+ SDValue MulOpRHS = ExtMulOpRHS ->getOperand(0);
21968
+ if (MulOpLHS .getValueType() != MulOpRHS .getValueType())
21973
21969
return SDValue();
21974
21970
21975
- EVT ReducedType = N->getValueType(0);
21976
- EVT MulSrcType = A .getValueType();
21971
+ EVT ReducedVT = N->getValueType(0);
21972
+ EVT MulSrcVT = MulOpLHS .getValueType();
21977
21973
21978
21974
// Dot products operate on chunks of four elements so there must be four times
21979
21975
// as many elements in the wide type
21980
- if (!(ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) &&
21981
- !(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
21982
- !(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) &&
21983
- !(ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8) &&
21984
- !(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) &&
21985
- !(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
21976
+ if (!(ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
21977
+ !(ReducedVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
21978
+ !(ReducedVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
21979
+ !(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
21980
+ !(ReducedVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
21981
+ !(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
21986
21982
return SDValue();
21987
21983
21984
+ bool MulOpLHSIsSigned = ExtMulOpLHS->getOpcode() == ISD::SIGN_EXTEND;
21985
+ bool MulOpRHSIsSigned = ExtMulOpRHS->getOpcode() == ISD::SIGN_EXTEND;
21988
21986
// If the extensions are mixed, we should lower it to a usdot instead
21989
21987
unsigned Opcode = 0;
21990
- if (AIsSigned != BIsSigned ) {
21988
+ if (MulOpLHSIsSigned != MulOpRHSIsSigned ) {
21991
21989
if (!Subtarget->hasMatMulInt8())
21992
21990
return SDValue();
21993
21991
21994
21992
bool Scalable = N->getValueType(0).isScalableVT();
21995
21993
// There's no nxv2i64 version of usdot
21996
- if (Scalable && ReducedType != MVT::nxv4i32 && ReducedType != MVT::nxv4i64)
21994
+ if (Scalable && ReducedVT != MVT::nxv4i32 && ReducedVT != MVT::nxv4i64)
21997
21995
return SDValue();
21998
21996
21999
21997
Opcode = AArch64ISD::USDOT;
22000
21998
// USDOT expects the signed operand to be last
22001
- if (!BIsSigned )
22002
- std::swap(A, B );
22003
- } else if (AIsSigned )
21999
+ if (!MulOpRHSIsSigned )
22000
+ std::swap(MulOpLHS, MulOpRHS );
22001
+ } else if (MulOpLHSIsSigned )
22004
22002
Opcode = AArch64ISD::SDOT;
22005
22003
else
22006
22004
Opcode = AArch64ISD::UDOT;
22007
22005
22008
22006
// Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
22009
22007
// product followed by a zero / sign extension
22010
- if ((ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) ||
22011
- (ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8)) {
22012
- EVT ReducedTypeI32 =
22013
- (ReducedType .isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
22008
+ if ((ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
22009
+ (ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
22010
+ EVT ReducedVTI32 =
22011
+ (ReducedVT .isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
22014
22012
22015
- auto DotI32 = DAG.getNode(Opcode, DL, ReducedTypeI32,
22016
- DAG.getConstant(0 , DL, ReducedTypeI32), A, B);
22017
- auto Extended = DAG.getSExtOrTrunc(DotI32 , DL, ReducedType );
22018
- return DAG.getNode(ISD::ADD , DL, NarrowOp.getValueType(), NarrowOp,
22019
- Extended);
22013
+ SDValue DotI32 =
22014
+ DAG.getNode(Opcode , DL, ReducedVTI32,
22015
+ DAG.getConstant(0 , DL, ReducedVTI32), MulOpLHS, MulOpRHS );
22016
+ SDValue Extended = DAG.getSExtOrTrunc(DotI32 , DL, ReducedVT);
22017
+ return DAG.getNode(ISD::ADD, DL, ReducedVT, Acc, Extended);
22020
22018
}
22021
22019
22022
- return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B );
22020
+ return DAG.getNode(Opcode, DL, ReducedVT, Acc, MulOpLHS, MulOpRHS );
22023
22021
}
22024
22022
22025
22023
SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
@@ -22036,32 +22034,29 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
22036
22034
22037
22035
SDLoc DL(N);
22038
22036
22039
- auto Acc = N->getOperand(1);
22040
- auto ExtInput = N->getOperand(2);
22041
-
22042
- EVT AccVT = Acc.getValueType();
22043
- EVT AccElemVT = AccVT.getVectorElementType();
22044
-
22045
- if (ExtInput.getValueType().getVectorElementType() != AccElemVT)
22037
+ if (!ISD::isExtOpcode(N->getOperand(2).getOpcode()))
22046
22038
return SDValue();
22047
-
22048
- unsigned ExtInputOpcode = ExtInput->getOpcode();
22049
- if (!ISD::isExtOpcode(ExtInputOpcode))
22039
+ SDValue Acc = N->getOperand(1);
22040
+ SDValue Ext = N->getOperand(2);
22041
+ EVT AccVT = Acc.getValueType();
22042
+ EVT ExtVT = Ext.getValueType();
22043
+ if (ExtVT.getVectorElementType() != AccVT.getVectorElementType())
22050
22044
return SDValue();
22051
22045
22052
- auto Input = ExtInput ->getOperand(0);
22053
- EVT InputVT = Input .getValueType();
22046
+ SDValue ExtOp = Ext ->getOperand(0);
22047
+ EVT ExtOpVT = ExtOp .getValueType();
22054
22048
22055
- if (!(InputVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
22056
- !(InputVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
22057
- !(InputVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
22049
+ if (!(ExtOpVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
22050
+ !(ExtOpVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
22051
+ !(ExtOpVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
22058
22052
return SDValue();
22059
22053
22060
- bool InputIsSigned = ExtInputOpcode == ISD::SIGN_EXTEND;
22061
- auto BottomOpcode = InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
22062
- auto TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
22063
- auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input);
22064
- return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input);
22054
+ bool ExtOpIsSigned = Ext.getOpcode() == ISD::SIGN_EXTEND;
22055
+ unsigned BottomOpcode =
22056
+ ExtOpIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
22057
+ unsigned TopOpcode = ExtOpIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
22058
+ SDValue BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, ExtOp);
22059
+ return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, ExtOp);
22065
22060
}
22066
22061
22067
22062
static SDValue performIntrinsicCombine(SDNode *N,
@@ -22073,9 +22068,9 @@ static SDValue performIntrinsicCombine(SDNode *N,
22073
22068
default:
22074
22069
break;
22075
22070
case Intrinsic::experimental_vector_partial_reduce_add: {
22076
- if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
22071
+ if (SDValue Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
22077
22072
return Dot;
22078
- if (auto WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
22073
+ if (SDValue WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
22079
22074
return WideAdd;
22080
22075
return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0),
22081
22076
N->getOperand(1), N->getOperand(2));
0 commit comments