Skip to content

Commit c95ec85

Browse files
committed
[RISCV][VLOPT] Compute demanded VLs up front. NFC
This replaces the worklist by instead computing what VL is demanded by each instruction's users first. checkUsers essentially already did this, so it's been renamed to computeDemandedVL. The demanded VLs are stored in a DenseMap, and then we can just do a single forward pass of tryReduceVL where we check if a candidate's demanded VL is less than its VLOp. This means the pass should now be in linear complexity, and allows us to relax the restriction on tied operands in more easily as in llvm#124066. Note that in order to avoid std::optional inside the DenseMap, I've removed the std::optionals and replaced them with VLMAX or 0 constant operands.
1 parent e0781aa commit c95ec85

File tree

5 files changed

+376
-77
lines changed

5 files changed

+376
-77
lines changed

llvm/lib/Target/RISCV/RISCVInstrInfo.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4235,6 +4235,8 @@ unsigned RISCV::getDestLog2EEW(const MCInstrDesc &Desc, unsigned Log2SEW) {
42354235

42364236
/// Given two VL operands, do we know that LHS <= RHS?
42374237
bool RISCV::isVLKnownLE(const MachineOperand &LHS, const MachineOperand &RHS) {
4238+
if (LHS.isImm() && LHS.getImm() == 0)
4239+
return true;
42384240
if (LHS.isReg() && RHS.isReg() && LHS.getReg().isVirtual() &&
42394241
LHS.getReg() == RHS.getReg())
42404242
return true;

llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp

Lines changed: 69 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ namespace {
3333
class RISCVVLOptimizer : public MachineFunctionPass {
3434
const MachineRegisterInfo *MRI;
3535
const MachineDominatorTree *MDT;
36+
const TargetInstrInfo *TII;
3637

3738
public:
3839
static char ID;
@@ -50,12 +51,15 @@ class RISCVVLOptimizer : public MachineFunctionPass {
5051
StringRef getPassName() const override { return PASS_NAME; }
5152

5253
private:
53-
std::optional<MachineOperand> getMinimumVLForUser(MachineOperand &UserOp);
54-
/// Returns the largest common VL MachineOperand that may be used to optimize
55-
/// MI. Returns std::nullopt if it failed to find a suitable VL.
56-
std::optional<MachineOperand> checkUsers(MachineInstr &MI);
54+
MachineOperand getMinimumVLForUser(MachineOperand &UserOp);
55+
/// Computes the VL of \p MI that is actually used by its users.
56+
MachineOperand computeDemandedVL(const MachineInstr &MI);
5757
bool tryReduceVL(MachineInstr &MI);
5858
bool isCandidate(const MachineInstr &MI) const;
59+
60+
/// For a given instruction, records what elements of it are demanded by
61+
/// downstream users.
62+
DenseMap<const MachineInstr *, MachineOperand> DemandedVLs;
5963
};
6064

6165
} // end anonymous namespace
@@ -1173,15 +1177,14 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
11731177
return true;
11741178
}
11751179

1176-
std::optional<MachineOperand>
1177-
RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) {
1180+
MachineOperand RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) {
11781181
const MachineInstr &UserMI = *UserOp.getParent();
11791182
const MCInstrDesc &Desc = UserMI.getDesc();
11801183

11811184
if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) {
11821185
LLVM_DEBUG(dbgs() << " Abort due to lack of VL, assume that"
11831186
" use VLMAX\n");
1184-
return std::nullopt;
1187+
return MachineOperand::CreateImm(RISCV::VLMaxSentinel);
11851188
}
11861189

11871190
// Instructions like reductions may use a vector register as a scalar
@@ -1201,46 +1204,59 @@ RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) {
12011204
// Looking for an immediate or a register VL that isn't X0.
12021205
assert((!VLOp.isReg() || VLOp.getReg() != RISCV::X0) &&
12031206
"Did not expect X0 VL");
1207+
1208+
// If we know the demanded VL of UserMI, then we can reduce the VL it
1209+
// requires.
1210+
if (DemandedVLs.contains(&UserMI)) {
1211+
// We can only shrink the demanded VL if the elementwise result doesn't
1212+
// depend on VL (i.e. not vredsum/viota etc.)
1213+
// Also conservatively restrict to supported instructions for now.
1214+
// TODO: Can we remove the isSupportedInstr check?
1215+
if (!RISCVII::elementsDependOnVL(
1216+
TII->get(RISCV::getRVVMCOpcode(UserMI.getOpcode())).TSFlags) &&
1217+
isSupportedInstr(UserMI)) {
1218+
const MachineOperand &DemandedVL = DemandedVLs.at(&UserMI);
1219+
if (RISCV::isVLKnownLE(DemandedVL, VLOp))
1220+
return DemandedVL;
1221+
}
1222+
}
1223+
12041224
return VLOp;
12051225
}
12061226

1207-
std::optional<MachineOperand> RISCVVLOptimizer::checkUsers(MachineInstr &MI) {
1208-
// FIXME: Avoid visiting each user for each time we visit something on the
1209-
// worklist, combined with an extra visit from the outer loop. Restructure
1210-
// along lines of an instcombine style worklist which integrates the outer
1211-
// pass.
1212-
std::optional<MachineOperand> CommonVL;
1227+
MachineOperand RISCVVLOptimizer::computeDemandedVL(const MachineInstr &MI) {
1228+
const MachineOperand &VLMAX = MachineOperand::CreateImm(RISCV::VLMaxSentinel);
1229+
MachineOperand DemandedVL = MachineOperand::CreateImm(0);
1230+
12131231
for (auto &UserOp : MRI->use_operands(MI.getOperand(0).getReg())) {
12141232
const MachineInstr &UserMI = *UserOp.getParent();
12151233
LLVM_DEBUG(dbgs() << " Checking user: " << UserMI << "\n");
12161234
if (mayReadPastVL(UserMI)) {
12171235
LLVM_DEBUG(dbgs() << " Abort because used by unsafe instruction\n");
1218-
return std::nullopt;
1236+
return VLMAX;
12191237
}
12201238

12211239
// If used as a passthru, elements past VL will be read.
12221240
if (UserOp.isTied()) {
12231241
LLVM_DEBUG(dbgs() << " Abort because user used as tied operand\n");
1224-
return std::nullopt;
1242+
return VLMAX;
12251243
}
12261244

1227-
auto VLOp = getMinimumVLForUser(UserOp);
1228-
if (!VLOp)
1229-
return std::nullopt;
1245+
const MachineOperand &VLOp = getMinimumVLForUser(UserOp);
12301246

12311247
// Use the largest VL among all the users. If we cannot determine this
12321248
// statically, then we cannot optimize the VL.
1233-
if (!CommonVL || RISCV::isVLKnownLE(*CommonVL, *VLOp)) {
1234-
CommonVL = *VLOp;
1235-
LLVM_DEBUG(dbgs() << " User VL is: " << VLOp << "\n");
1236-
} else if (!RISCV::isVLKnownLE(*VLOp, *CommonVL)) {
1249+
if (RISCV::isVLKnownLE(DemandedVL, VLOp)) {
1250+
DemandedVL = VLOp;
1251+
LLVM_DEBUG(dbgs() << " Demanded VL is: " << VLOp << "\n");
1252+
} else if (!RISCV::isVLKnownLE(VLOp, DemandedVL)) {
12371253
LLVM_DEBUG(dbgs() << " Abort because cannot determine a common VL\n");
1238-
return std::nullopt;
1254+
return VLMAX;
12391255
}
12401256

12411257
if (!RISCVII::hasSEWOp(UserMI.getDesc().TSFlags)) {
12421258
LLVM_DEBUG(dbgs() << " Abort due to lack of SEW operand\n");
1243-
return std::nullopt;
1259+
return VLMAX;
12441260
}
12451261

12461262
std::optional<OperandInfo> ConsumerInfo = getOperandInfo(UserOp, MRI);
@@ -1250,7 +1266,7 @@ std::optional<MachineOperand> RISCVVLOptimizer::checkUsers(MachineInstr &MI) {
12501266
LLVM_DEBUG(dbgs() << " Abort due to unknown operand information.\n");
12511267
LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n");
12521268
LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n");
1253-
return std::nullopt;
1269+
return VLMAX;
12541270
}
12551271

12561272
// If the operand is used as a scalar operand, then the EEW must be
@@ -1265,11 +1281,11 @@ std::optional<MachineOperand> RISCVVLOptimizer::checkUsers(MachineInstr &MI) {
12651281
<< " Abort due to incompatible information for EMUL or EEW.\n");
12661282
LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n");
12671283
LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n");
1268-
return std::nullopt;
1284+
return VLMAX;
12691285
}
12701286
}
12711287

1272-
return CommonVL;
1288+
return DemandedVL;
12731289
}
12741290

12751291
bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI) {
@@ -1285,40 +1301,40 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI) {
12851301
return false;
12861302
}
12871303

1288-
auto CommonVL = checkUsers(MI);
1304+
auto CommonVL = DemandedVLs[&MI];
12891305
if (!CommonVL)
12901306
return false;
12911307

1292-
assert((CommonVL->isImm() || CommonVL->getReg().isVirtual()) &&
1308+
assert((CommonVL.isImm() || CommonVL.getReg().isVirtual()) &&
12931309
"Expected VL to be an Imm or virtual Reg");
12941310

12951311
if (!RISCV::isVLKnownLE(*CommonVL, VLOp)) {
12961312
LLVM_DEBUG(dbgs() << " Abort due to CommonVL not <= VLOp.\n");
12971313
return false;
12981314
}
12991315

1300-
if (CommonVL->isIdenticalTo(VLOp)) {
1316+
if (CommonVL.isIdenticalTo(VLOp)) {
13011317
LLVM_DEBUG(
1302-
dbgs() << " Abort due to CommonVL == VLOp, no point in reducing.\n");
1318+
dbgs()
1319+
<< " Abort due to DemandedVL == VLOp, no point in reducing.\n");
13031320
return false;
13041321
}
13051322

1306-
if (CommonVL->isImm()) {
1323+
if (CommonVL.isImm()) {
13071324
LLVM_DEBUG(dbgs() << " Reduce VL from " << VLOp << " to "
1308-
<< CommonVL->getImm() << " for " << MI << "\n");
1309-
VLOp.ChangeToImmediate(CommonVL->getImm());
1325+
<< CommonVL.getImm() << " for " << MI << "\n");
1326+
VLOp.ChangeToImmediate(CommonVL.getImm());
13101327
return true;
13111328
}
1312-
const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->getReg());
1329+
const MachineInstr *VLMI = MRI->getVRegDef(CommonVL.getReg());
13131330
if (!MDT->dominates(VLMI, &MI))
13141331
return false;
1315-
LLVM_DEBUG(
1316-
dbgs() << " Reduce VL from " << VLOp << " to "
1317-
<< printReg(CommonVL->getReg(), MRI->getTargetRegisterInfo())
1318-
<< " for " << MI << "\n");
1332+
LLVM_DEBUG(dbgs() << " Reduce VL from " << VLOp << " to "
1333+
<< printReg(CommonVL.getReg(), MRI->getTargetRegisterInfo())
1334+
<< " for " << MI << "\n");
13191335

13201336
// All our checks passed. We can reduce VL.
1321-
VLOp.ChangeToRegister(CommonVL->getReg(), false);
1337+
VLOp.ChangeToRegister(CommonVL.getReg(), false);
13221338
return true;
13231339
}
13241340

@@ -1333,52 +1349,33 @@ bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) {
13331349
if (!ST.hasVInstructions())
13341350
return false;
13351351

1336-
SetVector<MachineInstr *> Worklist;
1337-
auto PushOperands = [this, &Worklist](MachineInstr &MI,
1338-
bool IgnoreSameBlock) {
1339-
for (auto &Op : MI.operands()) {
1340-
if (!Op.isReg() || !Op.isUse() || !Op.getReg().isVirtual() ||
1341-
!isVectorRegClass(Op.getReg(), MRI))
1342-
continue;
1352+
TII = ST.getInstrInfo();
13431353

1344-
MachineInstr *DefMI = MRI->getVRegDef(Op.getReg());
1345-
if (!isCandidate(*DefMI))
1346-
continue;
1347-
1348-
if (IgnoreSameBlock && DefMI->getParent() == MI.getParent())
1349-
continue;
1350-
1351-
Worklist.insert(DefMI);
1352-
}
1353-
};
1354-
1355-
// Do a first pass eagerly rewriting in roughly reverse instruction
1356-
// order, populate the worklist with any instructions we might need to
1357-
// revisit. We avoid adding definitions to the worklist if they're
1358-
// in the same block - we're about to visit them anyways.
13591354
bool MadeChange = false;
13601355
for (MachineBasicBlock &MBB : MF) {
13611356
// Avoid unreachable blocks as they have degenerate dominance
13621357
if (!MDT->isReachableFromEntry(&MBB))
13631358
continue;
13641359

1365-
for (auto &MI : reverse(MBB)) {
1360+
// For each instruction that defines a vector, compute what VL its
1361+
// downstream users demand.
1362+
for (const auto &MI : reverse(MBB)) {
1363+
if (!isCandidate(MI))
1364+
continue;
1365+
DemandedVLs.insert({&MI, computeDemandedVL(MI)});
1366+
}
1367+
1368+
// Then go through and see if we can reduce the VL of any instructions to
1369+
// only what's demanded.
1370+
for (auto &MI : MBB) {
13661371
if (!isCandidate(MI))
13671372
continue;
13681373
if (!tryReduceVL(MI))
13691374
continue;
13701375
MadeChange = true;
1371-
PushOperands(MI, /*IgnoreSameBlock*/ true);
13721376
}
1373-
}
13741377

1375-
while (!Worklist.empty()) {
1376-
assert(MadeChange);
1377-
MachineInstr &MI = *Worklist.pop_back_val();
1378-
assert(isCandidate(MI));
1379-
if (!tryReduceVL(MI))
1380-
continue;
1381-
PushOperands(MI, /*IgnoreSameBlock*/ false);
1378+
DemandedVLs.clear();
13821379
}
13831380

13841381
return MadeChange;

0 commit comments

Comments
 (0)