@@ -101,7 +101,7 @@ void testLLVMBufferTest() {
101101 std::vector<int32_t > v (5 );
102102 std::vector<void *> args ({v.data ()});
103103 auto rv = IntImm::make (0 );
104- LLVMCodeGen cg (rv, {& a});
104+ LLVMCodeGen cg (rv, {a});
105105 EXPECT_EQ (cg.value <int >(args), 0 );
106106}
107107
@@ -116,7 +116,7 @@ void testLLVMBlockTest() {
116116 Store::make (a, IntImm::make (0 ), IntImm::make (4 ), IntImm::make (1 )),
117117 });
118118
119- LLVMCodeGen cg (block, {& a});
119+ LLVMCodeGen cg (block, {a});
120120 EXPECT_EQ (cg.value <int >(args), 0 );
121121 EXPECT_EQ (v[0 ], 4 );
122122 EXPECT_EQ (v[1 ], 4 );
@@ -133,7 +133,7 @@ void testLLVMLoadStoreTest() {
133133 IntImm::make (0 ),
134134 Load::make (a, IntImm::make (0 ), IntImm::make (1 )),
135135 IntImm::make (1 ));
136- LLVMCodeGen cg (store, {& a, & b});
136+ LLVMCodeGen cg (store, {a, b});
137137 std::vector<void *> args ({a_buffer.data (), b_buffer.data ()});
138138 EXPECT_EQ (cg.value <int >(args), 0 );
139139 EXPECT_EQ (a_buffer[0 ], 42 );
@@ -151,7 +151,7 @@ void testLLVMVecLoadStoreTest() {
151151 Ramp::make (0 , 1 , 4 ),
152152 Load::make (a, Ramp::make (0 , 1 , 4 ), Broadcast::make (IntImm::make (1 ), 4 )),
153153 Broadcast::make (IntImm::make (1 ), 4 ));
154- LLVMCodeGen cg (store, {& a, & b});
154+ LLVMCodeGen cg (store, {a, b});
155155 std::vector<void *> args ({a_buffer.data (), b_buffer.data ()});
156156 EXPECT_EQ (cg.value <int >(args), 0 );
157157 EXPECT_EQ (a_buffer[0 ], 1 );
@@ -176,7 +176,7 @@ void testLLVMMemcpyTest() {
176176 auto expr =
177177 For::make (i, 0 , N, Store::make (b, i, Load::make (a, i, mask), mask));
178178
179- LLVMCodeGen cg (expr, {& a, & b});
179+ LLVMCodeGen cg (expr, {a, b});
180180
181181 std::vector<void *> args ({a_buffer.data (), b_buffer.data ()});
182182 ASSERT_EQ (cg.value <int >(args), 0 );
@@ -194,10 +194,9 @@ void testLLVMBzeroTest() {
194194
195195 auto mask = IntImm::make (1 );
196196 Var i (" i" , kInt32 );
197- auto expr =
198- For::make (i, 0 , N, Store::make (b, i, IntImm::make (0 ), mask));
197+ auto expr = For::make (i, 0 , N, Store::make (b, i, IntImm::make (0 ), mask));
199198
200- LLVMCodeGen cg (expr, {& b});
199+ LLVMCodeGen cg (expr, {b});
201200
202201 std::vector<void *> args ({b_buffer.data ()});
203202 ASSERT_EQ (cg.value <int >(args), 0 );
@@ -227,7 +226,7 @@ void testLLVMElemwiseAdd() {
227226 Add::make (Load::make (a, i, mask), Load::make (b, i, mask)),
228227 mask));
229228
230- LLVMCodeGen cg (expr, {& a, & b, & c});
229+ LLVMCodeGen cg (expr, {a, b, c});
231230
232231 std::vector<void *> args ({a_buffer.data (), b_buffer.data (), c_buffer.data ()});
233232 ASSERT_EQ (cg.value <int >(args), 0 );
@@ -257,7 +256,7 @@ void testLLVMElemwiseAddFloat() {
257256 N,
258257 Store::make (c, i, Load::make (a, i, mask) + Load::make (b, i, mask), mask));
259258
260- LLVMCodeGen cg (expr, {& a, & b, & c});
259+ LLVMCodeGen cg (expr, {a, b, c});
261260
262261 std::vector<void *> args ({a_buffer.data (), b_buffer.data (), c_buffer.data ()});
263262 ASSERT_EQ (cg.value <int >(args), 0 );
@@ -282,10 +281,14 @@ void testLLVMElemwiseLog10Float() {
282281 auto expr = For::make (
283282 i,
284283 0 ,
285- N/4 ,
286- Store::make (b, Ramp::make (i * 4 , 1 , 4 ), log10 (Load::make (a, Ramp::make (i * 4 , 1 , 4 ), mask)), mask));
284+ N / 4 ,
285+ Store::make (
286+ b,
287+ Ramp::make (i * 4 , 1 , 4 ),
288+ log10 (Load::make (a, Ramp::make (i * 4 , 1 , 4 ), mask)),
289+ mask));
287290
288- LLVMCodeGen cg (expr, {& a, & b});
291+ LLVMCodeGen cg (expr, {a, b});
289292
290293 std::vector<void *> args ({a_buffer.data (), b_buffer.data ()});
291294 ASSERT_EQ (cg.value <int >(args), 0 );
@@ -317,7 +320,7 @@ void testLLVMElemwiseMaxInt() {
317320 Max::make (Load::make (a, i, mask), Load::make (b, i, mask), false ),
318321 mask));
319322
320- LLVMCodeGen cg (expr, {& a, & b, & c});
323+ LLVMCodeGen cg (expr, {a, b, c});
321324
322325 std::vector<void *> args ({a_buffer.data (), b_buffer.data (), c_buffer.data ()});
323326 ASSERT_EQ (cg.value <int >(args), 0 );
@@ -351,7 +354,7 @@ void testLLVMElemwiseMinInt() {
351354 Min::make (Load::make (a, i, mask), Load::make (b, i, mask), false ),
352355 mask));
353356
354- LLVMCodeGen cg (expr, {& a, & b, & c});
357+ LLVMCodeGen cg (expr, {a, b, c});
355358
356359 std::vector<void *> args ({a_buffer.data (), b_buffer.data (), c_buffer.data ()});
357360 ASSERT_EQ (cg.value <int >(args), 0 );
@@ -385,7 +388,7 @@ void testLLVMElemwiseMaxNumFloat() {
385388 Max::make (Load::make (a, i, mask), Load::make (b, i, mask), false ),
386389 mask));
387390
388- LLVMCodeGen cg (expr, {& a, & b, & c});
391+ LLVMCodeGen cg (expr, {a, b, c});
389392
390393 std::vector<void *> args ({a_buffer.data (), b_buffer.data (), c_buffer.data ()});
391394 ASSERT_EQ (cg.value <int >(args), 0 );
@@ -419,7 +422,7 @@ void testLLVMElemwiseMaxNumNaNFloat() {
419422 Max::make (Load::make (a, i, mask), Load::make (b, i, mask), false ),
420423 mask));
421424
422- LLVMCodeGen cg (expr, {& a, & b, & c});
425+ LLVMCodeGen cg (expr, {a, b, c});
423426
424427 std::vector<void *> args ({a_buffer.data (), b_buffer.data (), c_buffer.data ()});
425428 ASSERT_EQ (cg.value <int >(args), 0 );
@@ -452,7 +455,7 @@ void testLLVMElemwiseMinNumFloat() {
452455 Min::make (Load::make (a, i, mask), Load::make (b, i, mask), false ),
453456 mask));
454457
455- LLVMCodeGen cg (expr, {& a, & b, & c});
458+ LLVMCodeGen cg (expr, {a, b, c});
456459
457460 std::vector<void *> args ({a_buffer.data (), b_buffer.data (), c_buffer.data ()});
458461 ASSERT_EQ (cg.value <int >(args), 0 );
@@ -486,7 +489,7 @@ void testLLVMElemwiseMinNumNaNFloat() {
486489 Min::make (Load::make (a, i, mask), Load::make (b, i, mask), false ),
487490 mask));
488491
489- LLVMCodeGen cg (expr, {& a, & b, & c});
492+ LLVMCodeGen cg (expr, {a, b, c});
490493
491494 std::vector<void *> args ({a_buffer.data (), b_buffer.data (), c_buffer.data ()});
492495 ASSERT_EQ (cg.value <int >(args), 0 );
@@ -520,7 +523,7 @@ void testLLVMElemwiseMaximumFloat() {
520523 Max::make (Load::make (a, i, mask), Load::make (b, i, mask), true ),
521524 mask));
522525
523- LLVMCodeGen cg (expr, {& a, & b, & c});
526+ LLVMCodeGen cg (expr, {a, b, c});
524527
525528 std::vector<void *> args ({a_buffer.data (), b_buffer.data (), c_buffer.data ()});
526529 ASSERT_EQ (cg.value <int >(args), 0 );
@@ -554,7 +557,7 @@ void testLLVMElemwiseMaximumNaNFloat() {
554557 Max::make (Load::make (a, i, mask), Load::make (b, i, mask), true ),
555558 mask));
556559
557- LLVMCodeGen cg (expr, {& a, & b, & c});
560+ LLVMCodeGen cg (expr, {a, b, c});
558561
559562 std::vector<void *> args ({a_buffer.data (), b_buffer.data (), c_buffer.data ()});
560563 ASSERT_EQ (cg.value <int >(args), 0 );
@@ -589,7 +592,7 @@ void testLLVMElemwiseMinimumFloat() {
589592 Min::make (Load::make (a, i, mask), Load::make (b, i, mask), true ),
590593 mask));
591594
592- LLVMCodeGen cg (expr, {& a, & b, & c});
595+ LLVMCodeGen cg (expr, {a, b, c});
593596
594597 std::vector<void *> args ({a_buffer.data (), b_buffer.data (), c_buffer.data ()});
595598 ASSERT_EQ (cg.value <int >(args), 0 );
@@ -623,7 +626,7 @@ void testLLVMElemwiseMinimumNaNFloat() {
623626 Min::make (Load::make (a, i, mask), Load::make (b, i, mask), true ),
624627 mask));
625628
626- LLVMCodeGen cg (expr, {& a, & b, & c});
629+ LLVMCodeGen cg (expr, {a, b, c});
627630
628631 std::vector<void *> args ({a_buffer.data (), b_buffer.data (), c_buffer.data ()});
629632 ASSERT_EQ (cg.value <int >(args), 0 );
@@ -668,7 +671,7 @@ void testLLVMCompareSelectIntEQ() {
668671 CompareSelectOperation::kEQ ),
669672 mask));
670673
671- LLVMCodeGen cg (expr, {& a, & b, & c});
674+ LLVMCodeGen cg (expr, {a, b, c});
672675
673676 std::vector<void *> args ({a_buffer.data (), b_buffer.data (), c_buffer.data ()});
674677 ASSERT_EQ (cg.value <int >(args), 0 );
@@ -707,7 +710,7 @@ void testLLVMCompareSelectFloatEQ() {
707710 CompareSelectOperation::kEQ ),
708711 mask));
709712
710- LLVMCodeGen cg (expr, {& a, & b, & c});
713+ LLVMCodeGen cg (expr, {a, b, c});
711714
712715 std::vector<void *> args ({a_buffer.data (), b_buffer.data (), c_buffer.data ()});
713716 ASSERT_EQ (cg.value <int >(args), 0 );
@@ -726,7 +729,7 @@ void testLLVMStoreFloat() {
726729 std::vector<float > result_buffer = {0 .0f };
727730 auto expr = Store::make (
728731 result, IntImm::make (0 ), FloatImm::make (3 .14f ), IntImm::make (1 ));
729- LLVMCodeGen cg (expr, {& result});
732+ LLVMCodeGen cg (expr, {result});
730733 std::vector<void *> args ({result_buffer.data ()});
731734 ASSERT_EQ (cg.value <int >(args), 0 );
732735 EXPECT_EQ (result_buffer[0 ], 3 .14f );
@@ -739,7 +742,7 @@ void testLLVMSimpleMath01() {
739742 Schedule sch = Schedule::make ({tensor});
740743 Stmt stmt = sch.Lower ();
741744 Buffer f_buf (tensor.function ().func_var (), kFloat32 , {N});
742- LLVMCodeGen cg (stmt, {& f_buf});
745+ LLVMCodeGen cg (stmt, {f_buf});
743746
744747 PaddedBuffer<float > f_v (N, " f_v" );
745748 std::vector<void *> args ({f_v.data ()});
@@ -764,7 +767,7 @@ void testLLVMComputeMul() {
764767 Schedule sch = Schedule::make ({c});
765768 Stmt s = sch.Lower ();
766769
767- LLVMCodeGen cg (s, {& a, & b, & c_buf});
770+ LLVMCodeGen cg (s, {a, b, c_buf});
768771
769772 std::vector<float > a_vec (N, 21 .0f );
770773 std::vector<float > b_vec (N, 2 .0f );
@@ -789,7 +792,7 @@ void testLLVMBroadcastAdd() {
789792 Schedule sch = Schedule::make ({c});
790793 Stmt s = sch.Lower ();
791794
792- LLVMCodeGen cg (s, {& a, & b, & c_buf});
795+ LLVMCodeGen cg (s, {a, b, c_buf});
793796
794797 std::vector<float > av (M * N);
795798 std::iota (av.begin (), av.end (), 0 );
@@ -805,6 +808,30 @@ void testLLVMBroadcastAdd() {
805808 }
806809 }
807810}
811+
812+ void testLLVMDynamicShapeAdd () {
813+ #if 0
814+ auto testWithSize = [](int32_t size) {
815+ Var n("n", kInt32);
816+ Buffer a(Var("a", kHandle), kFloat32, {n});
817+ Buffer b(Var("b", kHandle), kFloat32, {n});
818+ Buffer c(Var("c", kHandle), kFloat32, {n});
819+ Var i("i", kInt32);
820+ Stmt s = For::make(i, 0, n, Store::make(c, i, a(i) + b(i), 1));
821+ std::vector<float> aData(size, 1.0f);
822+ std::vector<float> bData(size, 2.0f);
823+ std::vector<float> cData(size, 0.0f);
824+ LLVMCodeGen cg(s, {a, b, c, n});
825+ std::vector<void*> args({aData.data(), bData.data(), cData.data(), size));
826+ cg.value<float>(args);
827+ ExpectAllNear(cData, std::vector<float>(size, 3.0f), 1e-7);
828+ };
829+ testWithSize(1);
830+ testWithSize(16);
831+ testWithSize(37);
832+ #endif
833+ }
834+
808835} // namespace jit
809836} // namespace torch
810837
0 commit comments