Skip to content

Commit aee3008

Browse files
authored
Fix wrong DIFFE_TYPE for forward mode in GetOrCreateShadowFunction (rust-lang#553)
1 parent f08a1bf commit aee3008

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3542,7 +3542,8 @@ Constant *GradientUtils::GetOrCreateShadowFunction(
35423542
std::pair<Argument *, std::set<int64_t>>(&a, {}));
35433543
DIFFE_TYPE typ;
35443544
if (a.getType()->isFPOrFPVectorTy()) {
3545-
typ = DIFFE_TYPE::OUT_DIFF;
3545+
typ = mode == DerivativeMode::ForwardMode ? DIFFE_TYPE::DUP_ARG
3546+
: DIFFE_TYPE::OUT_DIFF;
35463547
} else if (a.getType()->isIntegerTy() &&
35473548
cast<IntegerType>(a.getType())->getBitWidth() < 16) {
35483549
typ = DIFFE_TYPE::CONSTANT;
@@ -3554,7 +3555,8 @@ Constant *GradientUtils::GetOrCreateShadowFunction(
35543555
types.push_back(typ);
35553556
}
35563557

3557-
DIFFE_TYPE retType = fn->getReturnType()->isFPOrFPVectorTy()
3558+
DIFFE_TYPE retType = fn->getReturnType()->isFPOrFPVectorTy() &&
3559+
mode != DerivativeMode::ForwardMode
35583560
? DIFFE_TYPE::OUT_DIFF
35593561
: DIFFE_TYPE::DUP_ARG;
35603562
if (fn->getReturnType()->isVoidTy() || fn->getReturnType()->isEmptyTy() ||

0 commit comments

Comments
 (0)