@@ -384,9 +384,15 @@ class AdjointGenerator
384
384
Value *mask = nullptr, Value *orig_maskInit = nullptr) {
385
385
auto &DL = gutils->newFunc->getParent()->getDataLayout();
386
386
387
- assert (gutils->can_modref_map );
388
- assert (gutils->can_modref_map ->find (&I) != gutils->can_modref_map ->end ());
389
- bool can_modref = gutils->can_modref_map ->find (&I)->second ;
387
+ assert(Mode == DerivativeMode::ForwardMode ||
388
+ Mode == DerivativeMode::ForwardModeVector || gutils->can_modref_map);
389
+ assert(Mode == DerivativeMode::ForwardMode ||
390
+ Mode == DerivativeMode::ForwardModeVector ||
391
+ gutils->can_modref_map->find(&I) != gutils->can_modref_map->end());
392
+ bool can_modref = Mode == DerivativeMode::ForwardMode ||
393
+ Mode == DerivativeMode::ForwardModeVector
394
+ ? false
395
+ : gutils->can_modref_map->find(&I)->second;
390
396
391
397
constantval |= gutils->isConstantValue(&I);
392
398
@@ -5726,14 +5732,18 @@ class AdjointGenerator
5726
5732
IRBuilder<> BuilderZ(newCall);
5727
5733
BuilderZ.setFastMathFlags(getFast());
5728
5734
5729
- if (uncacheable_args_map.find (&call) == uncacheable_args_map.end ()) {
5735
+ if (uncacheable_args_map.find(&call) == uncacheable_args_map.end() &&
5736
+ Mode != DerivativeMode::ForwardMode &&
5737
+ Mode != DerivativeMode::ForwardModeVector) {
5730
5738
llvm::errs() << " call: " << call << "\n";
5731
5739
for (auto &pair : uncacheable_args_map) {
5732
5740
llvm::errs() << " + " << *pair.first << "\n";
5733
5741
}
5734
5742
}
5735
5743
5736
- assert (uncacheable_args_map.find (&call) != uncacheable_args_map.end ());
5744
+ assert(uncacheable_args_map.find(&call) != uncacheable_args_map.end() ||
5745
+ Mode == DerivativeMode::ForwardMode ||
5746
+ Mode == DerivativeMode::ForwardModeVector);
5737
5747
const std::map<Argument *, bool> &uncacheable_args =
5738
5748
uncacheable_args_map.find(&call)->second;
5739
5749
@@ -7613,7 +7623,9 @@ class AdjointGenerator
7613
7623
// If we need this value and it is illegal to recompute it (it writes or
7614
7624
// may load uncacheable data)
7615
7625
// Store and reload it
7616
- if (Mode != DerivativeMode::ReverseModeCombined && subretused &&
7626
+ if (Mode != DerivativeMode::ReverseModeCombined &&
7627
+ Mode != DerivativeMode::ForwardMode &&
7628
+ Mode != DerivativeMode::ForwardModeVector && subretused &&
7617
7629
(orig->mayWriteToMemory() ||
7618
7630
!gutils->legalRecompute(orig, ValueToValueMapTy(), nullptr))) {
7619
7631
if (!gutils->unnecessaryIntermediates.count(orig)) {
@@ -7719,8 +7731,7 @@ class AdjointGenerator
7719
7731
cast<Function>(called), subretType, argsInverted, gutils->TLI,
7720
7732
TR.analyzer.interprocedural, /*returnValue*/ retUsed,
7721
7733
/*subdretptr*/ false, DerivativeMode::ForwardMode, nullptr,
7722
- nextTypeInfo, uncacheable_args,
7723
- /* AtomicAdd*/ gutils->AtomicAdd );
7734
+ nextTypeInfo, {});
7724
7735
7725
7736
assert(newcalled);
7726
7737
FunctionType *FT = cast<FunctionType>(
0 commit comments