Skip to content

X86: add some missing lowerings for shuffles on bf16 element type. #76076

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 5, 2024

Conversation

bjacob
Copy link
Contributor

@bjacob bjacob commented Dec 20, 2023

Some shuffles with bf16 as element type were running into a llvm_unreachable. See the testcase added in this PR. Key to reproducing was to chain two shuffles.

define <2 x bfloat> @shuffle_chained_v32bf16_v2bf16(<32 x bfloat> %a) {
  %s = shufflevector <32 x bfloat> %a, <32 x bfloat> zeroinitializer, <32 x i32> <i32 0, i32 16, i32 1, i32 17, i32 2, i32 18, i32 3, i32 19, i32 4, i32 20, i32 5, i32 21, i32 6, i32 22, i32 7, i32 23, i32 8, i32 24, i32 9, i32 25, i32 10, i32 26, i32 11, i32 27, i32 12, i32 28, i32 13, i32 29, i32 14, i32 30, i32 15, i32 31>
  %s2 = shufflevector <32 x bfloat> %s, <32 x bfloat> zeroinitializer, <2 x i32> <i32 0, i32 1>
  ret <2 x bfloat> %s2
}

This was hitting this UNREACHABLE:

Not a valid 512-bit x86 vector type!
UNREACHABLE executed at /home/benoit/iree/third_party/llvm-project/llvm/lib/Target/X86/X86ISelLowering.cpp:17124!
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.      Program arguments: /home/benoit/mlir-build/bin/llc -mtriple=x86_64-unknown-linux-gnu -mattr=+avx512f,+avx512vl,+avx512bw,+avx512bf16
1.      Running pass 'Function Pass Manager' on module '<stdin>'.
2.      Running pass 'X86 DAG->DAG Instruction Selection' on function '@shuffle_chained_v32bf16_v2bf16'

@llvmbot llvmbot added backend:X86 llvm:SelectionDAG SelectionDAGISel as well labels Dec 20, 2023
@bjacob bjacob removed the llvm:SelectionDAG SelectionDAGISel as well label Dec 20, 2023
@llvmbot
Copy link
Member

llvmbot commented Dec 20, 2023

@llvm/pr-subscribers-llvm-selectiondag

@llvm/pr-subscribers-backend-x86

Author: None (bjacob)

Changes

These were apparently just unimplemented.

Do you prefer to take this over or teach me how/where to add tests for this?


Full diff: https://github.com/llvm/llvm-project/pull/76076.diff

2 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp (+11)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+55-14)
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 0917d0e4eb3e26..4dd8bfc76395d6 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -927,6 +927,17 @@ void SelectionDAGLegalize::LegalizeLoadOps(SDNode *Node) {
           Chain = Result.getValue(1);
           break;
         }
+        if (SrcVT.getScalarType() == MVT::bf16) {
+          EVT ISrcVT = SrcVT.changeTypeToInteger();
+          EVT IDestVT = DestVT.changeTypeToInteger();
+          EVT ILoadVT = TLI.getRegisterType(IDestVT.getSimpleVT());
+
+          SDValue Result = DAG.getExtLoad(ISD::ZEXTLOAD, dl, ILoadVT, Chain,
+                                          Ptr, ISrcVT, LD->getMemOperand());
+          Value = DAG.getNode(ISD::BF16_TO_FP, dl, DestVT, Result);
+          Chain = Result.getValue(1);
+          break;
+        }
       }
 
       assert(!SrcVT.isVector() &&
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index db5e4fe84f410a..b7123256f57dd6 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -12349,7 +12349,8 @@ static SDValue lowerShuffleAsBroadcast(const SDLoc &DL, MVT VT, SDValue V1,
   MVT EltVT = VT.getVectorElementType();
   if (!((Subtarget.hasSSE3() && VT == MVT::v2f64) ||
         (Subtarget.hasAVX() && (EltVT == MVT::f64 || EltVT == MVT::f32)) ||
-        (Subtarget.hasAVX2() && (VT.isInteger() || EltVT == MVT::f16))))
+        (Subtarget.hasAVX2() && (VT.isInteger() || EltVT == MVT::f16)) ||
+        (Subtarget.hasBF16() && EltVT == MVT::bf16)))
     return SDValue();
 
   // With MOVDDUP (v2f64) we can broadcast from a register or a load, otherwise
@@ -12527,7 +12528,8 @@ static SDValue lowerShuffleAsBroadcast(const SDLoc &DL, MVT VT, SDValue V1,
   // possibly narrower than VT. Then perform the broadcast.
   unsigned NumSrcElts = V.getValueSizeInBits() / NumEltBits;
   MVT CastVT = MVT::getVectorVT(VT.getVectorElementType(), NumSrcElts);
-  return DAG.getNode(Opcode, DL, VT, DAG.getBitcast(CastVT, V));
+  const auto &retval = DAG.getNode(Opcode, DL, VT, DAG.getBitcast(CastVT, V));
+  return retval;
 }
 
 // Check for whether we can use INSERTPS to perform the shuffle. We only use
@@ -13933,28 +13935,30 @@ static SDValue lowerV8F16Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
                                  const APInt &Zeroable, SDValue V1, SDValue V2,
                                  const X86Subtarget &Subtarget,
                                  SelectionDAG &DAG) {
-  assert(V1.getSimpleValueType() == MVT::v8f16 && "Bad operand type!");
-  assert(V2.getSimpleValueType() == MVT::v8f16 && "Bad operand type!");
+  assert((V1.getSimpleValueType() == MVT::v8f16 ||
+          V1.getSimpleValueType() == MVT::v8bf16) &&
+         "Bad operand type!");
+  assert(V2.getSimpleValueType() == V2.getSimpleValueType());
   assert(Mask.size() == 8 && "Unexpected mask size for v8 shuffle!");
   int NumV2Elements = count_if(Mask, [](int M) { return M >= 8; });
-
-  if (Subtarget.hasFP16()) {
+  if ((V1.getSimpleValueType() == MVT::v8f16 && Subtarget.hasFP16()) ||
+      (V1.getSimpleValueType() == MVT::v8bf16 && Subtarget.hasBF16())) {
     if (NumV2Elements == 0) {
       // Check for being able to broadcast a single element.
-      if (SDValue Broadcast = lowerShuffleAsBroadcast(DL, MVT::v8f16, V1, V2,
-                                                      Mask, Subtarget, DAG))
+      if (SDValue Broadcast = lowerShuffleAsBroadcast(
+              DL, V1.getSimpleValueType(), V1, V2, Mask, Subtarget, DAG))
         return Broadcast;
     }
     if (NumV2Elements == 1 && Mask[0] >= 8)
       if (SDValue V = lowerShuffleAsElementInsertion(
-              DL, MVT::v8f16, V1, V2, Mask, Zeroable, Subtarget, DAG))
+              DL, V1.getSimpleValueType(), V1, V2, Mask, Zeroable, Subtarget,
+              DAG))
         return V;
   }
-
-  V1 = DAG.getBitcast(MVT::v8i16, V1);
-  V2 = DAG.getBitcast(MVT::v8i16, V2);
-  return DAG.getBitcast(MVT::v8f16,
-                        DAG.getVectorShuffle(MVT::v8i16, DL, V1, V2, Mask));
+  return DAG.getBitcast(
+      V1.getSimpleValueType(),
+      DAG.getVectorShuffle(MVT::v8i16, DL, DAG.getBitcast(MVT::v8i16, V1),
+                           DAG.getBitcast(MVT::v8i16, V2), Mask));
 }
 
 // Lowers unary/binary shuffle as VPERMV/VPERMV3, for non-VLX targets,
@@ -14377,6 +14381,7 @@ static SDValue lower128BitShuffle(const SDLoc &DL, ArrayRef<int> Mask,
   case MVT::v8i16:
     return lowerV8I16Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
   case MVT::v8f16:
+  case MVT::v8bf16:
     return lowerV8F16Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
   case MVT::v16i8:
     return lowerV16I8Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
@@ -16295,6 +16300,21 @@ static SDValue lowerV16I16Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
                                     Subtarget, DAG);
 }
 
+static SDValue lowerV16F16Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
+                                  const APInt &Zeroable, SDValue V1, SDValue V2,
+                                  const X86Subtarget &Subtarget,
+                                  SelectionDAG &DAG) {
+  assert((V1.getSimpleValueType() == MVT::v16f16 ||
+          V1.getSimpleValueType() == MVT::v16bf16) &&
+         "Bad operand type!");
+  assert(V1.getSimpleValueType() == V2.getSimpleValueType() &&
+         "Bad operand type!");
+  return DAG.getBitcast(
+      V1.getSimpleValueType(),
+      lowerV16I16Shuffle(DL, Mask, Zeroable, DAG.getBitcast(MVT::v16i16, V1),
+                         DAG.getBitcast(MVT::v16i16, V2), Subtarget, DAG));
+}
+
 /// Handle lowering of 32-lane 8-bit integer shuffles.
 ///
 /// This routine is only called when we have AVX2 and thus a reasonable
@@ -16480,6 +16500,9 @@ static SDValue lower256BitShuffle(const SDLoc &DL, ArrayRef<int> Mask, MVT VT,
     return lowerV4I64Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
   case MVT::v8f32:
     return lowerV8F32Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
+  case MVT::v8f16:
+  case MVT::v8bf16:
+    return lowerV16F16Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
   case MVT::v8i32:
     return lowerV8I32Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
   case MVT::v16i16:
@@ -16953,6 +16976,21 @@ static SDValue lowerV32I16Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
   return lowerShuffleWithPERMV(DL, MVT::v32i16, Mask, V1, V2, Subtarget, DAG);
 }
 
+static SDValue lowerV32F16Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
+                                  const APInt &Zeroable, SDValue V1, SDValue V2,
+                                  const X86Subtarget &Subtarget,
+                                  SelectionDAG &DAG) {
+  assert((V1.getSimpleValueType() == MVT::v32f16 ||
+          V1.getSimpleValueType() == MVT::v32bf16) &&
+         "Bad operand type!");
+  assert(V1.getSimpleValueType() == V2.getSimpleValueType() &&
+         "Bad operand type!");
+  return DAG.getBitcast(
+      V1.getSimpleValueType(),
+      lowerV32I16Shuffle(DL, Mask, Zeroable, DAG.getBitcast(MVT::v32i16, V1),
+                         DAG.getBitcast(MVT::v32i16, V2), Subtarget, DAG));
+}
+
 /// Handle lowering of 64-lane 8-bit integer shuffles.
 static SDValue lowerV64I8Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
                                  const APInt &Zeroable, SDValue V1, SDValue V2,
@@ -17112,6 +17150,9 @@ static SDValue lower512BitShuffle(const SDLoc &DL, ArrayRef<int> Mask,
     return lowerV8F64Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
   case MVT::v16f32:
     return lowerV16F32Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
+  case MVT::v32f16:
+  case MVT::v32bf16:
+    return lowerV32F16Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
   case MVT::v8i64:
     return lowerV8I64Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
   case MVT::v16i32:

Copy link
Contributor

@phoebewang phoebewang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you prefer to take this over or teach me how/where to add tests for this?

What's the problem you want to solve here? From the code, I think v16bf16 is well supported. And I tested some v8bf16/v32bf16 cases and they can generate correct code too.
So I think they are also supported, just have difference in the code with FP16.
If you know some failure cases, you can add them as test cases.

Max191 added a commit to Max191/iree that referenced this pull request Dec 21, 2023
@bjacob
Copy link
Contributor Author

bjacob commented Dec 22, 2023

Do you prefer to take this over or teach me how/where to add tests for this?

What's the problem you want to solve here? From the code, I think v16bf16 is well supported. And I tested some v8bf16/v32bf16 cases and they can generate correct code too. So I think they are also supported, just have difference in the code with FP16. If you know some failure cases, you can add them as test cases.

Thanks for the review. Here is a minimized testcase:

target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

define i32 @run_initialize_dispatch_12_pack_bf16(<32 x bfloat> %0) #0 {
.preheader27.lr.ph:
  br label %.preheader26

.preheader26:                                     ; preds = %.preheader26, %.preheader27.lr.ph
  %1 = shufflevector <32 x bfloat> %0, <32 x bfloat> zeroinitializer, <32 x i32> <i32 0, i32 16, i32 1, i32 17, i32 2, i32 18, i32 3, i32 19, i32 4, i32 20, i32 5, i32 21, i32 6, i32 22, i32 7, i32 23, i32 8, i32 24, i32 9, i32 25, i32 10, i32 26, i32 11, i32 27, i32 12, i32 28, i32 13, i32 29, i32 14, i32 30, i32 15, i32 31>
  %2 = shufflevector <32 x bfloat> %1, <32 x bfloat> zeroinitializer, <2 x i32> <i32 0, i32 1>
  store <2 x bfloat> %2, ptr null, align 2
  br label %.preheader26
}

attributes #0 = { "target-cpu"="znver4" }

Repro:

llvm-project/bin/llc reduced.ll 
Not a valid 512-bit x86 vector type!
UNREACHABLE executed at /home/benoit/iree/third_party/llvm-project/llvm/lib/Target/X86/X86ISelLowering.cpp:17125!
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.      Program arguments: llvm-project/bin/llc reduced.ll
1.      Running pass 'Function Pass Manager' on module 'reduced.ll'.
2.      Running pass 'X86 DAG->DAG Instruction Selection' on function '@run_initialize_dispatch_12_pack_bf16'

Going to go over all your review comments now (thanks again) and update the PR with that as a test.

@bjacob bjacob force-pushed the bf16-review branch 2 times, most recently from 730bf96 to 77e90b7 Compare December 22, 2023 21:04
@bjacob bjacob changed the title X86: implement lowerings for shuffles on bf16 element type. X86: add some missing lowerings for shuffles on bf16 element type. Dec 22, 2023
@bjacob
Copy link
Contributor Author

bjacob commented Dec 22, 2023

@phoebewang , thanks again for the great review. It's much nicer now following your guidance, and I added a testcase. PTAL any time (this can wait until 2024!).

@bjacob bjacob requested a review from phoebewang December 22, 2023 21:08
Comment on lines 13942 to 13953
if ((V1.getSimpleValueType() == MVT::v8f16 && Subtarget.hasFP16()) ||
(V1.getSimpleValueType() == MVT::v8bf16 && Subtarget.hasBF16())) {
if (NumV2Elements == 0) {
// Check for being able to broadcast a single element.
if (SDValue Broadcast = lowerShuffleAsBroadcast(DL, MVT::v8f16, V1, V2,
Mask, Subtarget, DAG))
if (SDValue Broadcast = lowerShuffleAsBroadcast(
DL, V1.getSimpleValueType(), V1, V2, Mask, Subtarget, DAG))
return Broadcast;
}
if (NumV2Elements == 1 && Mask[0] >= 8)
if (SDValue V = lowerShuffleAsElementInsertion(
DL, MVT::v8f16, V1, V2, Mask, Zeroable, Subtarget, DAG))
DL, V1.getSimpleValueType(), V1, V2, Mask, Zeroable, Subtarget,
DAG))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still not correct. We did this for FP16 only because it have a vmovw instruction, this doesn't apply to BF16.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, thanks - I see. The 16-bit vmovw instruction is part of the +avx512fp16 feature, breaking symmetry between fp16 and bf16 here.

Comment on lines 13956 to 13959
return DAG.getBitcast(
V1.getSimpleValueType(),
DAG.getVectorShuffle(MVT::v8i16, DL, DAG.getBitcast(MVT::v8i16, V1),
DAG.getBitcast(MVT::v8i16, V2), Mask));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can move this out of this function and do it similar to v16bf16/v32bf16. But I'm not sure if we really need it, below test passes without this change https://godbolt.org/z/Koq16zGeM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it - that works! Actually I had tried something like this earlier, except that, like in the 256bit and 512bit cases, I was trying to do it for both bf16 and f16, but that was causing several test failures. What I was missing is the asymmetry here due to f16-specific logic in lowerV8F16Shuffle to target vmovw which only exists in avx512fp16. Once I made this bitcast only in the bf16 case, it just worked.

Now the PR is much simpler - thanks! PTAL.

@bjacob
Copy link
Contributor Author

bjacob commented Jan 3, 2024

Comment on lines 12356 to 12357
(Subtarget.hasAVX2() && (VT.isInteger() || EltVT == MVT::f16)) ||
(Subtarget.hasBF16() && EltVT == MVT::bf16)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure of the change here. Maybe leave it unchanged if it doesn't affect the test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. Done!

Copy link
Contributor

@phoebewang phoebewang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@phoebewang
Copy link
Contributor

Do you need help to merge it for you?

@bjacob
Copy link
Contributor Author

bjacob commented Jan 5, 2024

Thanks again for the great reviewing cycle!

@bjacob bjacob merged commit 054b5fc into llvm:main Jan 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:X86 llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants