Skip to content

Commit 45b526a

Browse files
committed
[LV] Honor uniform-after-vectorization in setVectorizedCallDecision.
The legacy cost model always computes the cost for uniforms as cost of VF = 1, but VPWidenCallRecipes would be created, as setVectorizedCallDecisions would not consider uniform calls. Fix setVectorizedCallDecision to set to Scalarize, if the call is uniform-after-vectorization. This fixes a bug in VPlan construction uncovered by the VPlan-based cost model. Fixes #111040.
1 parent b3e0bd3 commit 45b526a

File tree

2 files changed

+124
-3
lines changed

2 files changed

+124
-3
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6196,11 +6196,12 @@ void LoopVectorizationCostModel::setVectorizedCallDecision(ElementCount VF) {
61966196
getScalarizationOverhead(CI, VF, CostKind);
61976197

61986198
ScalarCost = ScalarCallCost * VF.getKnownMinValue() + ScalarizationCost;
6199-
// Honor ForcedScalars decision.
6199+
// Honor ForcedScalars and UniformAfterVectorization decisions.
62006200
// TODO: For calls, it might still be more profitable to widen. Use
62016201
// VPlan-based cost model to compare different options.
6202-
if (VF.isVector() && ForcedScalar != ForcedScalars.end() &&
6203-
ForcedScalar->second.contains(CI)) {
6202+
if (VF.isVector() && ((ForcedScalar != ForcedScalars.end() &&
6203+
ForcedScalar->second.contains(CI)) ||
6204+
isUniformAfterVectorization(CI, VF))) {
62046205
setCallWideningDecision(CI, VF, CM_Scalarize, nullptr,
62056206
Intrinsic::not_intrinsic, std::nullopt,
62066207
ScalarCost);
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt -p loop-vectorize -S %s | FileCheck %s
3+
4+
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
5+
target triple = "x86_64-unknown-linux-gnu"
6+
7+
; Test for https://github.com/llvm/llvm-project/issues/111040
8+
define void @smax_call_uniform(ptr %dst, i64 %x) {
9+
; CHECK-LABEL: define void @smax_call_uniform(
10+
; CHECK-SAME: ptr [[DST:%.*]], i64 [[X:%.*]]) {
11+
; CHECK-NEXT: [[ENTRY:.*]]:
12+
; CHECK-NEXT: [[C:%.*]] = icmp ult i8 -68, -69
13+
; CHECK-NEXT: [[MUL:%.*]] = mul nuw nsw i64 [[X]], 0
14+
; CHECK-NEXT: br i1 true, label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
15+
; CHECK: [[VECTOR_PH]]:
16+
; CHECK-NEXT: [[BROADCAST_SPLATINSERT:%.*]] = insertelement <2 x i1> poison, i1 [[C]], i64 0
17+
; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <2 x i1> [[BROADCAST_SPLATINSERT]], <2 x i1> poison, <2 x i32> zeroinitializer
18+
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
19+
; CHECK: [[VECTOR_BODY]]:
20+
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[PRED_UREM_CONTINUE6:.*]] ]
21+
; CHECK-NEXT: [[TMP0:%.*]] = xor <2 x i1> [[BROADCAST_SPLAT]], <i1 true, i1 true>
22+
; CHECK-NEXT: [[TMP1:%.*]] = xor <2 x i1> [[BROADCAST_SPLAT]], <i1 true, i1 true>
23+
; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x i1> [[TMP0]], i32 0
24+
; CHECK-NEXT: br i1 [[TMP2]], label %[[PRED_UREM_IF:.*]], label %[[PRED_UREM_CONTINUE:.*]]
25+
; CHECK: [[PRED_UREM_IF]]:
26+
; CHECK-NEXT: [[TMP3:%.*]] = urem i64 [[MUL]], [[X]]
27+
; CHECK-NEXT: br label %[[PRED_UREM_CONTINUE]]
28+
; CHECK: [[PRED_UREM_CONTINUE]]:
29+
; CHECK-NEXT: [[TMP4:%.*]] = phi i64 [ poison, %[[VECTOR_BODY]] ], [ [[TMP3]], %[[PRED_UREM_IF]] ]
30+
; CHECK-NEXT: [[TMP5:%.*]] = extractelement <2 x i1> [[TMP0]], i32 1
31+
; CHECK-NEXT: br i1 [[TMP5]], label %[[PRED_UREM_IF1:.*]], label %[[PRED_UREM_CONTINUE2:.*]]
32+
; CHECK: [[PRED_UREM_IF1]]:
33+
; CHECK-NEXT: [[TMP6:%.*]] = urem i64 [[MUL]], [[X]]
34+
; CHECK-NEXT: br label %[[PRED_UREM_CONTINUE2]]
35+
; CHECK: [[PRED_UREM_CONTINUE2]]:
36+
; CHECK-NEXT: [[TMP7:%.*]] = extractelement <2 x i1> [[TMP1]], i32 0
37+
; CHECK-NEXT: br i1 [[TMP7]], label %[[PRED_UREM_IF3:.*]], label %[[PRED_UREM_CONTINUE4:.*]]
38+
; CHECK: [[PRED_UREM_IF3]]:
39+
; CHECK-NEXT: [[TMP8:%.*]] = urem i64 [[MUL]], [[X]]
40+
; CHECK-NEXT: br label %[[PRED_UREM_CONTINUE4]]
41+
; CHECK: [[PRED_UREM_CONTINUE4]]:
42+
; CHECK-NEXT: [[TMP9:%.*]] = phi i64 [ poison, %[[PRED_UREM_CONTINUE2]] ], [ [[TMP8]], %[[PRED_UREM_IF3]] ]
43+
; CHECK-NEXT: [[TMP10:%.*]] = extractelement <2 x i1> [[TMP1]], i32 1
44+
; CHECK-NEXT: br i1 [[TMP10]], label %[[PRED_UREM_IF5:.*]], label %[[PRED_UREM_CONTINUE6]]
45+
; CHECK: [[PRED_UREM_IF5]]:
46+
; CHECK-NEXT: [[TMP11:%.*]] = urem i64 [[MUL]], [[X]]
47+
; CHECK-NEXT: br label %[[PRED_UREM_CONTINUE6]]
48+
; CHECK: [[PRED_UREM_CONTINUE6]]:
49+
; CHECK-NEXT: [[TMP12:%.*]] = tail call i64 @llvm.smax.i64(i64 [[TMP4]], i64 0)
50+
; CHECK-NEXT: [[TMP13:%.*]] = tail call i64 @llvm.smax.i64(i64 [[TMP9]], i64 0)
51+
; CHECK-NEXT: [[TMP14:%.*]] = extractelement <2 x i1> [[TMP0]], i32 0
52+
; CHECK-NEXT: [[PREDPHI:%.*]] = select i1 [[TMP14]], i64 [[TMP12]], i64 1
53+
; CHECK-NEXT: [[TMP15:%.*]] = extractelement <2 x i1> [[TMP1]], i32 0
54+
; CHECK-NEXT: [[PREDPHI7:%.*]] = select i1 [[TMP15]], i64 [[TMP13]], i64 1
55+
; CHECK-NEXT: [[TMP16:%.*]] = add i64 [[PREDPHI]], 1
56+
; CHECK-NEXT: [[TMP17:%.*]] = add i64 [[PREDPHI7]], 1
57+
; CHECK-NEXT: [[TMP18:%.*]] = getelementptr i64, ptr [[DST]], i64 [[TMP16]]
58+
; CHECK-NEXT: [[TMP19:%.*]] = getelementptr i64, ptr [[DST]], i64 [[TMP17]]
59+
; CHECK-NEXT: store i64 0, ptr [[TMP18]], align 8
60+
; CHECK-NEXT: store i64 0, ptr [[TMP19]], align 8
61+
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
62+
; CHECK-NEXT: [[TMP20:%.*]] = icmp eq i64 [[INDEX_NEXT]], 0
63+
; CHECK-NEXT: br i1 [[TMP20]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
64+
; CHECK: [[MIDDLE_BLOCK]]:
65+
; CHECK-NEXT: br i1 true, label %[[EXIT:.*]], label %[[SCALAR_PH]]
66+
; CHECK: [[SCALAR_PH]]:
67+
; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ 0, %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
68+
; CHECK-NEXT: br label %[[LOOP_HEADER:.*]]
69+
; CHECK: [[LOOP_HEADER]]:
70+
; CHECK-NEXT: [[IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], %[[LOOP_LATCH:.*]] ]
71+
; CHECK-NEXT: br i1 [[C]], label %[[LOOP_LATCH]], label %[[ELSE:.*]]
72+
; CHECK: [[ELSE]]:
73+
; CHECK-NEXT: [[REM:%.*]] = urem i64 [[MUL]], [[X]]
74+
; CHECK-NEXT: [[SMAX:%.*]] = tail call i64 @llvm.smax.i64(i64 [[REM]], i64 0)
75+
; CHECK-NEXT: br label %[[LOOP_LATCH]]
76+
; CHECK: [[LOOP_LATCH]]:
77+
; CHECK-NEXT: [[P:%.*]] = phi i64 [ 1, %[[LOOP_HEADER]] ], [ [[SMAX]], %[[ELSE]] ]
78+
; CHECK-NEXT: [[ADD:%.*]] = add i64 [[P]], 1
79+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i64, ptr [[DST]], i64 [[ADD]]
80+
; CHECK-NEXT: store i64 0, ptr [[GEP]], align 8
81+
; CHECK-NEXT: [[IV_NEXT]] = add i64 [[IV]], 1
82+
; CHECK-NEXT: [[EC:%.*]] = icmp eq i64 [[IV_NEXT]], 0
83+
; CHECK-NEXT: br i1 [[EC]], label %[[EXIT]], label %[[LOOP_HEADER]], !llvm.loop [[LOOP3:![0-9]+]]
84+
; CHECK: [[EXIT]]:
85+
; CHECK-NEXT: ret void
86+
;
87+
entry:
88+
%c = icmp ult i8 -68, -69
89+
%mul = mul nsw nuw i64 %x, 0
90+
br label %loop.header
91+
92+
loop.header:
93+
%iv = phi i64 [ 0, %entry ], [ %iv.next, %loop.latch ]
94+
br i1 %c, label %loop.latch, label %else
95+
96+
else:
97+
%rem = urem i64 %mul, %x
98+
%smax = tail call i64 @llvm.smax.i64(i64 %rem, i64 0)
99+
br label %loop.latch
100+
101+
loop.latch:
102+
%p = phi i64 [ 1, %loop.header ], [ %smax, %else ]
103+
%add = add i64 %p, 1
104+
%gep = getelementptr i64, ptr %dst, i64 %add
105+
store i64 0, ptr %gep, align 8
106+
%iv.next = add i64 %iv, 1
107+
%ec = icmp eq i64 %iv.next, 0
108+
br i1 %ec, label %exit, label %loop.header
109+
110+
exit:
111+
ret void
112+
}
113+
114+
declare i64 @llvm.smax.i64(i64, i64)
115+
;.
116+
; CHECK: [[LOOP0]] = distinct !{[[LOOP0]], [[META1:![0-9]+]], [[META2:![0-9]+]]}
117+
; CHECK: [[META1]] = !{!"llvm.loop.isvectorized", i32 1}
118+
; CHECK: [[META2]] = !{!"llvm.loop.unroll.runtime.disable"}
119+
; CHECK: [[LOOP3]] = distinct !{[[LOOP3]], [[META2]], [[META1]]}
120+
;.

0 commit comments

Comments
 (0)