diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index cfd82a342433f..f1e7a3f4421e8 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -12249,9 +12249,8 @@ void SelectionDAGBuilder::visitVectorSplice(const CallInst &I) { // VECTOR_SHUFFLE doesn't support a scalable mask so use a dedicated node. if (VT.isScalableVector()) { - MVT IdxVT = TLI.getVectorIdxTy(DAG.getDataLayout()); setValue(&I, DAG.getNode(ISD::VECTOR_SPLICE, DL, VT, V1, V2, - DAG.getConstant(Imm, DL, IdxVT))); + DAG.getVectorIdxConstant(Imm, DL))); return; } diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 2af679e0755b5..a1931bc1e4936 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1048,9 +1048,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN); setTargetDAGCombine({ISD::ANY_EXTEND, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND, - ISD::VECTOR_SPLICE, ISD::SIGN_EXTEND_INREG, - ISD::CONCAT_VECTORS, ISD::EXTRACT_SUBVECTOR, - ISD::INSERT_SUBVECTOR, ISD::STORE, ISD::BUILD_VECTOR}); + ISD::SIGN_EXTEND_INREG, ISD::CONCAT_VECTORS, + ISD::EXTRACT_SUBVECTOR, ISD::INSERT_SUBVECTOR, + ISD::STORE, ISD::BUILD_VECTOR}); setTargetDAGCombine(ISD::TRUNCATE); setTargetDAGCombine(ISD::LOAD); @@ -1580,6 +1580,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::MLOAD, VT, Custom); setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom); setOperationAction(ISD::SPLAT_VECTOR, VT, Legal); + setOperationAction(ISD::VECTOR_SPLICE, VT, Custom); if (!Subtarget->isLittleEndian()) setOperationAction(ISD::BITCAST, VT, Expand); @@ -10102,10 +10103,9 @@ SDValue AArch64TargetLowering::LowerVECTOR_SPLICE(SDValue Op, Op.getOperand(1)); } - // This will select to an EXT instruction, which has a maximum immediate - // value of 255, hence 2048-bits is the maximum value we can lower. - if (IdxVal >= 0 && - IdxVal < int64_t(2048 / Ty.getVectorElementType().getSizeInBits())) + // We can select to an EXT instruction when indexing the first 256 bytes. + unsigned BlockSize = AArch64::SVEBitsPerBlock / Ty.getVectorMinNumElements(); + if (IdxVal >= 0 && (IdxVal * BlockSize / 8) < 256) return Op; return SDValue(); @@ -24237,28 +24237,6 @@ performInsertVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { return performPostLD1Combine(N, DCI, true); } -static SDValue performSVESpliceCombine(SDNode *N, SelectionDAG &DAG) { - EVT Ty = N->getValueType(0); - if (Ty.isInteger()) - return SDValue(); - - EVT IntTy = Ty.changeVectorElementTypeToInteger(); - EVT ExtIntTy = getPackedSVEVectorVT(IntTy.getVectorElementCount()); - if (ExtIntTy.getVectorElementType().getScalarSizeInBits() < - IntTy.getVectorElementType().getScalarSizeInBits()) - return SDValue(); - - SDLoc DL(N); - SDValue LHS = DAG.getAnyExtOrTrunc(DAG.getBitcast(IntTy, N->getOperand(0)), - DL, ExtIntTy); - SDValue RHS = DAG.getAnyExtOrTrunc(DAG.getBitcast(IntTy, N->getOperand(1)), - DL, ExtIntTy); - SDValue Idx = N->getOperand(2); - SDValue Splice = DAG.getNode(ISD::VECTOR_SPLICE, DL, ExtIntTy, LHS, RHS, Idx); - SDValue Trunc = DAG.getAnyExtOrTrunc(Splice, DL, IntTy); - return DAG.getBitcast(Ty, Trunc); -} - static SDValue performFPExtendCombine(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const AArch64Subtarget *Subtarget) { @@ -24643,8 +24621,6 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, case ISD::MGATHER: case ISD::MSCATTER: return performMaskedGatherScatterCombine(N, DCI, DAG); - case ISD::VECTOR_SPLICE: - return performSVESpliceCombine(N, DAG); case ISD::FP_EXTEND: return performFPExtendCombine(N, DAG, DCI, Subtarget); case AArch64ISD::BRCOND: diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index 62e68de1359f7..64e545aa26b45 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -1994,14 +1994,21 @@ let Predicates = [HasSVEorSME] in { (LASTB_VPZ_D (PTRUE_D 31), ZPR:$Z1), dsub))>; // Splice with lane bigger or equal to 0 - def : Pat<(nxv16i8 (vector_splice (nxv16i8 ZPR:$Z1), (nxv16i8 ZPR:$Z2), (i64 (sve_ext_imm_0_255 i32:$index)))), - (EXT_ZZI ZPR:$Z1, ZPR:$Z2, imm0_255:$index)>; - def : Pat<(nxv8i16 (vector_splice (nxv8i16 ZPR:$Z1), (nxv8i16 ZPR:$Z2), (i64 (sve_ext_imm_0_127 i32:$index)))), - (EXT_ZZI ZPR:$Z1, ZPR:$Z2, imm0_255:$index)>; - def : Pat<(nxv4i32 (vector_splice (nxv4i32 ZPR:$Z1), (nxv4i32 ZPR:$Z2), (i64 (sve_ext_imm_0_63 i32:$index)))), - (EXT_ZZI ZPR:$Z1, ZPR:$Z2, imm0_255:$index)>; - def : Pat<(nxv2i64 (vector_splice (nxv2i64 ZPR:$Z1), (nxv2i64 ZPR:$Z2), (i64 (sve_ext_imm_0_31 i32:$index)))), - (EXT_ZZI ZPR:$Z1, ZPR:$Z2, imm0_255:$index)>; + foreach VT = [nxv16i8] in + def : Pat<(VT (vector_splice (VT ZPR:$Z1), (VT ZPR:$Z2), (i64 (sve_ext_imm_0_255 i32:$index)))), + (EXT_ZZI ZPR:$Z1, ZPR:$Z2, imm0_255:$index)>; + + foreach VT = [nxv8i16, nxv8f16, nxv8bf16] in + def : Pat<(VT (vector_splice (VT ZPR:$Z1), (VT ZPR:$Z2), (i64 (sve_ext_imm_0_127 i32:$index)))), + (EXT_ZZI ZPR:$Z1, ZPR:$Z2, imm0_255:$index)>; + + foreach VT = [nxv4i32, nxv4f16, nxv4f32, nxv4bf16] in + def : Pat<(VT (vector_splice (VT ZPR:$Z1), (VT ZPR:$Z2), (i64 (sve_ext_imm_0_63 i32:$index)))), + (EXT_ZZI ZPR:$Z1, ZPR:$Z2, imm0_255:$index)>; + + foreach VT = [nxv2i64, nxv2f16, nxv2f32, nxv2f64, nxv2bf16] in + def : Pat<(VT (vector_splice (VT ZPR:$Z1), (VT ZPR:$Z2), (i64 (sve_ext_imm_0_31 i32:$index)))), + (EXT_ZZI ZPR:$Z1, ZPR:$Z2, imm0_255:$index)>; defm CMPHS_PPzZZ : sve_int_cmp_0<0b000, "cmphs", SETUGE, SETULE>; defm CMPHI_PPzZZ : sve_int_cmp_0<0b001, "cmphi", SETUGT, SETULT>; diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td index 69c3238c7d614..fc7d3cdda4acd 100644 --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -7060,16 +7060,17 @@ multiclass sve_int_perm_splice { def _S : sve_int_perm_splice<0b10, asm, ZPR32>; def _D : sve_int_perm_splice<0b11, asm, ZPR64>; - def : SVE_3_Op_Pat(NAME # _B)>; - def : SVE_3_Op_Pat(NAME # _H)>; - def : SVE_3_Op_Pat(NAME # _S)>; - def : SVE_3_Op_Pat(NAME # _D)>; + foreach VT = [nxv16i8] in + def : SVE_3_Op_Pat(NAME # _B)>; - def : SVE_3_Op_Pat(NAME # _H)>; - def : SVE_3_Op_Pat(NAME # _S)>; - def : SVE_3_Op_Pat(NAME # _D)>; + foreach VT = [nxv8i16, nxv8f16, nxv8bf16] in + def : SVE_3_Op_Pat(NAME # _H)>; - def : SVE_3_Op_Pat(NAME # _H)>; + foreach VT = [nxv4i32, nxv4f16, nxv4f32, nxv4bf16] in + def : SVE_3_Op_Pat(NAME # _S)>; + + foreach VT = [nxv2i64, nxv2f16, nxv2f32, nxv2f64, nxv2bf16] in + def : SVE_3_Op_Pat(NAME # _D)>; } class sve2_int_perm_splice_cons sz8_64, string asm, diff --git a/llvm/test/CodeGen/AArch64/named-vector-shuffles-sve.ll b/llvm/test/CodeGen/AArch64/named-vector-shuffles-sve.ll index f5763cd61033b..d1171bc312473 100644 --- a/llvm/test/CodeGen/AArch64/named-vector-shuffles-sve.ll +++ b/llvm/test/CodeGen/AArch64/named-vector-shuffles-sve.ll @@ -692,6 +692,104 @@ define @splice_nxv2f64_neg3( %a, %res } +define @splice_nxv2bf16_neg_idx( %a, %b) #0 { +; CHECK-LABEL: splice_nxv2bf16_neg_idx: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d, vl1 +; CHECK-NEXT: rev p0.d, p0.d +; CHECK-NEXT: splice z0.d, p0, z0.d, z1.d +; CHECK-NEXT: ret + %res = call @llvm.vector.splice.nxv2bf16( %a, %b, i32 -1) + ret %res +} + +define @splice_nxv2bf16_neg2_idx( %a, %b) #0 { +; CHECK-LABEL: splice_nxv2bf16_neg2_idx: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d, vl2 +; CHECK-NEXT: rev p0.d, p0.d +; CHECK-NEXT: splice z0.d, p0, z0.d, z1.d +; CHECK-NEXT: ret + %res = call @llvm.vector.splice.nxv2bf16( %a, %b, i32 -2) + ret %res +} + +define @splice_nxv2bf16_first_idx( %a, %b) #0 { +; CHECK-LABEL: splice_nxv2bf16_first_idx: +; CHECK: // %bb.0: +; CHECK-NEXT: ext z0.b, z0.b, z1.b, #8 +; CHECK-NEXT: ret + %res = call @llvm.vector.splice.nxv2bf16( %a, %b, i32 1) + ret %res +} + +define @splice_nxv2bf16_last_idx( %a, %b) vscale_range(16,16) #0 { +; CHECK-LABEL: splice_nxv2bf16_last_idx: +; CHECK: // %bb.0: +; CHECK-NEXT: ext z0.b, z0.b, z1.b, #248 +; CHECK-NEXT: ret + %res = call @llvm.vector.splice.nxv2bf16( %a, %b, i32 31) + ret %res +} + +define @splice_nxv4bf16_neg_idx( %a, %b) #0 { +; CHECK-LABEL: splice_nxv4bf16_neg_idx: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s, vl1 +; CHECK-NEXT: rev p0.s, p0.s +; CHECK-NEXT: splice z0.s, p0, z0.s, z1.s +; CHECK-NEXT: ret + %res = call @llvm.vector.splice.nxv4bf16( %a, %b, i32 -1) + ret %res +} + +define @splice_nxv4bf16_neg3_idx( %a, %b) #0 { +; CHECK-LABEL: splice_nxv4bf16_neg3_idx: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s, vl3 +; CHECK-NEXT: rev p0.s, p0.s +; CHECK-NEXT: splice z0.s, p0, z0.s, z1.s +; CHECK-NEXT: ret + %res = call @llvm.vector.splice.nxv4bf16( %a, %b, i32 -3) + ret %res +} + +define @splice_nxv4bf16_first_idx( %a, %b) #0 { +; CHECK-LABEL: splice_nxv4bf16_first_idx: +; CHECK: // %bb.0: +; CHECK-NEXT: ext z0.b, z0.b, z1.b, #4 +; CHECK-NEXT: ret + %res = call @llvm.vector.splice.nxv4bf16( %a, %b, i32 1) + ret %res +} + +define @splice_nxv4bf16_last_idx( %a, %b) vscale_range(16,16) #0 { +; CHECK-LABEL: splice_nxv4bf16_last_idx: +; CHECK: // %bb.0: +; CHECK-NEXT: ext z0.b, z0.b, z1.b, #252 +; CHECK-NEXT: ret + %res = call @llvm.vector.splice.nxv4bf16( %a, %b, i32 63) + ret %res +} + +define @splice_nxv8bf16_first_idx( %a, %b) #0 { +; CHECK-LABEL: splice_nxv8bf16_first_idx: +; CHECK: // %bb.0: +; CHECK-NEXT: ext z0.b, z0.b, z1.b, #2 +; CHECK-NEXT: ret + %res = call @llvm.vector.splice.nxv8bf16( %a, %b, i32 1) + ret %res +} + +define @splice_nxv8bf16_last_idx( %a, %b) vscale_range(16,16) #0 { +; CHECK-LABEL: splice_nxv8bf16_last_idx: +; CHECK: // %bb.0: +; CHECK-NEXT: ext z0.b, z0.b, z1.b, #254 +; CHECK-NEXT: ret + %res = call @llvm.vector.splice.nxv8bf16( %a, %b, i32 127) + ret %res +} + ; Ensure predicate based splice is promoted to use ZPRs. define @splice_nxv2i1( %a, %b) #0 { ; CHECK-LABEL: splice_nxv2i1: @@ -834,12 +932,14 @@ declare @llvm.vector.splice.nxv2i1(, @llvm.vector.splice.nxv4i1(, , i32) declare @llvm.vector.splice.nxv8i1(, , i32) declare @llvm.vector.splice.nxv16i1(, , i32) + declare @llvm.vector.splice.nxv2i8(, , i32) declare @llvm.vector.splice.nxv16i8(, , i32) declare @llvm.vector.splice.nxv8i16(, , i32) declare @llvm.vector.splice.nxv4i32(, , i32) declare @llvm.vector.splice.nxv8i32(, , i32) declare @llvm.vector.splice.nxv2i64(, , i32) + declare @llvm.vector.splice.nxv2f16(, , i32) declare @llvm.vector.splice.nxv4f16(, , i32) declare @llvm.vector.splice.nxv8f16(, , i32) @@ -848,4 +948,8 @@ declare @llvm.vector.splice.nxv4f32(, < declare @llvm.vector.splice.nxv16f32(, , i32) declare @llvm.vector.splice.nxv2f64(, , i32) +declare @llvm.vector.splice.nxv2bf16(, , i32) +declare @llvm.vector.splice.nxv4bf16(, , i32) +declare @llvm.vector.splice.nxv8bf16(, , i32) + attributes #0 = { nounwind "target-features"="+sve" }