@@ -268,43 +268,60 @@ struct CastedValue {
268
268
unsigned ZExtBits = 0 ;
269
269
unsigned SExtBits = 0 ;
270
270
unsigned TruncBits = 0 ;
271
+ // / Whether trunc(V) is non-negative.
272
+ bool IsNonNegative = false ;
271
273
272
274
explicit CastedValue (const Value *V) : V(V) {}
273
275
explicit CastedValue (const Value *V, unsigned ZExtBits, unsigned SExtBits,
274
- unsigned TruncBits)
275
- : V(V), ZExtBits(ZExtBits), SExtBits(SExtBits), TruncBits(TruncBits) {}
276
+ unsigned TruncBits, bool IsNonNegative)
277
+ : V(V), ZExtBits(ZExtBits), SExtBits(SExtBits), TruncBits(TruncBits),
278
+ IsNonNegative(IsNonNegative) {}
276
279
277
280
unsigned getBitWidth () const {
278
281
return V->getType ()->getPrimitiveSizeInBits () - TruncBits + ZExtBits +
279
282
SExtBits;
280
283
}
281
284
282
- CastedValue withValue (const Value *NewV) const {
283
- return CastedValue (NewV, ZExtBits, SExtBits, TruncBits);
285
+ CastedValue withValue (const Value *NewV, bool PreserveNonNeg) const {
286
+ return CastedValue (NewV, ZExtBits, SExtBits, TruncBits,
287
+ IsNonNegative && PreserveNonNeg);
284
288
}
285
289
286
290
// / Replace V with zext(NewV)
287
- CastedValue withZExtOfValue (const Value *NewV) const {
291
+ CastedValue withZExtOfValue (const Value *NewV, bool ZExtNonNegative ) const {
288
292
unsigned ExtendBy = V->getType ()->getPrimitiveSizeInBits () -
289
293
NewV->getType ()->getPrimitiveSizeInBits ();
290
294
if (ExtendBy <= TruncBits)
291
- return CastedValue (NewV, ZExtBits, SExtBits, TruncBits - ExtendBy);
295
+ // zext<nneg>(trunc(zext(NewV))) == zext<nneg>(trunc(NewV))
296
+ // The nneg can be preserved on the outer zext here.
297
+ return CastedValue (NewV, ZExtBits, SExtBits, TruncBits - ExtendBy,
298
+ IsNonNegative);
292
299
293
300
// zext(sext(zext(NewV))) == zext(zext(zext(NewV)))
294
301
ExtendBy -= TruncBits;
295
- return CastedValue (NewV, ZExtBits + SExtBits + ExtendBy, 0 , 0 );
302
+ // zext<nneg>(zext(NewV)) == zext(NewV)
303
+ // zext(zext<nneg>(NewV)) == zext<nneg>(NewV)
304
+ // The nneg can be preserved from the inner zext here but must be dropped
305
+ // from the outer.
306
+ return CastedValue (NewV, ZExtBits + SExtBits + ExtendBy, 0 , 0 ,
307
+ ZExtNonNegative);
296
308
}
297
309
298
310
// / Replace V with sext(NewV)
299
311
CastedValue withSExtOfValue (const Value *NewV) const {
300
312
unsigned ExtendBy = V->getType ()->getPrimitiveSizeInBits () -
301
313
NewV->getType ()->getPrimitiveSizeInBits ();
302
314
if (ExtendBy <= TruncBits)
303
- return CastedValue (NewV, ZExtBits, SExtBits, TruncBits - ExtendBy);
315
+ // zext<nneg>(trunc(sext(NewV))) == zext<nneg>(trunc(NewV))
316
+ // The nneg can be preserved on the outer zext here
317
+ return CastedValue (NewV, ZExtBits, SExtBits, TruncBits - ExtendBy,
318
+ IsNonNegative);
304
319
305
320
// zext(sext(sext(NewV)))
306
321
ExtendBy -= TruncBits;
307
- return CastedValue (NewV, ZExtBits, SExtBits + ExtendBy, 0 );
322
+ // zext<nneg>(sext(sext(NewV))) = zext<nneg>(sext(NewV))
323
+ // The nneg can be preserved on the outer zext here
324
+ return CastedValue (NewV, ZExtBits, SExtBits + ExtendBy, 0 , IsNonNegative);
308
325
}
309
326
310
327
APInt evaluateWith (APInt N) const {
@@ -333,8 +350,15 @@ struct CastedValue {
333
350
}
334
351
335
352
bool hasSameCastsAs (const CastedValue &Other) const {
336
- return ZExtBits == Other.ZExtBits && SExtBits == Other.SExtBits &&
337
- TruncBits == Other.TruncBits ;
353
+ if (ZExtBits == Other.ZExtBits && SExtBits == Other.SExtBits &&
354
+ TruncBits == Other.TruncBits )
355
+ return true ;
356
+ // If either CastedValue has a nneg zext then the sext/zext bits are
357
+ // interchangable for that value.
358
+ if (IsNonNegative || Other.IsNonNegative )
359
+ return (ZExtBits + SExtBits == Other.ZExtBits + Other.SExtBits &&
360
+ TruncBits == Other.TruncBits );
361
+ return false ;
338
362
}
339
363
};
340
364
@@ -410,21 +434,21 @@ static LinearExpression GetLinearExpression(
410
434
411
435
[[fallthrough]];
412
436
case Instruction::Add: {
413
- E = GetLinearExpression (Val.withValue (BOp->getOperand (0 )), DL,
437
+ E = GetLinearExpression (Val.withValue (BOp->getOperand (0 ), false ), DL,
414
438
Depth + 1 , AC, DT);
415
439
E.Offset += RHS;
416
440
E.IsNSW &= NSW;
417
441
break ;
418
442
}
419
443
case Instruction::Sub: {
420
- E = GetLinearExpression (Val.withValue (BOp->getOperand (0 )), DL,
444
+ E = GetLinearExpression (Val.withValue (BOp->getOperand (0 ), false ), DL,
421
445
Depth + 1 , AC, DT);
422
446
E.Offset -= RHS;
423
447
E.IsNSW &= NSW;
424
448
break ;
425
449
}
426
450
case Instruction::Mul:
427
- E = GetLinearExpression (Val.withValue (BOp->getOperand (0 )), DL,
451
+ E = GetLinearExpression (Val.withValue (BOp->getOperand (0 ), false ), DL,
428
452
Depth + 1 , AC, DT)
429
453
.mul (RHS, NSW);
430
454
break ;
@@ -437,7 +461,7 @@ static LinearExpression GetLinearExpression(
437
461
if (RHS.getLimitedValue () > Val.getBitWidth ())
438
462
return Val;
439
463
440
- E = GetLinearExpression (Val.withValue (BOp->getOperand (0 )), DL,
464
+ E = GetLinearExpression (Val.withValue (BOp->getOperand (0 ), NSW ), DL,
441
465
Depth + 1 , AC, DT);
442
466
E.Offset <<= RHS.getLimitedValue ();
443
467
E.Scale <<= RHS.getLimitedValue ();
@@ -448,10 +472,10 @@ static LinearExpression GetLinearExpression(
448
472
}
449
473
}
450
474
451
- if (isa <ZExtInst>(Val.V ))
475
+ if (const auto *ZExt = dyn_cast <ZExtInst>(Val.V ))
452
476
return GetLinearExpression (
453
- Val.withZExtOfValue (cast<CastInst>(Val. V )-> getOperand ( 0 )) ,
454
- DL, Depth + 1 , AC, DT);
477
+ Val.withZExtOfValue (ZExt-> getOperand ( 0 ), ZExt-> hasNonNeg ()), DL ,
478
+ Depth + 1 , AC, DT);
455
479
456
480
if (isa<SExtInst>(Val.V ))
457
481
return GetLinearExpression (
@@ -673,7 +697,7 @@ BasicAAResult::DecomposeGEPExpression(const Value *V, const DataLayout &DL,
673
697
unsigned SExtBits = IndexSize > Width ? IndexSize - Width : 0 ;
674
698
unsigned TruncBits = IndexSize < Width ? Width - IndexSize : 0 ;
675
699
LinearExpression LE = GetLinearExpression (
676
- CastedValue (Index, 0 , SExtBits, TruncBits), DL, 0 , AC, DT);
700
+ CastedValue (Index, 0 , SExtBits, TruncBits, false ), DL, 0 , AC, DT);
677
701
678
702
// Scale by the type size.
679
703
unsigned TypeSize = AllocTypeSize.getFixedValue ();
0 commit comments