Skip to content

Commit eac9b7c

Browse files
authored
Loop-level rematerialization (rust-lang#516)
* Add primitive loop rematerialization * Rematerialize loops * Add and fix test
1 parent 8c606a5 commit eac9b7c

File tree

7 files changed

+600
-149
lines changed

7 files changed

+600
-149
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,7 @@ class AdjointGenerator
767767
for (auto pair : gutils->rematerializableAllocations) {
768768
if (is_value_needed_in_reverse<ValueType::Primal>(
769769
TR, gutils, pair.first, Mode, Seen, oldUnreachable)) {
770-
if (pair.second.second.count(&SI)) {
770+
if (pair.second.stores.count(&SI)) {
771771
return;
772772
}
773773
}
@@ -2539,7 +2539,7 @@ class AdjointGenerator
25392539
for (auto pair : gutils->rematerializableAllocations) {
25402540
if (is_value_needed_in_reverse<ValueType::Primal>(
25412541
TR, gutils, pair.first, Mode, Seen, oldUnreachable)) {
2542-
if (pair.second.second.count(&MS)) {
2542+
if (pair.second.stores.count(&MS)) {
25432543
rematerialized = true;
25442544
break;
25452545
}
@@ -5398,8 +5398,12 @@ class AdjointGenerator
53985398
Builder2, /*lookup*/ true));
53995399
cal->setCallingConv(dwait->getCallingConv());
54005400
cal->setDebugLoc(gutils->getNewFromOriginal(call.getDebugLoc()));
5401+
#if LLVM_VERSION_MAJOR >= 14
5402+
cal->addFnAttr(Attribute::AlwaysInline);
5403+
#else
54015404
cal->addAttribute(AttributeList::FunctionIndex,
54025405
Attribute::AlwaysInline);
5406+
#endif
54035407
Builder2.CreateBr(endBlock);
54045408

54055409
Builder2.SetInsertPoint(endBlock);
@@ -5536,8 +5540,12 @@ class AdjointGenerator
55365540
Builder2, /*lookup*/ true));
55375541
cal->setCallingConv(dwait->getCallingConv());
55385542
cal->setDebugLoc(gutils->getNewFromOriginal(call.getDebugLoc()));
5543+
#if LLVM_VERSION_MAJOR >= 14
5544+
cal->addFnAttr(Attribute::AlwaysInline);
5545+
#else
55395546
cal->addAttribute(AttributeList::FunctionIndex,
55405547
Attribute::AlwaysInline);
5548+
#endif
55415549
Builder2.CreateBr(eloopBlock);
55425550

55435551
Builder2.SetInsertPoint(eloopBlock);
@@ -8060,7 +8068,7 @@ class AdjointGenerator
80608068
Seen[UsageKey(pair.first, ValueType::Primal)] = false;
80618069
bool rematerializedPrimal = false;
80628070
for (auto pair : gutils->rematerializableAllocations) {
8063-
if (pair.second.second.count(orig) &&
8071+
if (pair.second.stores.count(orig) &&
80648072
is_value_needed_in_reverse<ValueType::Primal>(
80658073
TR, gutils, pair.first, Mode, Seen, oldUnreachable)) {
80668074
rematerializedPrimal = true;
@@ -8818,7 +8826,7 @@ class AdjointGenerator
88188826

88198827
// If a rematerializable allocation.
88208828
for (auto rmat : gutils->rematerializableAllocations) {
8821-
if (rmat.second.second.count(orig)) {
8829+
if (rmat.second.stores.count(orig)) {
88228830
// Leave the original free behavior since this won't be used
88238831
// in the reverse pass in split mode
88248832
if (Mode == DerivativeMode::ReverseModePrimal) {

enzyme/Enzyme/DifferentialUseAnalysis.h

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -364,8 +364,8 @@ static inline bool is_value_needed_in_reverse(
364364
// we'll set it to unused, then check the gep, then here we'll
365365
// directly say unused by induction instead of checking the final
366366
// loads.
367-
if (pair.second.second.count(SI))
368-
for (LoadInst *L : pair.second.first)
367+
if (pair.second.stores.count(SI))
368+
for (LoadInst *L : pair.second.loads)
369369
if (is_value_needed_in_reverse<VT>(TR, gutils, L, mode, seen,
370370
oldUnreachable)) {
371371
return seen[idx] = true;
@@ -535,21 +535,20 @@ static inline int cmpLoopNest(Loop *prev, Loop *next) {
535535
return -1;
536536
}
537537

538-
static inline void
539-
minCut(const DataLayout &DL, LoopInfo &OrigLI,
540-
const SmallPtrSetImpl<Value *> &Recomputes,
541-
const SmallPtrSetImpl<Value *> &Intermediates,
542-
SmallPtrSetImpl<Value *> &Required, SmallPtrSetImpl<Value *> &MinReq,
543-
const ValueMap<Value *, std::pair<SmallPtrSet<LoadInst *, 1>,
544-
SmallPtrSet<Instruction *, 1>>>
545-
&rematerializableAllocations) {
538+
static inline void minCut(const DataLayout &DL, LoopInfo &OrigLI,
539+
const SmallPtrSetImpl<Value *> &Recomputes,
540+
const SmallPtrSetImpl<Value *> &Intermediates,
541+
SmallPtrSetImpl<Value *> &Required,
542+
SmallPtrSetImpl<Value *> &MinReq,
543+
const ValueMap<Value *, GradientUtils::Rematerializer>
544+
&rematerializableAllocations) {
546545
Graph G;
547546
for (auto V : Intermediates) {
548547
G[Node(V, false)].insert(Node(V, true));
549548
for (auto U : V->users()) {
550549
if (auto I = dyn_cast<Instruction>(U)) {
551550
for (auto pair : rematerializableAllocations) {
552-
if (Intermediates.count(pair.first) && pair.second.second.count(I))
551+
if (Intermediates.count(pair.first) && pair.second.stores.count(I))
553552
G[Node(V, true)].insert(Node(pair.first, false));
554553
}
555554
}
@@ -560,7 +559,7 @@ minCut(const DataLayout &DL, LoopInfo &OrigLI,
560559
}
561560
for (auto pair : rematerializableAllocations) {
562561
if (Intermediates.count(pair.first)) {
563-
for (LoadInst *L : pair.second.first) {
562+
for (LoadInst *L : pair.second.loads) {
564563
if (Intermediates.count(L)) {
565564
G[Node(pair.first, true)].insert(Node(L, false));
566565
}

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,7 @@ struct CacheAnalysis {
105105

106106
const ValueMap<const CallInst *, SmallPtrSet<const CallInst *, 1>>
107107
&allocationsWithGuaranteedFree;
108-
const ValueMap<Value *, std::pair<SmallPtrSet<LoadInst *, 1>,
109-
SmallPtrSet<Instruction *, 1>>>
108+
const ValueMap<Value *, GradientUtils::Rematerializer>
110109
&rematerializableAllocations;
111110
TypeResults &TR;
112111
AAResults &AA;
@@ -123,8 +122,7 @@ struct CacheAnalysis {
123122
CacheAnalysis(
124123
const ValueMap<const CallInst *, SmallPtrSet<const CallInst *, 1>>
125124
&allocationsWithGuaranteedFree,
126-
const ValueMap<Value *, std::pair<SmallPtrSet<LoadInst *, 1>,
127-
SmallPtrSet<Instruction *, 1>>>
125+
const ValueMap<Value *, GradientUtils::Rematerializer>
128126
&rematerializableAllocations,
129127
TypeResults &TR, AAResults &AA, Function *oldFunc, ScalarEvolution &SE,
130128
LoopInfo &OrigLI, DominatorTree &OrigDT, TargetLibraryInfo &TLI,
@@ -269,6 +267,13 @@ struct CacheAnalysis {
269267
}
270268
}
271269

270+
// Any load from a rematerializable allocation is definitionally
271+
// reloadable. Notably we don't need to perform the allFollowers
272+
// of check as the loop scope caching should allow us to ignore
273+
// such stores.
274+
if (rematerializableAllocations.count(obj))
275+
return false;
276+
272277
// If not running combined, check if pointer operand is overwritten
273278
// by a subsequent call (i.e. not this function).
274279
bool can_modref = false;

0 commit comments

Comments
 (0)