Skip to content

Commit f08a1bf

Browse files
authored
Fix simplify of invoke (rust-lang#552)
1 parent 2a56819 commit f08a1bf

File tree

2 files changed

+120
-6
lines changed

2 files changed

+120
-6
lines changed

enzyme/Enzyme/FunctionUtils.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -726,13 +726,14 @@ Function *CreateMPIWrapper(Function *F) {
726726
#endif
727727
return W;
728728
}
729+
template <typename T>
729730
static void SimplifyMPIQueries(Function &NewF, FunctionAnalysisManager &FAM) {
730731
DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(NewF);
731-
SmallVector<CallInst *, 4> Todo;
732-
SmallVector<CallInst *, 0> OMPBounds;
732+
SmallVector<T *, 4> Todo;
733+
SmallVector<T *, 0> OMPBounds;
733734
for (auto &BB : NewF) {
734735
for (auto &I : BB) {
735-
if (auto CI = dyn_cast<CallInst>(&I)) {
736+
if (auto CI = dyn_cast<T>(&I)) {
736737
Function *Fn = CI->getCalledFunction();
737738
if (Fn == nullptr)
738739
continue;
@@ -751,6 +752,8 @@ static void SimplifyMPIQueries(Function &NewF, FunctionAnalysisManager &FAM) {
751752
}
752753
}
753754
}
755+
if (Todo.size() == 0 && OMPBounds.size() == 0)
756+
return;
754757
for (auto CI : Todo) {
755758
IRBuilder<> B(CI);
756759
Value *arg[] = {CI->getArgOperand(0)};
@@ -802,7 +805,11 @@ static void SimplifyMPIQueries(Function &NewF, FunctionAnalysisManager &FAM) {
802805
}
803806
}
804807
}
805-
B.SetInsertPoint(res->getNextNode());
808+
if (auto II = dyn_cast<InvokeInst>(res)) {
809+
B.SetInsertPoint(II->getNormalDest()->getFirstNonPHI());
810+
} else {
811+
B.SetInsertPoint(res->getNextNode());
812+
}
806813
B.CreateStore(res, storePointer);
807814
}
808815
for (auto Bound : OMPBounds) {
@@ -818,7 +825,11 @@ static void SimplifyMPIQueries(Function &NewF, FunctionAnalysisManager &FAM) {
818825
B.CreateStore(B.CreateLoad(AI), AI2);
819826
#endif
820827
Bound->setArgOperand(i, AI2);
821-
B.SetInsertPoint(Bound->getNextNode());
828+
if (auto II = dyn_cast<InvokeInst>(Bound)) {
829+
B.SetInsertPoint(II->getNormalDest()->getFirstNonPHI());
830+
} else {
831+
B.SetInsertPoint(Bound->getNextNode());
832+
}
822833
#if LLVM_VERSION_MAJOR > 7
823834
B.CreateStore(B.CreateLoad(AI2->getAllocatedType(), AI2), AI);
824835
#else
@@ -1191,7 +1202,8 @@ Function *PreProcessCache::preprocessForClone(Function *F,
11911202
ConstantFoldTerminator(BE);
11921203
}
11931204

1194-
SimplifyMPIQueries(*NewF, FAM);
1205+
SimplifyMPIQueries<CallInst>(*NewF, FAM);
1206+
SimplifyMPIQueries<InvokeInst>(*NewF, FAM);
11951207

11961208
if (EnzymeLowerGlobals) {
11971209
std::vector<CallInst *> Calls;
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
; RUN: if [ %llvmver -ge 11 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=1 -mem2reg -instsimplify -adce -simplifycfg -S | FileCheck %s; fi
2+
3+
source_filename = "/home/ubuntu/LULESH-MPI-RAJA/lulesh-v2.0/RAJA/lulesh.cpp"
4+
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
5+
target triple = "x86_64-unknown-linux-gnu"
6+
7+
%struct.ident_t = type { i32, i32, i32, i32, i8* }
8+
@0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
9+
@1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, i8* getelementptr inbounds ([23 x i8], [23 x i8]* @0, i32 0, i32 0) }, align 8
10+
@2 = private unnamed_addr constant %struct.ident_t { i32 0, i32 514, i32 0, i32 22, i8* getelementptr inbounds ([23 x i8], [23 x i8]* @0, i32 0, i32 0) }, align 8
11+
@3 = private unnamed_addr constant %struct.ident_t { i32 0, i32 66, i32 0, i32 22, i8* getelementptr inbounds ([23 x i8], [23 x i8]* @0, i32 0, i32 0) }, align 8
12+
13+
; Function Attrs: argmemonly nofree nosync nounwind willreturn
14+
declare void @llvm.lifetime.start.p0i8(i64, i8* nocapture)
15+
16+
; Function Attrs: argmemonly nofree nosync nounwind willreturn
17+
declare void @llvm.lifetime.end.p0i8(i64, i8* nocapture)
18+
19+
define void @caller(i8* %call18, i8* %call27) {
20+
entry:
21+
call void @_Z17__enzyme_autodiffPvS_S_(i8* bitcast (void (i64**, double*, i64)* @_ZL16LagrangeLeapFrogP6Domain to i8*), i8* %call18, i8* %call18, i8* %call27, i64 10)
22+
ret void
23+
}
24+
25+
declare i32 @__gxx_personality_v0(...)
26+
27+
declare void @_Z17__enzyme_autodiffPvS_S_(i8*, i8*, i8*, i8*, i64)
28+
29+
; Function Attrs: inlinehint nounwind uwtable
30+
define internal void @_ZL16LagrangeLeapFrogP6Domain(i64** noalias %i12p, double* noalias %i13, i64 %a.val3) {
31+
entry:
32+
%i12 = load i64*, i64** %i12p, align 8
33+
call void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* nonnull @1, i32 3, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, i64, i64*, double*)* @.omp_outlined. to void (i32*, i32*, ...)*), i64 %a.val3, i64* nonnull %i12, double* %i13)
34+
ret void
35+
}
36+
37+
; Function Attrs: alwaysinline norecurse nounwind uwtable
38+
define internal void @.omp_outlined.(i32* noalias nocapture noundef readnone %.global_tid., i32* noalias nocapture noundef readnone %.bound_tid., i64 %.val3, i64* noalias %i12, double* noalias %i13) personality i8* bitcast (i32 (...)* @__gxx_personality_v0 to i8*) {
39+
entry:
40+
%.omp.lb.i.i = alloca i64, align 8
41+
%.omp.ub.i.i = alloca i64, align 8
42+
%.omp.stride.i.i = alloca i64, align 8
43+
%.omp.is_last.i.i = alloca i32, align 4
44+
%i4 = tail call i32 @__kmpc_global_thread_num(%struct.ident_t* nonnull @1)
45+
%sub9.i.i = add nsw i64 %.val3, -1
46+
store i64 0, i64* %.omp.lb.i.i, align 8
47+
store i64 %sub9.i.i, i64* %.omp.ub.i.i, align 8
48+
store i64 1, i64* %.omp.stride.i.i, align 8
49+
store i32 0, i32* %.omp.is_last.i.i, align 4
50+
invoke void @__kmpc_for_static_init_8(%struct.ident_t* nonnull @2, i32 %i4, i32 34, i32* nonnull %.omp.is_last.i.i, i64* nonnull %.omp.lb.i.i, i64* nonnull %.omp.ub.i.i, i64* nonnull %.omp.stride.i.i, i64 1, i64 1)
51+
to label %.noexc unwind label %terminate.lpad
52+
53+
.noexc: ; preds = %entry
54+
%i9 = load i64, i64* %.omp.ub.i.i, align 8
55+
%cmp11.i.i = icmp sgt i64 %i9, %sub9.i.i
56+
%cond.i.i = select i1 %cmp11.i.i, i64 %sub9.i.i, i64 %i9
57+
store i64 %cond.i.i, i64* %.omp.ub.i.i, align 8
58+
%i10 = load i64, i64* %.omp.lb.i.i, align 8
59+
%cmp12.not3.i.i = icmp sgt i64 %i10, %cond.i.i
60+
br i1 %cmp12.not3.i.i, label %omp.loop.exit.i.i, label %omp.inner.for.body.lr.ph.i.i
61+
62+
omp.inner.for.body.lr.ph.i.i: ; preds = %.noexc
63+
br label %omp.inner.for.body.i.i
64+
65+
omp.inner.for.body.i.i: ; preds = %omp.inner.for.inc.i.i, %omp.inner.for.body.lr.ph.i.i
66+
%.omp.iv.04.i.i = phi i64 [ %i10, %omp.inner.for.body.lr.ph.i.i ], [ %add15.i.i, %omp.inner.for.inc.i.i ]
67+
%sub.i.i.i.i = load i64, i64* %i12, align 8
68+
br label %for.body.i.i.i
69+
70+
for.body.i.i.i: ; preds = %for.body.i.i.i, %omp.inner.for.body.i.i
71+
%i.03.i.i.i = phi i64 [ %inc.i.i.i, %for.body.i.i.i ], [ 0, %omp.inner.for.body.i.i ]
72+
%inc.i.i.i = add nuw nsw i64 %i.03.i.i.i, 1
73+
%exitcond.not.i.i.i = icmp eq i64 %inc.i.i.i, %sub.i.i.i.i
74+
br i1 %exitcond.not.i.i.i, label %omp.inner.for.inc.i.i, label %for.body.i.i.i
75+
76+
omp.inner.for.inc.i.i: ; preds = %for.body.i.i.i
77+
%add.ptr.i.i.i.i.i = getelementptr inbounds double, double* %i13, i64 %.omp.iv.04.i.i
78+
store double 1.000000e+00, double* %add.ptr.i.i.i.i.i, align 8
79+
%add15.i.i = add i64 %.omp.iv.04.i.i, 1
80+
%exitcond.not.i.i = icmp eq i64 %.omp.iv.04.i.i, %cond.i.i
81+
br i1 %exitcond.not.i.i, label %omp.loop.exit.i.i, label %omp.inner.for.body.i.i
82+
83+
omp.loop.exit.i.i: ; preds = %omp.inner.for.inc.i.i, %.noexc
84+
call void @__kmpc_for_static_fini(%struct.ident_t* nonnull @2, i32 %i4)
85+
ret void
86+
87+
terminate.lpad: ; preds = %entry
88+
%i16 = landingpad { i8*, i32 }
89+
catch i8* null
90+
unreachable
91+
}
92+
93+
; Function Attrs: nounwind
94+
declare void @__kmpc_fork_call(%struct.ident_t* nocapture readonly, i32, void (i32*, i32*, ...)* nocapture readonly, ...)
95+
96+
declare i32 @__kmpc_global_thread_num(%struct.ident_t* nocapture readonly)
97+
98+
declare void @__kmpc_for_static_init_8(%struct.ident_t*, i32, i32, i32*, i64*, i64*, i64*, i64, i64)
99+
100+
declare void @__kmpc_for_static_fini(%struct.ident_t* nocapture readonly, i32)
101+
102+
; CHECK: diffe.omp_outlined

0 commit comments

Comments
 (0)