Skip to content

Commit 23306ab

Browse files
authored
Fix recursive map inversion (rust-lang#795)
1 parent faa1c45 commit 23306ab

File tree

2 files changed

+91
-1
lines changed

2 files changed

+91
-1
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9031,7 +9031,6 @@ class AdjointGenerator
90319031
auto ifound = gutils->invertedPointers.find(orig);
90329032
assert(ifound != gutils->invertedPointers.end());
90339033
auto placeholder = cast<PHINode>(&*ifound->second);
9034-
gutils->invertedPointers.erase(ifound);
90359034

90369035
if (subretType == DIFFE_TYPE::DUP_ARG) {
90379036
Value *shadow = placeholder;
@@ -9070,6 +9069,7 @@ class AdjointGenerator
90709069
if (Mode == DerivativeMode::ReverseModeGradient)
90719070
needsReplacement = false;
90729071
}
9072+
gutils->invertedPointers.erase((const Value *)orig);
90739073
gutils->invertedPointers.insert(std::make_pair(
90749074
(const Value *)orig, InvertedPointerVH(gutils, shadow)));
90759075
if (needsReplacement) {
@@ -9078,6 +9078,7 @@ class AdjointGenerator
90789078
gutils->erase(placeholder);
90799079
}
90809080
} else {
9081+
gutils->invertedPointers.erase((const Value *)orig);
90819082
gutils->erase(placeholder);
90829083
}
90839084
}
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -simplifycfg -S | FileCheck %s
2+
3+
source_filename = "map.cpp"
4+
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
5+
target triple = "x86_64-unknown-linux-gnu"
6+
7+
declare double* @_ZSt18_Rb_tree_incrementPKSt18_Rb_tree_node_base(double*)
8+
9+
define double @f(double* %a2) {
10+
entry:
11+
br label %for.body
12+
13+
for.body: ; preds = %entry, %for.body
14+
%q = phi double [ %add, %for.body ], [ 0.000000e+00, %entry ]
15+
%iter = phi double* [ %call.i, %for.body ], [ %a2, %entry ]
16+
%a4 = load double, double* %iter, align 8
17+
%add = fadd double %q, %a4
18+
%call.i = tail call double* @_ZSt18_Rb_tree_incrementPKSt18_Rb_tree_node_base(double* %iter)
19+
%cmp.i.not = icmp eq double* %call.i, null
20+
br i1 %cmp.i.not, label %for.end, label %for.body
21+
22+
for.end: ; preds = %for.body, %entry
23+
%q.0.lcssa = phi double [ %add, %for.body ]
24+
ret double %q.0.lcssa
25+
}
26+
27+
define void @caller() {
28+
entry:
29+
call void (...) @_Z17__enzyme_autodiffPviS_S_(i8* bitcast (double (double *)* @f to i8*), metadata !"enzyme_dup", i8* null, i8* null)
30+
ret void
31+
}
32+
33+
declare void @_Z17__enzyme_autodiffPviS_S_(...)
34+
35+
36+
; CHECK: define internal void @diffef(double* %a2, double* %"a2'", double %differeturn)
37+
; CHECK-NEXT: entry:
38+
; CHECK-NEXT: br label %for.body
39+
40+
; CHECK: for.body: ; preds = %__enzyme_exponentialallocation.exit, %entry
41+
; CHECK-NEXT: %_cache.0 = phi double** [ null, %entry ], [ %12, %__enzyme_exponentialallocation.exit ]
42+
; CHECK-NEXT: %iv = phi i64 [ %iv.next, %__enzyme_exponentialallocation.exit ], [ 0, %entry ]
43+
; CHECK-NEXT: %0 = phi double* [ %14, %__enzyme_exponentialallocation.exit ], [ %"a2'", %entry ]
44+
; CHECK-NEXT: %iter = phi double* [ %call.i, %__enzyme_exponentialallocation.exit ], [ %a2, %entry ]
45+
; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1
46+
; CHECK-NEXT: %1 = bitcast double** %_cache.0 to i8*
47+
; CHECK-NEXT: %2 = and i64 %iv.next, 1
48+
; CHECK-NEXT: %3 = icmp ne i64 %2, 0
49+
; CHECK-NEXT: %4 = call i64 @llvm.ctpop.i64(i64 %iv.next)
50+
; CHECK-NEXT: %5 = icmp ult i64 %4, 3
51+
; CHECK-NEXT: %6 = and i1 %5, %3
52+
; CHECK-NEXT: br i1 %6, label %grow.i, label %__enzyme_exponentialallocation.exit
53+
54+
; CHECK: grow.i: ; preds = %for.body
55+
; CHECK-NEXT: %7 = call i64 @llvm.ctlz.i64(i64 %iv.next, i1 true)
56+
; CHECK-NEXT: %8 = sub nuw nsw i64 64, %7
57+
; CHECK-NEXT: %9 = shl i64 8, %8
58+
; CHECK-NEXT: %10 = call i8* @realloc(i8* %1, i64 %9)
59+
; CHECK-NEXT: br label %__enzyme_exponentialallocation.exit
60+
61+
; CHECK: __enzyme_exponentialallocation.exit: ; preds = %for.body, %grow.i
62+
; CHECK-NEXT: %11 = phi i8* [ %10, %grow.i ], [ %1, %for.body ]
63+
; CHECK-NEXT: %12 = bitcast i8* %11 to double**
64+
; CHECK-NEXT: %13 = getelementptr inbounds double*, double** %12, i64 %iv
65+
; CHECK-NEXT: store double* %0, double** %13, align 8, !invariant.group !0
66+
; CHECK-NEXT: %14 = call double* @_ZSt18_Rb_tree_incrementPKSt18_Rb_tree_node_base(double* %0)
67+
; CHECK-NEXT: %call.i = tail call double* @_ZSt18_Rb_tree_incrementPKSt18_Rb_tree_node_base(double* %iter)
68+
; CHECK-NEXT: %cmp.i.not = icmp eq double* %call.i, null
69+
; CHECK-NEXT: br i1 %cmp.i.not, label %invertfor.body, label %for.body
70+
71+
; CHECK: invertentry: ; preds = %invertfor.body
72+
; CHECK-NEXT: tail call void @free(i8* nonnull %11)
73+
; CHECK-NEXT: ret void
74+
75+
; CHECK: invertfor.body: ; preds = %__enzyme_exponentialallocation.exit, %incinvertfor.body
76+
; CHECK-NEXT: %"iv'ac.0" = phi i64 [ %21, %incinvertfor.body ], [ %iv, %__enzyme_exponentialallocation.exit ]
77+
; CHECK-NEXT: %15 = getelementptr inbounds double*, double** %12, i64 %"iv'ac.0"
78+
; CHECK-NEXT: %16 = load double*, double** %15, align 8, !invariant.group !0
79+
; CHECK-NEXT: %17 = load double, double* %16, align 8
80+
; CHECK-NEXT: %18 = fadd fast double %17, %differeturn
81+
; CHECK-NEXT: store double %18, double* %16, align 8
82+
; CHECK-NEXT: %19 = icmp eq i64 %"iv'ac.0", 0
83+
; CHECK-NEXT: %20 = select {{(fast )?}}i1 %19, double 0.000000e+00, double %differeturn
84+
; CHECK-NEXT: br i1 %19, label %invertentry, label %incinvertfor.body
85+
86+
; CHECK: incinvertfor.body: ; preds = %invertfor.body
87+
; CHECK-NEXT: %21 = add nsw i64 %"iv'ac.0", -1
88+
; CHECK-NEXT: br label %invertfor.body
89+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)