@@ -765,14 +765,15 @@ class PullbackCloner::Implementation final
765
765
SILValue wrappedAdjoint,
766
766
SILType optionalTy);
767
767
768
- // / Accumulate optional buffer from `wrappedAdjoint`.
768
+ // / Accumulate adjoint of `wrappedAdjoint` into optionalBuffer .
769
769
void accumulateAdjointForOptionalBuffer (SILBasicBlock *bb,
770
770
SILValue optionalBuffer,
771
771
SILValue wrappedAdjoint);
772
772
773
- // / Set optional value from `wrappedAdjoint`.
774
- void setAdjointValueForOptional (SILBasicBlock *bb, SILValue optionalValue,
775
- SILValue wrappedAdjoint);
773
+ // / Accumulate adjoint of `wrappedAdjoint` into optionalValue.
774
+ void accumulateAdjointValueForOptional (SILBasicBlock *bb,
775
+ SILValue optionalValue,
776
+ SILValue wrappedAdjoint);
776
777
777
778
// --------------------------------------------------------------------------//
778
779
// Array literal initialization differentiation
@@ -2732,8 +2733,8 @@ void PullbackCloner::Implementation::accumulateAdjointForOptionalBuffer(
2732
2733
builder.createDeallocStack (pbLoc, optTanAdjBuf);
2733
2734
}
2734
2735
2735
- // Set the adjoint value for the incoming `Optional` value.
2736
- void PullbackCloner::Implementation::setAdjointValueForOptional (
2736
+ // Accumulate adjoint for the incoming `Optional` value.
2737
+ void PullbackCloner::Implementation::accumulateAdjointValueForOptional (
2737
2738
SILBasicBlock *bb, SILValue optionalValue, SILValue wrappedAdjoint) {
2738
2739
assert (getTangentValueCategory (optionalValue) == SILValueCategory::Object);
2739
2740
auto pbLoc = getPullback ().getLocation ();
@@ -2745,10 +2746,11 @@ void PullbackCloner::Implementation::setAdjointValueForOptional(
2745
2746
2746
2747
auto optTanAdjVal = builder.emitLoadValueOperation (
2747
2748
pbLoc, optTanAdjBuf, LoadOwnershipQualifier::Take);
2749
+
2748
2750
recordTemporary (optTanAdjVal);
2749
2751
builder.createDeallocStack (pbLoc, optTanAdjBuf);
2750
2752
2751
- setAdjointValue (bb, optionalValue, makeConcreteAdjointValue (optTanAdjVal));
2753
+ addAdjointValue (bb, optionalValue, makeConcreteAdjointValue (optTanAdjVal), pbLoc );
2752
2754
}
2753
2755
2754
2756
SILBasicBlock *PullbackCloner::Implementation::buildPullbackSuccessor (
@@ -2959,12 +2961,12 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) {
2959
2961
// Handle `switch_enum` on `Optional`.
2960
2962
auto termInst = bbArg->getSingleTerminator ();
2961
2963
if (isSwitchEnumInstOnOptional (termInst)) {
2962
- setAdjointValueForOptional (bb, incomingValue, concreteBBArgAdjCopy);
2964
+ accumulateAdjointValueForOptional (bb, incomingValue, concreteBBArgAdjCopy);
2963
2965
} else {
2964
2966
blockTemporaries[getPullbackBlock (predBB)].insert (
2965
2967
concreteBBArgAdjCopy);
2966
- setAdjointValue (predBB, incomingValue,
2967
- makeConcreteAdjointValue (concreteBBArgAdjCopy));
2968
+ addAdjointValue (predBB, incomingValue,
2969
+ makeConcreteAdjointValue (concreteBBArgAdjCopy), pbLoc );
2968
2970
}
2969
2971
}
2970
2972
break ;
0 commit comments