diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index 9d729d448502d..c3da09f025122 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -1896,6 +1896,18 @@ SDValue SelectionDAGBuilder::getValueImpl(const Value *V) { DAG.getConstant(0, getCurSDLoc(), MVT::nxv16i1)); } + if (VT.isRISCVVectorTuple()) { + assert(C->isNullValue() && "Can only zero this target type!"); + return NodeMap[V] = DAG.getNode( + ISD::BITCAST, getCurSDLoc(), VT, + DAG.getNode( + ISD::SPLAT_VECTOR, getCurSDLoc(), + EVT::getVectorVT(*DAG.getContext(), MVT::i8, + VT.getSizeInBits().getKnownMinValue() / 8, + true), + DAG.getConstant(0, getCurSDLoc(), MVT::getIntegerVT(8)))); + } + VectorType *VecTy = cast(V->getType()); // Now that we know the number and type of the elements, get that number of diff --git a/llvm/lib/IR/Type.cpp b/llvm/lib/IR/Type.cpp index ac6b8b4c19700..ffa80faf6e249 100644 --- a/llvm/lib/IR/Type.cpp +++ b/llvm/lib/IR/Type.cpp @@ -990,7 +990,7 @@ static TargetTypeInfo getTargetTypeInfo(const TargetExtType *Ty) { Ty->getIntParameter(0); return TargetTypeInfo( ScalableVectorType::get(Type::getInt8Ty(C), TotalNumElts), - TargetExtType::CanBeLocal); + TargetExtType::CanBeLocal, TargetExtType::HasZeroInit); } // DirectX resources diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 329b42d621cee..1e66245f41bfa 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -18054,6 +18054,20 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); EVT SrcVT = N0.getValueType(); + if (VT.isRISCVVectorTuple() && N0->getOpcode() == ISD::SPLAT_VECTOR) { + unsigned NF = VT.getRISCVVectorTupleNumFields(); + unsigned NumScalElts = VT.getSizeInBits().getKnownMinValue() / (NF * 8); + SDValue EltVal = DAG.getConstant(0, DL, Subtarget.getXLenVT()); + MVT ScalTy = MVT::getScalableVectorVT(MVT::getIntegerVT(8), NumScalElts); + + SDValue Splat = DAG.getNode(ISD::SPLAT_VECTOR, DL, ScalTy, EltVal); + + SDValue Result = DAG.getUNDEF(VT); + for (unsigned i = 0; i < NF; ++i) + Result = DAG.getNode(RISCVISD::TUPLE_INSERT, DL, VT, Result, Splat, + DAG.getVectorIdxConstant(i, DL)); + return Result; + } // If this is a bitcast between a MVT::v4i1/v2i1/v1i1 and an illegal integer // type, widen both sides to avoid a trip through memory. if ((SrcVT == MVT::v1i1 || SrcVT == MVT::v2i1 || SrcVT == MVT::v4i1) && diff --git a/llvm/test/CodeGen/RISCV/vector-tuple-zeroinitializer.ll b/llvm/test/CodeGen/RISCV/vector-tuple-zeroinitializer.ll new file mode 100644 index 0000000000000..fb1104e0a3b80 --- /dev/null +++ b/llvm/test/CodeGen/RISCV/vector-tuple-zeroinitializer.ll @@ -0,0 +1,52 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc -mtriple=riscv32 -mattr=+v \ +; RUN: -verify-machineinstrs %s -o - | FileCheck %s --check-prefixes=CHECK +; RUN: llc -mtriple=riscv64 -mattr=+v \ +; RUN: -verify-machineinstrs %s -o - | FileCheck %s --check-prefixes=CHECK + +define target("riscv.vector.tuple", , 2) @test_tuple_zero_power_of_2() { +; CHECK-LABEL: test_tuple_zero_power_of_2: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vsetvli a0, zero, e8, m2, ta, ma +; CHECK-NEXT: vmv.v.i v8, 0 +; CHECK-NEXT: vmv.v.i v10, 0 +; CHECK-NEXT: ret +entry: + ret target("riscv.vector.tuple", , 2) zeroinitializer +} + +define target("riscv.vector.tuple", , 3) @test_tuple_zero_non_power_of_2() { +; CHECK-LABEL: test_tuple_zero_non_power_of_2: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vsetvli a0, zero, e8, m2, ta, ma +; CHECK-NEXT: vmv.v.i v8, 0 +; CHECK-NEXT: vmv.v.i v10, 0 +; CHECK-NEXT: vmv.v.i v12, 0 +; CHECK-NEXT: ret +entry: + ret target("riscv.vector.tuple", , 3) zeroinitializer +} + +define target("riscv.vector.tuple", , 2) @test_tuple_zero_insert1( %a) { +; CHECK-LABEL: test_tuple_zero_insert1: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vsetvli a0, zero, e8, m2, ta, ma +; CHECK-NEXT: vmv.v.i v10, 0 +; CHECK-NEXT: ret +entry: + %1 = call target("riscv.vector.tuple", , 2) @llvm.riscv.tuple.insert.triscv.vector.tuple_nxv16i8_2t.nxv4i32(target("riscv.vector.tuple", , 2) zeroinitializer, %a, i32 0) + ret target("riscv.vector.tuple", , 2) %1 +} + +define target("riscv.vector.tuple", , 2) @test_tuple_zero_insert2( %a) { +; CHECK-LABEL: test_tuple_zero_insert2: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vsetvli a0, zero, e8, m2, ta, ma +; CHECK-NEXT: vmv.v.i v6, 0 +; CHECK-NEXT: vmv2r.v v10, v8 +; CHECK-NEXT: vmv2r.v v8, v6 +; CHECK-NEXT: ret +entry: + %1 = call target("riscv.vector.tuple", , 2) @llvm.riscv.tuple.insert.triscv.vector.tuple_nxv16i8_2t.nxv4i32(target("riscv.vector.tuple", , 2) zeroinitializer, %a, i32 1) + ret target("riscv.vector.tuple", , 2) %1 +}