diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 35e54ebd5129f..a90ddf132c389 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -3741,9 +3741,11 @@ static SDValue getZeroVector(MVT VT, const X86Subtarget &Subtarget, // type. This ensures they get CSE'd. But if the integer type is not // available, use a floating-point +0.0 instead. SDValue Vec; + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); if (!Subtarget.hasSSE2() && VT.is128BitVector()) { Vec = DAG.getConstantFP(+0.0, dl, MVT::v4f32); - } else if (VT.isFloatingPoint()) { + } else if (VT.isFloatingPoint() && + TLI.isTypeLegal(VT.getVectorElementType())) { Vec = DAG.getConstantFP(+0.0, dl, VT); } else if (VT.getVectorElementType() == MVT::i1) { assert((Subtarget.hasBWI() || VT.getVectorNumElements() <= 16) && diff --git a/llvm/lib/Target/X86/X86InstrVecCompiler.td b/llvm/lib/Target/X86/X86InstrVecCompiler.td index 70bd77bba03ab..bbd19cf8d5b25 100644 --- a/llvm/lib/Target/X86/X86InstrVecCompiler.td +++ b/llvm/lib/Target/X86/X86InstrVecCompiler.td @@ -130,6 +130,9 @@ let Predicates = [HasAVX, NoVLX] in { defm : subvec_zero_lowering<"DQA", VR128, v32i8, v16i8, sub_xmm>; } +let Predicates = [HasAVXNECONVERT, NoVLX] in + defm : subvec_zero_lowering<"DQA", VR128, v16bf16, v8bf16, sub_xmm>; + let Predicates = [HasVLX] in { defm : subvec_zero_lowering<"APDZ128", VR128X, v4f64, v2f64, sub_xmm>; defm : subvec_zero_lowering<"APSZ128", VR128X, v8f32, v4f32, sub_xmm>; @@ -175,6 +178,12 @@ let Predicates = [HasFP16, HasVLX] in { defm : subvec_zero_lowering<"APSZ256", VR256X, v32f16, v16f16, sub_ymm>; } +let Predicates = [HasBF16, HasVLX] in { + defm : subvec_zero_lowering<"APSZ128", VR128X, v16bf16, v8bf16, sub_xmm>; + defm : subvec_zero_lowering<"APSZ128", VR128X, v32bf16, v8bf16, sub_xmm>; + defm : subvec_zero_lowering<"APSZ256", VR256X, v32bf16, v16bf16, sub_ymm>; +} + class maskzeroupper : PatLeaf<(vt RC:$src), [{ return isMaskZeroExtended(N); diff --git a/llvm/test/CodeGen/X86/bfloat.ll b/llvm/test/CodeGen/X86/bfloat.ll index 674a0eacb0ca9..9c65310f79d7e 100644 --- a/llvm/test/CodeGen/X86/bfloat.ll +++ b/llvm/test/CodeGen/X86/bfloat.ll @@ -2529,3 +2529,17 @@ define <8 x bfloat> @extract_v32bf16_v8bf16(<32 x bfloat> %x) { %a = shufflevector <32 x bfloat> %x, <32 x bfloat> undef, <8 x i32> ret <8 x bfloat> %a } + +define <16 x bfloat> @concat_zero_v8bf16(<8 x bfloat> %x, <8 x bfloat> %y) { +; SSE2-LABEL: concat_zero_v8bf16: +; SSE2: # %bb.0: +; SSE2-NEXT: xorps %xmm1, %xmm1 +; SSE2-NEXT: retq +; +; AVX-LABEL: concat_zero_v8bf16: +; AVX: # %bb.0: +; AVX-NEXT: vmovaps %xmm0, %xmm0 +; AVX-NEXT: retq + %a = shufflevector <8 x bfloat> %x, <8 x bfloat> zeroinitializer, <16 x i32> + ret <16 x bfloat> %a +}