From 9db059cfa080600d367243aec1c8153195a0214b Mon Sep 17 00:00:00 2001 From: Anton Korobeynikov Date: Fri, 5 Jul 2024 01:17:32 -0700 Subject: [PATCH 1/2] Correctly propagate optional adjoint through `switch_enum` Fixes #74978 --- .../Differentiation/PullbackCloner.cpp | 18 ++++++++++-------- test/AutoDiff/validation-test/optional.swift | 18 ++++++++++++++++++ 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/lib/SILOptimizer/Differentiation/PullbackCloner.cpp b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp index e39bcb127b4a3..d276fdde8ed67 100644 --- a/lib/SILOptimizer/Differentiation/PullbackCloner.cpp +++ b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp @@ -765,14 +765,15 @@ class PullbackCloner::Implementation final SILValue wrappedAdjoint, SILType optionalTy); - /// Accumulate optional buffer from `wrappedAdjoint`. + /// Accumulate adjoint of `wrappedAdjoint` into optionalBuffer. void accumulateAdjointForOptionalBuffer(SILBasicBlock *bb, SILValue optionalBuffer, SILValue wrappedAdjoint); - /// Set optional value from `wrappedAdjoint`. - void setAdjointValueForOptional(SILBasicBlock *bb, SILValue optionalValue, - SILValue wrappedAdjoint); + /// Accumulate adjoint of `wrappedAdjoint` into optionalValue. + void accumulateAdjointValueForOptional(SILBasicBlock *bb, + SILValue optionalValue, + SILValue wrappedAdjoint); //--------------------------------------------------------------------------// // Array literal initialization differentiation @@ -2733,8 +2734,8 @@ void PullbackCloner::Implementation::accumulateAdjointForOptionalBuffer( builder.createDeallocStack(pbLoc, optTanAdjBuf); } -// Set the adjoint value for the incoming `Optional` value. -void PullbackCloner::Implementation::setAdjointValueForOptional( +// Accumulate adjoint for the incoming `Optional` value. +void PullbackCloner::Implementation::accumulateAdjointValueForOptional( SILBasicBlock *bb, SILValue optionalValue, SILValue wrappedAdjoint) { assert(getTangentValueCategory(optionalValue) == SILValueCategory::Object); auto pbLoc = getPullback().getLocation(); @@ -2746,10 +2747,11 @@ void PullbackCloner::Implementation::setAdjointValueForOptional( auto optTanAdjVal = builder.emitLoadValueOperation( pbLoc, optTanAdjBuf, LoadOwnershipQualifier::Take); + recordTemporary(optTanAdjVal); builder.createDeallocStack(pbLoc, optTanAdjBuf); - setAdjointValue(bb, optionalValue, makeConcreteAdjointValue(optTanAdjVal)); + addAdjointValue(bb, optionalValue, makeConcreteAdjointValue(optTanAdjVal), pbLoc); } SILBasicBlock *PullbackCloner::Implementation::buildPullbackSuccessor( @@ -2960,7 +2962,7 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) { // Handle `switch_enum` on `Optional`. auto termInst = bbArg->getSingleTerminator(); if (isSwitchEnumInstOnOptional(termInst)) { - setAdjointValueForOptional(bb, incomingValue, concreteBBArgAdjCopy); + accumulateAdjointValueForOptional(bb, incomingValue, concreteBBArgAdjCopy); } else { blockTemporaries[getPullbackBlock(predBB)].insert( concreteBBArgAdjCopy); diff --git a/test/AutoDiff/validation-test/optional.swift b/test/AutoDiff/validation-test/optional.swift index e849140087737..4a2ce8c92b9c5 100644 --- a/test/AutoDiff/validation-test/optional.swift +++ b/test/AutoDiff/validation-test/optional.swift @@ -22,12 +22,30 @@ func optional_nil_coalescing(_ maybeX: Float?) -> Float { */ OptionalTests.test("Active") { + @differentiable(reverse) + func id(y: Float) -> Float? { + return y + } + + @differentiable(reverse) + func id2(y: Float?) -> Float { + return y! + } + @differentiable(reverse) func square(y: Float) -> Float? { return y * y } + @differentiable(reverse) + func square2(y: Float?) -> Float { + return y! * y! + } + + expectEqual(gradient(at: 10, of: {y in id(y:y)!}), .init(1.0)) + expectEqual(gradient(at: 10, of: {y in id2(y:y)}), .init(1.0)) expectEqual(gradient(at: 10, of: {y in square(y:y)!}), .init(20.0)) + expectEqual(gradient(at: 10, of: {y in square2(y:y)}), .init(20.0)) } OptionalTests.test("Let") { From 0c5ae930768c0f07d3d9e4bfdae3b195b8a02c3c Mon Sep 17 00:00:00 2001 From: Anton Korobeynikov Date: Thu, 11 Jul 2024 15:15:11 -0700 Subject: [PATCH 2/2] Accumulate adjoints for normal values as well --- lib/SILOptimizer/Differentiation/PullbackCloner.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/SILOptimizer/Differentiation/PullbackCloner.cpp b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp index d276fdde8ed67..f4e8b0da0c7a4 100644 --- a/lib/SILOptimizer/Differentiation/PullbackCloner.cpp +++ b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp @@ -2966,8 +2966,8 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) { } else { blockTemporaries[getPullbackBlock(predBB)].insert( concreteBBArgAdjCopy); - setAdjointValue(predBB, incomingValue, - makeConcreteAdjointValue(concreteBBArgAdjCopy)); + addAdjointValue(predBB, incomingValue, + makeConcreteAdjointValue(concreteBBArgAdjCopy), pbLoc); } } break;