Skip to content

Commit e54c162

Browse files
committed
Undo renames + put back std::optional
1 parent 519c90b commit e54c162

File tree

3 files changed

+49
-50
lines changed

3 files changed

+49
-50
lines changed

llvm/lib/Target/RISCV/RISCVInstrInfo.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4235,8 +4235,6 @@ unsigned RISCV::getDestLog2EEW(const MCInstrDesc &Desc, unsigned Log2SEW) {
42354235

42364236
/// Given two VL operands, do we know that LHS <= RHS?
42374237
bool RISCV::isVLKnownLE(const MachineOperand &LHS, const MachineOperand &RHS) {
4238-
if (LHS.isImm() && LHS.getImm() == 0)
4239-
return true;
42404238
if (LHS.isReg() && RHS.isReg() && LHS.getReg().isVirtual() &&
42414239
LHS.getReg() == RHS.getReg())
42424240
return true;

llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp

Lines changed: 44 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,10 @@ class RISCVVLOptimizer : public MachineFunctionPass {
5151
StringRef getPassName() const override { return PASS_NAME; }
5252

5353
private:
54-
MachineOperand getMinimumVLForUser(MachineOperand &UserOp);
55-
/// Computes the minimum demanded VL of \p MI, i.e. the minimum VL that's used
56-
/// by its users downstream.
57-
/// Returns 0 if MI has no users.
58-
MachineOperand computeDemandedVL(const MachineInstr &MI);
54+
std::optional<MachineOperand> getMinimumVLForUser(MachineOperand &UserOp);
55+
/// Returns the largest common VL MachineOperand that may be used to optimize
56+
/// MI. Returns std::nullopt if it failed to find a suitable VL.
57+
std::optional<MachineOperand> checkUsers(MachineInstr &MI);
5958
bool tryReduceVL(MachineInstr &MI);
6059
bool isCandidate(const MachineInstr &MI) const;
6160

@@ -1179,14 +1178,15 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
11791178
return true;
11801179
}
11811180

1182-
MachineOperand RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) {
1181+
std::optional<MachineOperand>
1182+
RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) {
11831183
const MachineInstr &UserMI = *UserOp.getParent();
11841184
const MCInstrDesc &Desc = UserMI.getDesc();
11851185

11861186
if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) {
11871187
LLVM_DEBUG(dbgs() << " Abort due to lack of VL, assume that"
11881188
" use VLMAX\n");
1189-
return MachineOperand::CreateImm(RISCV::VLMaxSentinel);
1189+
return std::nullopt;
11901190
}
11911191

11921192
// Instructions like reductions may use a vector register as a scalar
@@ -1223,40 +1223,39 @@ MachineOperand RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) {
12231223
return VLOp;
12241224
}
12251225

1226-
MachineOperand RISCVVLOptimizer::computeDemandedVL(const MachineInstr &MI) {
1227-
const MachineOperand &VLMAX = MachineOperand::CreateImm(RISCV::VLMaxSentinel);
1228-
MachineOperand DemandedVL = MachineOperand::CreateImm(0);
1229-
1226+
std::optional<MachineOperand> RISCVVLOptimizer::checkUsers(MachineInstr &MI) {
1227+
std::optional<MachineOperand> CommonVL;
12301228
for (auto &UserOp : MRI->use_operands(MI.getOperand(0).getReg())) {
12311229
const MachineInstr &UserMI = *UserOp.getParent();
12321230
LLVM_DEBUG(dbgs() << " Checking user: " << UserMI << "\n");
12331231
if (mayReadPastVL(UserMI)) {
12341232
LLVM_DEBUG(dbgs() << " Abort because used by unsafe instruction\n");
1235-
return VLMAX;
1233+
return std::nullopt;
12361234
}
12371235

12381236
// If used as a passthru, elements past VL will be read.
12391237
if (UserOp.isTied()) {
12401238
LLVM_DEBUG(dbgs() << " Abort because user used as tied operand\n");
1241-
return VLMAX;
1239+
return std::nullopt;
12421240
}
12431241

1244-
const MachineOperand &VLOp = getMinimumVLForUser(UserOp);
1245-
1246-
// The minimum demanded VL is the largest VL read amongst all the users. If
1247-
// we cannot determine this statically, then we cannot optimize the VL.
1248-
if (RISCV::isVLKnownLE(DemandedVL, VLOp)) {
1249-
DemandedVL = VLOp;
1250-
LLVM_DEBUG(dbgs() << " Demanded VL is: " << VLOp << "\n");
1251-
} else if (!RISCV::isVLKnownLE(VLOp, DemandedVL)) {
1252-
LLVM_DEBUG(
1253-
dbgs() << " Abort because cannot determine the demanded VL\n");
1254-
return VLMAX;
1242+
auto VLOp = getMinimumVLForUser(UserOp);
1243+
if (!VLOp)
1244+
return std::nullopt;
1245+
1246+
// Use the largest VL among all the users. If we cannot determine this
1247+
// statically, then we cannot optimize the VL.
1248+
if (!CommonVL || RISCV::isVLKnownLE(*CommonVL, *VLOp)) {
1249+
CommonVL = *VLOp;
1250+
LLVM_DEBUG(dbgs() << " User VL is: " << VLOp << "\n");
1251+
} else if (!RISCV::isVLKnownLE(*VLOp, *CommonVL)) {
1252+
LLVM_DEBUG(dbgs() << " Abort because cannot determine a common VL\n");
1253+
return std::nullopt;
12551254
}
12561255

12571256
if (!RISCVII::hasSEWOp(UserMI.getDesc().TSFlags)) {
12581257
LLVM_DEBUG(dbgs() << " Abort due to lack of SEW operand\n");
1259-
return VLMAX;
1258+
return std::nullopt;
12601259
}
12611260

12621261
std::optional<OperandInfo> ConsumerInfo = getOperandInfo(UserOp, MRI);
@@ -1266,7 +1265,7 @@ MachineOperand RISCVVLOptimizer::computeDemandedVL(const MachineInstr &MI) {
12661265
LLVM_DEBUG(dbgs() << " Abort due to unknown operand information.\n");
12671266
LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n");
12681267
LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n");
1269-
return VLMAX;
1268+
return std::nullopt;
12701269
}
12711270

12721271
// If the operand is used as a scalar operand, then the EEW must be
@@ -1281,52 +1280,53 @@ MachineOperand RISCVVLOptimizer::computeDemandedVL(const MachineInstr &MI) {
12811280
<< " Abort due to incompatible information for EMUL or EEW.\n");
12821281
LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n");
12831282
LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n");
1284-
return VLMAX;
1283+
return std::nullopt;
12851284
}
12861285
}
12871286

1288-
return DemandedVL;
1287+
return CommonVL;
12891288
}
12901289

12911290
bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI) {
12921291
LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n");
12931292

1294-
const MachineOperand &DemandedVL = DemandedVLs.at(&MI);
1293+
if (!DemandedVLs.contains(&MI))
1294+
return false;
1295+
auto CommonVL = std::make_optional(DemandedVLs.at(&MI));
12951296

1296-
assert((DemandedVL.isImm() || DemandedVL.getReg().isVirtual()) &&
1297+
assert((CommonVL->isImm() || CommonVL->getReg().isVirtual()) &&
12971298
"Expected VL to be an Imm or virtual Reg");
12981299

12991300
unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc());
13001301
MachineOperand &VLOp = MI.getOperand(VLOpNum);
13011302

1302-
if (!RISCV::isVLKnownLE(DemandedVL, VLOp)) {
1303-
LLVM_DEBUG(dbgs() << " Abort due to DemandedVL not <= VLOp.\n");
1303+
if (!RISCV::isVLKnownLE(*CommonVL, VLOp)) {
1304+
LLVM_DEBUG(dbgs() << " Abort due to CommonVL not <= VLOp.\n");
13041305
return false;
13051306
}
13061307

1307-
if (DemandedVL.isIdenticalTo(VLOp)) {
1308+
if (CommonVL->isIdenticalTo(VLOp)) {
13081309
LLVM_DEBUG(
1309-
dbgs()
1310-
<< " Abort due to DemandedVL == VLOp, no point in reducing.\n");
1310+
dbgs() << " Abort due to CommonVL == VLOp, no point in reducing.\n");
13111311
return false;
13121312
}
13131313

1314-
if (DemandedVL.isImm()) {
1314+
if (CommonVL->isImm()) {
13151315
LLVM_DEBUG(dbgs() << " Reduce VL from " << VLOp << " to "
1316-
<< DemandedVL.getImm() << " for " << MI << "\n");
1317-
VLOp.ChangeToImmediate(DemandedVL.getImm());
1316+
<< CommonVL->getImm() << " for " << MI << "\n");
1317+
VLOp.ChangeToImmediate(CommonVL->getImm());
13181318
return true;
13191319
}
1320-
const MachineInstr *VLMI = MRI->getVRegDef(DemandedVL.getReg());
1320+
const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->getReg());
13211321
if (!MDT->dominates(VLMI, &MI))
13221322
return false;
13231323
LLVM_DEBUG(
13241324
dbgs() << " Reduce VL from " << VLOp << " to "
1325-
<< printReg(DemandedVL.getReg(), MRI->getTargetRegisterInfo())
1325+
<< printReg(CommonVL->getReg(), MRI->getTargetRegisterInfo())
13261326
<< " for " << MI << "\n");
13271327

13281328
// All our checks passed. We can reduce VL.
1329-
VLOp.ChangeToRegister(DemandedVL.getReg(), false);
1329+
VLOp.ChangeToRegister(CommonVL->getReg(), false);
13301330
return true;
13311331
}
13321332

@@ -1351,10 +1351,11 @@ bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) {
13511351

13521352
// For each instruction that defines a vector, compute what VL its
13531353
// downstream users demand.
1354-
for (const auto &MI : reverse(MBB)) {
1354+
for (MachineInstr &MI : reverse(MBB)) {
13551355
if (!isCandidate(MI))
13561356
continue;
1357-
DemandedVLs.insert({&MI, computeDemandedVL(MI)});
1357+
if (auto DemandedVL = checkUsers(MI))
1358+
DemandedVLs.insert({&MI, *DemandedVL});
13581359
}
13591360

13601361
// Then go through and see if we can reduce the VL of any instructions to

llvm/test/CodeGen/RISCV/rvv/vlopt-same-vl.ll

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,22 @@
55

66
; GitHub Issue #123862 provided a case where the riscv-vl-optimizer pass was
77
; very slow. It was found that that case benefited greatly from aborting due
8-
; to DemandedVL == VLOp. Adding the case provided in the issue would show up
8+
; to CommonVL == VLOp. Adding the case provided in the issue would show up
99
; as a long running test instead of a test failure. We would likley have a hard
1010
; time figuring if that case had a regression. So instead, we check this output
1111
; which was responsible for speeding it up.
1212

1313
define <vscale x 4 x i32> @same_vl_imm(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b) {
14-
; CHECK: Demanded VL is: 4
15-
; CHECK: Abort due to DemandedVL == VLOp, no point in reducing.
14+
; CHECK: User VL is: 4
15+
; CHECK: Abort due to CommonVL == VLOp, no point in reducing.
1616
%v = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, i64 4)
1717
%w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a, i64 4)
1818
ret <vscale x 4 x i32> %w
1919
}
2020

2121
define <vscale x 4 x i32> @same_vl_reg(<vscale x 4 x i32> %passthru, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, i64 %vl) {
22-
; CHECK: Demanded VL is: %3:gprnox0
23-
; CHECK: Abort due to DemandedVL == VLOp, no point in reducing.
22+
; CHECK: User VL is: %3:gprnox0
23+
; CHECK: Abort due to CommonVL == VLOp, no point in reducing.
2424
%v = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %a, <vscale x 4 x i32> %b, i64 %vl)
2525
%w = call <vscale x 4 x i32> @llvm.riscv.vadd.nxv4i32.nxv4i32(<vscale x 4 x i32> poison, <vscale x 4 x i32> %v, <vscale x 4 x i32> %a, i64 %vl)
2626
ret <vscale x 4 x i32> %w

0 commit comments

Comments
 (0)