-
Notifications
You must be signed in to change notification settings - Fork 13.6k
X86: add some missing lowerings for shuffles on bf16
element type.
#76076
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-selectiondag @llvm/pr-subscribers-backend-x86 Author: None (bjacob) ChangesThese were apparently just unimplemented. Do you prefer to take this over or teach me how/where to add tests for this? Full diff: https://github.com/llvm/llvm-project/pull/76076.diff 2 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 0917d0e4eb3e26..4dd8bfc76395d6 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -927,6 +927,17 @@ void SelectionDAGLegalize::LegalizeLoadOps(SDNode *Node) {
Chain = Result.getValue(1);
break;
}
+ if (SrcVT.getScalarType() == MVT::bf16) {
+ EVT ISrcVT = SrcVT.changeTypeToInteger();
+ EVT IDestVT = DestVT.changeTypeToInteger();
+ EVT ILoadVT = TLI.getRegisterType(IDestVT.getSimpleVT());
+
+ SDValue Result = DAG.getExtLoad(ISD::ZEXTLOAD, dl, ILoadVT, Chain,
+ Ptr, ISrcVT, LD->getMemOperand());
+ Value = DAG.getNode(ISD::BF16_TO_FP, dl, DestVT, Result);
+ Chain = Result.getValue(1);
+ break;
+ }
}
assert(!SrcVT.isVector() &&
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index db5e4fe84f410a..b7123256f57dd6 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -12349,7 +12349,8 @@ static SDValue lowerShuffleAsBroadcast(const SDLoc &DL, MVT VT, SDValue V1,
MVT EltVT = VT.getVectorElementType();
if (!((Subtarget.hasSSE3() && VT == MVT::v2f64) ||
(Subtarget.hasAVX() && (EltVT == MVT::f64 || EltVT == MVT::f32)) ||
- (Subtarget.hasAVX2() && (VT.isInteger() || EltVT == MVT::f16))))
+ (Subtarget.hasAVX2() && (VT.isInteger() || EltVT == MVT::f16)) ||
+ (Subtarget.hasBF16() && EltVT == MVT::bf16)))
return SDValue();
// With MOVDDUP (v2f64) we can broadcast from a register or a load, otherwise
@@ -12527,7 +12528,8 @@ static SDValue lowerShuffleAsBroadcast(const SDLoc &DL, MVT VT, SDValue V1,
// possibly narrower than VT. Then perform the broadcast.
unsigned NumSrcElts = V.getValueSizeInBits() / NumEltBits;
MVT CastVT = MVT::getVectorVT(VT.getVectorElementType(), NumSrcElts);
- return DAG.getNode(Opcode, DL, VT, DAG.getBitcast(CastVT, V));
+ const auto &retval = DAG.getNode(Opcode, DL, VT, DAG.getBitcast(CastVT, V));
+ return retval;
}
// Check for whether we can use INSERTPS to perform the shuffle. We only use
@@ -13933,28 +13935,30 @@ static SDValue lowerV8F16Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
const APInt &Zeroable, SDValue V1, SDValue V2,
const X86Subtarget &Subtarget,
SelectionDAG &DAG) {
- assert(V1.getSimpleValueType() == MVT::v8f16 && "Bad operand type!");
- assert(V2.getSimpleValueType() == MVT::v8f16 && "Bad operand type!");
+ assert((V1.getSimpleValueType() == MVT::v8f16 ||
+ V1.getSimpleValueType() == MVT::v8bf16) &&
+ "Bad operand type!");
+ assert(V2.getSimpleValueType() == V2.getSimpleValueType());
assert(Mask.size() == 8 && "Unexpected mask size for v8 shuffle!");
int NumV2Elements = count_if(Mask, [](int M) { return M >= 8; });
-
- if (Subtarget.hasFP16()) {
+ if ((V1.getSimpleValueType() == MVT::v8f16 && Subtarget.hasFP16()) ||
+ (V1.getSimpleValueType() == MVT::v8bf16 && Subtarget.hasBF16())) {
if (NumV2Elements == 0) {
// Check for being able to broadcast a single element.
- if (SDValue Broadcast = lowerShuffleAsBroadcast(DL, MVT::v8f16, V1, V2,
- Mask, Subtarget, DAG))
+ if (SDValue Broadcast = lowerShuffleAsBroadcast(
+ DL, V1.getSimpleValueType(), V1, V2, Mask, Subtarget, DAG))
return Broadcast;
}
if (NumV2Elements == 1 && Mask[0] >= 8)
if (SDValue V = lowerShuffleAsElementInsertion(
- DL, MVT::v8f16, V1, V2, Mask, Zeroable, Subtarget, DAG))
+ DL, V1.getSimpleValueType(), V1, V2, Mask, Zeroable, Subtarget,
+ DAG))
return V;
}
-
- V1 = DAG.getBitcast(MVT::v8i16, V1);
- V2 = DAG.getBitcast(MVT::v8i16, V2);
- return DAG.getBitcast(MVT::v8f16,
- DAG.getVectorShuffle(MVT::v8i16, DL, V1, V2, Mask));
+ return DAG.getBitcast(
+ V1.getSimpleValueType(),
+ DAG.getVectorShuffle(MVT::v8i16, DL, DAG.getBitcast(MVT::v8i16, V1),
+ DAG.getBitcast(MVT::v8i16, V2), Mask));
}
// Lowers unary/binary shuffle as VPERMV/VPERMV3, for non-VLX targets,
@@ -14377,6 +14381,7 @@ static SDValue lower128BitShuffle(const SDLoc &DL, ArrayRef<int> Mask,
case MVT::v8i16:
return lowerV8I16Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
case MVT::v8f16:
+ case MVT::v8bf16:
return lowerV8F16Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
case MVT::v16i8:
return lowerV16I8Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
@@ -16295,6 +16300,21 @@ static SDValue lowerV16I16Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
Subtarget, DAG);
}
+static SDValue lowerV16F16Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
+ const APInt &Zeroable, SDValue V1, SDValue V2,
+ const X86Subtarget &Subtarget,
+ SelectionDAG &DAG) {
+ assert((V1.getSimpleValueType() == MVT::v16f16 ||
+ V1.getSimpleValueType() == MVT::v16bf16) &&
+ "Bad operand type!");
+ assert(V1.getSimpleValueType() == V2.getSimpleValueType() &&
+ "Bad operand type!");
+ return DAG.getBitcast(
+ V1.getSimpleValueType(),
+ lowerV16I16Shuffle(DL, Mask, Zeroable, DAG.getBitcast(MVT::v16i16, V1),
+ DAG.getBitcast(MVT::v16i16, V2), Subtarget, DAG));
+}
+
/// Handle lowering of 32-lane 8-bit integer shuffles.
///
/// This routine is only called when we have AVX2 and thus a reasonable
@@ -16480,6 +16500,9 @@ static SDValue lower256BitShuffle(const SDLoc &DL, ArrayRef<int> Mask, MVT VT,
return lowerV4I64Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
case MVT::v8f32:
return lowerV8F32Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
+ case MVT::v8f16:
+ case MVT::v8bf16:
+ return lowerV16F16Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
case MVT::v8i32:
return lowerV8I32Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
case MVT::v16i16:
@@ -16953,6 +16976,21 @@ static SDValue lowerV32I16Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
return lowerShuffleWithPERMV(DL, MVT::v32i16, Mask, V1, V2, Subtarget, DAG);
}
+static SDValue lowerV32F16Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
+ const APInt &Zeroable, SDValue V1, SDValue V2,
+ const X86Subtarget &Subtarget,
+ SelectionDAG &DAG) {
+ assert((V1.getSimpleValueType() == MVT::v32f16 ||
+ V1.getSimpleValueType() == MVT::v32bf16) &&
+ "Bad operand type!");
+ assert(V1.getSimpleValueType() == V2.getSimpleValueType() &&
+ "Bad operand type!");
+ return DAG.getBitcast(
+ V1.getSimpleValueType(),
+ lowerV32I16Shuffle(DL, Mask, Zeroable, DAG.getBitcast(MVT::v32i16, V1),
+ DAG.getBitcast(MVT::v32i16, V2), Subtarget, DAG));
+}
+
/// Handle lowering of 64-lane 8-bit integer shuffles.
static SDValue lowerV64I8Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
const APInt &Zeroable, SDValue V1, SDValue V2,
@@ -17112,6 +17150,9 @@ static SDValue lower512BitShuffle(const SDLoc &DL, ArrayRef<int> Mask,
return lowerV8F64Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
case MVT::v16f32:
return lowerV16F32Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
+ case MVT::v32f16:
+ case MVT::v32bf16:
+ return lowerV32F16Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
case MVT::v8i64:
return lowerV8I64Shuffle(DL, Mask, Zeroable, V1, V2, Subtarget, DAG);
case MVT::v16i32:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you prefer to take this over or teach me how/where to add tests for this?
What's the problem you want to solve here? From the code, I think v16bf16
is well supported. And I tested some v8bf16/v32bf16
cases and they can generate correct code too.
So I think they are also supported, just have difference in the code with FP16.
If you know some failure cases, you can add them as test cases.
Thanks for the review. Here is a minimized testcase: target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"
define i32 @run_initialize_dispatch_12_pack_bf16(<32 x bfloat> %0) #0 {
.preheader27.lr.ph:
br label %.preheader26
.preheader26: ; preds = %.preheader26, %.preheader27.lr.ph
%1 = shufflevector <32 x bfloat> %0, <32 x bfloat> zeroinitializer, <32 x i32> <i32 0, i32 16, i32 1, i32 17, i32 2, i32 18, i32 3, i32 19, i32 4, i32 20, i32 5, i32 21, i32 6, i32 22, i32 7, i32 23, i32 8, i32 24, i32 9, i32 25, i32 10, i32 26, i32 11, i32 27, i32 12, i32 28, i32 13, i32 29, i32 14, i32 30, i32 15, i32 31>
%2 = shufflevector <32 x bfloat> %1, <32 x bfloat> zeroinitializer, <2 x i32> <i32 0, i32 1>
store <2 x bfloat> %2, ptr null, align 2
br label %.preheader26
}
attributes #0 = { "target-cpu"="znver4" } Repro:
Going to go over all your review comments now (thanks again) and update the PR with that as a test. |
730bf96
to
77e90b7
Compare
bf16
element type.bf16
element type.
@phoebewang , thanks again for the great review. It's much nicer now following your guidance, and I added a testcase. PTAL any time (this can wait until 2024!). |
if ((V1.getSimpleValueType() == MVT::v8f16 && Subtarget.hasFP16()) || | ||
(V1.getSimpleValueType() == MVT::v8bf16 && Subtarget.hasBF16())) { | ||
if (NumV2Elements == 0) { | ||
// Check for being able to broadcast a single element. | ||
if (SDValue Broadcast = lowerShuffleAsBroadcast(DL, MVT::v8f16, V1, V2, | ||
Mask, Subtarget, DAG)) | ||
if (SDValue Broadcast = lowerShuffleAsBroadcast( | ||
DL, V1.getSimpleValueType(), V1, V2, Mask, Subtarget, DAG)) | ||
return Broadcast; | ||
} | ||
if (NumV2Elements == 1 && Mask[0] >= 8) | ||
if (SDValue V = lowerShuffleAsElementInsertion( | ||
DL, MVT::v8f16, V1, V2, Mask, Zeroable, Subtarget, DAG)) | ||
DL, V1.getSimpleValueType(), V1, V2, Mask, Zeroable, Subtarget, | ||
DAG)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is still not correct. We did this for FP16 only because it have a vmovw
instruction, this doesn't apply to BF16.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, thanks - I see. The 16-bit vmovw
instruction is part of the +avx512fp16
feature, breaking symmetry between fp16 and bf16 here.
return DAG.getBitcast( | ||
V1.getSimpleValueType(), | ||
DAG.getVectorShuffle(MVT::v8i16, DL, DAG.getBitcast(MVT::v8i16, V1), | ||
DAG.getBitcast(MVT::v8i16, V2), Mask)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can move this out of this function and do it similar to v16bf16/v32bf16
. But I'm not sure if we really need it, below test passes without this change https://godbolt.org/z/Koq16zGeM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it - that works! Actually I had tried something like this earlier, except that, like in the 256bit and 512bit cases, I was trying to do it for both bf16
and f16
, but that was causing several test failures. What I was missing is the asymmetry here due to f16
-specific logic in lowerV8F16Shuffle
to target vmovw
which only exists in avx512fp16
. Once I made this bitcast only in the bf16
case, it just worked.
Now the PR is much simpler - thanks! PTAL.
The CI failure appears unrelated: |
(Subtarget.hasAVX2() && (VT.isInteger() || EltVT == MVT::f16)) || | ||
(Subtarget.hasBF16() && EltVT == MVT::bf16))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure of the change here. Maybe leave it unchanged if it doesn't affect the test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed. Done!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Do you need help to merge it for you? |
Thanks again for the great reviewing cycle! |
Some shuffles with
bf16
as element type were running into allvm_unreachable
. See the testcase added in this PR. Key to reproducing was to chain two shuffles.This was hitting this UNREACHABLE: