Skip to content

Commit 8dc23ef

Browse files
[NFC][AArch64][SVE] Rename variables in partial reduction lowering functions (#120589)
1 parent 385b144 commit 8dc23ef

File tree

1 file changed

+55
-60
lines changed

1 file changed

+55
-60
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 55 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -21952,74 +21952,72 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
2195221952

2195321953
SDLoc DL(N);
2195421954

21955-
// The narrower of the two operands. Used as the accumulator
21956-
auto NarrowOp = N->getOperand(1);
21957-
auto MulOp = N->getOperand(2);
21958-
if (MulOp->getOpcode() != ISD::MUL)
21955+
SDValue Op2 = N->getOperand(2);
21956+
if (Op2->getOpcode() != ISD::MUL ||
21957+
!ISD::isExtOpcode(Op2->getOperand(0)->getOpcode()) ||
21958+
!ISD::isExtOpcode(Op2->getOperand(1)->getOpcode()))
2195921959
return SDValue();
2196021960

21961-
auto ExtA = MulOp->getOperand(0);
21962-
auto ExtB = MulOp->getOperand(1);
21963-
21964-
if (!ISD::isExtOpcode(ExtA->getOpcode()) ||
21965-
!ISD::isExtOpcode(ExtB->getOpcode()))
21966-
return SDValue();
21967-
bool AIsSigned = ExtA->getOpcode() == ISD::SIGN_EXTEND;
21968-
bool BIsSigned = ExtB->getOpcode() == ISD::SIGN_EXTEND;
21961+
SDValue Acc = N->getOperand(1);
21962+
SDValue Mul = N->getOperand(2);
21963+
SDValue ExtMulOpLHS = Mul->getOperand(0);
21964+
SDValue ExtMulOpRHS = Mul->getOperand(1);
2196921965

21970-
auto A = ExtA->getOperand(0);
21971-
auto B = ExtB->getOperand(0);
21972-
if (A.getValueType() != B.getValueType())
21966+
SDValue MulOpLHS = ExtMulOpLHS->getOperand(0);
21967+
SDValue MulOpRHS = ExtMulOpRHS->getOperand(0);
21968+
if (MulOpLHS.getValueType() != MulOpRHS.getValueType())
2197321969
return SDValue();
2197421970

21975-
EVT ReducedType = N->getValueType(0);
21976-
EVT MulSrcType = A.getValueType();
21971+
EVT ReducedVT = N->getValueType(0);
21972+
EVT MulSrcVT = MulOpLHS.getValueType();
2197721973

2197821974
// Dot products operate on chunks of four elements so there must be four times
2197921975
// as many elements in the wide type
21980-
if (!(ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) &&
21981-
!(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
21982-
!(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) &&
21983-
!(ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8) &&
21984-
!(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) &&
21985-
!(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
21976+
if (!(ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
21977+
!(ReducedVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
21978+
!(ReducedVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
21979+
!(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
21980+
!(ReducedVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
21981+
!(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
2198621982
return SDValue();
2198721983

21984+
bool MulOpLHSIsSigned = ExtMulOpLHS->getOpcode() == ISD::SIGN_EXTEND;
21985+
bool MulOpRHSIsSigned = ExtMulOpRHS->getOpcode() == ISD::SIGN_EXTEND;
2198821986
// If the extensions are mixed, we should lower it to a usdot instead
2198921987
unsigned Opcode = 0;
21990-
if (AIsSigned != BIsSigned) {
21988+
if (MulOpLHSIsSigned != MulOpRHSIsSigned) {
2199121989
if (!Subtarget->hasMatMulInt8())
2199221990
return SDValue();
2199321991

2199421992
bool Scalable = N->getValueType(0).isScalableVT();
2199521993
// There's no nxv2i64 version of usdot
21996-
if (Scalable && ReducedType != MVT::nxv4i32 && ReducedType != MVT::nxv4i64)
21994+
if (Scalable && ReducedVT != MVT::nxv4i32 && ReducedVT != MVT::nxv4i64)
2199721995
return SDValue();
2199821996

2199921997
Opcode = AArch64ISD::USDOT;
2200021998
// USDOT expects the signed operand to be last
22001-
if (!BIsSigned)
22002-
std::swap(A, B);
22003-
} else if (AIsSigned)
21999+
if (!MulOpRHSIsSigned)
22000+
std::swap(MulOpLHS, MulOpRHS);
22001+
} else if (MulOpLHSIsSigned)
2200422002
Opcode = AArch64ISD::SDOT;
2200522003
else
2200622004
Opcode = AArch64ISD::UDOT;
2200722005

2200822006
// Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
2200922007
// product followed by a zero / sign extension
22010-
if ((ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) ||
22011-
(ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8)) {
22012-
EVT ReducedTypeI32 =
22013-
(ReducedType.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
22008+
if ((ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
22009+
(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
22010+
EVT ReducedVTI32 =
22011+
(ReducedVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
2201422012

22015-
auto DotI32 = DAG.getNode(Opcode, DL, ReducedTypeI32,
22016-
DAG.getConstant(0, DL, ReducedTypeI32), A, B);
22017-
auto Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedType);
22018-
return DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(), NarrowOp,
22019-
Extended);
22013+
SDValue DotI32 =
22014+
DAG.getNode(Opcode, DL, ReducedVTI32,
22015+
DAG.getConstant(0, DL, ReducedVTI32), MulOpLHS, MulOpRHS);
22016+
SDValue Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedVT);
22017+
return DAG.getNode(ISD::ADD, DL, ReducedVT, Acc, Extended);
2202022018
}
2202122019

22022-
return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
22020+
return DAG.getNode(Opcode, DL, ReducedVT, Acc, MulOpLHS, MulOpRHS);
2202322021
}
2202422022

2202522023
SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
@@ -22036,32 +22034,29 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
2203622034

2203722035
SDLoc DL(N);
2203822036

22039-
auto Acc = N->getOperand(1);
22040-
auto ExtInput = N->getOperand(2);
22041-
22042-
EVT AccVT = Acc.getValueType();
22043-
EVT AccElemVT = AccVT.getVectorElementType();
22044-
22045-
if (ExtInput.getValueType().getVectorElementType() != AccElemVT)
22037+
if (!ISD::isExtOpcode(N->getOperand(2).getOpcode()))
2204622038
return SDValue();
22047-
22048-
unsigned ExtInputOpcode = ExtInput->getOpcode();
22049-
if (!ISD::isExtOpcode(ExtInputOpcode))
22039+
SDValue Acc = N->getOperand(1);
22040+
SDValue Ext = N->getOperand(2);
22041+
EVT AccVT = Acc.getValueType();
22042+
EVT ExtVT = Ext.getValueType();
22043+
if (ExtVT.getVectorElementType() != AccVT.getVectorElementType())
2205022044
return SDValue();
2205122045

22052-
auto Input = ExtInput->getOperand(0);
22053-
EVT InputVT = Input.getValueType();
22046+
SDValue ExtOp = Ext->getOperand(0);
22047+
EVT ExtOpVT = ExtOp.getValueType();
2205422048

22055-
if (!(InputVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
22056-
!(InputVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
22057-
!(InputVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
22049+
if (!(ExtOpVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
22050+
!(ExtOpVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
22051+
!(ExtOpVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
2205822052
return SDValue();
2205922053

22060-
bool InputIsSigned = ExtInputOpcode == ISD::SIGN_EXTEND;
22061-
auto BottomOpcode = InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
22062-
auto TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
22063-
auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input);
22064-
return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input);
22054+
bool ExtOpIsSigned = Ext.getOpcode() == ISD::SIGN_EXTEND;
22055+
unsigned BottomOpcode =
22056+
ExtOpIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
22057+
unsigned TopOpcode = ExtOpIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
22058+
SDValue BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, ExtOp);
22059+
return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, ExtOp);
2206522060
}
2206622061

2206722062
static SDValue performIntrinsicCombine(SDNode *N,
@@ -22073,9 +22068,9 @@ static SDValue performIntrinsicCombine(SDNode *N,
2207322068
default:
2207422069
break;
2207522070
case Intrinsic::experimental_vector_partial_reduce_add: {
22076-
if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
22071+
if (SDValue Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
2207722072
return Dot;
22078-
if (auto WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
22073+
if (SDValue WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
2207922074
return WideAdd;
2208022075
return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0),
2208122076
N->getOperand(1), N->getOperand(2));

0 commit comments

Comments
 (0)