Skip to content

Commit 6e45ead

Browse files
authored
preparations for new modes (rust-lang#350)
1 parent e97c0d1 commit 6e45ead

File tree

7 files changed

+126
-128
lines changed

7 files changed

+126
-128
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -384,9 +384,15 @@ class AdjointGenerator
384384
Value *mask = nullptr, Value *orig_maskInit = nullptr) {
385385
auto &DL = gutils->newFunc->getParent()->getDataLayout();
386386

387-
assert(gutils->can_modref_map);
388-
assert(gutils->can_modref_map->find(&I) != gutils->can_modref_map->end());
389-
bool can_modref = gutils->can_modref_map->find(&I)->second;
387+
assert(Mode == DerivativeMode::ForwardMode ||
388+
Mode == DerivativeMode::ForwardModeVector || gutils->can_modref_map);
389+
assert(Mode == DerivativeMode::ForwardMode ||
390+
Mode == DerivativeMode::ForwardModeVector ||
391+
gutils->can_modref_map->find(&I) != gutils->can_modref_map->end());
392+
bool can_modref = Mode == DerivativeMode::ForwardMode ||
393+
Mode == DerivativeMode::ForwardModeVector
394+
? false
395+
: gutils->can_modref_map->find(&I)->second;
390396

391397
constantval |= gutils->isConstantValue(&I);
392398

@@ -5726,14 +5732,18 @@ class AdjointGenerator
57265732
IRBuilder<> BuilderZ(newCall);
57275733
BuilderZ.setFastMathFlags(getFast());
57285734

5729-
if (uncacheable_args_map.find(&call) == uncacheable_args_map.end()) {
5735+
if (uncacheable_args_map.find(&call) == uncacheable_args_map.end() &&
5736+
Mode != DerivativeMode::ForwardMode &&
5737+
Mode != DerivativeMode::ForwardModeVector) {
57305738
llvm::errs() << " call: " << call << "\n";
57315739
for (auto &pair : uncacheable_args_map) {
57325740
llvm::errs() << " + " << *pair.first << "\n";
57335741
}
57345742
}
57355743

5736-
assert(uncacheable_args_map.find(&call) != uncacheable_args_map.end());
5744+
assert(uncacheable_args_map.find(&call) != uncacheable_args_map.end() ||
5745+
Mode == DerivativeMode::ForwardMode ||
5746+
Mode == DerivativeMode::ForwardModeVector);
57375747
const std::map<Argument *, bool> &uncacheable_args =
57385748
uncacheable_args_map.find(&call)->second;
57395749

@@ -7613,7 +7623,9 @@ class AdjointGenerator
76137623
// If we need this value and it is illegal to recompute it (it writes or
76147624
// may load uncacheable data)
76157625
// Store and reload it
7616-
if (Mode != DerivativeMode::ReverseModeCombined && subretused &&
7626+
if (Mode != DerivativeMode::ReverseModeCombined &&
7627+
Mode != DerivativeMode::ForwardMode &&
7628+
Mode != DerivativeMode::ForwardModeVector && subretused &&
76177629
(orig->mayWriteToMemory() ||
76187630
!gutils->legalRecompute(orig, ValueToValueMapTy(), nullptr))) {
76197631
if (!gutils->unnecessaryIntermediates.count(orig)) {
@@ -7719,8 +7731,7 @@ class AdjointGenerator
77197731
cast<Function>(called), subretType, argsInverted, gutils->TLI,
77207732
TR.analyzer.interprocedural, /*returnValue*/ retUsed,
77217733
/*subdretptr*/ false, DerivativeMode::ForwardMode, nullptr,
7722-
nextTypeInfo, uncacheable_args,
7723-
/*AtomicAdd*/ gutils->AtomicAdd);
7734+
nextTypeInfo, {});
77247735

77257736
assert(newcalled);
77267737
FunctionType *FT = cast<FunctionType>(

enzyme/Enzyme/CApi.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -336,8 +336,7 @@ LLVMValueRef EnzymeCreateForwardDiff(
336336
CDIFFE_TYPE *constant_args, size_t constant_args_size,
337337
EnzymeTypeAnalysisRef TA, uint8_t returnValue, uint8_t dretUsed,
338338
CDerivativeMode mode, LLVMTypeRef additionalArg, CFnTypeInfo typeInfo,
339-
uint8_t *_uncacheable_args, size_t uncacheable_args_size, uint8_t AtomicAdd,
340-
uint8_t PostOpt) {
339+
uint8_t *_uncacheable_args, size_t uncacheable_args_size, uint8_t PostOpt) {
341340
std::vector<DIFFE_TYPE> nconstant_args((DIFFE_TYPE *)constant_args,
342341
(DIFFE_TYPE *)constant_args +
343342
constant_args_size);
@@ -352,7 +351,7 @@ LLVMValueRef EnzymeCreateForwardDiff(
352351
cast<Function>(unwrap(todiff)), (DIFFE_TYPE)retType, nconstant_args,
353352
eunwrap(TA).TLI, eunwrap(TA), returnValue, dretUsed, (DerivativeMode)mode,
354353
unwrap(additionalArg), eunwrap(typeInfo, cast<Function>(unwrap(todiff))),
355-
uncacheable_args, AtomicAdd, PostOpt));
354+
uncacheable_args, PostOpt));
356355
}
357356
LLVMValueRef EnzymeCreatePrimalAndGradient(
358357
EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType,

enzyme/Enzyme/CApi.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -116,13 +116,14 @@ typedef enum {
116116
DEM_ReverseModeCombined = 3,
117117
} CDerivativeMode;
118118

119-
LLVMValueRef EnzymeCreateForwardDiff(
120-
EnzymeLogicRef, LLVMValueRef todiff, CDIFFE_TYPE retType,
121-
CDIFFE_TYPE *constant_args, size_t constant_args_size,
122-
EnzymeTypeAnalysisRef TA, uint8_t returnValue, uint8_t dretUsed,
123-
CDerivativeMode mode, LLVMTypeRef additionalArg,
124-
struct CFnTypeInfo typeInfo, uint8_t *_uncacheable_args,
125-
size_t uncacheable_args_size, uint8_t AtomicAdd, uint8_t PostOpt);
119+
LLVMValueRef
120+
EnzymeCreateForwardDiff(EnzymeLogicRef, LLVMValueRef todiff,
121+
CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args,
122+
size_t constant_args_size, EnzymeTypeAnalysisRef TA,
123+
uint8_t returnValue, uint8_t dretUsed,
124+
CDerivativeMode mode, LLVMTypeRef additionalArg,
125+
struct CFnTypeInfo typeInfo, uint8_t *_uncacheable_args,
126+
size_t uncacheable_args_size, uint8_t PostOpt);
126127

127128
LLVMValueRef EnzymeCreatePrimalAndGradient(
128129
EnzymeLogicRef, LLVMValueRef todiff, CDIFFE_TYPE retType,

enzyme/Enzyme/Enzyme.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -899,7 +899,7 @@ class Enzyme : public ModulePass {
899899
newFunc = Logic.CreateForwardDiff(
900900
cast<Function>(fn), retType, constants, TLI, TA,
901901
/*should return*/ false, /*dretPtr*/ false, mode,
902-
/*addedType*/ nullptr, type_args, volatile_args, AtomicAdd, PostOpt);
902+
/*addedType*/ nullptr, type_args, volatile_args, PostOpt);
903903
break;
904904
case DerivativeMode::ReverseModeCombined:
905905
newFunc = Logic.CreatePrimalAndGradient(

0 commit comments

Comments
 (0)