@@ -716,8 +716,8 @@ class AdjointGenerator
716716 if (constantval) {
717717 ts = setPtrDiffe (orig_ptr, Constant::getNullValue (valType), Builder2);
718718 } else {
719- auto dif1 =
720- Builder2. CreateLoad (gutils->invertPointerM (orig_ptr, Builder2));
719+ auto dif1 = Builder2. CreateLoad (
720+ lookup (gutils->invertPointerM (orig_ptr, Builder2) , Builder2));
721721#if LLVM_VERSION_MAJOR >= 10
722722 dif1->setAlignment (SI.getAlign ());
723723#else
@@ -1340,8 +1340,6 @@ class AdjointGenerator
13401340
13411341 std::vector<SelectInst *> addToDiffe (Value *val, Value *dif,
13421342 IRBuilder<> &Builder, Type *T) {
1343- assert (Mode == DerivativeMode::ReverseModeGradient ||
1344- Mode == DerivativeMode::ReverseModeCombined);
13451343 return ((DiffeGradientUtils *)gutils)->addToDiffe (val, dif, Builder, T);
13461344 }
13471345
@@ -1928,7 +1926,8 @@ class AdjointGenerator
19281926 // (which thus == src and may be illegal)
19291927 if (gutils->isConstantValue (orig_src)) {
19301928 SmallVector<Value *, 4 > args;
1931- args.push_back (gutils->invertPointerM (orig_dst, Builder2));
1929+ args.push_back (
1930+ lookup (gutils->invertPointerM (orig_dst, Builder2), Builder2));
19321931 if (args[0 ]->getType ()->isIntegerTy ())
19331932 args[0 ] = Builder2.CreateIntToPtr (
19341933 args[0 ], Type::getInt8PtrTy (MTI->getContext ()));
@@ -1958,7 +1957,8 @@ class AdjointGenerator
19581957
19591958 } else {
19601959 SmallVector<Value *, 4 > args;
1961- auto dsto = gutils->invertPointerM (orig_dst, Builder2);
1960+ auto dsto =
1961+ lookup (gutils->invertPointerM (orig_dst, Builder2), Builder2);
19621962 if (dsto->getType ()->isIntegerTy ())
19631963 dsto = Builder2.CreateIntToPtr (
19641964 dsto, Type::getInt8PtrTy (dsto->getContext ()));
@@ -1968,7 +1968,8 @@ class AdjointGenerator
19681968 if (offset != 0 )
19691969 dsto = Builder2.CreateConstInBoundsGEP1_64 (dsto, offset);
19701970 args.push_back (Builder2.CreatePointerCast (dsto, secretpt));
1971- auto srco = gutils->invertPointerM (orig_src, Builder2);
1971+ auto srco =
1972+ lookup (gutils->invertPointerM (orig_src, Builder2), Builder2);
19721973 if (srco->getType ()->isIntegerTy ())
19731974 srco = Builder2.CreateIntToPtr (
19741975 srco, Type::getInt8PtrTy (srco->getContext ()));
@@ -2949,7 +2950,8 @@ class AdjointGenerator
29492950 IRBuilder<> Builder2 (call.getParent ());
29502951 getReverseBuilder (Builder2);
29512952 args.push_back (
2952- gutils->invertPointerM (call.getArgOperand (i), Builder2));
2953+ lookup (gutils->invertPointerM (call.getArgOperand (i), Builder2),
2954+ Builder2));
29532955 }
29542956 pre_args.push_back (
29552957 gutils->invertPointerM (call.getArgOperand (i), BuilderZ));
@@ -3715,7 +3717,8 @@ class AdjointGenerator
37153717 llvm::errs () << " warning could not automatically determine mpi "
37163718 " status type, assuming [24 x i8]\n " ;
37173719 }
3718- Value *d_req = gutils->invertPointerM (call.getOperand (6 ), Builder2);
3720+ Value *d_req = lookup (
3721+ gutils->invertPointerM (call.getOperand (6 ), Builder2), Builder2);
37193722 Value *args[] = {/* req*/ d_req,
37203723 /* status*/ IRBuilder<>(gutils->inversionAllocs )
37213724 .CreateAlloca (statusType)};
@@ -3769,7 +3772,8 @@ class AdjointGenerator
37693772 ConstantInt::get (Type::getInt8Ty (Builder2.getContext ()), 0 );
37703773 auto volatile_arg = ConstantInt::getFalse (Builder2.getContext ());
37713774 assert (!gutils->isConstantValue (call.getOperand (0 )));
3772- auto dbuf = gutils->invertPointerM (call.getOperand (0 ), Builder2);
3775+ auto dbuf = lookup (
3776+ gutils->invertPointerM (call.getOperand (0 ), Builder2), Builder2);
37733777 if (dbuf->getType ()->isIntegerTy ())
37743778 dbuf = Builder2.CreateIntToPtr (
37753779 dbuf, Type::getInt8PtrTy (call.getContext ()));
@@ -3790,8 +3794,8 @@ class AdjointGenerator
37903794 memset->addParamAttr (0 , Attribute::NonNull);
37913795 } else if (funcName == " MPI_Isend" || funcName == " PMPI_Isend" ) {
37923796 assert (!gutils->isConstantValue (call.getOperand (0 )));
3793- Value *shadow =
3794- gutils->invertPointerM (call.getOperand (0 ), Builder2);
3797+ Value *shadow = lookup (
3798+ gutils->invertPointerM (call.getOperand (0 ), Builder2), Builder2) ;
37953799 if (Mode == DerivativeMode::ReverseModeCombined) {
37963800 assert (firstallocation);
37973801 firstallocation = lookup (firstallocation, Builder2);
@@ -3830,7 +3834,8 @@ class AdjointGenerator
38303834 getReverseBuilder (Builder2);
38313835
38323836 assert (!gutils->isConstantValue (call.getOperand (0 )));
3833- Value *d_req = gutils->invertPointerM (call.getOperand (0 ), Builder2);
3837+ Value *d_req = lookup (
3838+ gutils->invertPointerM (call.getOperand (0 ), Builder2), Builder2);
38343839 if (d_req->getType ()->isIntegerTy ()) {
38353840 d_req = Builder2.CreateIntToPtr (
38363841 d_req,
@@ -3908,8 +3913,8 @@ class AdjointGenerator
39083913 assert (!gutils->isConstantValue (call.getOperand (1 )));
39093914 Value *count =
39103915 lookup (gutils->getNewFromOriginal (call.getOperand (0 )), Builder2);
3911- Value *d_req_orig =
3912- gutils->invertPointerM (call.getOperand (1 ), Builder2);
3916+ Value *d_req_orig = lookup (
3917+ gutils->invertPointerM (call.getOperand (1 ), Builder2), Builder2) ;
39133918 if (d_req_orig->getType ()->isIntegerTy ()) {
39143919 d_req_orig = Builder2.CreateIntToPtr (
39153920 d_req_orig,
@@ -4007,7 +4012,8 @@ class AdjointGenerator
40074012 Mode == DerivativeMode::ReverseModeCombined) {
40084013 IRBuilder<> Builder2 (call.getParent ());
40094014 getReverseBuilder (Builder2);
4010- Value *shadow = gutils->invertPointerM (call.getOperand (0 ), Builder2);
4015+ Value *shadow = lookup (
4016+ gutils->invertPointerM (call.getOperand (0 ), Builder2), Builder2);
40114017
40124018 if (shadow->getType ()->isIntegerTy ())
40134019 shadow = Builder2.CreateIntToPtr (
@@ -4095,7 +4101,8 @@ class AdjointGenerator
40954101 IRBuilder<> Builder2 (call.getParent ());
40964102 getReverseBuilder (Builder2);
40974103
4098- Value *shadow = gutils->invertPointerM (call.getOperand (0 ), Builder2);
4104+ Value *shadow = lookup (
4105+ gutils->invertPointerM (call.getOperand (0 ), Builder2), Builder2);
40994106 if (shadow->getType ()->isIntegerTy ())
41004107 shadow = Builder2.CreateIntToPtr (
41014108 shadow, Type::getInt8PtrTy (call.getContext ()));
@@ -4165,7 +4172,8 @@ class AdjointGenerator
41654172 IRBuilder<> Builder2 (call.getParent ());
41664173 getReverseBuilder (Builder2);
41674174
4168- Value *shadow = gutils->invertPointerM (call.getOperand (0 ), Builder2);
4175+ Value *shadow = lookup (
4176+ gutils->invertPointerM (call.getOperand (0 ), Builder2), Builder2);
41694177 if (shadow->getType ()->isIntegerTy ())
41704178 shadow = Builder2.CreateIntToPtr (
41714179 shadow, Type::getInt8PtrTy (call.getContext ()));
@@ -4365,11 +4373,13 @@ class AdjointGenerator
43654373 report_fatal_error (" unhandled mpi_allreduce op" );
43664374 }
43674375
4368- Value *shadow_recvbuf = gutils->invertPointerM (orig_recvbuf, Builder2);
4376+ Value *shadow_recvbuf =
4377+ lookup (gutils->invertPointerM (orig_recvbuf, Builder2), Builder2);
43694378 if (shadow_recvbuf->getType ()->isIntegerTy ())
43704379 shadow_recvbuf = Builder2.CreateIntToPtr (
43714380 shadow_recvbuf, Type::getInt8PtrTy (call.getContext ()));
4372- Value *shadow_sendbuf = gutils->invertPointerM (orig_sendbuf, Builder2);
4381+ Value *shadow_sendbuf =
4382+ lookup (gutils->invertPointerM (orig_sendbuf, Builder2), Builder2);
43734383 if (shadow_sendbuf->getType ()->isIntegerTy ())
43744384 shadow_sendbuf = Builder2.CreateIntToPtr (
43754385 shadow_sendbuf, Type::getInt8PtrTy (call.getContext ()));
@@ -4552,11 +4562,13 @@ class AdjointGenerator
45524562 report_fatal_error (" unhandled mpi_allreduce op" );
45534563 }
45544564
4555- Value *shadow_recvbuf = gutils->invertPointerM (orig_recvbuf, Builder2);
4565+ Value *shadow_recvbuf =
4566+ lookup (gutils->invertPointerM (orig_recvbuf, Builder2), Builder2);
45564567 if (shadow_recvbuf->getType ()->isIntegerTy ())
45574568 shadow_recvbuf = Builder2.CreateIntToPtr (
45584569 shadow_recvbuf, Type::getInt8PtrTy (call.getContext ()));
4559- Value *shadow_sendbuf = gutils->invertPointerM (orig_sendbuf, Builder2);
4570+ Value *shadow_sendbuf =
4571+ lookup (gutils->invertPointerM (orig_sendbuf, Builder2), Builder2);
45604572 if (shadow_sendbuf->getType ()->isIntegerTy ())
45614573 shadow_sendbuf = Builder2.CreateIntToPtr (
45624574 shadow_sendbuf, Type::getInt8PtrTy (call.getContext ()));
@@ -4665,11 +4677,13 @@ class AdjointGenerator
46654677 Value *orig_root = call.getOperand (6 );
46664678 Value *orig_comm = call.getOperand (7 );
46674679
4668- Value *shadow_recvbuf = gutils->invertPointerM (orig_recvbuf, Builder2);
4680+ Value *shadow_recvbuf =
4681+ lookup (gutils->invertPointerM (orig_recvbuf, Builder2), Builder2);
46694682 if (shadow_recvbuf->getType ()->isIntegerTy ())
46704683 shadow_recvbuf = Builder2.CreateIntToPtr (
46714684 shadow_recvbuf, Type::getInt8PtrTy (call.getContext ()));
4672- Value *shadow_sendbuf = gutils->invertPointerM (orig_sendbuf, Builder2);
4685+ Value *shadow_sendbuf =
4686+ lookup (gutils->invertPointerM (orig_sendbuf, Builder2), Builder2);
46734687 if (shadow_sendbuf->getType ()->isIntegerTy ())
46744688 shadow_sendbuf = Builder2.CreateIntToPtr (
46754689 shadow_sendbuf, Type::getInt8PtrTy (call.getContext ()));
@@ -4820,11 +4834,13 @@ class AdjointGenerator
48204834 Value *orig_root = call.getOperand (6 );
48214835 Value *orig_comm = call.getOperand (7 );
48224836
4823- Value *shadow_recvbuf = gutils->invertPointerM (orig_recvbuf, Builder2);
4837+ Value *shadow_recvbuf =
4838+ lookup (gutils->invertPointerM (orig_recvbuf, Builder2), Builder2);
48244839 if (shadow_recvbuf->getType ()->isIntegerTy ())
48254840 shadow_recvbuf = Builder2.CreateIntToPtr (
48264841 shadow_recvbuf, Type::getInt8PtrTy (call.getContext ()));
4827- Value *shadow_sendbuf = gutils->invertPointerM (orig_sendbuf, Builder2);
4842+ Value *shadow_sendbuf =
4843+ lookup (gutils->invertPointerM (orig_sendbuf, Builder2), Builder2);
48284844 if (shadow_sendbuf->getType ()->isIntegerTy ())
48294845 shadow_sendbuf = Builder2.CreateIntToPtr (
48304846 shadow_sendbuf, Type::getInt8PtrTy (call.getContext ()));
@@ -5008,11 +5024,13 @@ class AdjointGenerator
50085024 Value *orig_recvcount = call.getOperand (4 );
50095025 Value *orig_comm = call.getOperand (6 );
50105026
5011- Value *shadow_recvbuf = gutils->invertPointerM (orig_recvbuf, Builder2);
5027+ Value *shadow_recvbuf =
5028+ lookup (gutils->invertPointerM (orig_recvbuf, Builder2), Builder2);
50125029 if (shadow_recvbuf->getType ()->isIntegerTy ())
50135030 shadow_recvbuf = Builder2.CreateIntToPtr (
50145031 shadow_recvbuf, Type::getInt8PtrTy (call.getContext ()));
5015- Value *shadow_sendbuf = gutils->invertPointerM (orig_sendbuf, Builder2);
5032+ Value *shadow_sendbuf =
5033+ lookup (gutils->invertPointerM (orig_sendbuf, Builder2), Builder2);
50165034 if (shadow_sendbuf->getType ()->isIntegerTy ())
50175035 shadow_sendbuf = Builder2.CreateIntToPtr (
50185036 shadow_sendbuf, Type::getInt8PtrTy (call.getContext ()));
@@ -5502,7 +5520,8 @@ class AdjointGenerator
55025520 diffe (orig, Builder2),
55035521 structarg1,
55045522 estride,
5505- gutils->invertPointerM (orig->getArgOperand (3 ), Builder2),
5523+ lookup (gutils->invertPointerM (orig->getArgOperand (3 ), Builder2),
5524+ Builder2),
55065525 lookup (gutils->getNewFromOriginal (orig->getArgOperand (4 )),
55075526 Builder2)};
55085527 firstdcall = Builder2.CreateCall (derivcall, args1);
@@ -5520,7 +5539,8 @@ class AdjointGenerator
55205539 diffe (orig, Builder2),
55215540 structarg2,
55225541 estride,
5523- gutils->invertPointerM (orig->getArgOperand (1 ), Builder2),
5542+ lookup (gutils->invertPointerM (orig->getArgOperand (1 ), Builder2),
5543+ Builder2),
55245544 lookup (gutils->getNewFromOriginal (orig->getArgOperand (2 )),
55255545 Builder2)};
55265546 seconddcall = Builder2.CreateCall (derivcall, args2);
@@ -7267,7 +7287,8 @@ class AdjointGenerator
72677287 IRBuilder<> Builder2 (call.getParent ());
72687288 getReverseBuilder (Builder2);
72697289 args.push_back (
7270- gutils->invertPointerM (orig->getArgOperand (i), Builder2));
7290+ lookup (gutils->invertPointerM (orig->getArgOperand (i), Builder2),
7291+ Builder2));
72717292 }
72727293 pre_args.push_back (
72737294 gutils->invertPointerM (orig->getArgOperand (i), BuilderZ));
@@ -7702,7 +7723,7 @@ class AdjointGenerator
77027723 llvm::errs () << " orig: " << *orig << " callval: " << *callval << " \n " ;
77037724 }
77047725 assert (!gutils->isConstantValue (callval));
7705- newcalled = gutils->invertPointerM (callval, Builder2);
7726+ newcalled = lookup ( gutils->invertPointerM (callval, Builder2) , Builder2);
77067727
77077728 auto ft = cast<FunctionType>(
77087729 cast<PointerType>(callval->getType ())->getElementType ());
0 commit comments