Skip to content

Commit 1e083a3

Browse files
authored
Fix reverse vector mode malloc (rust-lang#641)
* Fix reverse vector mode malloc
1 parent 661712f commit 1e083a3

File tree

6 files changed

+588
-121
lines changed

6 files changed

+588
-121
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 128 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -10075,71 +10075,77 @@ class AdjointGenerator
1007510075
anti = gutils->cacheForReverse(
1007610076
bb, anti, getIndex(orig, CacheType::Shadow));
1007710077
} else {
10078+
auto rule = [&]() {
1007810079
#if LLVM_VERSION_MAJOR >= 11
10079-
anti = bb.CreateCall(orig->getFunctionType(),
10080-
orig->getCalledOperand(), args,
10081-
orig->getName() + "'mi");
10080+
Value *anti = bb.CreateCall(orig->getFunctionType(),
10081+
orig->getCalledOperand(), args,
10082+
orig->getName() + "'mi");
1008210083
#else
10083-
anti = bb.CreateCall(orig->getCalledValue(), args,
10084-
orig->getName() + "'mi");
10084+
anti = bb.CreateCall(orig->getCalledValue(), args,
10085+
orig->getName() + "'mi");
1008510086
#endif
10086-
cast<CallInst>(anti)->setAttributes(orig->getAttributes());
10087-
cast<CallInst>(anti)->setCallingConv(orig->getCallingConv());
10088-
cast<CallInst>(anti)->setTailCallKind(orig->getTailCallKind());
10089-
cast<CallInst>(anti)->setDebugLoc(dbgLoc);
10087+
cast<CallInst>(anti)->setAttributes(orig->getAttributes());
10088+
cast<CallInst>(anti)->setCallingConv(orig->getCallingConv());
10089+
cast<CallInst>(anti)->setTailCallKind(orig->getTailCallKind());
10090+
cast<CallInst>(anti)->setDebugLoc(dbgLoc);
1009010091

1009110092
#if LLVM_VERSION_MAJOR >= 14
10092-
cast<CallInst>(anti)->addAttributeAtIndex(
10093-
AttributeList::ReturnIndex, Attribute::NoAlias);
10094-
cast<CallInst>(anti)->addAttributeAtIndex(
10095-
AttributeList::ReturnIndex, Attribute::NonNull);
10096-
#else
10097-
cast<CallInst>(anti)->addAttribute(AttributeList::ReturnIndex,
10098-
Attribute::NoAlias);
10099-
cast<CallInst>(anti)->addAttribute(AttributeList::ReturnIndex,
10100-
Attribute::NonNull);
10101-
#endif
10102-
10103-
if (called->getName() == "malloc" ||
10104-
called->getName() == "_Znwm") {
10105-
if (auto ci = dyn_cast<ConstantInt>(args[0])) {
10106-
unsigned derefBytes = ci->getLimitedValue();
10107-
CallInst *cal =
10108-
cast<CallInst>(gutils->getNewFromOriginal(orig));
10093+
cast<CallInst>(anti)->addAttributeAtIndex(
10094+
AttributeList::ReturnIndex, Attribute::NoAlias);
10095+
cast<CallInst>(anti)->addAttributeAtIndex(
10096+
AttributeList::ReturnIndex, Attribute::NonNull);
10097+
#else
10098+
cast<CallInst>(anti)->addAttribute(AttributeList::ReturnIndex,
10099+
Attribute::NoAlias);
10100+
cast<CallInst>(anti)->addAttribute(AttributeList::ReturnIndex,
10101+
Attribute::NonNull);
10102+
#endif
10103+
10104+
if (called->getName() == "malloc" ||
10105+
called->getName() == "_Znwm") {
10106+
if (auto ci = dyn_cast<ConstantInt>(args[0])) {
10107+
unsigned derefBytes = ci->getLimitedValue();
10108+
CallInst *cal =
10109+
cast<CallInst>(gutils->getNewFromOriginal(orig));
1010910110
#if LLVM_VERSION_MAJOR >= 14
10110-
cast<CallInst>(anti)->addDereferenceableRetAttr(derefBytes);
10111-
cal->addDereferenceableRetAttr(derefBytes);
10111+
cast<CallInst>(anti)->addDereferenceableRetAttr(derefBytes);
10112+
cal->addDereferenceableRetAttr(derefBytes);
1011210113
#ifndef FLANG
10113-
AttrBuilder B(called->getContext());
10114-
#else
10115-
AttrBuilder B;
10116-
#endif
10117-
B.addDereferenceableOrNullAttr(derefBytes);
10118-
cast<CallInst>(anti)->setAttributes(
10119-
cast<CallInst>(anti)->getAttributes().addRetAttributes(
10120-
orig->getContext(), B));
10121-
cal->setAttributes(cal->getAttributes().addRetAttributes(
10122-
orig->getContext(), B));
10123-
cal->addAttributeAtIndex(AttributeList::ReturnIndex,
10124-
Attribute::NoAlias);
10125-
cal->addAttributeAtIndex(AttributeList::ReturnIndex,
10126-
Attribute::NonNull);
10127-
#else
10128-
cast<CallInst>(anti)->addDereferenceableAttr(
10129-
llvm::AttributeList::ReturnIndex, derefBytes);
10130-
cal->addDereferenceableAttr(llvm::AttributeList::ReturnIndex,
10131-
derefBytes);
10132-
cast<CallInst>(anti)->addDereferenceableOrNullAttr(
10133-
llvm::AttributeList::ReturnIndex, derefBytes);
10134-
cal->addDereferenceableOrNullAttr(
10135-
llvm::AttributeList::ReturnIndex, derefBytes);
10136-
cal->addAttribute(AttributeList::ReturnIndex,
10137-
Attribute::NoAlias);
10138-
cal->addAttribute(AttributeList::ReturnIndex,
10139-
Attribute::NonNull);
10114+
AttrBuilder B(called->getContext());
10115+
#else
10116+
AttrBuilder B;
10117+
#endif
10118+
B.addDereferenceableOrNullAttr(derefBytes);
10119+
cast<CallInst>(anti)->setAttributes(
10120+
cast<CallInst>(anti)->getAttributes().addRetAttributes(
10121+
orig->getContext(), B));
10122+
cal->setAttributes(cal->getAttributes().addRetAttributes(
10123+
orig->getContext(), B));
10124+
cal->addAttributeAtIndex(AttributeList::ReturnIndex,
10125+
Attribute::NoAlias);
10126+
cal->addAttributeAtIndex(AttributeList::ReturnIndex,
10127+
Attribute::NonNull);
10128+
#else
10129+
cast<CallInst>(anti)->addDereferenceableAttr(
10130+
llvm::AttributeList::ReturnIndex, derefBytes);
10131+
cal->addDereferenceableAttr(
10132+
llvm::AttributeList::ReturnIndex, derefBytes);
10133+
cast<CallInst>(anti)->addDereferenceableOrNullAttr(
10134+
llvm::AttributeList::ReturnIndex, derefBytes);
10135+
cal->addDereferenceableOrNullAttr(
10136+
llvm::AttributeList::ReturnIndex, derefBytes);
10137+
cal->addAttribute(AttributeList::ReturnIndex,
10138+
Attribute::NoAlias);
10139+
cal->addAttribute(AttributeList::ReturnIndex,
10140+
Attribute::NonNull);
1014010141
#endif
10142+
}
1014110143
}
10142-
}
10144+
return anti;
10145+
};
10146+
10147+
anti = applyChainRule(orig->getType(), bb, rule);
10148+
1014310149
gutils->invertedPointers.erase(found);
1014410150
if (&*bb.GetInsertPoint() == placeholder)
1014510151
bb.SetInsertPoint(placeholder->getNextNode());
@@ -10163,6 +10169,7 @@ class AdjointGenerator
1016310169
#else
1016410170
replacement->setAlignment(Alignment);
1016510171
#endif
10172+
1016610173
gutils->replaceAWithB(cast<Instruction>(anti), replacement);
1016710174
gutils->erase(cast<Instruction>(anti));
1016810175
anti = replacement;
@@ -10176,18 +10183,33 @@ class AdjointGenerator
1017610183
backwardsShadow) ||
1017710184
(Mode == DerivativeMode::ForwardModeSplit &&
1017810185
backwardsShadow)) {
10179-
if (!inLoop)
10180-
zeroKnownAllocation(bb, anti, args, *called, gutils->TLI);
10186+
if (!inLoop) {
10187+
applyChainRule(
10188+
bb,
10189+
[&](Value *anti) {
10190+
zeroKnownAllocation(bb, anti, args, *called,
10191+
gutils->TLI);
10192+
},
10193+
anti);
10194+
}
1018110195
}
1018210196
}
1018310197
gutils->invertedPointers.insert(
1018410198
std::make_pair(orig, InvertedPointerVH(gutils, anti)));
1018510199
}
1018610200
endAnti:;
10201+
10202+
bool isAlloca = anti ? isa<AllocaInst>(anti) : false;
10203+
if (gutils->getWidth() != 1) {
10204+
if (auto insertion = dyn_cast_or_null<InsertElementInst>(anti)) {
10205+
isAlloca = isa<AllocaInst>(insertion->getOperand(1));
10206+
}
10207+
}
10208+
1018710209
if (((Mode == DerivativeMode::ReverseModeCombined && shouldFree()) ||
1018810210
(Mode == DerivativeMode::ReverseModeGradient && shouldFree()) ||
1018910211
(Mode == DerivativeMode::ForwardModeSplit && shouldFree())) &&
10190-
!isa<AllocaInst>(anti)) {
10212+
!isAlloca) {
1019110213
IRBuilder<> Builder2(call.getParent());
1019210214
getReverseBuilder(Builder2);
1019310215
assert(anti);
@@ -10198,16 +10220,19 @@ class AdjointGenerator
1019810220
assert(
1019910221
PointerType::getUnqual(Type::getInt8Ty(tofree->getContext())));
1020010222
assert(Type::getInt8PtrTy(tofree->getContext()));
10201-
auto CI = freeKnownAllocation(Builder2, tofree, *called, dbgLoc,
10202-
gutils->TLI);
10203-
if (CI)
10223+
auto rule = [&](Value *tofree) {
10224+
auto CI = freeKnownAllocation(Builder2, tofree, *called, dbgLoc,
10225+
gutils->TLI);
10226+
if (CI)
1020410227
#if LLVM_VERSION_MAJOR >= 14
10205-
CI->addAttributeAtIndex(AttributeList::FirstArgIndex,
10206-
Attribute::NonNull);
10228+
CI->addAttributeAtIndex(AttributeList::FirstArgIndex,
10229+
Attribute::NonNull);
1020710230
#else
10208-
CI->addAttribute(AttributeList::FirstArgIndex,
10209-
Attribute::NonNull);
10231+
CI->addAttribute(AttributeList::FirstArgIndex,
10232+
Attribute::NonNull);
1021010233
#endif
10234+
};
10235+
applyChainRule(Builder2, rule, tofree);
1021110236
}
1021210237
} else if (Mode == DerivativeMode::ForwardMode) {
1021310238
IRBuilder<> Builder2(&call);
@@ -10284,18 +10309,26 @@ class AdjointGenerator
1028410309
if (auto CI = dyn_cast<ConstantInt>(orig->getArgOperand(0))) {
1028510310
B.SetInsertPoint(gutils->inversionAllocs);
1028610311
}
10287-
auto replacement = B.CreateAlloca(
10288-
Type::getInt8Ty(orig->getContext()),
10289-
gutils->getNewFromOriginal(orig->getArgOperand(0)));
10290-
auto Alignment =
10291-
cast<ConstantInt>(
10292-
cast<ConstantAsMetadata>(MD->getOperand(0))->getValue())
10293-
->getLimitedValue();
10312+
10313+
auto rule = [&]() {
10314+
auto replacement = B.CreateAlloca(
10315+
Type::getInt8Ty(orig->getContext()),
10316+
gutils->getNewFromOriginal(orig->getArgOperand(0)));
10317+
auto Alignment =
10318+
cast<ConstantInt>(
10319+
cast<ConstantAsMetadata>(MD->getOperand(0))->getValue())
10320+
->getLimitedValue();
1029410321
#if LLVM_VERSION_MAJOR >= 10
10295-
replacement->setAlignment(Align(Alignment));
10322+
replacement->setAlignment(Align(Alignment));
1029610323
#else
10297-
replacement->setAlignment(Alignment);
10324+
replacement->setAlignment(Alignment);
1029810325
#endif
10326+
return replacement;
10327+
};
10328+
10329+
Value *replacement =
10330+
applyChainRule(Type::getInt8Ty(orig->getContext()), B, rule);
10331+
1029910332
gutils->replaceAWithB(newCall, replacement);
1030010333
gutils->erase(newCall);
1030110334
return;
@@ -10337,21 +10370,29 @@ class AdjointGenerator
1033710370
if (auto CI = dyn_cast<ConstantInt>(orig->getArgOperand(0))) {
1033810371
B.SetInsertPoint(gutils->inversionAllocs);
1033910372
}
10340-
auto replacement = B.CreateAlloca(
10341-
Type::getInt8Ty(orig->getContext()),
10342-
gutils->getNewFromOriginal(orig->getArgOperand(0)));
10343-
auto Alignment =
10344-
cast<ConstantInt>(
10345-
cast<ConstantAsMetadata>(MD->getOperand(0))->getValue())
10346-
->getLimitedValue();
10347-
// Don't set zero alignment
10348-
if (Alignment) {
10373+
10374+
auto rule = [&]() {
10375+
auto replacement = B.CreateAlloca(
10376+
Type::getInt8Ty(orig->getContext()),
10377+
gutils->getNewFromOriginal(orig->getArgOperand(0)));
10378+
auto Alignment =
10379+
cast<ConstantInt>(
10380+
cast<ConstantAsMetadata>(MD->getOperand(0))->getValue())
10381+
->getLimitedValue();
10382+
// Don't set zero alignment
10383+
if (Alignment) {
1034910384
#if LLVM_VERSION_MAJOR >= 10
10350-
replacement->setAlignment(Align(Alignment));
10385+
replacement->setAlignment(Align(Alignment));
1035110386
#else
10352-
replacement->setAlignment(Alignment);
10387+
replacement->setAlignment(Alignment);
1035310388
#endif
10354-
}
10389+
}
10390+
return replacement;
10391+
};
10392+
10393+
Value *replacement =
10394+
applyChainRule(Type::getInt8Ty(orig->getContext()), B, rule);
10395+
1035510396
gutils->replaceAWithB(newCall, replacement);
1035610397
gutils->erase(newCall);
1035710398
}

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 52 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2687,55 +2687,74 @@ BasicBlock *GradientUtils::getReverseOrLatchMerge(BasicBlock *BB,
26872687
anti =
26882688
shadowHandlers[called->getName().str()](NB, orig, args);
26892689
} else {
2690+
auto rule = [&]() {
26902691
#if LLVM_VERSION_MAJOR >= 11
2691-
anti = NB.CreateCall(orig->getFunctionType(),
2692-
orig->getCalledOperand(), args,
2693-
orig->getName() + "'mi");
2692+
Value *anti = NB.CreateCall(
2693+
orig->getFunctionType(), orig->getCalledOperand(),
2694+
args, orig->getName() + "'mi");
26942695
#else
2695-
anti = NB.CreateCall(orig->getCalledValue(), args,
2696-
orig->getName() + "'mi");
2696+
Value *anti = NB.CreateCall(orig->getCalledValue(), args,
2697+
orig->getName() + "'mi");
26972698
#endif
2698-
cast<CallInst>(anti)->setAttributes(orig->getAttributes());
2699-
cast<CallInst>(anti)->setCallingConv(
2700-
orig->getCallingConv());
2701-
cast<CallInst>(anti)->setTailCallKind(
2702-
orig->getTailCallKind());
2703-
cast<CallInst>(anti)->setDebugLoc(
2704-
getNewFromOriginal(I.getDebugLoc()));
2699+
cast<CallInst>(anti)->setAttributes(
2700+
orig->getAttributes());
2701+
cast<CallInst>(anti)->setCallingConv(
2702+
orig->getCallingConv());
2703+
cast<CallInst>(anti)->setTailCallKind(
2704+
orig->getTailCallKind());
2705+
cast<CallInst>(anti)->setDebugLoc(
2706+
getNewFromOriginal(I.getDebugLoc()));
27052707

27062708
#if LLVM_VERSION_MAJOR >= 14
2707-
cast<CallInst>(anti)->addAttributeAtIndex(
2708-
AttributeList::ReturnIndex, Attribute::NoAlias);
2709-
cast<CallInst>(anti)->addAttributeAtIndex(
2710-
AttributeList::ReturnIndex, Attribute::NonNull);
2709+
cast<CallInst>(anti)->addAttributeAtIndex(
2710+
AttributeList::ReturnIndex, Attribute::NoAlias);
2711+
cast<CallInst>(anti)->addAttributeAtIndex(
2712+
AttributeList::ReturnIndex, Attribute::NonNull);
27112713
#else
2712-
cast<CallInst>(anti)->addAttribute(
2713-
AttributeList::ReturnIndex, Attribute::NoAlias);
2714-
cast<CallInst>(anti)->addAttribute(
2715-
AttributeList::ReturnIndex, Attribute::NonNull);
2714+
cast<CallInst>(anti)->addAttribute(
2715+
AttributeList::ReturnIndex, Attribute::NoAlias);
2716+
cast<CallInst>(anti)->addAttribute(
2717+
AttributeList::ReturnIndex, Attribute::NonNull);
27162718
#endif
2719+
return anti;
2720+
};
2721+
2722+
anti = applyChainRule(orig->getType(), NB, rule);
2723+
27172724
if (auto MD = hasMetadata(orig, "enzyme_fromstack")) {
2718-
AllocaInst *replacement = NB.CreateAlloca(
2719-
Type::getInt8Ty(orig->getContext()), args[0]);
2720-
replacement->takeName(anti);
2721-
auto Alignment =
2722-
cast<ConstantInt>(
2723-
cast<ConstantAsMetadata>(MD->getOperand(0))
2724-
->getValue())
2725-
->getLimitedValue();
2725+
auto rule = [&](Value *anti) {
2726+
AllocaInst *replacement = NB.CreateAlloca(
2727+
Type::getInt8Ty(orig->getContext()), args[0]);
2728+
replacement->takeName(anti);
2729+
auto Alignment =
2730+
cast<ConstantInt>(
2731+
cast<ConstantAsMetadata>(MD->getOperand(0))
2732+
->getValue())
2733+
->getLimitedValue();
27262734
#if LLVM_VERSION_MAJOR >= 10
2727-
replacement->setAlignment(Align(Alignment));
2735+
replacement->setAlignment(Align(Alignment));
27282736
#else
2729-
replacement->setAlignment(Alignment);
2737+
replacement->setAlignment(Alignment);
27302738
#endif
2731-
replacement->setDebugLoc(
2732-
getNewFromOriginal(I.getDebugLoc()));
2739+
replacement->setDebugLoc(
2740+
getNewFromOriginal(I.getDebugLoc()));
2741+
return replacement;
2742+
};
2743+
2744+
Value *replacement = applyChainRule(
2745+
Type::getInt8Ty(orig->getContext()), NB, rule, anti);
2746+
27332747
replaceAWithB(cast<Instruction>(anti), replacement);
27342748
erase(cast<Instruction>(anti));
27352749
anti = replacement;
27362750
}
27372751

2738-
zeroKnownAllocation(NB, anti, args, *called, TLI);
2752+
applyChainRule(
2753+
NB,
2754+
[&](Value *anti) {
2755+
zeroKnownAllocation(NB, anti, args, *called, TLI);
2756+
},
2757+
anti);
27392758
}
27402759
} else {
27412760
llvm_unreachable("Unknown shadow rematerialization value");

0 commit comments

Comments
 (0)