Skip to content

Commit 9d63e0f

Browse files
authored
[AutoDiff] Correctly propagate optional adjoint through switch_enum (#74985)
Fixes #74978
1 parent a8a1eb2 commit 9d63e0f

File tree

2 files changed

+30
-10
lines changed

2 files changed

+30
-10
lines changed

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

+12-10
Original file line numberDiff line numberDiff line change
@@ -765,14 +765,15 @@ class PullbackCloner::Implementation final
765765
SILValue wrappedAdjoint,
766766
SILType optionalTy);
767767

768-
/// Accumulate optional buffer from `wrappedAdjoint`.
768+
/// Accumulate adjoint of `wrappedAdjoint` into optionalBuffer.
769769
void accumulateAdjointForOptionalBuffer(SILBasicBlock *bb,
770770
SILValue optionalBuffer,
771771
SILValue wrappedAdjoint);
772772

773-
/// Set optional value from `wrappedAdjoint`.
774-
void setAdjointValueForOptional(SILBasicBlock *bb, SILValue optionalValue,
775-
SILValue wrappedAdjoint);
773+
/// Accumulate adjoint of `wrappedAdjoint` into optionalValue.
774+
void accumulateAdjointValueForOptional(SILBasicBlock *bb,
775+
SILValue optionalValue,
776+
SILValue wrappedAdjoint);
776777

777778
//--------------------------------------------------------------------------//
778779
// Array literal initialization differentiation
@@ -2732,8 +2733,8 @@ void PullbackCloner::Implementation::accumulateAdjointForOptionalBuffer(
27322733
builder.createDeallocStack(pbLoc, optTanAdjBuf);
27332734
}
27342735

2735-
// Set the adjoint value for the incoming `Optional` value.
2736-
void PullbackCloner::Implementation::setAdjointValueForOptional(
2736+
// Accumulate adjoint for the incoming `Optional` value.
2737+
void PullbackCloner::Implementation::accumulateAdjointValueForOptional(
27372738
SILBasicBlock *bb, SILValue optionalValue, SILValue wrappedAdjoint) {
27382739
assert(getTangentValueCategory(optionalValue) == SILValueCategory::Object);
27392740
auto pbLoc = getPullback().getLocation();
@@ -2745,10 +2746,11 @@ void PullbackCloner::Implementation::setAdjointValueForOptional(
27452746

27462747
auto optTanAdjVal = builder.emitLoadValueOperation(
27472748
pbLoc, optTanAdjBuf, LoadOwnershipQualifier::Take);
2749+
27482750
recordTemporary(optTanAdjVal);
27492751
builder.createDeallocStack(pbLoc, optTanAdjBuf);
27502752

2751-
setAdjointValue(bb, optionalValue, makeConcreteAdjointValue(optTanAdjVal));
2753+
addAdjointValue(bb, optionalValue, makeConcreteAdjointValue(optTanAdjVal), pbLoc);
27522754
}
27532755

27542756
SILBasicBlock *PullbackCloner::Implementation::buildPullbackSuccessor(
@@ -2959,12 +2961,12 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) {
29592961
// Handle `switch_enum` on `Optional`.
29602962
auto termInst = bbArg->getSingleTerminator();
29612963
if (isSwitchEnumInstOnOptional(termInst)) {
2962-
setAdjointValueForOptional(bb, incomingValue, concreteBBArgAdjCopy);
2964+
accumulateAdjointValueForOptional(bb, incomingValue, concreteBBArgAdjCopy);
29632965
} else {
29642966
blockTemporaries[getPullbackBlock(predBB)].insert(
29652967
concreteBBArgAdjCopy);
2966-
setAdjointValue(predBB, incomingValue,
2967-
makeConcreteAdjointValue(concreteBBArgAdjCopy));
2968+
addAdjointValue(predBB, incomingValue,
2969+
makeConcreteAdjointValue(concreteBBArgAdjCopy), pbLoc);
29682970
}
29692971
}
29702972
break;

test/AutoDiff/validation-test/optional.swift

+18
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,30 @@ func optional_nil_coalescing(_ maybeX: Float?) -> Float {
2222
*/
2323

2424
OptionalTests.test("Active") {
25+
@differentiable(reverse)
26+
func id(y: Float) -> Float? {
27+
return y
28+
}
29+
30+
@differentiable(reverse)
31+
func id2(y: Float?) -> Float {
32+
return y!
33+
}
34+
2535
@differentiable(reverse)
2636
func square(y: Float) -> Float? {
2737
return y * y
2838
}
2939

40+
@differentiable(reverse)
41+
func square2(y: Float?) -> Float {
42+
return y! * y!
43+
}
44+
45+
expectEqual(gradient(at: 10, of: {y in id(y:y)!}), .init(1.0))
46+
expectEqual(gradient(at: 10, of: {y in id2(y:y)}), .init(1.0))
3047
expectEqual(gradient(at: 10, of: {y in square(y:y)!}), .init(20.0))
48+
expectEqual(gradient(at: 10, of: {y in square2(y:y)}), .init(20.0))
3149
}
3250

3351
OptionalTests.test("Let") {

0 commit comments

Comments
 (0)