diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 6276cc3a7c089..de3285cdd3f06 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -47,6 +47,7 @@ #include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h" #include "llvm/ADT/APSInt.h" #include "llvm/ADT/BreadthFirstIterator.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SmallSet.h" #include "llvm/Support/CommandLine.h" @@ -84,6 +85,9 @@ class DifferentiationTransformer { /// Context necessary for performing the transformations. ADContext context; + /// Cache used in getUnwrappedCurryThunkFunction. + llvm::DenseMap afdToSILFn; + /// Promotes the given `differentiable_function` instruction to a valid /// `@differentiable` function-typed value. SILValue promoteToDifferentiableFunction(DifferentiableFunctionInst *inst, @@ -96,6 +100,25 @@ class DifferentiationTransformer { SILBuilder &builder, SILLocation loc, DifferentiationInvoker invoker); + /// Emits a reference to a derivative function of `original`, differentiated + /// with respect to a superset of `desiredIndices`. Returns the `SILValue` for + /// the derivative function and the actual indices that the derivative + /// function is with respect to. + /// + /// Returns `None` on failure, signifying that a diagnostic has been emitted + /// using `invoker`. + std::optional> + emitDerivativeFunctionReference( + SILBuilder &builder, const AutoDiffConfig &desiredConfig, + AutoDiffDerivativeFunctionKind kind, SILValue original, + DifferentiationInvoker invoker, + SmallVectorImpl &newBuffersToDealloc); + + /// If the given function corresponds to AutoClosureExpr with either + /// SingleCurryThunk or DoubleCurryThunk kind, get the SILFunction + /// corresponding to the function being wrapped in the thunk. + SILFunction *getUnwrappedCurryThunkFunction(SILFunction *originalFn); + public: /// Construct an `DifferentiationTransformer` for the given module. explicit DifferentiationTransformer(SILModuleTransform &transform) @@ -453,21 +476,63 @@ static SILValue reapplyFunctionConversion( llvm_unreachable("Unhandled function conversion instruction"); } -/// Emits a reference to a derivative function of `original`, differentiated -/// with respect to a superset of `desiredIndices`. Returns the `SILValue` for -/// the derivative function and the actual indices that the derivative function -/// is with respect to. -/// -/// Returns `None` on failure, signifying that a diagnostic has been emitted -/// using `invoker`. -static std::optional> -emitDerivativeFunctionReference( - DifferentiationTransformer &transformer, SILBuilder &builder, - const AutoDiffConfig &desiredConfig, AutoDiffDerivativeFunctionKind kind, - SILValue original, DifferentiationInvoker invoker, - SmallVectorImpl &newBuffersToDealloc) { - ADContext &context = transformer.getContext(); +SILFunction *DifferentiationTransformer::getUnwrappedCurryThunkFunction( + SILFunction *originalFn) { + auto *autoCE = dyn_cast_or_null( + originalFn->getDeclRef().getAbstractClosureExpr()); + if (autoCE == nullptr) + return nullptr; + + auto *ae = dyn_cast_or_null(autoCE->getUnwrappedCurryThunkExpr()); + if (ae == nullptr) + return nullptr; + AbstractFunctionDecl *afd = cast(ae->getCalledValue( + /*skipFunctionConversions=*/true)); + auto silFnIt = afdToSILFn.find(afd); + if (silFnIt == afdToSILFn.end()) { + assert(afdToSILFn.empty() && "Expect all 'afdToSILFn' cache entries to be " + "filled at once on the first access attempt"); + + SILModule *module = getTransform().getModule(); + for (SILFunction ¤tFunc : module->getFunctions()) { + if (auto *currentAFD = + currentFunc.getDeclRef().getAbstractFunctionDecl()) { + // Update cache only with AFDs which might be potentially wrapped by a + // curry thunk. This includes member function references and references + // to functions having external property wrapper parameters (see + // ExprRewriter::buildDeclRef). If new use cases of curry thunks appear + // in future, the assertion after the loop will be a trigger for such + // cases being unhandled here. + // + // FIXME: References to functions having external property wrapper + // parameters are not handled since we can't now construct a test case + // for that due to the crash + // https://github.com/swiftlang/swift/issues/77613 + if (currentAFD->hasCurriedSelf()) { + auto [_, wasEmplace] = + afdToSILFn.try_emplace(currentAFD, ¤tFunc); + assert(wasEmplace && "Expect all 'afdToSILFn' cache entries to be " + "filled at once on the first access attempt"); + } + } + } + + silFnIt = afdToSILFn.find(afd); + assert(silFnIt != afdToSILFn.end() && + "Expect present curry thunk to SIL function mapping after " + "'afdToSILFn' cache fill"); + } + + return silFnIt->second; +} + +std::optional> +DifferentiationTransformer::emitDerivativeFunctionReference( + SILBuilder &builder, const AutoDiffConfig &desiredConfig, + AutoDiffDerivativeFunctionKind kind, SILValue original, + DifferentiationInvoker invoker, + SmallVectorImpl &newBuffersToDealloc) { // If `original` is itself an `DifferentiableFunctionExtractInst` whose kind // matches the given kind and desired differentiation parameter indices, // simply extract the derivative function of its function operand, retain the @@ -610,26 +675,36 @@ emitDerivativeFunctionReference( DifferentiabilityKind::Reverse, desiredParameterIndices, desiredResultIndices, derivativeConstrainedGenSig, /*jvp*/ nullptr, /*vjp*/ nullptr, /*isSerialized*/ false); - if (transformer.canonicalizeDifferentiabilityWitness( - minimalWitness, invoker, IsNotSerialized)) + if (canonicalizeDifferentiabilityWitness(minimalWitness, invoker, + IsNotSerialized)) return std::nullopt; } assert(minimalWitness); - if (original->getFunction()->isSerialized() && - !hasPublicVisibility(minimalWitness->getLinkage())) { - enum { Inlinable = 0, DefaultArgument = 1 }; - unsigned fragileKind = Inlinable; - // FIXME: This is not a very robust way of determining if the function is - // a default argument. Also, we have not exhaustively listed all the kinds - // of fragility. - if (original->getFunction()->getLinkage() == SILLinkage::PublicNonABI) - fragileKind = DefaultArgument; - context.emitNondifferentiabilityError( - original, invoker, diag::autodiff_private_derivative_from_fragile, - fragileKind, - isa_and_nonnull( - originalFRI->getLoc().getAsASTNode())); - return std::nullopt; + if (original->getFunction()->isSerialized()) { + // When dealing with curry thunk, look at the function being wrapped + // inside implicit closure. If it has public visibility, the corresponding + // differentiability witness also has public visibility. It should be OK + // for implicit wrapper closure and its witness to have private linkage. + SILFunction *unwrappedFn = getUnwrappedCurryThunkFunction(originalFn); + bool isWitnessPublic = + unwrappedFn == nullptr + ? hasPublicVisibility(minimalWitness->getLinkage()) + : hasPublicVisibility(unwrappedFn->getLinkage()); + if (!isWitnessPublic) { + enum { Inlinable = 0, DefaultArgument = 1 }; + unsigned fragileKind = Inlinable; + // FIXME: This is not a very robust way of determining if the function + // is a default argument. Also, we have not exhaustively listed all the + // kinds of fragility. + if (original->getFunction()->getLinkage() == SILLinkage::PublicNonABI) + fragileKind = DefaultArgument; + context.emitNondifferentiabilityError( + original, invoker, diag::autodiff_private_derivative_from_fragile, + fragileKind, + isa_and_nonnull( + originalFRI->getLoc().getAsASTNode())); + return std::nullopt; + } } // TODO(TF-482): Move generic requirement checking logic to // `getExactDifferentiabilityWitness` and @@ -1121,8 +1196,8 @@ SILValue DifferentiationTransformer::promoteToDifferentiableFunction( for (auto derivativeFnKind : {AutoDiffDerivativeFunctionKind::JVP, AutoDiffDerivativeFunctionKind::VJP}) { auto derivativeFnAndIndices = emitDerivativeFunctionReference( - *this, builder, desiredConfig, derivativeFnKind, origFnOperand, - invoker, newBuffersToDealloc); + builder, desiredConfig, derivativeFnKind, origFnOperand, invoker, + newBuffersToDealloc); // Show an error at the operator, highlight the argument, and show a note // at the definition site of the argument. if (!derivativeFnAndIndices) diff --git a/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift b/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift index c04f5a99993a1..5b446df14536d 100644 --- a/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift +++ b/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift @@ -771,32 +771,6 @@ public func fragileDifferentiable(_ x: Float) -> Float { implicitlyDifferentiableFromFragile(x) } - -// FIXME: Differentiable curry thunk RequirementMachine error (rdar://87429620, https://github.com/apple/swift/issues/54819). -#if false -// TF-1208: Test curry thunk differentiation regression. -public struct Struct_54819 { - var x: Scalar -} -extension Struct_54819: Differentiable where Scalar: Differentiable { - @differentiable(reverse) - public static func id(x: Self) -> Self { - return x - } -} -@differentiable(reverse, wrt: x) -public func f_54819( - _ x: Struct_54819, - // NOTE(TF-1208): This diagnostic is unexpected because `Struct_54819.id` is marked `@differentiable`. - // xpected-error @+3 2 {{function is not differentiable}} - // xpected-note @+2 {{differentiated functions in '@inlinable' functions must be marked '@differentiable' or have a public '@derivative'; this is not possible with a closure, make a top-level function instead}} - // xpected-note @+1 {{opaque non-'@differentiable' function is not differentiable}} - reduction: @differentiable(reverse) (Struct_54819) -> Struct_54819 = Struct_54819.id -) -> Struct_54819 { - reduction(x) -} -#endif - //===----------------------------------------------------------------------===// // Coroutines (SIL function yields, `begin_apply`) (not yet supported) //===----------------------------------------------------------------------===// diff --git a/test/AutoDiff/SILOptimizer/fragile_curry_thunk.swift b/test/AutoDiff/SILOptimizer/fragile_curry_thunk.swift new file mode 100644 index 0000000000000..d82f1f0daff27 --- /dev/null +++ b/test/AutoDiff/SILOptimizer/fragile_curry_thunk.swift @@ -0,0 +1,73 @@ +// RUN: %target-swift-frontend -emit-sil -verify -primary-file %s -o /dev/null + +import _Differentiation + +/// Minimal reproducer for both single and double curry thunk + +@inlinable +func caller( + of f: @differentiable(reverse) (_: Thing) -> Thing +) -> Int where Thing.TangentVector == Thing { + return 42 +} + +public struct Struct: Differentiable where Thing.TangentVector == Thing { + @inlinable + static func foo_single() -> Int { + return caller(of: callee_single) // No error expected + } + + @inlinable + @differentiable(reverse) + static func callee_single(input: Thing) -> Thing { + return input + } + + @inlinable + func foo_double() -> Int { + return caller(of: callee_double) // No error expected + } + + @inlinable + @differentiable(reverse) + func callee_double(input: Thing) -> Thing { + return input + } +} + +/// Reproducer from https://github.com/swiftlang/swift/issues/75776 + +public struct Solution2: Differentiable where Thing.TangentVector == Thing { + @inlinable + public static func optimization() -> Thing { + var initial = Thing.zero + let (_, delta) = valueWithGradient(at: initial, of: simulationWithLoss) // No error expected + initial.move(by: delta) + return initial + } + + @inlinable + @differentiable(reverse) + static func simulationWithLoss(input: Thing) -> Thing { + return input // implementation + } +} + +/// Reproducer from https://github.com/swiftlang/swift/issues/54819 + +public struct TF_688_Struct { + var x: Scalar +} +extension TF_688_Struct: Differentiable where Scalar: Differentiable { + @differentiable(reverse) + public static func id(x: Self) -> Self { + return x + } +} +@differentiable(reverse, wrt: x) +public func TF_688( + _ x: TF_688_Struct, + reduction: @differentiable(reverse) (TF_688_Struct) -> TF_688_Struct = TF_688_Struct.id // No error expected +) -> TF_688_Struct { + reduction(x) +} diff --git a/test/AutoDiff/SILOptimizer/generics.swift b/test/AutoDiff/SILOptimizer/generics.swift index 335821cddec61..48dfb48346c53 100644 --- a/test/AutoDiff/SILOptimizer/generics.swift +++ b/test/AutoDiff/SILOptimizer/generics.swift @@ -250,27 +250,6 @@ extension TF_682_Proto where Self : Differentiable, } } -// NOTE(TF-1208): Differentiation regression due to changes in curry thunk generation. -/* -// TF-688: Test generic curry thunk cloning. -public struct TF_688_Struct { - var x: Scalar -} -extension TF_688_Struct: Differentiable where Scalar: Differentiable { - @differentiable(reverse) - public static func id(x: Self) -> Self { - return x - } -} -@differentiable(reverse, wrt: x) -public func TF_688( - _ x: TF_688_Struct, - reduction: @differentiable(reverse) (TF_688_Struct) -> TF_688_Struct = TF_688_Struct.id -) -> TF_688_Struct { - reduction(x) -} -*/ - // TF-697: Test generic requirements of generated derivative function. protocol TF_697_Module: Differentiable { associatedtype Input diff --git a/test/AutoDiff/compiler_crashers/rdar87429620-differentiable-curry-thunk-reqmachine.swift b/test/AutoDiff/compiler_crashers_fixed/rdar87429620-differentiable-curry-thunk-reqmachine.swift similarity index 98% rename from test/AutoDiff/compiler_crashers/rdar87429620-differentiable-curry-thunk-reqmachine.swift rename to test/AutoDiff/compiler_crashers_fixed/rdar87429620-differentiable-curry-thunk-reqmachine.swift index 7709ee1f28a4d..a6020e815e14a 100644 --- a/test/AutoDiff/compiler_crashers/rdar87429620-differentiable-curry-thunk-reqmachine.swift +++ b/test/AutoDiff/compiler_crashers_fixed/rdar87429620-differentiable-curry-thunk-reqmachine.swift @@ -1,5 +1,4 @@ // RUN: %target-swift-frontend -emit-sil -verify %s -// XFAIL: * // rdar://87429620 // https://github.com/apple/swift/issues/54819