File tree Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -3542,7 +3542,8 @@ Constant *GradientUtils::GetOrCreateShadowFunction(
3542
3542
std::pair<Argument *, std::set<int64_t >>(&a, {}));
3543
3543
DIFFE_TYPE typ;
3544
3544
if (a.getType ()->isFPOrFPVectorTy ()) {
3545
- typ = DIFFE_TYPE::OUT_DIFF;
3545
+ typ = mode == DerivativeMode::ForwardMode ? DIFFE_TYPE::DUP_ARG
3546
+ : DIFFE_TYPE::OUT_DIFF;
3546
3547
} else if (a.getType ()->isIntegerTy () &&
3547
3548
cast<IntegerType>(a.getType ())->getBitWidth () < 16 ) {
3548
3549
typ = DIFFE_TYPE::CONSTANT;
@@ -3554,7 +3555,8 @@ Constant *GradientUtils::GetOrCreateShadowFunction(
3554
3555
types.push_back (typ);
3555
3556
}
3556
3557
3557
- DIFFE_TYPE retType = fn->getReturnType ()->isFPOrFPVectorTy ()
3558
+ DIFFE_TYPE retType = fn->getReturnType ()->isFPOrFPVectorTy () &&
3559
+ mode != DerivativeMode::ForwardMode
3558
3560
? DIFFE_TYPE::OUT_DIFF
3559
3561
: DIFFE_TYPE::DUP_ARG;
3560
3562
if (fn->getReturnType ()->isVoidTy () || fn->getReturnType ()->isEmptyTy () ||
You can’t perform that action at this time.
0 commit comments