Skip to content

Commit ffe4181

Browse files
authored
[Support] Add KnownBits::abds signed absolute difference and rename absdiff -> abdu (#84897)
When I created KnownBits::absdiff, I totally missed that we already have ISD::ABDS/ABDU nodes, and we use this term in other places/targets as well. I've added the KnownBits::abds implementation and renamed KnownBits::absdiff to KnownBits::abdu. Followup to #84791
1 parent 4e3310a commit ffe4181

File tree

4 files changed

+65
-11
lines changed

4 files changed

+65
-11
lines changed

llvm/include/llvm/Support/KnownBits.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -390,8 +390,11 @@ struct KnownBits {
390390
/// Compute known bits for smin(LHS, RHS).
391391
static KnownBits smin(const KnownBits &LHS, const KnownBits &RHS);
392392

393-
/// Compute known bits for absdiff(LHS, RHS).
394-
static KnownBits absdiff(const KnownBits &LHS, const KnownBits &RHS);
393+
/// Compute known bits for abdu(LHS, RHS).
394+
static KnownBits abdu(const KnownBits &LHS, const KnownBits &RHS);
395+
396+
/// Compute known bits for abds(LHS, RHS).
397+
static KnownBits abds(const KnownBits &LHS, const KnownBits &RHS);
395398

396399
/// Compute known bits for shl(LHS, RHS).
397400
/// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.

llvm/lib/Support/KnownBits.cpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,8 @@ KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) {
231231
return Flip(umax(Flip(LHS), Flip(RHS)));
232232
}
233233

234-
KnownBits KnownBits::absdiff(const KnownBits &LHS, const KnownBits &RHS) {
235-
// absdiff(LHS,RHS) = sub(umax(LHS,RHS), umin(LHS,RHS)).
234+
KnownBits KnownBits::abdu(const KnownBits &LHS, const KnownBits &RHS) {
235+
// abdu(LHS,RHS) = sub(umax(LHS,RHS), umin(LHS,RHS)).
236236
KnownBits UMaxValue = umax(LHS, RHS);
237237
KnownBits UMinValue = umin(LHS, RHS);
238238
KnownBits MinMaxDiff = computeForAddSub(/*Add=*/false, /*NSW=*/false,
@@ -250,6 +250,25 @@ KnownBits KnownBits::absdiff(const KnownBits &LHS, const KnownBits &RHS) {
250250
return KnownAbsDiff;
251251
}
252252

253+
KnownBits KnownBits::abds(const KnownBits &LHS, const KnownBits &RHS) {
254+
// abds(LHS,RHS) = sub(smax(LHS,RHS), smin(LHS,RHS)).
255+
KnownBits SMaxValue = smax(LHS, RHS);
256+
KnownBits SMinValue = smin(LHS, RHS);
257+
KnownBits MinMaxDiff = computeForAddSub(/*Add=*/false, /*NSW=*/false,
258+
/*NUW=*/false, SMaxValue, SMinValue);
259+
260+
// find the common bits between sub(LHS,RHS) and sub(RHS,LHS).
261+
KnownBits Diff0 =
262+
computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS, RHS);
263+
KnownBits Diff1 =
264+
computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, RHS, LHS);
265+
KnownBits SubDiff = Diff0.intersectWith(Diff1);
266+
267+
KnownBits KnownAbsDiff = MinMaxDiff.unionWith(SubDiff);
268+
assert(!KnownAbsDiff.hasConflict() && "Bad Output");
269+
return KnownAbsDiff;
270+
}
271+
253272
static unsigned getMaxShiftAmount(const APInt &MaxValue, unsigned BitWidth) {
254273
if (isPowerOf2_32(BitWidth))
255274
return MaxValue.extractBitsAsZExtValue(Log2_32(BitWidth), 0);

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36753,7 +36753,7 @@ static void computeKnownBitsForPSADBW(SDValue LHS, SDValue RHS,
3675336753
APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
3675436754
Known = DAG.computeKnownBits(RHS, DemandedSrcElts, Depth + 1);
3675536755
Known2 = DAG.computeKnownBits(LHS, DemandedSrcElts, Depth + 1);
36756-
Known = KnownBits::absdiff(Known, Known2).zext(16);
36756+
Known = KnownBits::abdu(Known, Known2).zext(16);
3675736757
// Known = (((D0 + D1) + (D2 + D3)) + ((D4 + D5) + (D6 + D7)))
3675836758
Known = KnownBits::computeForAddSub(/*Add=*/true, /*NSW=*/true, /*NUW=*/true,
3675936759
Known, Known);

llvm/unittests/Support/KnownBitsTest.cpp

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -294,18 +294,18 @@ TEST(KnownBitsTest, SignBitUnknown) {
294294
EXPECT_TRUE(Known.isSignUnknown());
295295
}
296296

297-
TEST(KnownBitsTest, AbsDiffSpecialCase) {
298-
// There are 2 implementation of absdiff - both are currently needed to cover
297+
TEST(KnownBitsTest, ABDUSpecialCase) {
298+
// There are 2 implementations of abdu - both are currently needed to cover
299299
// extra cases.
300300
KnownBits LHS, RHS, Res;
301301

302-
// absdiff(LHS,RHS) = sub(umax(LHS,RHS), umin(LHS,RHS)).
302+
// abdu(LHS,RHS) = sub(umax(LHS,RHS), umin(LHS,RHS)).
303303
// Actual: false (Inputs = 1011, 101?, Computed = 000?, Exact = 000?)
304304
LHS.One = APInt(4, 0b1011);
305305
RHS.One = APInt(4, 0b1010);
306306
LHS.Zero = APInt(4, 0b0100);
307307
RHS.Zero = APInt(4, 0b0100);
308-
Res = KnownBits::absdiff(LHS, RHS);
308+
Res = KnownBits::abdu(LHS, RHS);
309309
EXPECT_EQ(0b0000ul, Res.One.getZExtValue());
310310
EXPECT_EQ(0b1110ul, Res.Zero.getZExtValue());
311311

@@ -315,11 +315,37 @@ TEST(KnownBitsTest, AbsDiffSpecialCase) {
315315
RHS.One = APInt(4, 0b1000);
316316
LHS.Zero = APInt(4, 0b0000);
317317
RHS.Zero = APInt(4, 0b0111);
318-
Res = KnownBits::absdiff(LHS, RHS);
318+
Res = KnownBits::abdu(LHS, RHS);
319319
EXPECT_EQ(0b0001ul, Res.One.getZExtValue());
320320
EXPECT_EQ(0b0000ul, Res.Zero.getZExtValue());
321321
}
322322

323+
TEST(KnownBitsTest, ABDSSpecialCase) {
324+
// There are 2 implementations of abds - both are currently needed to cover
325+
// extra cases.
326+
KnownBits LHS, RHS, Res;
327+
328+
// abds(LHS,RHS) = sub(smax(LHS,RHS), smin(LHS,RHS)).
329+
// Actual: false (Inputs = 1011, 10??, Computed = ????, Exact = 00??)
330+
LHS.One = APInt(4, 0b1011);
331+
RHS.One = APInt(4, 0b1000);
332+
LHS.Zero = APInt(4, 0b0100);
333+
RHS.Zero = APInt(4, 0b0100);
334+
Res = KnownBits::abds(LHS, RHS);
335+
EXPECT_EQ(0, Res.One.getSExtValue());
336+
EXPECT_EQ(-4, Res.Zero.getSExtValue());
337+
338+
// find the common bits between sub(LHS,RHS) and sub(RHS,LHS).
339+
// Actual: false (Inputs = ???1, 1000, Computed = ???1, Exact = 0??1)
340+
LHS.One = APInt(4, 0b0001);
341+
RHS.One = APInt(4, 0b1000);
342+
LHS.Zero = APInt(4, 0b0000);
343+
RHS.Zero = APInt(4, 0b0111);
344+
Res = KnownBits::abds(LHS, RHS);
345+
EXPECT_EQ(1, Res.One.getSExtValue());
346+
EXPECT_EQ(0, Res.Zero.getSExtValue());
347+
}
348+
323349
TEST(KnownBitsTest, BinaryExhaustive) {
324350
testBinaryOpExhaustive(
325351
[](const KnownBits &Known1, const KnownBits &Known2) {
@@ -359,10 +385,16 @@ TEST(KnownBitsTest, BinaryExhaustive) {
359385
[](const APInt &N1, const APInt &N2) { return APIntOps::smin(N1, N2); });
360386
testBinaryOpExhaustive(
361387
[](const KnownBits &Known1, const KnownBits &Known2) {
362-
return KnownBits::absdiff(Known1, Known2);
388+
return KnownBits::abdu(Known1, Known2);
363389
},
364390
[](const APInt &N1, const APInt &N2) { return APIntOps::abdu(N1, N2); },
365391
checkCorrectnessOnlyBinary);
392+
testBinaryOpExhaustive(
393+
[](const KnownBits &Known1, const KnownBits &Known2) {
394+
return KnownBits::abds(Known1, Known2);
395+
},
396+
[](const APInt &N1, const APInt &N2) { return APIntOps::abds(N1, N2); },
397+
checkCorrectnessOnlyBinary);
366398
testBinaryOpExhaustive(
367399
[](const KnownBits &Known1, const KnownBits &Known2) {
368400
return KnownBits::udiv(Known1, Known2);

0 commit comments

Comments
 (0)