@@ -10075,71 +10075,77 @@ class AdjointGenerator
10075
10075
anti = gutils->cacheForReverse(
10076
10076
bb, anti, getIndex(orig, CacheType::Shadow));
10077
10077
} else {
10078
+ auto rule = [&]() {
10078
10079
#if LLVM_VERSION_MAJOR >= 11
10079
- anti = bb.CreateCall(orig->getFunctionType(),
10080
- orig->getCalledOperand(), args,
10081
- orig->getName() + "'mi");
10080
+ Value * anti = bb.CreateCall(orig->getFunctionType(),
10081
+ orig->getCalledOperand(), args,
10082
+ orig->getName() + "'mi");
10082
10083
#else
10083
- anti = bb.CreateCall(orig->getCalledValue(), args,
10084
- orig->getName() + "'mi");
10084
+ anti = bb.CreateCall(orig->getCalledValue(), args,
10085
+ orig->getName() + "'mi");
10085
10086
#endif
10086
- cast<CallInst>(anti)->setAttributes(orig->getAttributes());
10087
- cast<CallInst>(anti)->setCallingConv(orig->getCallingConv());
10088
- cast<CallInst>(anti)->setTailCallKind(orig->getTailCallKind());
10089
- cast<CallInst>(anti)->setDebugLoc(dbgLoc);
10087
+ cast<CallInst>(anti)->setAttributes(orig->getAttributes());
10088
+ cast<CallInst>(anti)->setCallingConv(orig->getCallingConv());
10089
+ cast<CallInst>(anti)->setTailCallKind(orig->getTailCallKind());
10090
+ cast<CallInst>(anti)->setDebugLoc(dbgLoc);
10090
10091
10091
10092
#if LLVM_VERSION_MAJOR >= 14
10092
- cast<CallInst>(anti)->addAttributeAtIndex(
10093
- AttributeList::ReturnIndex, Attribute::NoAlias);
10094
- cast<CallInst>(anti)->addAttributeAtIndex(
10095
- AttributeList::ReturnIndex, Attribute::NonNull);
10096
- #else
10097
- cast<CallInst>(anti)->addAttribute(AttributeList::ReturnIndex,
10098
- Attribute::NoAlias);
10099
- cast<CallInst>(anti)->addAttribute(AttributeList::ReturnIndex,
10100
- Attribute::NonNull);
10101
- #endif
10102
-
10103
- if (called->getName() == "malloc" ||
10104
- called->getName() == "_Znwm") {
10105
- if (auto ci = dyn_cast<ConstantInt>(args[0])) {
10106
- unsigned derefBytes = ci->getLimitedValue();
10107
- CallInst *cal =
10108
- cast<CallInst>(gutils->getNewFromOriginal(orig));
10093
+ cast<CallInst>(anti)->addAttributeAtIndex(
10094
+ AttributeList::ReturnIndex, Attribute::NoAlias);
10095
+ cast<CallInst>(anti)->addAttributeAtIndex(
10096
+ AttributeList::ReturnIndex, Attribute::NonNull);
10097
+ #else
10098
+ cast<CallInst>(anti)->addAttribute(AttributeList::ReturnIndex,
10099
+ Attribute::NoAlias);
10100
+ cast<CallInst>(anti)->addAttribute(AttributeList::ReturnIndex,
10101
+ Attribute::NonNull);
10102
+ #endif
10103
+
10104
+ if (called->getName() == "malloc" ||
10105
+ called->getName() == "_Znwm") {
10106
+ if (auto ci = dyn_cast<ConstantInt>(args[0])) {
10107
+ unsigned derefBytes = ci->getLimitedValue();
10108
+ CallInst *cal =
10109
+ cast<CallInst>(gutils->getNewFromOriginal(orig));
10109
10110
#if LLVM_VERSION_MAJOR >= 14
10110
- cast<CallInst>(anti)->addDereferenceableRetAttr(derefBytes);
10111
- cal->addDereferenceableRetAttr(derefBytes);
10111
+ cast<CallInst>(anti)->addDereferenceableRetAttr(derefBytes);
10112
+ cal->addDereferenceableRetAttr(derefBytes);
10112
10113
#ifndef FLANG
10113
- AttrBuilder B(called->getContext());
10114
- #else
10115
- AttrBuilder B;
10116
- #endif
10117
- B.addDereferenceableOrNullAttr(derefBytes);
10118
- cast<CallInst>(anti)->setAttributes(
10119
- cast<CallInst>(anti)->getAttributes().addRetAttributes(
10120
- orig->getContext(), B));
10121
- cal->setAttributes(cal->getAttributes().addRetAttributes(
10122
- orig->getContext(), B));
10123
- cal->addAttributeAtIndex(AttributeList::ReturnIndex,
10124
- Attribute::NoAlias);
10125
- cal->addAttributeAtIndex(AttributeList::ReturnIndex,
10126
- Attribute::NonNull);
10127
- #else
10128
- cast<CallInst>(anti)->addDereferenceableAttr(
10129
- llvm::AttributeList::ReturnIndex, derefBytes);
10130
- cal->addDereferenceableAttr(llvm::AttributeList::ReturnIndex,
10131
- derefBytes);
10132
- cast<CallInst>(anti)->addDereferenceableOrNullAttr(
10133
- llvm::AttributeList::ReturnIndex, derefBytes);
10134
- cal->addDereferenceableOrNullAttr(
10135
- llvm::AttributeList::ReturnIndex, derefBytes);
10136
- cal->addAttribute(AttributeList::ReturnIndex,
10137
- Attribute::NoAlias);
10138
- cal->addAttribute(AttributeList::ReturnIndex,
10139
- Attribute::NonNull);
10114
+ AttrBuilder B(called->getContext());
10115
+ #else
10116
+ AttrBuilder B;
10117
+ #endif
10118
+ B.addDereferenceableOrNullAttr(derefBytes);
10119
+ cast<CallInst>(anti)->setAttributes(
10120
+ cast<CallInst>(anti)->getAttributes().addRetAttributes(
10121
+ orig->getContext(), B));
10122
+ cal->setAttributes(cal->getAttributes().addRetAttributes(
10123
+ orig->getContext(), B));
10124
+ cal->addAttributeAtIndex(AttributeList::ReturnIndex,
10125
+ Attribute::NoAlias);
10126
+ cal->addAttributeAtIndex(AttributeList::ReturnIndex,
10127
+ Attribute::NonNull);
10128
+ #else
10129
+ cast<CallInst>(anti)->addDereferenceableAttr(
10130
+ llvm::AttributeList::ReturnIndex, derefBytes);
10131
+ cal->addDereferenceableAttr(
10132
+ llvm::AttributeList::ReturnIndex, derefBytes);
10133
+ cast<CallInst>(anti)->addDereferenceableOrNullAttr(
10134
+ llvm::AttributeList::ReturnIndex, derefBytes);
10135
+ cal->addDereferenceableOrNullAttr(
10136
+ llvm::AttributeList::ReturnIndex, derefBytes);
10137
+ cal->addAttribute(AttributeList::ReturnIndex,
10138
+ Attribute::NoAlias);
10139
+ cal->addAttribute(AttributeList::ReturnIndex,
10140
+ Attribute::NonNull);
10140
10141
#endif
10142
+ }
10141
10143
}
10142
- }
10144
+ return anti;
10145
+ };
10146
+
10147
+ anti = applyChainRule(orig->getType(), bb, rule);
10148
+
10143
10149
gutils->invertedPointers.erase(found);
10144
10150
if (&*bb.GetInsertPoint() == placeholder)
10145
10151
bb.SetInsertPoint(placeholder->getNextNode());
@@ -10163,6 +10169,7 @@ class AdjointGenerator
10163
10169
#else
10164
10170
replacement->setAlignment(Alignment);
10165
10171
#endif
10172
+
10166
10173
gutils->replaceAWithB(cast<Instruction>(anti), replacement);
10167
10174
gutils->erase(cast<Instruction>(anti));
10168
10175
anti = replacement;
@@ -10176,18 +10183,33 @@ class AdjointGenerator
10176
10183
backwardsShadow) ||
10177
10184
(Mode == DerivativeMode::ForwardModeSplit &&
10178
10185
backwardsShadow)) {
10179
- if (!inLoop)
10180
- zeroKnownAllocation(bb, anti, args, *called, gutils->TLI);
10186
+ if (!inLoop) {
10187
+ applyChainRule(
10188
+ bb,
10189
+ [&](Value *anti) {
10190
+ zeroKnownAllocation(bb, anti, args, *called,
10191
+ gutils->TLI);
10192
+ },
10193
+ anti);
10194
+ }
10181
10195
}
10182
10196
}
10183
10197
gutils->invertedPointers.insert(
10184
10198
std::make_pair(orig, InvertedPointerVH(gutils, anti)));
10185
10199
}
10186
10200
endAnti:;
10201
+
10202
+ bool isAlloca = anti ? isa<AllocaInst>(anti) : false;
10203
+ if (gutils->getWidth() != 1) {
10204
+ if (auto insertion = dyn_cast_or_null<InsertElementInst>(anti)) {
10205
+ isAlloca = isa<AllocaInst>(insertion->getOperand(1));
10206
+ }
10207
+ }
10208
+
10187
10209
if (((Mode == DerivativeMode::ReverseModeCombined && shouldFree()) ||
10188
10210
(Mode == DerivativeMode::ReverseModeGradient && shouldFree()) ||
10189
10211
(Mode == DerivativeMode::ForwardModeSplit && shouldFree())) &&
10190
- !isa<AllocaInst>(anti) ) {
10212
+ !isAlloca ) {
10191
10213
IRBuilder<> Builder2(call.getParent());
10192
10214
getReverseBuilder(Builder2);
10193
10215
assert(anti);
@@ -10198,16 +10220,19 @@ class AdjointGenerator
10198
10220
assert(
10199
10221
PointerType::getUnqual(Type::getInt8Ty(tofree->getContext())));
10200
10222
assert(Type::getInt8PtrTy(tofree->getContext()));
10201
- auto CI = freeKnownAllocation(Builder2, tofree, *called, dbgLoc,
10202
- gutils->TLI);
10203
- if (CI)
10223
+ auto rule = [&](Value *tofree) {
10224
+ auto CI = freeKnownAllocation(Builder2, tofree, *called, dbgLoc,
10225
+ gutils->TLI);
10226
+ if (CI)
10204
10227
#if LLVM_VERSION_MAJOR >= 14
10205
- CI->addAttributeAtIndex(AttributeList::FirstArgIndex,
10206
- Attribute::NonNull);
10228
+ CI->addAttributeAtIndex(AttributeList::FirstArgIndex,
10229
+ Attribute::NonNull);
10207
10230
#else
10208
- CI->addAttribute(AttributeList::FirstArgIndex,
10209
- Attribute::NonNull);
10231
+ CI->addAttribute(AttributeList::FirstArgIndex,
10232
+ Attribute::NonNull);
10210
10233
#endif
10234
+ };
10235
+ applyChainRule(Builder2, rule, tofree);
10211
10236
}
10212
10237
} else if (Mode == DerivativeMode::ForwardMode) {
10213
10238
IRBuilder<> Builder2(&call);
@@ -10284,18 +10309,26 @@ class AdjointGenerator
10284
10309
if (auto CI = dyn_cast<ConstantInt>(orig->getArgOperand(0))) {
10285
10310
B.SetInsertPoint(gutils->inversionAllocs);
10286
10311
}
10287
- auto replacement = B.CreateAlloca(
10288
- Type::getInt8Ty(orig->getContext()),
10289
- gutils->getNewFromOriginal(orig->getArgOperand(0)));
10290
- auto Alignment =
10291
- cast<ConstantInt>(
10292
- cast<ConstantAsMetadata>(MD->getOperand(0))->getValue())
10293
- ->getLimitedValue();
10312
+
10313
+ auto rule = [&]() {
10314
+ auto replacement = B.CreateAlloca(
10315
+ Type::getInt8Ty(orig->getContext()),
10316
+ gutils->getNewFromOriginal(orig->getArgOperand(0)));
10317
+ auto Alignment =
10318
+ cast<ConstantInt>(
10319
+ cast<ConstantAsMetadata>(MD->getOperand(0))->getValue())
10320
+ ->getLimitedValue();
10294
10321
#if LLVM_VERSION_MAJOR >= 10
10295
- replacement->setAlignment(Align(Alignment));
10322
+ replacement->setAlignment(Align(Alignment));
10296
10323
#else
10297
- replacement->setAlignment(Alignment);
10324
+ replacement->setAlignment(Alignment);
10298
10325
#endif
10326
+ return replacement;
10327
+ };
10328
+
10329
+ Value *replacement =
10330
+ applyChainRule(Type::getInt8Ty(orig->getContext()), B, rule);
10331
+
10299
10332
gutils->replaceAWithB(newCall, replacement);
10300
10333
gutils->erase(newCall);
10301
10334
return;
@@ -10337,21 +10370,29 @@ class AdjointGenerator
10337
10370
if (auto CI = dyn_cast<ConstantInt>(orig->getArgOperand(0))) {
10338
10371
B.SetInsertPoint(gutils->inversionAllocs);
10339
10372
}
10340
- auto replacement = B.CreateAlloca(
10341
- Type::getInt8Ty(orig->getContext()),
10342
- gutils->getNewFromOriginal(orig->getArgOperand(0)));
10343
- auto Alignment =
10344
- cast<ConstantInt>(
10345
- cast<ConstantAsMetadata>(MD->getOperand(0))->getValue())
10346
- ->getLimitedValue();
10347
- // Don't set zero alignment
10348
- if (Alignment) {
10373
+
10374
+ auto rule = [&]() {
10375
+ auto replacement = B.CreateAlloca(
10376
+ Type::getInt8Ty(orig->getContext()),
10377
+ gutils->getNewFromOriginal(orig->getArgOperand(0)));
10378
+ auto Alignment =
10379
+ cast<ConstantInt>(
10380
+ cast<ConstantAsMetadata>(MD->getOperand(0))->getValue())
10381
+ ->getLimitedValue();
10382
+ // Don't set zero alignment
10383
+ if (Alignment) {
10349
10384
#if LLVM_VERSION_MAJOR >= 10
10350
- replacement->setAlignment(Align(Alignment));
10385
+ replacement->setAlignment(Align(Alignment));
10351
10386
#else
10352
- replacement->setAlignment(Alignment);
10387
+ replacement->setAlignment(Alignment);
10353
10388
#endif
10354
- }
10389
+ }
10390
+ return replacement;
10391
+ };
10392
+
10393
+ Value *replacement =
10394
+ applyChainRule(Type::getInt8Ty(orig->getContext()), B, rule);
10395
+
10355
10396
gutils->replaceAWithB(newCall, replacement);
10356
10397
gutils->erase(newCall);
10357
10398
}
0 commit comments