Skip to content

Commit 0894dea

Browse files
authored
Fix batch bug (rust-lang#731)
* Fix vector/scalar analysis * add test * fix test * no clang format
1 parent 8b8cd30 commit 0894dea

File tree

4 files changed

+97
-19
lines changed

4 files changed

+97
-19
lines changed

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4612,7 +4612,7 @@ llvm::Function *EnzymeLogic::CreateBatch(Function *tobatch, unsigned width,
46124612

46134613
if (Instruction *cur_inst = dyn_cast<Instruction>(cur)) {
46144614
if (!isa<CallInst>(cur_inst) && !cur_inst->mayReadOrWriteMemory()) {
4615-
for (auto &op : todo_inst->operands())
4615+
for (auto &op : cur_inst->operands())
46164616
toCheck.insert(op);
46174617
continue;
46184618
}

enzyme/test/Enzyme/BatchMode/intsum.ll

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,10 @@ declare [4 x float] @__enzyme_batch(...)
4141

4242
; CHECK: do.body: ; preds = %do.body, %entry
4343
; CHECK-NEXT: %i = phi i64 [ %inc, %do.body ], [ 0, %entry ]
44-
; CHECK-NEXT: %intsum = phi i32 [ 0, %entry ], [ %intadd, %do.body ]
44+
; CHECK-NEXT: %intsum0 = phi i32 [ 0, %entry ], [ %intadd0, %do.body ]
45+
; CHECK-NEXT: %intsum = phi i32 [ 0, %entry ], [ %intadd1, %do.body ]
46+
; CHECK-NEXT: %intsum1 = phi i32 [ 0, %entry ], [ %intadd2, %do.body ]
47+
; CHECK-NEXT: %intsum2 = phi i32 [ 0, %entry ], [ %intadd3, %do.body ]
4548
; CHECK-NEXT: %arrayidx0 = getelementptr inbounds float, float* %unwrap.array0, i64 %i
4649
; CHECK-NEXT: %arrayidx1 = getelementptr inbounds float, float* %unwrap.array1, i64 %i
4750
; CHECK-NEXT: %arrayidx2 = getelementptr inbounds float, float* %unwrap.array2, i64 %i
@@ -50,21 +53,30 @@ declare [4 x float] @__enzyme_batch(...)
5053
; CHECK-NEXT: %loaded1 = load float, float* %arrayidx1
5154
; CHECK-NEXT: %loaded2 = load float, float* %arrayidx2
5255
; CHECK-NEXT: %loaded3 = load float, float* %arrayidx3
53-
; CHECK-NEXT: %fltload = bitcast i32 %intsum to float
54-
; CHECK-NEXT: %add0 = fadd float %fltload, %loaded0
55-
; CHECK-NEXT: %add1 = fadd float %fltload, %loaded1
56-
; CHECK-NEXT: %add2 = fadd float %fltload, %loaded2
57-
; CHECK-NEXT: %add3 = fadd float %fltload, %loaded3
58-
; CHECK-NEXT: %intadd = bitcast float %add0 to i32
56+
; CHECK-NEXT: %fltload0 = bitcast i32 %intsum0 to float
57+
; CHECK-NEXT: %fltload1 = bitcast i32 %intsum to float
58+
; CHECK-NEXT: %fltload2 = bitcast i32 %intsum1 to float
59+
; CHECK-NEXT: %fltload3 = bitcast i32 %intsum2 to float
60+
; CHECK-NEXT: %add0 = fadd float %fltload0, %loaded0
61+
; CHECK-NEXT: %add1 = fadd float %fltload1, %loaded1
62+
; CHECK-NEXT: %add2 = fadd float %fltload2, %loaded2
63+
; CHECK-NEXT: %add3 = fadd float %fltload3, %loaded3
64+
; CHECK-NEXT: %intadd0 = bitcast float %add0 to i32
65+
; CHECK-NEXT: %intadd1 = bitcast float %add1 to i32
66+
; CHECK-NEXT: %intadd2 = bitcast float %add2 to i32
67+
; CHECK-NEXT: %intadd3 = bitcast float %add3 to i32
5968
; CHECK-NEXT: %inc = add nuw nsw i64 %i, 1
6069
; CHECK-NEXT: %cmp = icmp eq i64 %inc, 5
6170
; CHECK-NEXT: br i1 %cmp, label %do.end, label %do.body
6271

6372
; CHECK: do.end: ; preds = %do.body
64-
; CHECK-NEXT: %lcssa = phi float [ %add0, %do.body ]
65-
; CHECK-NEXT: %mrv = insertvalue [4 x float] undef, float %lcssa, 0
66-
; CHECK-NEXT: %mrv1 = insertvalue [4 x float] %mrv, float %lcssa, 1
67-
; CHECK-NEXT: %mrv2 = insertvalue [4 x float] %mrv1, float %lcssa, 2
68-
; CHECK-NEXT: %mrv3 = insertvalue [4 x float] %mrv2, float %lcssa, 3
69-
; CHECK-NEXT: ret [4 x float] %mrv3
73+
; CHECK-NEXT: %lcssa0 = phi float [ %add0, %do.body ]
74+
; CHECK-NEXT: %lcssa = phi float [ %add1, %do.body ]
75+
; CHECK-NEXT: %lcssa3 = phi float [ %add2, %do.body ]
76+
; CHECK-NEXT: %lcssa4 = phi float [ %add3, %do.body ]
77+
; CHECK-NEXT: %mrv = insertvalue [4 x float] undef, float %lcssa0, 0
78+
; CHECK-NEXT: %mrv5 = insertvalue [4 x float] %mrv, float %lcssa, 1
79+
; CHECK-NEXT: %mrv6 = insertvalue [4 x float] %mrv5, float %lcssa3, 2
80+
; CHECK-NEXT: %mrv7 = insertvalue [4 x float] %mrv6, float %lcssa4, 3
81+
; CHECK-NEXT: ret [4 x float] %mrv7
7082
; CHECK-NEXT: }

enzyme/test/Enzyme/BatchMode/square-scalar-add.ll

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,13 @@ entry:
2626
; CHECK-NEXT: %mul1 = fmul double %unwrap.x1, %unwrap.x1
2727
; CHECK-NEXT: %mul2 = fmul double %unwrap.x2, %unwrap.x2
2828
; CHECK-NEXT: %mul3 = fmul double %unwrap.x3, %unwrap.x3
29-
; CHECK-NEXT: %add = fadd double %mul0, %y
30-
; CHECK-NEXT: %mrv = insertvalue [4 x double] undef, double %add, 0
31-
; CHECK-NEXT: %mrv1 = insertvalue [4 x double] %mrv, double %add, 1
32-
; CHECK-NEXT: %mrv2 = insertvalue [4 x double] %mrv1, double %add, 2
33-
; CHECK-NEXT: %mrv3 = insertvalue [4 x double] %mrv2, double %add, 3
29+
; CHECK-NEXT: %add0 = fadd double %mul0, %y
30+
; CHECK-NEXT: %add1 = fadd double %mul1, %y
31+
; CHECK-NEXT: %add2 = fadd double %mul2, %y
32+
; CHECK-NEXT: %add3 = fadd double %mul3, %y
33+
; CHECK-NEXT: %mrv = insertvalue [4 x double] undef, double %add0, 0
34+
; CHECK-NEXT: %mrv1 = insertvalue [4 x double] %mrv, double %add1, 1
35+
; CHECK-NEXT: %mrv2 = insertvalue [4 x double] %mrv1, double %add2, 2
36+
; CHECK-NEXT: %mrv3 = insertvalue [4 x double] %mrv2, double %add3, 3
3437
; CHECK-NEXT: ret [4 x double] %mrv3
3538
; CHECK-NEXT: }
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// RUN: %clang -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
2+
// RUN: %clang -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
3+
// RUN: %clang -O1 -g %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
4+
// RUN: %clang -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
5+
// RUN: %clang -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
6+
// RUN: %clang -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %lli -
7+
// RUN: %clang -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %lli -
8+
// RUN: %clang -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %lli -
9+
// RUN: %clang -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %lli -
10+
11+
#include <stdio.h>
12+
13+
#include "test_utils.h"
14+
15+
extern void __enzyme_batch(void *, int, int, char *, char *, char *, char *);
16+
extern int enzyme_dup;
17+
extern int enzyme_width;
18+
19+
#pragma pack(1)
20+
struct Foo {
21+
int arr[3];
22+
double x;
23+
float y;
24+
double res;
25+
};
26+
27+
void f(char *foo) {
28+
double *xptr = (double *)(foo + sizeof(int[3]));
29+
float *yptr = (float *)(foo + sizeof(int[3]) + sizeof(double));
30+
double *resptr =
31+
(double *)(foo + sizeof(int[3]) + sizeof(double) + sizeof(float));
32+
double x = *xptr;
33+
float y = *yptr;
34+
*resptr = x * y;
35+
}
36+
37+
void df(char *dfoo1, char *dfoo2, char *dfoo3, char *dfoo4) {
38+
__enzyme_batch((void *)f, enzyme_width, 4, dfoo1, dfoo2, dfoo3, dfoo4);
39+
}
40+
41+
int main() {
42+
Foo foo1;
43+
foo1.x = 10;
44+
foo1.y = 9.0;
45+
Foo foo2;
46+
foo2.x = 99.0;
47+
foo2.y = 7.0;
48+
Foo foo3;
49+
foo3.x = 1.1;
50+
foo3.y = 9.0;
51+
Foo foo4;
52+
foo4.x = 3.14;
53+
foo4.y = 0.1;
54+
55+
double expected[4] = {90.0, 693.0, 9.9, 0.314};
56+
57+
df((char *)&foo1, (char *)&foo2, (char *)&foo3, (char *)&foo4);
58+
59+
APPROX_EQ(foo1.res, expected[0], 1e-9);
60+
APPROX_EQ(foo2.res, expected[1], 1e-9);
61+
APPROX_EQ(foo3.res, expected[2], 1e-9);
62+
APPROX_EQ(foo4.res, expected[3], 1e-8);
63+
}

0 commit comments

Comments
 (0)