Skip to content

Commit 169a45a

Browse files
committed
[RISCV][VLOPT] Compute demanded VLs up front. NFC
This replaces the worklist by instead computing what VL is demanded by each instruction's users first. checkUsers essentially already did this, so it's been renamed to computeDemandedVL. The demanded VLs are stored in a DenseMap, and then we can just do a single forward pass of tryReduceVL where we check if a candidate's demanded VL is less than its VLOp. This means the pass should now be in linear complexity, and allows us to relax the restriction on tied operands in more easily as in #124066. Note that in order to avoid std::optional inside the DenseMap, I've removed the std::optionals and replaced them with VLMAX or 0 constant operands.
1 parent b2647ff commit 169a45a

File tree

5 files changed

+380
-85
lines changed

5 files changed

+380
-85
lines changed

llvm/lib/Target/RISCV/RISCVInstrInfo.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4232,6 +4232,8 @@ unsigned RISCV::getDestLog2EEW(const MCInstrDesc &Desc, unsigned Log2SEW) {
42324232

42334233
/// Given two VL operands, do we know that LHS <= RHS?
42344234
bool RISCV::isVLKnownLE(const MachineOperand &LHS, const MachineOperand &RHS) {
4235+
if (LHS.isImm() && LHS.getImm() == 0)
4236+
return true;
42354237
if (LHS.isReg() && RHS.isReg() && LHS.getReg().isVirtual() &&
42364238
LHS.getReg() == RHS.getReg())
42374239
return true;

llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp

Lines changed: 71 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ namespace {
3333
class RISCVVLOptimizer : public MachineFunctionPass {
3434
const MachineRegisterInfo *MRI;
3535
const MachineDominatorTree *MDT;
36+
const TargetInstrInfo *TII;
3637

3738
public:
3839
static char ID;
@@ -50,12 +51,15 @@ class RISCVVLOptimizer : public MachineFunctionPass {
5051
StringRef getPassName() const override { return PASS_NAME; }
5152

5253
private:
53-
std::optional<MachineOperand> getMinimumVLForUser(MachineOperand &UserOp);
54-
/// Returns the largest common VL MachineOperand that may be used to optimize
55-
/// MI. Returns std::nullopt if it failed to find a suitable VL.
56-
std::optional<MachineOperand> checkUsers(MachineInstr &MI);
54+
MachineOperand getMinimumVLForUser(MachineOperand &UserOp);
55+
/// Computes the VL of \p MI that is actually used by its users.
56+
MachineOperand computeDemandedVL(const MachineInstr &MI);
5757
bool tryReduceVL(MachineInstr &MI);
5858
bool isCandidate(const MachineInstr &MI) const;
59+
60+
/// For a given instruction, records what elements of it are demanded by
61+
/// downstream users.
62+
DenseMap<const MachineInstr *, MachineOperand> DemandedVLs;
5963
};
6064

6165
} // end anonymous namespace
@@ -1202,15 +1206,14 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
12021206
return true;
12031207
}
12041208

1205-
std::optional<MachineOperand>
1206-
RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) {
1209+
MachineOperand RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) {
12071210
const MachineInstr &UserMI = *UserOp.getParent();
12081211
const MCInstrDesc &Desc = UserMI.getDesc();
12091212

12101213
if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) {
12111214
LLVM_DEBUG(dbgs() << " Abort due to lack of VL, assume that"
12121215
" use VLMAX\n");
1213-
return std::nullopt;
1216+
return MachineOperand::CreateImm(RISCV::VLMaxSentinel);
12141217
}
12151218

12161219
// Instructions like reductions may use a vector register as a scalar
@@ -1230,46 +1233,59 @@ RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) {
12301233
// Looking for an immediate or a register VL that isn't X0.
12311234
assert((!VLOp.isReg() || VLOp.getReg() != RISCV::X0) &&
12321235
"Did not expect X0 VL");
1236+
1237+
// If we know the demanded VL of UserMI, then we can reduce the VL it
1238+
// requires.
1239+
if (DemandedVLs.contains(&UserMI)) {
1240+
// We can only shrink the demanded VL if the elementwise result doesn't
1241+
// depend on VL (i.e. not vredsum/viota etc.)
1242+
// Also conservatively restrict to supported instructions for now.
1243+
// TODO: Can we remove the isSupportedInstr check?
1244+
if (!RISCVII::elementsDependOnVL(
1245+
TII->get(RISCV::getRVVMCOpcode(UserMI.getOpcode())).TSFlags) &&
1246+
isSupportedInstr(UserMI)) {
1247+
const MachineOperand &DemandedVL = DemandedVLs.at(&UserMI);
1248+
if (RISCV::isVLKnownLE(DemandedVL, VLOp))
1249+
return DemandedVL;
1250+
}
1251+
}
1252+
12331253
return VLOp;
12341254
}
12351255

1236-
std::optional<MachineOperand> RISCVVLOptimizer::checkUsers(MachineInstr &MI) {
1237-
// FIXME: Avoid visiting each user for each time we visit something on the
1238-
// worklist, combined with an extra visit from the outer loop. Restructure
1239-
// along lines of an instcombine style worklist which integrates the outer
1240-
// pass.
1241-
std::optional<MachineOperand> CommonVL;
1256+
MachineOperand RISCVVLOptimizer::computeDemandedVL(const MachineInstr &MI) {
1257+
const MachineOperand &VLMAX = MachineOperand::CreateImm(RISCV::VLMaxSentinel);
1258+
MachineOperand DemandedVL = MachineOperand::CreateImm(0);
1259+
12421260
for (auto &UserOp : MRI->use_operands(MI.getOperand(0).getReg())) {
12431261
const MachineInstr &UserMI = *UserOp.getParent();
12441262
LLVM_DEBUG(dbgs() << " Checking user: " << UserMI << "\n");
12451263
if (mayReadPastVL(UserMI)) {
12461264
LLVM_DEBUG(dbgs() << " Abort because used by unsafe instruction\n");
1247-
return std::nullopt;
1265+
return VLMAX;
12481266
}
12491267

12501268
// Tied operands might pass through.
12511269
if (UserOp.isTied()) {
12521270
LLVM_DEBUG(dbgs() << " Abort because user used as tied operand\n");
1253-
return std::nullopt;
1271+
return VLMAX;
12541272
}
12551273

1256-
auto VLOp = getMinimumVLForUser(UserOp);
1257-
if (!VLOp)
1258-
return std::nullopt;
1274+
const MachineOperand &VLOp = getMinimumVLForUser(UserOp);
12591275

12601276
// Use the largest VL among all the users. If we cannot determine this
12611277
// statically, then we cannot optimize the VL.
1262-
if (!CommonVL || RISCV::isVLKnownLE(*CommonVL, *VLOp)) {
1263-
CommonVL = *VLOp;
1264-
LLVM_DEBUG(dbgs() << " User VL is: " << VLOp << "\n");
1265-
} else if (!RISCV::isVLKnownLE(*VLOp, *CommonVL)) {
1278+
if (RISCV::isVLKnownLE(DemandedVL, VLOp)) {
1279+
DemandedVL = VLOp;
1280+
LLVM_DEBUG(dbgs() << " Demanded VL is: " << VLOp << "\n");
1281+
} else if (!RISCV::isVLKnownLE(VLOp, DemandedVL)) {
12661282
LLVM_DEBUG(dbgs() << " Abort because cannot determine a common VL\n");
1267-
return std::nullopt;
1283+
return VLMAX;
12681284
}
12691285

12701286
if (!RISCVII::hasSEWOp(UserMI.getDesc().TSFlags)) {
12711287
LLVM_DEBUG(dbgs() << " Abort due to lack of SEW operand\n");
1272-
return std::nullopt;
1288+
return VLMAX;
12731289
}
12741290

12751291
std::optional<OperandInfo> ConsumerInfo = getOperandInfo(UserOp, MRI);
@@ -1279,7 +1295,7 @@ std::optional<MachineOperand> RISCVVLOptimizer::checkUsers(MachineInstr &MI) {
12791295
LLVM_DEBUG(dbgs() << " Abort due to unknown operand information.\n");
12801296
LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n");
12811297
LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n");
1282-
return std::nullopt;
1298+
return VLMAX;
12831299
}
12841300

12851301
// If the operand is used as a scalar operand, then the EEW must be
@@ -1294,53 +1310,51 @@ std::optional<MachineOperand> RISCVVLOptimizer::checkUsers(MachineInstr &MI) {
12941310
<< " Abort due to incompatible information for EMUL or EEW.\n");
12951311
LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n");
12961312
LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n");
1297-
return std::nullopt;
1313+
return VLMAX;
12981314
}
12991315
}
13001316

1301-
return CommonVL;
1317+
return DemandedVL;
13021318
}
13031319

13041320
bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI) {
13051321
LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n");
13061322

1307-
auto CommonVL = checkUsers(MI);
1308-
if (!CommonVL)
1309-
return false;
1323+
const MachineOperand &CommonVL = DemandedVLs.at(&MI);
13101324

1311-
assert((CommonVL->isImm() || CommonVL->getReg().isVirtual()) &&
1325+
assert((CommonVL.isImm() || CommonVL.getReg().isVirtual()) &&
13121326
"Expected VL to be an Imm or virtual Reg");
13131327

13141328
unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc());
13151329
MachineOperand &VLOp = MI.getOperand(VLOpNum);
13161330

1317-
if (!RISCV::isVLKnownLE(*CommonVL, VLOp)) {
1318-
LLVM_DEBUG(dbgs() << " Abort due to CommonVL not <= VLOp.\n");
1331+
if (!RISCV::isVLKnownLE(CommonVL, VLOp)) {
1332+
LLVM_DEBUG(dbgs() << " Abort due to DemandedVL not <= VLOp.\n");
13191333
return false;
13201334
}
13211335

1322-
if (CommonVL->isIdenticalTo(VLOp)) {
1336+
if (CommonVL.isIdenticalTo(VLOp)) {
13231337
LLVM_DEBUG(
1324-
dbgs() << " Abort due to CommonVL == VLOp, no point in reducing.\n");
1338+
dbgs()
1339+
<< " Abort due to DemandedVL == VLOp, no point in reducing.\n");
13251340
return false;
13261341
}
13271342

1328-
if (CommonVL->isImm()) {
1343+
if (CommonVL.isImm()) {
13291344
LLVM_DEBUG(dbgs() << " Reduce VL from " << VLOp << " to "
1330-
<< CommonVL->getImm() << " for " << MI << "\n");
1331-
VLOp.ChangeToImmediate(CommonVL->getImm());
1345+
<< CommonVL.getImm() << " for " << MI << "\n");
1346+
VLOp.ChangeToImmediate(CommonVL.getImm());
13321347
return true;
13331348
}
1334-
const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->getReg());
1349+
const MachineInstr *VLMI = MRI->getVRegDef(CommonVL.getReg());
13351350
if (!MDT->dominates(VLMI, &MI))
13361351
return false;
1337-
LLVM_DEBUG(
1338-
dbgs() << " Reduce VL from " << VLOp << " to "
1339-
<< printReg(CommonVL->getReg(), MRI->getTargetRegisterInfo())
1340-
<< " for " << MI << "\n");
1352+
LLVM_DEBUG(dbgs() << " Reduce VL from " << VLOp << " to "
1353+
<< printReg(CommonVL.getReg(), MRI->getTargetRegisterInfo())
1354+
<< " for " << MI << "\n");
13411355

13421356
// All our checks passed. We can reduce VL.
1343-
VLOp.ChangeToRegister(CommonVL->getReg(), false);
1357+
VLOp.ChangeToRegister(CommonVL.getReg(), false);
13441358
return true;
13451359
}
13461360

@@ -1355,52 +1369,33 @@ bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) {
13551369
if (!ST.hasVInstructions())
13561370
return false;
13571371

1358-
SetVector<MachineInstr *> Worklist;
1359-
auto PushOperands = [this, &Worklist](MachineInstr &MI,
1360-
bool IgnoreSameBlock) {
1361-
for (auto &Op : MI.operands()) {
1362-
if (!Op.isReg() || !Op.isUse() || !Op.getReg().isVirtual() ||
1363-
!isVectorRegClass(Op.getReg(), MRI))
1364-
continue;
1365-
1366-
MachineInstr *DefMI = MRI->getVRegDef(Op.getReg());
1367-
if (!isCandidate(*DefMI))
1368-
continue;
1369-
1370-
if (IgnoreSameBlock && DefMI->getParent() == MI.getParent())
1371-
continue;
1372-
1373-
Worklist.insert(DefMI);
1374-
}
1375-
};
1372+
TII = ST.getInstrInfo();
13761373

1377-
// Do a first pass eagerly rewriting in roughly reverse instruction
1378-
// order, populate the worklist with any instructions we might need to
1379-
// revisit. We avoid adding definitions to the worklist if they're
1380-
// in the same block - we're about to visit them anyways.
13811374
bool MadeChange = false;
13821375
for (MachineBasicBlock &MBB : MF) {
13831376
// Avoid unreachable blocks as they have degenerate dominance
13841377
if (!MDT->isReachableFromEntry(&MBB))
13851378
continue;
13861379

1387-
for (auto &MI : make_range(MBB.rbegin(), MBB.rend())) {
1380+
// For each instruction that defines a vector, compute what VL its
1381+
// downstream users demand.
1382+
for (const auto &MI : reverse(MBB)) {
1383+
if (!isCandidate(MI))
1384+
continue;
1385+
DemandedVLs.insert({&MI, computeDemandedVL(MI)});
1386+
}
1387+
1388+
// Then go through and see if we can reduce the VL of any instructions to
1389+
// only what's demanded.
1390+
for (auto &MI : MBB) {
13881391
if (!isCandidate(MI))
13891392
continue;
13901393
if (!tryReduceVL(MI))
13911394
continue;
13921395
MadeChange = true;
1393-
PushOperands(MI, /*IgnoreSameBlock*/ true);
13941396
}
1395-
}
13961397

1397-
while (!Worklist.empty()) {
1398-
assert(MadeChange);
1399-
MachineInstr &MI = *Worklist.pop_back_val();
1400-
assert(isCandidate(MI));
1401-
if (!tryReduceVL(MI))
1402-
continue;
1403-
PushOperands(MI, /*IgnoreSameBlock*/ false);
1398+
DemandedVLs.clear();
14041399
}
14051400

14061401
return MadeChange;

0 commit comments

Comments
 (0)