Skip to content

Commit 6bac414

Browse files
[RISCV][GISEL] Legalize G_INSERT_SUBVECTOR (#108859)
This code is heavily based on the SelectionDAG lowerINSERT_SUBVECTOR code.
1 parent 22e21bc commit 6bac414

File tree

7 files changed

+881
-4
lines changed

7 files changed

+881
-4
lines changed

llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,18 @@ class GExtractSubvector : public GenericMachineInstr {
811811
}
812812
};
813813

814+
/// Represents a insert subvector.
815+
class GInsertSubvector : public GenericMachineInstr {
816+
public:
817+
Register getBigVec() const { return getOperand(1).getReg(); }
818+
Register getSubVec() const { return getOperand(2).getReg(); }
819+
uint64_t getIndexImm() const { return getOperand(3).getImm(); }
820+
821+
static bool classof(const MachineInstr *MI) {
822+
return MI->getOpcode() == TargetOpcode::G_INSERT_SUBVECTOR;
823+
}
824+
};
825+
814826
/// Represents a freeze.
815827
class GFreeze : public GenericMachineInstr {
816828
public:

llvm/include/llvm/CodeGen/GlobalISel/LegalizerHelper.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,8 @@ class LegalizerHelper {
380380
LLT CastTy);
381381
LegalizeResult bitcastExtractSubvector(MachineInstr &MI, unsigned TypeIdx,
382382
LLT CastTy);
383+
LegalizeResult bitcastInsertSubvector(MachineInstr &MI, unsigned TypeIdx,
384+
LLT CastTy);
383385

384386
LegalizeResult lowerConstant(MachineInstr &MI);
385387
LegalizeResult lowerFConstant(MachineInstr &MI);

llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3276,6 +3276,33 @@ LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) {
32763276
Observer.changedInstr(MI);
32773277
return Legalized;
32783278
}
3279+
case TargetOpcode::G_INSERT_SUBVECTOR: {
3280+
if (TypeIdx != 0)
3281+
return UnableToLegalize;
3282+
3283+
GInsertSubvector &IS = cast<GInsertSubvector>(MI);
3284+
Register BigVec = IS.getBigVec();
3285+
Register SubVec = IS.getSubVec();
3286+
3287+
LLT SubVecTy = MRI.getType(SubVec);
3288+
LLT SubVecWideTy = SubVecTy.changeElementType(WideTy.getElementType());
3289+
3290+
// Widen the G_INSERT_SUBVECTOR
3291+
auto BigZExt = MIRBuilder.buildZExt(WideTy, BigVec);
3292+
auto SubZExt = MIRBuilder.buildZExt(SubVecWideTy, SubVec);
3293+
auto WideInsert = MIRBuilder.buildInsertSubvector(WideTy, BigZExt, SubZExt,
3294+
IS.getIndexImm());
3295+
3296+
// Truncate back down
3297+
auto SplatZero = MIRBuilder.buildSplatVector(
3298+
WideTy, MIRBuilder.buildConstant(WideTy.getElementType(), 0));
3299+
MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_NE, IS.getReg(0), WideInsert,
3300+
SplatZero);
3301+
3302+
MI.eraseFromParent();
3303+
3304+
return Legalized;
3305+
}
32793306
}
32803307
}
32813308

@@ -3725,6 +3752,77 @@ LegalizerHelper::bitcastExtractSubvector(MachineInstr &MI, unsigned TypeIdx,
37253752
return Legalized;
37263753
}
37273754

3755+
/// This attempts to bitcast G_INSERT_SUBVECTOR to CastTy.
3756+
///
3757+
/// <vscale x 16 x i1> = G_INSERT_SUBVECTOR <vscale x 16 x i1>,
3758+
/// <vscale x 8 x i1>,
3759+
/// N
3760+
///
3761+
/// ===>
3762+
///
3763+
/// <vscale x 2 x i8> = G_BITCAST <vscale x 16 x i1>
3764+
/// <vscale x 1 x i8> = G_BITCAST <vscale x 8 x i1>
3765+
/// <vscale x 2 x i8> = G_INSERT_SUBVECTOR <vscale x 2 x i8>,
3766+
/// <vscale x 1 x i8>, N / 8
3767+
/// <vscale x 16 x i1> = G_BITCAST <vscale x 2 x i8>
3768+
LegalizerHelper::LegalizeResult
3769+
LegalizerHelper::bitcastInsertSubvector(MachineInstr &MI, unsigned TypeIdx,
3770+
LLT CastTy) {
3771+
auto ES = cast<GInsertSubvector>(&MI);
3772+
3773+
if (!CastTy.isVector())
3774+
return UnableToLegalize;
3775+
3776+
if (TypeIdx != 0)
3777+
return UnableToLegalize;
3778+
3779+
Register Dst = ES->getReg(0);
3780+
Register BigVec = ES->getBigVec();
3781+
Register SubVec = ES->getSubVec();
3782+
uint64_t Idx = ES->getIndexImm();
3783+
3784+
MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
3785+
3786+
LLT DstTy = MRI.getType(Dst);
3787+
LLT BigVecTy = MRI.getType(BigVec);
3788+
LLT SubVecTy = MRI.getType(SubVec);
3789+
3790+
if (DstTy == CastTy)
3791+
return Legalized;
3792+
3793+
if (DstTy.getSizeInBits() != CastTy.getSizeInBits())
3794+
return UnableToLegalize;
3795+
3796+
ElementCount DstTyEC = DstTy.getElementCount();
3797+
ElementCount BigVecTyEC = BigVecTy.getElementCount();
3798+
ElementCount SubVecTyEC = SubVecTy.getElementCount();
3799+
auto DstTyMinElts = DstTyEC.getKnownMinValue();
3800+
auto BigVecTyMinElts = BigVecTyEC.getKnownMinValue();
3801+
auto SubVecTyMinElts = SubVecTyEC.getKnownMinValue();
3802+
3803+
unsigned CastEltSize = CastTy.getElementType().getSizeInBits();
3804+
unsigned DstEltSize = DstTy.getElementType().getSizeInBits();
3805+
if (CastEltSize < DstEltSize)
3806+
return UnableToLegalize;
3807+
3808+
auto AdjustAmt = CastEltSize / DstEltSize;
3809+
if (Idx % AdjustAmt != 0 || DstTyMinElts % AdjustAmt != 0 ||
3810+
BigVecTyMinElts % AdjustAmt != 0 || SubVecTyMinElts % AdjustAmt != 0)
3811+
return UnableToLegalize;
3812+
3813+
Idx /= AdjustAmt;
3814+
BigVecTy = LLT::vector(BigVecTyEC.divideCoefficientBy(AdjustAmt), AdjustAmt);
3815+
SubVecTy = LLT::vector(SubVecTyEC.divideCoefficientBy(AdjustAmt), AdjustAmt);
3816+
auto CastBigVec = MIRBuilder.buildBitcast(BigVecTy, BigVec);
3817+
auto CastSubVec = MIRBuilder.buildBitcast(SubVecTy, SubVec);
3818+
auto PromotedIS =
3819+
MIRBuilder.buildInsertSubvector(CastTy, CastBigVec, CastSubVec, Idx);
3820+
MIRBuilder.buildBitcast(Dst, PromotedIS);
3821+
3822+
ES->eraseFromParent();
3823+
return Legalized;
3824+
}
3825+
37283826
LegalizerHelper::LegalizeResult LegalizerHelper::lowerLoad(GAnyLoad &LoadMI) {
37293827
// Lower to a memory-width G_LOAD and a G_SEXT/G_ZEXT/G_ANYEXT
37303828
Register DstReg = LoadMI.getDstReg();
@@ -4033,6 +4131,8 @@ LegalizerHelper::bitcast(MachineInstr &MI, unsigned TypeIdx, LLT CastTy) {
40334131
return bitcastConcatVector(MI, TypeIdx, CastTy);
40344132
case TargetOpcode::G_EXTRACT_SUBVECTOR:
40354133
return bitcastExtractSubvector(MI, TypeIdx, CastTy);
4134+
case TargetOpcode::G_INSERT_SUBVECTOR:
4135+
return bitcastInsertSubvector(MI, TypeIdx, CastTy);
40364136
default:
40374137
return UnableToLegalize;
40384138
}

llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp

Lines changed: 138 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,12 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
615615
all(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST),
616616
typeIsLegalIntOrFPVec(1, IntOrFPVecTys, ST))));
617617

618+
getActionDefinitionsBuilder(G_INSERT_SUBVECTOR)
619+
.customIf(all(typeIsLegalBoolVec(0, BoolVecTys, ST),
620+
typeIsLegalBoolVec(1, BoolVecTys, ST)))
621+
.customIf(all(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST),
622+
typeIsLegalIntOrFPVec(1, IntOrFPVecTys, ST)));
623+
618624
getLegacyLegalizerInfo().computeTables();
619625
}
620626

@@ -834,9 +840,7 @@ static MachineInstrBuilder buildAllOnesMask(LLT VecTy, const SrcOp &VL,
834840
/// Gets the two common "VL" operands: an all-ones mask and the vector length.
835841
/// VecTy is a scalable vector type.
836842
static std::pair<MachineInstrBuilder, MachineInstrBuilder>
837-
buildDefaultVLOps(const DstOp &Dst, MachineIRBuilder &MIB,
838-
MachineRegisterInfo &MRI) {
839-
LLT VecTy = Dst.getLLTTy(MRI);
843+
buildDefaultVLOps(LLT VecTy, MachineIRBuilder &MIB, MachineRegisterInfo &MRI) {
840844
assert(VecTy.isScalableVector() && "Expecting scalable container type");
841845
const RISCVSubtarget &STI = MIB.getMF().getSubtarget<RISCVSubtarget>();
842846
LLT XLenTy(STI.getXLenVT());
@@ -890,7 +894,7 @@ bool RISCVLegalizerInfo::legalizeSplatVector(MachineInstr &MI,
890894
// Handle case of s64 element vectors on rv32
891895
if (XLenTy.getSizeInBits() == 32 &&
892896
VecTy.getElementType().getSizeInBits() == 64) {
893-
auto [_, VL] = buildDefaultVLOps(Dst, MIB, MRI);
897+
auto [_, VL] = buildDefaultVLOps(MRI.getType(Dst), MIB, MRI);
894898
buildSplatSplitS64WithVL(Dst, MIB.buildUndef(VecTy), SplatVal, VL, MIB,
895899
MRI);
896900
MI.eraseFromParent();
@@ -1025,6 +1029,134 @@ bool RISCVLegalizerInfo::legalizeExtractSubvector(MachineInstr &MI,
10251029
return true;
10261030
}
10271031

1032+
bool RISCVLegalizerInfo::legalizeInsertSubvector(MachineInstr &MI,
1033+
LegalizerHelper &Helper,
1034+
MachineIRBuilder &MIB) const {
1035+
GInsertSubvector &IS = cast<GInsertSubvector>(MI);
1036+
1037+
MachineRegisterInfo &MRI = *MIB.getMRI();
1038+
1039+
Register Dst = IS.getReg(0);
1040+
Register BigVec = IS.getBigVec();
1041+
Register LitVec = IS.getSubVec();
1042+
uint64_t Idx = IS.getIndexImm();
1043+
1044+
LLT BigTy = MRI.getType(BigVec);
1045+
LLT LitTy = MRI.getType(LitVec);
1046+
1047+
if (Idx == 0 ||
1048+
MRI.getVRegDef(BigVec)->getOpcode() == TargetOpcode::G_IMPLICIT_DEF)
1049+
return true;
1050+
1051+
// We don't have the ability to slide mask vectors up indexed by their i1
1052+
// elements; the smallest we can do is i8. Often we are able to bitcast to
1053+
// equivalent i8 vectors. Otherwise, we can must zeroextend to equivalent i8
1054+
// vectors and truncate down after the insert.
1055+
if (LitTy.getElementType() == LLT::scalar(1)) {
1056+
auto BigTyMinElts = BigTy.getElementCount().getKnownMinValue();
1057+
auto LitTyMinElts = LitTy.getElementCount().getKnownMinValue();
1058+
if (BigTyMinElts >= 8 && LitTyMinElts >= 8)
1059+
return Helper.bitcast(
1060+
IS, 0,
1061+
LLT::vector(BigTy.getElementCount().divideCoefficientBy(8), 8));
1062+
1063+
// We can't slide this mask vector up indexed by its i1 elements.
1064+
// This poses a problem when we wish to insert a scalable vector which
1065+
// can't be re-expressed as a larger type. Just choose the slow path and
1066+
// extend to a larger type, then truncate back down.
1067+
LLT ExtBigTy = BigTy.changeElementType(LLT::scalar(8));
1068+
return Helper.widenScalar(IS, 0, ExtBigTy);
1069+
}
1070+
1071+
const RISCVRegisterInfo *TRI = STI.getRegisterInfo();
1072+
unsigned SubRegIdx, RemIdx;
1073+
std::tie(SubRegIdx, RemIdx) =
1074+
RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(
1075+
getMVTForLLT(BigTy), getMVTForLLT(LitTy), Idx, TRI);
1076+
1077+
TypeSize VecRegSize = TypeSize::getScalable(RISCV::RVVBitsPerBlock);
1078+
assert(isPowerOf2_64(
1079+
STI.expandVScale(LitTy.getSizeInBits()).getKnownMinValue()));
1080+
bool ExactlyVecRegSized =
1081+
STI.expandVScale(LitTy.getSizeInBits())
1082+
.isKnownMultipleOf(STI.expandVScale(VecRegSize));
1083+
1084+
// If the Idx has been completely eliminated and this subvector's size is a
1085+
// vector register or a multiple thereof, or the surrounding elements are
1086+
// undef, then this is a subvector insert which naturally aligns to a vector
1087+
// register. These can easily be handled using subregister manipulation.
1088+
if (RemIdx == 0 && ExactlyVecRegSized)
1089+
return true;
1090+
1091+
// If the subvector is smaller than a vector register, then the insertion
1092+
// must preserve the undisturbed elements of the register. We do this by
1093+
// lowering to an EXTRACT_SUBVECTOR grabbing the nearest LMUL=1 vector type
1094+
// (which resolves to a subregister copy), performing a VSLIDEUP to place the
1095+
// subvector within the vector register, and an INSERT_SUBVECTOR of that
1096+
// LMUL=1 type back into the larger vector (resolving to another subregister
1097+
// operation). See below for how our VSLIDEUP works. We go via a LMUL=1 type
1098+
// to avoid allocating a large register group to hold our subvector.
1099+
1100+
// VSLIDEUP works by leaving elements 0<i<OFFSET undisturbed, elements
1101+
// OFFSET<=i<VL set to the "subvector" and vl<=i<VLMAX set to the tail policy
1102+
// (in our case undisturbed). This means we can set up a subvector insertion
1103+
// where OFFSET is the insertion offset, and the VL is the OFFSET plus the
1104+
// size of the subvector.
1105+
const LLT XLenTy(STI.getXLenVT());
1106+
LLT InterLitTy = BigTy;
1107+
Register AlignedExtract = BigVec;
1108+
unsigned AlignedIdx = Idx - RemIdx;
1109+
if (TypeSize::isKnownGT(BigTy.getSizeInBits(),
1110+
getLMUL1Ty(BigTy).getSizeInBits())) {
1111+
InterLitTy = getLMUL1Ty(BigTy);
1112+
// Extract a subvector equal to the nearest full vector register type. This
1113+
// should resolve to a G_EXTRACT on a subreg.
1114+
AlignedExtract =
1115+
MIB.buildExtractSubvector(InterLitTy, BigVec, AlignedIdx).getReg(0);
1116+
}
1117+
1118+
auto Insert = MIB.buildInsertSubvector(InterLitTy, MIB.buildUndef(InterLitTy),
1119+
LitVec, 0);
1120+
1121+
auto [Mask, _] = buildDefaultVLOps(BigTy, MIB, MRI);
1122+
auto VL = MIB.buildVScale(XLenTy, LitTy.getElementCount().getKnownMinValue());
1123+
1124+
// If we're inserting into the lowest elements, use a tail undisturbed
1125+
// vmv.v.v.
1126+
MachineInstrBuilder Inserted;
1127+
bool NeedInsertSubvec =
1128+
TypeSize::isKnownGT(BigTy.getSizeInBits(), InterLitTy.getSizeInBits());
1129+
Register InsertedDst =
1130+
NeedInsertSubvec ? MRI.createGenericVirtualRegister(InterLitTy) : Dst;
1131+
if (RemIdx == 0) {
1132+
Inserted = MIB.buildInstr(RISCV::G_VMV_V_V_VL, {InsertedDst},
1133+
{AlignedExtract, Insert, VL});
1134+
} else {
1135+
auto SlideupAmt = MIB.buildVScale(XLenTy, RemIdx);
1136+
// Construct the vector length corresponding to RemIdx + length(LitTy).
1137+
VL = MIB.buildAdd(XLenTy, SlideupAmt, VL);
1138+
// Use tail agnostic policy if we're inserting over InterLitTy's tail.
1139+
ElementCount EndIndex =
1140+
ElementCount::getScalable(RemIdx) + LitTy.getElementCount();
1141+
uint64_t Policy = RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED;
1142+
if (STI.expandVScale(EndIndex) ==
1143+
STI.expandVScale(InterLitTy.getElementCount()))
1144+
Policy = RISCVII::TAIL_AGNOSTIC;
1145+
1146+
Inserted =
1147+
MIB.buildInstr(RISCV::G_VSLIDEUP_VL, {InsertedDst},
1148+
{AlignedExtract, Insert, SlideupAmt, Mask, VL, Policy});
1149+
}
1150+
1151+
// If required, insert this subvector back into the correct vector register.
1152+
// This should resolve to an INSERT_SUBREG instruction.
1153+
if (NeedInsertSubvec)
1154+
MIB.buildInsertSubvector(Dst, BigVec, Inserted, AlignedIdx);
1155+
1156+
MI.eraseFromParent();
1157+
return true;
1158+
}
1159+
10281160
bool RISCVLegalizerInfo::legalizeCustom(
10291161
LegalizerHelper &Helper, MachineInstr &MI,
10301162
LostDebugLocObserver &LocObserver) const {
@@ -1092,6 +1224,8 @@ bool RISCVLegalizerInfo::legalizeCustom(
10921224
return legalizeSplatVector(MI, MIRBuilder);
10931225
case TargetOpcode::G_EXTRACT_SUBVECTOR:
10941226
return legalizeExtractSubvector(MI, MIRBuilder);
1227+
case TargetOpcode::G_INSERT_SUBVECTOR:
1228+
return legalizeInsertSubvector(MI, Helper, MIRBuilder);
10951229
case TargetOpcode::G_LOAD:
10961230
case TargetOpcode::G_STORE:
10971231
return legalizeLoadStore(MI, Helper, MIRBuilder);

llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ class RISCVLegalizerInfo : public LegalizerInfo {
4747
bool legalizeExt(MachineInstr &MI, MachineIRBuilder &MIRBuilder) const;
4848
bool legalizeSplatVector(MachineInstr &MI, MachineIRBuilder &MIB) const;
4949
bool legalizeExtractSubvector(MachineInstr &MI, MachineIRBuilder &MIB) const;
50+
bool legalizeInsertSubvector(MachineInstr &MI, LegalizerHelper &Helper,
51+
MachineIRBuilder &MIB) const;
5052
bool legalizeLoadStore(MachineInstr &MI, LegalizerHelper &Helper,
5153
MachineIRBuilder &MIB) const;
5254
};

llvm/lib/Target/RISCV/RISCVInstrGISel.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,20 @@ def G_VSLIDEDOWN_VL : RISCVGenericInstruction {
6767
}
6868
def : GINodeEquiv<G_VSLIDEDOWN_VL, riscv_slidedown_vl>;
6969

70+
// Pseudo equivalent to a RISCVISD::VMV_V_V_VL
71+
def G_VMV_V_V_VL : RISCVGenericInstruction {
72+
let OutOperandList = (outs type0:$dst);
73+
let InOperandList = (ins type0:$vec, type2:$vl);
74+
let hasSideEffects = false;
75+
}
76+
def : GINodeEquiv<G_VMV_V_V_VL, riscv_vmv_v_v_vl>;
77+
78+
// Pseudo equivalent to a RISCVISD::VSLIDEUP_VL
79+
def G_VSLIDEUP_VL : RISCVGenericInstruction {
80+
let OutOperandList = (outs type0:$dst);
81+
let InOperandList = (ins type0:$merge, type0:$vec, type1:$idx, type2:$mask,
82+
type3:$vl, type4:$policy);
83+
let hasSideEffects = false;
84+
}
85+
def : GINodeEquiv<G_VSLIDEUP_VL, riscv_slideup_vl>;
86+

0 commit comments

Comments
 (0)