Skip to content

Commit e4def7c

Browse files
committed
Correctly propagate optional adjoint through switch_enum
Fixes #74978
1 parent 9580b21 commit e4def7c

File tree

2 files changed

+27
-7
lines changed

2 files changed

+27
-7
lines changed

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -765,14 +765,15 @@ class PullbackCloner::Implementation final
765765
SILValue wrappedAdjoint,
766766
SILType optionalTy);
767767

768-
/// Accumulate optional buffer from `wrappedAdjoint`.
768+
/// Accumulate adjoint of `wrappedAdjoint` into optionalBuffer.
769769
void accumulateAdjointForOptionalBuffer(SILBasicBlock *bb,
770770
SILValue optionalBuffer,
771771
SILValue wrappedAdjoint);
772772

773-
/// Set optional value from `wrappedAdjoint`.
774-
void setAdjointValueForOptional(SILBasicBlock *bb, SILValue optionalValue,
775-
SILValue wrappedAdjoint);
773+
/// Accumulate adjoint of `wrappedAdjoint` into optionalValue.
774+
void accumulateAdjointValueForOptional(SILBasicBlock *bb,
775+
SILValue optionalValue,
776+
SILValue wrappedAdjoint);
776777

777778
//--------------------------------------------------------------------------//
778779
// Array literal initialization differentiation
@@ -2734,7 +2735,7 @@ void PullbackCloner::Implementation::accumulateAdjointForOptionalBuffer(
27342735
}
27352736

27362737
// Set the adjoint value for the incoming `Optional` value.
2737-
void PullbackCloner::Implementation::setAdjointValueForOptional(
2738+
void PullbackCloner::Implementation::accumulateAdjointValueForOptional(
27382739
SILBasicBlock *bb, SILValue optionalValue, SILValue wrappedAdjoint) {
27392740
assert(getTangentValueCategory(optionalValue) == SILValueCategory::Object);
27402741
auto pbLoc = getPullback().getLocation();
@@ -2746,10 +2747,11 @@ void PullbackCloner::Implementation::setAdjointValueForOptional(
27462747

27472748
auto optTanAdjVal = builder.emitLoadValueOperation(
27482749
pbLoc, optTanAdjBuf, LoadOwnershipQualifier::Take);
2750+
27492751
recordTemporary(optTanAdjVal);
27502752
builder.createDeallocStack(pbLoc, optTanAdjBuf);
27512753

2752-
setAdjointValue(bb, optionalValue, makeConcreteAdjointValue(optTanAdjVal));
2754+
addAdjointValue(bb, optionalValue, makeConcreteAdjointValue(optTanAdjVal), pbLoc);
27532755
}
27542756

27552757
SILBasicBlock *PullbackCloner::Implementation::buildPullbackSuccessor(
@@ -2960,7 +2962,7 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) {
29602962
// Handle `switch_enum` on `Optional`.
29612963
auto termInst = bbArg->getSingleTerminator();
29622964
if (isSwitchEnumInstOnOptional(termInst)) {
2963-
setAdjointValueForOptional(bb, incomingValue, concreteBBArgAdjCopy);
2965+
accumulateAdjointValueForOptional(bb, incomingValue, concreteBBArgAdjCopy);
29642966
} else {
29652967
blockTemporaries[getPullbackBlock(predBB)].insert(
29662968
concreteBBArgAdjCopy);

test/AutoDiff/validation-test/optional.swift

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,30 @@ func optional_nil_coalescing(_ maybeX: Float?) -> Float {
2222
*/
2323

2424
OptionalTests.test("Active") {
25+
@differentiable(reverse)
26+
func id(y: Float) -> Float? {
27+
return y
28+
}
29+
30+
@differentiable(reverse)
31+
func id2(y: Float?) -> Float {
32+
return y!
33+
}
34+
2535
@differentiable(reverse)
2636
func square(y: Float) -> Float? {
2737
return y * y
2838
}
2939

40+
@differentiable(reverse)
41+
func square2(y: Float?) -> Float {
42+
return y! * y!
43+
}
44+
45+
expectEqual(gradient(at: 10, of: {y in id(y:y)!}), .init(1.0))
46+
expectEqual(gradient(at: 10, of: {y in id2(y:y)}), .init(1.0))
3047
expectEqual(gradient(at: 10, of: {y in square(y:y)!}), .init(20.0))
48+
expectEqual(gradient(at: 10, of: {y in square2(y:y)}), .init(20.0))
3149
}
3250

3351
OptionalTests.test("Let") {

0 commit comments

Comments
 (0)