Skip to content

Commit 4ae0a13

Browse files
committed
[InstCombine] add assert in SimplifyDemandedVectorElts and improve readability; NFC
1 parent d4e006e commit 4ae0a13

File tree

1 file changed

+22
-19
lines changed

1 file changed

+22
-19
lines changed

llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,28 +1247,31 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
12471247
break;
12481248
}
12491249
case Instruction::ShuffleVector: {
1250-
ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(I);
1251-
unsigned LHSVWidth =
1252-
Shuffle->getOperand(0)->getType()->getVectorNumElements();
1253-
APInt LeftDemanded(LHSVWidth, 0), RightDemanded(LHSVWidth, 0);
1250+
auto *Shuffle = cast<ShuffleVectorInst>(I);
1251+
assert(Shuffle->getOperand(0)->getType() ==
1252+
Shuffle->getOperand(1)->getType() &&
1253+
"Expected shuffle operands to have same type");
1254+
unsigned OpWidth =
1255+
Shuffle->getOperand(0)->getType()->getVectorNumElements();
1256+
APInt LeftDemanded(OpWidth, 0), RightDemanded(OpWidth, 0);
12541257
for (unsigned i = 0; i < VWidth; i++) {
12551258
if (DemandedElts[i]) {
12561259
unsigned MaskVal = Shuffle->getMaskValue(i);
12571260
if (MaskVal != -1u) {
1258-
assert(MaskVal < LHSVWidth * 2 &&
1261+
assert(MaskVal < OpWidth * 2 &&
12591262
"shufflevector mask index out of range!");
1260-
if (MaskVal < LHSVWidth)
1263+
if (MaskVal < OpWidth)
12611264
LeftDemanded.setBit(MaskVal);
12621265
else
1263-
RightDemanded.setBit(MaskVal - LHSVWidth);
1266+
RightDemanded.setBit(MaskVal - OpWidth);
12641267
}
12651268
}
12661269
}
12671270

1268-
APInt LHSUndefElts(LHSVWidth, 0);
1271+
APInt LHSUndefElts(OpWidth, 0);
12691272
simplifyAndSetOp(I, 0, LeftDemanded, LHSUndefElts);
12701273

1271-
APInt RHSUndefElts(LHSVWidth, 0);
1274+
APInt RHSUndefElts(OpWidth, 0);
12721275
simplifyAndSetOp(I, 1, RightDemanded, RHSUndefElts);
12731276

12741277
bool NewUndefElts = false;
@@ -1283,23 +1286,23 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
12831286
} else if (!DemandedElts[i]) {
12841287
NewUndefElts = true;
12851288
UndefElts.setBit(i);
1286-
} else if (MaskVal < LHSVWidth) {
1289+
} else if (MaskVal < OpWidth) {
12871290
if (LHSUndefElts[MaskVal]) {
12881291
NewUndefElts = true;
12891292
UndefElts.setBit(i);
12901293
} else {
1291-
LHSIdx = LHSIdx == -1u ? i : LHSVWidth;
1292-
LHSValIdx = LHSValIdx == -1u ? MaskVal : LHSVWidth;
1294+
LHSIdx = LHSIdx == -1u ? i : OpWidth;
1295+
LHSValIdx = LHSValIdx == -1u ? MaskVal : OpWidth;
12931296
LHSUniform = LHSUniform && (MaskVal == i);
12941297
}
12951298
} else {
1296-
if (RHSUndefElts[MaskVal - LHSVWidth]) {
1299+
if (RHSUndefElts[MaskVal - OpWidth]) {
12971300
NewUndefElts = true;
12981301
UndefElts.setBit(i);
12991302
} else {
1300-
RHSIdx = RHSIdx == -1u ? i : LHSVWidth;
1301-
RHSValIdx = RHSValIdx == -1u ? MaskVal - LHSVWidth : LHSVWidth;
1302-
RHSUniform = RHSUniform && (MaskVal - LHSVWidth == i);
1303+
RHSIdx = RHSIdx == -1u ? i : OpWidth;
1304+
RHSValIdx = RHSValIdx == -1u ? MaskVal - OpWidth : OpWidth;
1305+
RHSUniform = RHSUniform && (MaskVal - OpWidth == i);
13031306
}
13041307
}
13051308
}
@@ -1308,20 +1311,20 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
13081311
// this constant vector to single insertelement instruction.
13091312
// shufflevector V, C, <v1, v2, .., ci, .., vm> ->
13101313
// insertelement V, C[ci], ci-n
1311-
if (LHSVWidth == Shuffle->getType()->getNumElements()) {
1314+
if (OpWidth == Shuffle->getType()->getNumElements()) {
13121315
Value *Op = nullptr;
13131316
Constant *Value = nullptr;
13141317
unsigned Idx = -1u;
13151318

13161319
// Find constant vector with the single element in shuffle (LHS or RHS).
1317-
if (LHSIdx < LHSVWidth && RHSUniform) {
1320+
if (LHSIdx < OpWidth && RHSUniform) {
13181321
if (auto *CV = dyn_cast<ConstantVector>(Shuffle->getOperand(0))) {
13191322
Op = Shuffle->getOperand(1);
13201323
Value = CV->getOperand(LHSValIdx);
13211324
Idx = LHSIdx;
13221325
}
13231326
}
1324-
if (RHSIdx < LHSVWidth && LHSUniform) {
1327+
if (RHSIdx < OpWidth && LHSUniform) {
13251328
if (auto *CV = dyn_cast<ConstantVector>(Shuffle->getOperand(1))) {
13261329
Op = Shuffle->getOperand(0);
13271330
Value = CV->getOperand(RHSValIdx);

0 commit comments

Comments
 (0)