Skip to content

[AutoDiff] Support curry thunks differentiation in fragile funcs #77615

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 108 additions & 33 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
#include "llvm/ADT/APSInt.h"
#include "llvm/ADT/BreadthFirstIterator.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/CommandLine.h"
Expand Down Expand Up @@ -84,6 +85,9 @@ class DifferentiationTransformer {
/// Context necessary for performing the transformations.
ADContext context;

/// Cache used in getUnwrappedCurryThunkFunction.
llvm::DenseMap<AbstractFunctionDecl *, SILFunction *> afdToSILFn;

/// Promotes the given `differentiable_function` instruction to a valid
/// `@differentiable` function-typed value.
SILValue promoteToDifferentiableFunction(DifferentiableFunctionInst *inst,
Expand All @@ -96,6 +100,25 @@ class DifferentiationTransformer {
SILBuilder &builder, SILLocation loc,
DifferentiationInvoker invoker);

/// Emits a reference to a derivative function of `original`, differentiated
/// with respect to a superset of `desiredIndices`. Returns the `SILValue` for
/// the derivative function and the actual indices that the derivative
/// function is with respect to.
///
/// Returns `None` on failure, signifying that a diagnostic has been emitted
/// using `invoker`.
std::optional<std::pair<SILValue, AutoDiffConfig>>
emitDerivativeFunctionReference(
SILBuilder &builder, const AutoDiffConfig &desiredConfig,
AutoDiffDerivativeFunctionKind kind, SILValue original,
DifferentiationInvoker invoker,
SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc);

/// If the given function corresponds to AutoClosureExpr with either
/// SingleCurryThunk or DoubleCurryThunk kind, get the SILFunction
/// corresponding to the function being wrapped in the thunk.
SILFunction *getUnwrappedCurryThunkFunction(SILFunction *originalFn);

public:
/// Construct an `DifferentiationTransformer` for the given module.
explicit DifferentiationTransformer(SILModuleTransform &transform)
Expand Down Expand Up @@ -453,21 +476,63 @@ static SILValue reapplyFunctionConversion(
llvm_unreachable("Unhandled function conversion instruction");
}

/// Emits a reference to a derivative function of `original`, differentiated
/// with respect to a superset of `desiredIndices`. Returns the `SILValue` for
/// the derivative function and the actual indices that the derivative function
/// is with respect to.
///
/// Returns `None` on failure, signifying that a diagnostic has been emitted
/// using `invoker`.
static std::optional<std::pair<SILValue, AutoDiffConfig>>
emitDerivativeFunctionReference(
DifferentiationTransformer &transformer, SILBuilder &builder,
const AutoDiffConfig &desiredConfig, AutoDiffDerivativeFunctionKind kind,
SILValue original, DifferentiationInvoker invoker,
SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) {
ADContext &context = transformer.getContext();
SILFunction *DifferentiationTransformer::getUnwrappedCurryThunkFunction(
SILFunction *originalFn) {
auto *autoCE = dyn_cast_or_null<AutoClosureExpr>(
originalFn->getDeclRef().getAbstractClosureExpr());
if (autoCE == nullptr)
return nullptr;

auto *ae = dyn_cast_or_null<ApplyExpr>(autoCE->getUnwrappedCurryThunkExpr());
if (ae == nullptr)
return nullptr;

AbstractFunctionDecl *afd = cast<AbstractFunctionDecl>(ae->getCalledValue(
/*skipFunctionConversions=*/true));
auto silFnIt = afdToSILFn.find(afd);
if (silFnIt == afdToSILFn.end()) {
assert(afdToSILFn.empty() && "Expect all 'afdToSILFn' cache entries to be "
"filled at once on the first access attempt");

SILModule *module = getTransform().getModule();
for (SILFunction &currentFunc : module->getFunctions()) {
if (auto *currentAFD =
currentFunc.getDeclRef().getAbstractFunctionDecl()) {
// Update cache only with AFDs which might be potentially wrapped by a
// curry thunk. This includes member function references and references
// to functions having external property wrapper parameters (see
// ExprRewriter::buildDeclRef). If new use cases of curry thunks appear
// in future, the assertion after the loop will be a trigger for such
// cases being unhandled here.
//
// FIXME: References to functions having external property wrapper
// parameters are not handled since we can't now construct a test case
// for that due to the crash
// https://github.com/swiftlang/swift/issues/77613
if (currentAFD->hasCurriedSelf()) {
auto [_, wasEmplace] =
afdToSILFn.try_emplace(currentAFD, &currentFunc);
assert(wasEmplace && "Expect all 'afdToSILFn' cache entries to be "
"filled at once on the first access attempt");
}
}
}

silFnIt = afdToSILFn.find(afd);
assert(silFnIt != afdToSILFn.end() &&
"Expect present curry thunk to SIL function mapping after "
"'afdToSILFn' cache fill");
}

return silFnIt->second;
}

std::optional<std::pair<SILValue, AutoDiffConfig>>
DifferentiationTransformer::emitDerivativeFunctionReference(
SILBuilder &builder, const AutoDiffConfig &desiredConfig,
AutoDiffDerivativeFunctionKind kind, SILValue original,
DifferentiationInvoker invoker,
SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) {
// If `original` is itself an `DifferentiableFunctionExtractInst` whose kind
// matches the given kind and desired differentiation parameter indices,
// simply extract the derivative function of its function operand, retain the
Expand Down Expand Up @@ -610,26 +675,36 @@ emitDerivativeFunctionReference(
DifferentiabilityKind::Reverse, desiredParameterIndices,
desiredResultIndices, derivativeConstrainedGenSig, /*jvp*/ nullptr,
/*vjp*/ nullptr, /*isSerialized*/ false);
if (transformer.canonicalizeDifferentiabilityWitness(
minimalWitness, invoker, IsNotSerialized))
if (canonicalizeDifferentiabilityWitness(minimalWitness, invoker,
IsNotSerialized))
return std::nullopt;
}
assert(minimalWitness);
if (original->getFunction()->isSerialized() &&
!hasPublicVisibility(minimalWitness->getLinkage())) {
enum { Inlinable = 0, DefaultArgument = 1 };
unsigned fragileKind = Inlinable;
// FIXME: This is not a very robust way of determining if the function is
// a default argument. Also, we have not exhaustively listed all the kinds
// of fragility.
if (original->getFunction()->getLinkage() == SILLinkage::PublicNonABI)
fragileKind = DefaultArgument;
context.emitNondifferentiabilityError(
original, invoker, diag::autodiff_private_derivative_from_fragile,
fragileKind,
isa_and_nonnull<AbstractClosureExpr>(
originalFRI->getLoc().getAsASTNode<Expr>()));
return std::nullopt;
if (original->getFunction()->isSerialized()) {
// When dealing with curry thunk, look at the function being wrapped
// inside implicit closure. If it has public visibility, the corresponding
// differentiability witness also has public visibility. It should be OK
// for implicit wrapper closure and its witness to have private linkage.
SILFunction *unwrappedFn = getUnwrappedCurryThunkFunction(originalFn);
bool isWitnessPublic =
unwrappedFn == nullptr
? hasPublicVisibility(minimalWitness->getLinkage())
: hasPublicVisibility(unwrappedFn->getLinkage());
if (!isWitnessPublic) {
enum { Inlinable = 0, DefaultArgument = 1 };
unsigned fragileKind = Inlinable;
// FIXME: This is not a very robust way of determining if the function
// is a default argument. Also, we have not exhaustively listed all the
// kinds of fragility.
if (original->getFunction()->getLinkage() == SILLinkage::PublicNonABI)
fragileKind = DefaultArgument;
context.emitNondifferentiabilityError(
original, invoker, diag::autodiff_private_derivative_from_fragile,
fragileKind,
isa_and_nonnull<AbstractClosureExpr>(
originalFRI->getLoc().getAsASTNode<Expr>()));
return std::nullopt;
}
}
// TODO(TF-482): Move generic requirement checking logic to
// `getExactDifferentiabilityWitness` and
Expand Down Expand Up @@ -1121,8 +1196,8 @@ SILValue DifferentiationTransformer::promoteToDifferentiableFunction(
for (auto derivativeFnKind : {AutoDiffDerivativeFunctionKind::JVP,
AutoDiffDerivativeFunctionKind::VJP}) {
auto derivativeFnAndIndices = emitDerivativeFunctionReference(
*this, builder, desiredConfig, derivativeFnKind, origFnOperand,
invoker, newBuffersToDealloc);
builder, desiredConfig, derivativeFnKind, origFnOperand, invoker,
newBuffersToDealloc);
// Show an error at the operator, highlight the argument, and show a note
// at the definition site of the argument.
if (!derivativeFnAndIndices)
Expand Down
26 changes: 0 additions & 26 deletions test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -771,32 +771,6 @@ public func fragileDifferentiable(_ x: Float) -> Float {
implicitlyDifferentiableFromFragile(x)
}


// FIXME: Differentiable curry thunk RequirementMachine error (rdar://87429620, https://github.com/apple/swift/issues/54819).
#if false
// TF-1208: Test curry thunk differentiation regression.
public struct Struct_54819<Scalar> {
var x: Scalar
}
extension Struct_54819: Differentiable where Scalar: Differentiable {
@differentiable(reverse)
public static func id(x: Self) -> Self {
return x
}
}
@differentiable(reverse, wrt: x)
public func f_54819<Scalar: Differentiable>(
_ x: Struct_54819<Scalar>,
// NOTE(TF-1208): This diagnostic is unexpected because `Struct_54819.id` is marked `@differentiable`.
// xpected-error @+3 2 {{function is not differentiable}}
// xpected-note @+2 {{differentiated functions in '@inlinable' functions must be marked '@differentiable' or have a public '@derivative'; this is not possible with a closure, make a top-level function instead}}
// xpected-note @+1 {{opaque non-'@differentiable' function is not differentiable}}
reduction: @differentiable(reverse) (Struct_54819<Scalar>) -> Struct_54819<Scalar> = Struct_54819.id
) -> Struct_54819<Scalar> {
reduction(x)
}
#endif

//===----------------------------------------------------------------------===//
// Coroutines (SIL function yields, `begin_apply`) (not yet supported)
//===----------------------------------------------------------------------===//
Expand Down
73 changes: 73 additions & 0 deletions test/AutoDiff/SILOptimizer/fragile_curry_thunk.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// RUN: %target-swift-frontend -emit-sil -verify -primary-file %s -o /dev/null

import _Differentiation

/// Minimal reproducer for both single and double curry thunk

@inlinable
func caller<Thing: Differentiable & FloatingPoint>(
of f: @differentiable(reverse) (_: Thing) -> Thing
) -> Int where Thing.TangentVector == Thing {
return 42
}

public struct Struct<Thing: Differentiable & FloatingPoint>: Differentiable where Thing.TangentVector == Thing {
@inlinable
static func foo_single() -> Int {
return caller(of: callee_single) // No error expected
}

@inlinable
@differentiable(reverse)
static func callee_single(input: Thing) -> Thing {
return input
}

@inlinable
func foo_double() -> Int {
return caller(of: callee_double) // No error expected
}

@inlinable
@differentiable(reverse)
func callee_double(input: Thing) -> Thing {
return input
}
}

/// Reproducer from https://github.com/swiftlang/swift/issues/75776

public struct Solution2<Thing: Differentiable & FloatingPoint>: Differentiable where Thing.TangentVector == Thing {
@inlinable
public static func optimization() -> Thing {
var initial = Thing.zero
let (_, delta) = valueWithGradient(at: initial, of: simulationWithLoss) // No error expected
initial.move(by: delta)
return initial
}

@inlinable
@differentiable(reverse)
static func simulationWithLoss(input: Thing) -> Thing {
return input // implementation
}
}

/// Reproducer from https://github.com/swiftlang/swift/issues/54819

public struct TF_688_Struct<Scalar> {
var x: Scalar
}
extension TF_688_Struct: Differentiable where Scalar: Differentiable {
@differentiable(reverse)
public static func id(x: Self) -> Self {
return x
}
}
@differentiable(reverse, wrt: x)
public func TF_688<Scalar: Differentiable>(
_ x: TF_688_Struct<Scalar>,
reduction: @differentiable(reverse) (TF_688_Struct<Scalar>) -> TF_688_Struct<Scalar> = TF_688_Struct.id // No error expected
) -> TF_688_Struct<Scalar> {
reduction(x)
}
21 changes: 0 additions & 21 deletions test/AutoDiff/SILOptimizer/generics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -250,27 +250,6 @@ extension TF_682_Proto where Self : Differentiable,
}
}

// NOTE(TF-1208): Differentiation regression due to changes in curry thunk generation.
/*
// TF-688: Test generic curry thunk cloning.
public struct TF_688_Struct<Scalar> {
var x: Scalar
}
extension TF_688_Struct: Differentiable where Scalar: Differentiable {
@differentiable(reverse)
public static func id(x: Self) -> Self {
return x
}
}
@differentiable(reverse, wrt: x)
public func TF_688<Scalar: Differentiable>(
_ x: TF_688_Struct<Scalar>,
reduction: @differentiable(reverse) (TF_688_Struct<Scalar>) -> TF_688_Struct<Scalar> = TF_688_Struct.id
) -> TF_688_Struct<Scalar> {
reduction(x)
}
*/

// TF-697: Test generic requirements of generated derivative function.
protocol TF_697_Module: Differentiable {
associatedtype Input
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// RUN: %target-swift-frontend -emit-sil -verify %s
// XFAIL: *

// rdar://87429620
// https://github.com/apple/swift/issues/54819
Expand Down