Skip to content

Commit 306b9c7

Browse files
authored
[SCEV] Handle more add/addrec mixes in computeConstantDifference() (#101999)
computeConstantDifference() can currently look through addrecs with identical steps, and then through adds with identical operands (apart from constants). However, it fails to handle minor variations, such as two nested add recs, or an outer add with an inner addrec (rather than the other way around). This patch supports these cases by adding a loop over the simplifications, limited to a small number of iterations. The motivation is the same as in #101339, to make computeConstantDifference() powerful enough to replace existing uses of `dyn_cast<SCEVConstant>(getMinusSCEV())` with it. Though as the IR test diff shows, other callers may also benefit.
1 parent 334a366 commit 306b9c7

File tree

3 files changed

+105
-87
lines changed

3 files changed

+105
-87
lines changed

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 80 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -11951,62 +11951,94 @@ ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) {
1195111951
// We avoid subtracting expressions here because this function is usually
1195211952
// fairly deep in the call stack (i.e. is called many times).
1195311953

11954-
// X - X = 0.
1195511954
unsigned BW = getTypeSizeInBits(More->getType());
11956-
if (More == Less)
11957-
return APInt(BW, 0);
11958-
11959-
if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
11960-
const auto *LAR = cast<SCEVAddRecExpr>(Less);
11961-
const auto *MAR = cast<SCEVAddRecExpr>(More);
11962-
11963-
if (LAR->getLoop() != MAR->getLoop())
11964-
return std::nullopt;
11965-
11966-
// We look at affine expressions only; not for correctness but to keep
11967-
// getStepRecurrence cheap.
11968-
if (!LAR->isAffine() || !MAR->isAffine())
11969-
return std::nullopt;
11955+
APInt Diff(BW, 0);
11956+
// Try various simplifications to reduce the difference to a constant. Limit
11957+
// the number of allowed simplifications to keep compile-time low.
11958+
for (unsigned I = 0; I < 4; ++I) {
11959+
if (More == Less)
11960+
return Diff;
11961+
11962+
// Reduce addrecs with identical steps to their start value.
11963+
if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
11964+
const auto *LAR = cast<SCEVAddRecExpr>(Less);
11965+
const auto *MAR = cast<SCEVAddRecExpr>(More);
11966+
11967+
if (LAR->getLoop() != MAR->getLoop())
11968+
return std::nullopt;
11969+
11970+
// We look at affine expressions only; not for correctness but to keep
11971+
// getStepRecurrence cheap.
11972+
if (!LAR->isAffine() || !MAR->isAffine())
11973+
return std::nullopt;
11974+
11975+
if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
11976+
return std::nullopt;
11977+
11978+
Less = LAR->getStart();
11979+
More = MAR->getStart();
11980+
continue;
11981+
}
1197011982

11971-
if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
11983+
// Try to cancel out common factors in two add expressions.
11984+
SmallDenseMap<const SCEV *, int, 8> Multiplicity;
11985+
auto Add = [&](const SCEV *S, int Mul) {
11986+
if (auto *C = dyn_cast<SCEVConstant>(S)) {
11987+
if (Mul == 1) {
11988+
Diff += C->getAPInt();
11989+
} else {
11990+
assert(Mul == -1);
11991+
Diff -= C->getAPInt();
11992+
}
11993+
} else
11994+
Multiplicity[S] += Mul;
11995+
};
11996+
auto Decompose = [&](const SCEV *S, int Mul) {
11997+
if (isa<SCEVAddExpr>(S)) {
11998+
for (const SCEV *Op : S->operands())
11999+
Add(Op, Mul);
12000+
} else
12001+
Add(S, Mul);
12002+
};
12003+
Decompose(More, 1);
12004+
Decompose(Less, -1);
12005+
12006+
// Check whether all the non-constants cancel out, or reduce to new
12007+
// More/Less values.
12008+
const SCEV *NewMore = nullptr, *NewLess = nullptr;
12009+
for (const auto [S, Mul] : Multiplicity) {
12010+
if (Mul == 0)
12011+
continue;
12012+
if (Mul == 1) {
12013+
if (NewMore)
12014+
return std::nullopt;
12015+
NewMore = S;
12016+
} else if (Mul == -1) {
12017+
if (NewLess)
12018+
return std::nullopt;
12019+
NewLess = S;
12020+
} else
12021+
return std::nullopt;
12022+
}
12023+
12024+
// Values stayed the same, no point in trying further.
12025+
if (NewMore == More || NewLess == Less)
1197212026
return std::nullopt;
1197312027

11974-
Less = LAR->getStart();
11975-
More = MAR->getStart();
12028+
More = NewMore;
12029+
Less = NewLess;
1197612030

11977-
// fall through
11978-
}
12031+
// Reduced to constant.
12032+
if (!More && !Less)
12033+
return Diff;
1197912034

11980-
// Try to cancel out common factors in two add expressions.
11981-
SmallDenseMap<const SCEV *, int, 8> Multiplicity;
11982-
APInt Diff(BW, 0);
11983-
auto Add = [&](const SCEV *S, int Mul) {
11984-
if (auto *C = dyn_cast<SCEVConstant>(S)) {
11985-
if (Mul == 1) {
11986-
Diff += C->getAPInt();
11987-
} else {
11988-
assert(Mul == -1);
11989-
Diff -= C->getAPInt();
11990-
}
11991-
} else
11992-
Multiplicity[S] += Mul;
11993-
};
11994-
auto Decompose = [&](const SCEV *S, int Mul) {
11995-
if (isa<SCEVAddExpr>(S)) {
11996-
for (const SCEV *Op : S->operands())
11997-
Add(Op, Mul);
11998-
} else
11999-
Add(S, Mul);
12000-
};
12001-
Decompose(More, 1);
12002-
Decompose(Less, -1);
12003-
12004-
// Check whether all the non-constants cancel out.
12005-
for (const auto &[_, Mul] : Multiplicity)
12006-
if (Mul != 0)
12035+
// Left with variable on only one side, bail out.
12036+
if (!More || !Less)
1200712037
return std::nullopt;
12038+
}
1200812039

12009-
return Diff;
12040+
// Did not reduce to constant.
12041+
return std::nullopt;
1201012042
}
1201112043

1201212044
bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(

llvm/test/Transforms/LoopVectorize/skeleton-lcssa-crash.ll

Lines changed: 23 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@ define i16 @test(ptr %arg, i64 %N) {
2929
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP0]], 2
3030
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_MEMCHECK:%.*]]
3131
; CHECK: vector.memcheck:
32-
; CHECK-NEXT: [[UGLYGEP:%.*]] = getelementptr i8, ptr [[L_2_LCSSA]], i64 2
33-
; CHECK-NEXT: [[UGLYGEP5:%.*]] = getelementptr i8, ptr [[L_1_LCSSA]], i64 2
32+
; CHECK-NEXT: [[SCEVGEP:%.*]] = getelementptr i8, ptr [[L_2_LCSSA]], i64 2
33+
; CHECK-NEXT: [[SCEVGEP5:%.*]] = getelementptr i8, ptr [[L_1_LCSSA]], i64 2
3434
; CHECK-NEXT: [[TMP1:%.*]] = shl i64 [[N]], 1
3535
; CHECK-NEXT: [[TMP2:%.*]] = add i64 [[TMP1]], 4
36-
; CHECK-NEXT: [[UGLYGEP6:%.*]] = getelementptr i8, ptr [[L_1_LCSSA]], i64 [[TMP2]]
37-
; CHECK-NEXT: [[BOUND0:%.*]] = icmp ult ptr [[L_2_LCSSA]], [[UGLYGEP6]]
38-
; CHECK-NEXT: [[BOUND1:%.*]] = icmp ult ptr [[UGLYGEP5]], [[UGLYGEP]]
36+
; CHECK-NEXT: [[SCEVGEP6:%.*]] = getelementptr i8, ptr [[L_1_LCSSA]], i64 [[TMP2]]
37+
; CHECK-NEXT: [[BOUND0:%.*]] = icmp ult ptr [[L_2_LCSSA]], [[SCEVGEP6]]
38+
; CHECK-NEXT: [[BOUND1:%.*]] = icmp ult ptr [[SCEVGEP5]], [[SCEVGEP]]
3939
; CHECK-NEXT: [[FOUND_CONFLICT:%.*]] = and i1 [[BOUND0]], [[BOUND1]]
4040
; CHECK-NEXT: br i1 [[FOUND_CONFLICT]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
4141
; CHECK: vector.ph:
@@ -48,10 +48,10 @@ define i16 @test(ptr %arg, i64 %N) {
4848
; CHECK-NEXT: [[TMP4:%.*]] = add nuw nsw i64 [[TMP3]], 1
4949
; CHECK-NEXT: [[TMP5:%.*]] = getelementptr inbounds i16, ptr [[L_1]], i64 [[TMP4]]
5050
; CHECK-NEXT: [[TMP6:%.*]] = getelementptr inbounds i16, ptr [[TMP5]], i32 0
51-
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <2 x i16>, ptr [[TMP6]], align 2, !alias.scope !0
51+
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <2 x i16>, ptr [[TMP6]], align 2, !alias.scope [[META0:![0-9]+]]
5252
; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds i16, ptr [[L_2]], i64 0
5353
; CHECK-NEXT: [[TMP8:%.*]] = extractelement <2 x i16> [[WIDE_LOAD]], i32 1
54-
; CHECK-NEXT: store i16 [[TMP8]], ptr [[TMP7]], align 2, !alias.scope !3, !noalias !0
54+
; CHECK-NEXT: store i16 [[TMP8]], ptr [[TMP7]], align 2, !alias.scope [[META3:![0-9]+]], !noalias [[META0]]
5555
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 2
5656
; CHECK-NEXT: [[TMP9:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
5757
; CHECK-NEXT: br i1 [[TMP9]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP5:![0-9]+]]
@@ -74,7 +74,7 @@ define i16 @test(ptr %arg, i64 %N) {
7474
; CHECK-NEXT: [[LOOP_L_1:%.*]] = load i16, ptr [[GEP_1]], align 2
7575
; CHECK-NEXT: [[GEP_2:%.*]] = getelementptr inbounds i16, ptr [[L_2_LCSSA]], i64 0
7676
; CHECK-NEXT: store i16 [[LOOP_L_1]], ptr [[GEP_2]], align 2
77-
; CHECK-NEXT: br i1 [[C_5]], label [[LOOP_3]], label [[EXIT_LOOPEXIT]], !llvm.loop [[LOOP7:![0-9]+]]
77+
; CHECK-NEXT: br i1 [[C_5]], label [[LOOP_3]], label [[EXIT_LOOPEXIT]], !llvm.loop [[LOOP8:![0-9]+]]
7878
; CHECK: exit.loopexit:
7979
; CHECK-NEXT: br label [[EXIT:%.*]]
8080
; CHECK: exit.loopexit1:
@@ -138,31 +138,17 @@ define void @test2(ptr %dst) {
138138
; CHECK-NEXT: [[INDVAR_NEXT]] = add i32 [[INDVAR]], 1
139139
; CHECK-NEXT: br i1 [[C_1]], label [[LOOP_2]], label [[LOOP_3_PH:%.*]]
140140
; CHECK: loop.3.ph:
141-
; CHECK-NEXT: [[INDVAR_LCSSA1:%.*]] = phi i32 [ [[INDVAR]], [[LOOP_2]] ]
142141
; CHECK-NEXT: [[INDVAR_LCSSA:%.*]] = phi i32 [ [[INDVAR]], [[LOOP_2]] ]
143142
; CHECK-NEXT: [[IV_1_LCSSA:%.*]] = phi i64 [ [[IV_1]], [[LOOP_2]] ]
144143
; CHECK-NEXT: [[TMP0:%.*]] = and i64 [[IV_1_LCSSA]], 4294967295
145-
; CHECK-NEXT: [[TMP1:%.*]] = mul i32 [[INDVAR_LCSSA1]], -1
144+
; CHECK-NEXT: [[TMP1:%.*]] = mul i32 [[INDVAR_LCSSA]], -1
146145
; CHECK-NEXT: [[TMP2:%.*]] = add i32 [[TMP1]], 1000
147-
; CHECK-NEXT: [[SMIN2:%.*]] = call i32 @llvm.smin.i32(i32 [[TMP2]], i32 1)
148-
; CHECK-NEXT: [[TMP3:%.*]] = sub i32 [[TMP2]], [[SMIN2]]
146+
; CHECK-NEXT: [[SMIN:%.*]] = call i32 @llvm.smin.i32(i32 [[TMP2]], i32 1)
147+
; CHECK-NEXT: [[TMP3:%.*]] = sub i32 [[TMP2]], [[SMIN]]
149148
; CHECK-NEXT: [[TMP4:%.*]] = zext i32 [[TMP3]] to i64
150149
; CHECK-NEXT: [[TMP5:%.*]] = add nuw nsw i64 [[TMP4]], 1
151150
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP5]], 2
152-
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_SCEVCHECK:%.*]]
153-
; CHECK: vector.scevcheck:
154-
; CHECK-NEXT: [[TMP6:%.*]] = mul i32 [[INDVAR_LCSSA]], -1
155-
; CHECK-NEXT: [[TMP7:%.*]] = add i32 [[TMP6]], 1000
156-
; CHECK-NEXT: [[SMIN:%.*]] = call i32 @llvm.smin.i32(i32 [[TMP7]], i32 1)
157-
; CHECK-NEXT: [[TMP8:%.*]] = sub i32 [[TMP7]], [[SMIN]]
158-
; CHECK-NEXT: [[TMP9:%.*]] = add i32 [[TMP6]], 999
159-
; CHECK-NEXT: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 1, i32 [[TMP8]])
160-
; CHECK-NEXT: [[MUL_RESULT:%.*]] = extractvalue { i32, i1 } [[MUL]], 0
161-
; CHECK-NEXT: [[MUL_OVERFLOW:%.*]] = extractvalue { i32, i1 } [[MUL]], 1
162-
; CHECK-NEXT: [[TMP10:%.*]] = sub i32 [[TMP9]], [[MUL_RESULT]]
163-
; CHECK-NEXT: [[TMP11:%.*]] = icmp ugt i32 [[TMP10]], [[TMP9]]
164-
; CHECK-NEXT: [[TMP12:%.*]] = or i1 [[TMP11]], [[MUL_OVERFLOW]]
165-
; CHECK-NEXT: br i1 [[TMP12]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
151+
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
166152
; CHECK: vector.ph:
167153
; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[TMP5]], 2
168154
; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 [[TMP5]], [[N_MOD_VF]]
@@ -171,21 +157,21 @@ define void @test2(ptr %dst) {
171157
; CHECK: vector.body:
172158
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
173159
; CHECK-NEXT: [[OFFSET_IDX:%.*]] = sub i64 [[TMP0]], [[INDEX]]
174-
; CHECK-NEXT: [[TMP13:%.*]] = add i64 [[OFFSET_IDX]], 0
175-
; CHECK-NEXT: [[TMP14:%.*]] = add nsw i64 [[TMP13]], -1
176-
; CHECK-NEXT: [[TMP15:%.*]] = and i64 [[TMP14]], 4294967295
177-
; CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds i32, ptr [[DST:%.*]], i64 [[TMP15]]
178-
; CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds i32, ptr [[TMP16]], i32 0
179-
; CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds i32, ptr [[TMP17]], i32 -1
180-
; CHECK-NEXT: store <2 x i32> zeroinitializer, ptr [[TMP18]], align 4
160+
; CHECK-NEXT: [[TMP6:%.*]] = add i64 [[OFFSET_IDX]], 0
161+
; CHECK-NEXT: [[TMP7:%.*]] = add nsw i64 [[TMP6]], -1
162+
; CHECK-NEXT: [[TMP8:%.*]] = and i64 [[TMP7]], 4294967295
163+
; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds i32, ptr [[DST:%.*]], i64 [[TMP8]]
164+
; CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds i32, ptr [[TMP9]], i32 0
165+
; CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds i32, ptr [[TMP10]], i32 -1
166+
; CHECK-NEXT: store <2 x i32> zeroinitializer, ptr [[TMP11]], align 4
181167
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 2
182-
; CHECK-NEXT: [[TMP19:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
183-
; CHECK-NEXT: br i1 [[TMP19]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
168+
; CHECK-NEXT: [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
169+
; CHECK-NEXT: br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP9:![0-9]+]]
184170
; CHECK: middle.block:
185171
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP5]], [[N_VEC]]
186172
; CHECK-NEXT: br i1 [[CMP_N]], label [[LOOP_1_LATCH:%.*]], label [[SCALAR_PH]]
187173
; CHECK: scalar.ph:
188-
; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ [[IND_END]], [[MIDDLE_BLOCK]] ], [ [[TMP0]], [[LOOP_3_PH]] ], [ [[TMP0]], [[VECTOR_SCEVCHECK]] ]
174+
; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ [[IND_END]], [[MIDDLE_BLOCK]] ], [ [[TMP0]], [[LOOP_3_PH]] ]
189175
; CHECK-NEXT: br label [[LOOP_3:%.*]]
190176
; CHECK: loop.3:
191177
; CHECK-NEXT: [[IV_2:%.*]] = phi i64 [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ], [ [[IV_2_NEXT:%.*]], [[LOOP_3]] ]
@@ -195,7 +181,7 @@ define void @test2(ptr %dst) {
195181
; CHECK-NEXT: store i32 0, ptr [[GEP_DST]], align 4
196182
; CHECK-NEXT: [[IV_2_TRUNC:%.*]] = trunc i64 [[IV_2]] to i32
197183
; CHECK-NEXT: [[EC:%.*]] = icmp sgt i32 [[IV_2_TRUNC]], 1
198-
; CHECK-NEXT: br i1 [[EC]], label [[LOOP_3]], label [[LOOP_1_LATCH]], !llvm.loop [[LOOP9:![0-9]+]]
184+
; CHECK-NEXT: br i1 [[EC]], label [[LOOP_3]], label [[LOOP_1_LATCH]], !llvm.loop [[LOOP10:![0-9]+]]
199185
; CHECK: loop.1.latch:
200186
; CHECK-NEXT: [[C_2:%.*]] = call i1 @cond()
201187
; CHECK-NEXT: br i1 [[C_2]], label [[EXIT:%.*]], label [[LOOP_1_HEADER]]

llvm/unittests/Analysis/ScalarEvolutionTest.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,8 +1202,8 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
12021202
EXPECT_EQ(diff(ScevIV, ScevIVNext), -1);
12031203
EXPECT_EQ(diff(ScevIVNext, ScevIV), 1);
12041204
EXPECT_EQ(diff(ScevIVNext, ScevIVNext), 0);
1205-
EXPECT_EQ(diff(ScevIV2P3, ScevIV2), std::nullopt); // TODO
1206-
EXPECT_EQ(diff(ScevIV2PVar, ScevIV2PVarP3), std::nullopt); // TODO
1205+
EXPECT_EQ(diff(ScevIV2P3, ScevIV2), 3);
1206+
EXPECT_EQ(diff(ScevIV2PVar, ScevIV2PVarP3), -3);
12071207
EXPECT_EQ(diff(ScevV0, ScevIV), std::nullopt);
12081208
EXPECT_EQ(diff(ScevIVNext, ScevV3), std::nullopt);
12091209
EXPECT_EQ(diff(ScevYY, ScevV3), std::nullopt);

0 commit comments

Comments
 (0)