Skip to content

[NFC][AArch64][SVE] Rename variables in partial reduction lowering functions #120589

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 20, 2024

Conversation

JamesChesterman
Copy link
Contributor

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Dec 19, 2024

@llvm/pr-subscribers-backend-aarch64

Author: James Chesterman (JamesChesterman)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/120589.diff

1 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+50-52)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d1354ccf376609..290f349c77809f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21739,73 +21739,71 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
   SDLoc DL(N);
 
   // The narrower of the two operands. Used as the accumulator
-  auto NarrowOp = N->getOperand(1);
-  auto MulOp = N->getOperand(2);
-  if (MulOp->getOpcode() != ISD::MUL)
+  auto A = N->getOperand(1);
+  auto B = N->getOperand(2);
+  if (B->getOpcode() != ISD::MUL)
     return SDValue();
 
-  auto ExtA = MulOp->getOperand(0);
-  auto ExtB = MulOp->getOperand(1);
+  auto ExtMulOp1 = B->getOperand(0);
+  auto ExtMulOp2 = B->getOperand(1);
 
-  if (!ISD::isExtOpcode(ExtA->getOpcode()) ||
-      !ISD::isExtOpcode(ExtB->getOpcode()))
+  if (!ISD::isExtOpcode(ExtMulOp1->getOpcode()) ||
+      !ISD::isExtOpcode(ExtMulOp2->getOpcode()))
     return SDValue();
-  bool AIsSigned = ExtA->getOpcode() == ISD::SIGN_EXTEND;
-  bool BIsSigned = ExtB->getOpcode() == ISD::SIGN_EXTEND;
+  bool MulOp1IsSigned = ExtMulOp1->getOpcode() == ISD::SIGN_EXTEND;
+  bool MulOp2IsSigned = ExtMulOp2->getOpcode() == ISD::SIGN_EXTEND;
 
-  auto A = ExtA->getOperand(0);
-  auto B = ExtB->getOperand(0);
-  if (A.getValueType() != B.getValueType())
+  auto MulOp1 = ExtMulOp1->getOperand(0);
+  auto MulOp2 = ExtMulOp2->getOperand(0);
+  if (MulOp1.getValueType() != MulOp2.getValueType())
     return SDValue();
 
-  EVT ReducedType = N->getValueType(0);
-  EVT MulSrcType = A.getValueType();
+  EVT AVT = N->getValueType(0);
+  EVT MulSrcVT = MulOp1.getValueType();
 
   // Dot products operate on chunks of four elements so there must be four times
   // as many elements in the wide type
-  if (!(ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) &&
-      !(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
-      !(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) &&
-      !(ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8) &&
-      !(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) &&
-      !(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
+  if (!(AVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
+      !(AVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
+      !(AVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
+      !(AVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
+      !(AVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
+      !(AVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
     return SDValue();
 
   // If the extensions are mixed, we should lower it to a usdot instead
   unsigned Opcode = 0;
-  if (AIsSigned != BIsSigned) {
+  if (MulOp1IsSigned != MulOp2IsSigned) {
     if (!Subtarget->hasMatMulInt8())
       return SDValue();
 
     bool Scalable = N->getValueType(0).isScalableVT();
     // There's no nxv2i64 version of usdot
-    if (Scalable && ReducedType != MVT::nxv4i32 && ReducedType != MVT::nxv4i64)
+    if (Scalable && AVT != MVT::nxv4i32 && AVT != MVT::nxv4i64)
       return SDValue();
 
     Opcode = AArch64ISD::USDOT;
     // USDOT expects the signed operand to be last
-    if (!BIsSigned)
-      std::swap(A, B);
-  } else if (AIsSigned)
+    if (!MulOp2IsSigned)
+      std::swap(MulOp1, MulOp2);
+  } else if (MulOp1IsSigned)
     Opcode = AArch64ISD::SDOT;
   else
     Opcode = AArch64ISD::UDOT;
 
   // Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
   // product followed by a zero / sign extension
-  if ((ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) ||
-      (ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8)) {
-    EVT ReducedTypeI32 =
-        (ReducedType.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
+  if ((AVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
+      (AVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
+    EVT AVTI32 = (AVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
 
-    auto DotI32 = DAG.getNode(Opcode, DL, ReducedTypeI32,
-                              DAG.getConstant(0, DL, ReducedTypeI32), A, B);
-    auto Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedType);
-    return DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(), NarrowOp,
-                       Extended);
+    auto DotI32 = DAG.getNode(Opcode, DL, AVTI32,
+                              DAG.getConstant(0, DL, AVTI32), MulOp1, MulOp2);
+    auto Extended = DAG.getSExtOrTrunc(DotI32, DL, AVT);
+    return DAG.getNode(ISD::ADD, DL, A.getValueType(), A, Extended);
   }
 
-  return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
+  return DAG.getNode(Opcode, DL, AVT, A, MulOp1, MulOp2);
 }
 
 SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
@@ -21822,32 +21820,32 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
 
   SDLoc DL(N);
 
-  auto Acc = N->getOperand(1);
-  auto ExtInput = N->getOperand(2);
+  auto A = N->getOperand(1);
+  auto ExtB = N->getOperand(2);
 
-  EVT AccVT = Acc.getValueType();
-  EVT AccElemVT = AccVT.getVectorElementType();
+  EVT AVT = A.getValueType();
+  EVT AElemVT = AVT.getVectorElementType();
 
-  if (ExtInput.getValueType().getVectorElementType() != AccElemVT)
+  if (ExtB.getValueType().getVectorElementType() != AElemVT)
     return SDValue();
 
-  unsigned ExtInputOpcode = ExtInput->getOpcode();
-  if (!ISD::isExtOpcode(ExtInputOpcode))
+  unsigned ExtBOpcode = ExtB->getOpcode();
+  if (!ISD::isExtOpcode(ExtBOpcode))
     return SDValue();
 
-  auto Input = ExtInput->getOperand(0);
-  EVT InputVT = Input.getValueType();
+  auto B = ExtB->getOperand(0);
+  EVT BVT = B.getValueType();
 
-  if (!(InputVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
-      !(InputVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
-      !(InputVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
+  if (!(BVT == MVT::nxv4i32 && AVT == MVT::nxv2i64) &&
+      !(BVT == MVT::nxv8i16 && AVT == MVT::nxv4i32) &&
+      !(BVT == MVT::nxv16i8 && AVT == MVT::nxv8i16))
     return SDValue();
 
-  bool InputIsSigned = ExtInputOpcode == ISD::SIGN_EXTEND;
-  auto BottomOpcode = InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
-  auto TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
-  auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input);
-  return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input);
+  bool BIsSigned = ExtBOpcode == ISD::SIGN_EXTEND;
+  auto BottomOpcode = BIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
+  auto TopOpcode = BIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
+  auto BottomNode = DAG.getNode(BottomOpcode, DL, AVT, A, B);
+  return DAG.getNode(TopOpcode, DL, AVT, BottomNode, B);
 }
 
 static SDValue performIntrinsicCombine(SDNode *N,

Copy link

github-actions bot commented Dec 20, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@JamesChesterman JamesChesterman merged commit 8dc23ef into llvm:main Dec 20, 2024
8 checks passed
@JamesChesterman JamesChesterman deleted the nfc-rename-pr-vars branch February 7, 2025 11:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants