@@ -765,14 +765,15 @@ class PullbackCloner::Implementation final
765
765
SILValue wrappedAdjoint,
766
766
SILType optionalTy);
767
767
768
- // / Accumulate optional buffer from `wrappedAdjoint`.
768
+ // / Accumulate adjoint of `wrappedAdjoint` into optionalBuffer .
769
769
void accumulateAdjointForOptionalBuffer (SILBasicBlock *bb,
770
770
SILValue optionalBuffer,
771
771
SILValue wrappedAdjoint);
772
772
773
- // / Set optional value from `wrappedAdjoint`.
774
- void setAdjointValueForOptional (SILBasicBlock *bb, SILValue optionalValue,
775
- SILValue wrappedAdjoint);
773
+ // / Accumulate adjoint of `wrappedAdjoint` into optionalValue.
774
+ void accumulateAdjointValueForOptional (SILBasicBlock *bb,
775
+ SILValue optionalValue,
776
+ SILValue wrappedAdjoint);
776
777
777
778
// --------------------------------------------------------------------------//
778
779
// Array literal initialization differentiation
@@ -2734,7 +2735,7 @@ void PullbackCloner::Implementation::accumulateAdjointForOptionalBuffer(
2734
2735
}
2735
2736
2736
2737
// Set the adjoint value for the incoming `Optional` value.
2737
- void PullbackCloner::Implementation::setAdjointValueForOptional (
2738
+ void PullbackCloner::Implementation::accumulateAdjointValueForOptional (
2738
2739
SILBasicBlock *bb, SILValue optionalValue, SILValue wrappedAdjoint) {
2739
2740
assert (getTangentValueCategory (optionalValue) == SILValueCategory::Object);
2740
2741
auto pbLoc = getPullback ().getLocation ();
@@ -2746,10 +2747,11 @@ void PullbackCloner::Implementation::setAdjointValueForOptional(
2746
2747
2747
2748
auto optTanAdjVal = builder.emitLoadValueOperation (
2748
2749
pbLoc, optTanAdjBuf, LoadOwnershipQualifier::Take);
2750
+
2749
2751
recordTemporary (optTanAdjVal);
2750
2752
builder.createDeallocStack (pbLoc, optTanAdjBuf);
2751
2753
2752
- setAdjointValue (bb, optionalValue, makeConcreteAdjointValue (optTanAdjVal));
2754
+ addAdjointValue (bb, optionalValue, makeConcreteAdjointValue (optTanAdjVal), pbLoc );
2753
2755
}
2754
2756
2755
2757
SILBasicBlock *PullbackCloner::Implementation::buildPullbackSuccessor (
@@ -2960,7 +2962,7 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) {
2960
2962
// Handle `switch_enum` on `Optional`.
2961
2963
auto termInst = bbArg->getSingleTerminator ();
2962
2964
if (isSwitchEnumInstOnOptional (termInst)) {
2963
- setAdjointValueForOptional (bb, incomingValue, concreteBBArgAdjCopy);
2965
+ accumulateAdjointValueForOptional (bb, incomingValue, concreteBBArgAdjCopy);
2964
2966
} else {
2965
2967
blockTemporaries[getPullbackBlock (predBB)].insert (
2966
2968
concreteBBArgAdjCopy);
0 commit comments