Skip to content

Commit 38edc2b

Browse files
committed
[AutoDiff] Support curry thunks differentiation in inlinable funcs
Inside inlinable functions, we expect functions to either be explicitly marked as differentiable or have a public explicit derivative defined. This is obviously not possible for single and double curry thunks which are a special case of `AutoClosureExpr`. Instead of looking at the thunk itself, we unwrap it and look at the function being wrapped. While the thunk itself and its differentiability witness will not have public visibility, it's not an issue for the case where the function being wrapped (and its witness) have public visibility. Fixes swiftlang#54819 Fixes swiftlang#75776
1 parent 45657fe commit 38edc2b

File tree

6 files changed

+182
-63
lines changed

6 files changed

+182
-63
lines changed

include/swift/AST/Expr.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4387,6 +4387,8 @@ class ClosureExpr : public AbstractClosureExpr {
43874387
class AutoClosureExpr : public AbstractClosureExpr {
43884388
BraceStmt *Body;
43894389

4390+
ApplyExpr *getUnwrappedCurryThunkImpl() const;
4391+
43904392
public:
43914393
enum class Kind : uint8_t {
43924394
// An autoclosure with type () -> Result. Formed from type checking an
@@ -4448,6 +4450,10 @@ class AutoClosureExpr : public AbstractClosureExpr {
44484450
/// - otherwise, returns nullptr for convenience.
44494451
Expr *getUnwrappedCurryThunkExpr() const;
44504452

4453+
/// Same as getUnwrappedCurryThunkExpr, but get the called ValueDecl instead
4454+
/// of the expr.
4455+
ValueDecl *getUnwrappedCurryThunkCalledValue() const;
4456+
44514457
// Implement isa/cast/dyncast/etc.
44524458
static bool classof(const Expr *E) {
44534459
return E->getKind() == ExprKind::AutoClosure;

lib/AST/Expr.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2133,7 +2133,7 @@ Expr *AutoClosureExpr::getSingleExpressionBody() const {
21332133
return cast<ReturnStmt>(Body->getLastElement().get<Stmt *>())->getResult();
21342134
}
21352135

2136-
Expr *AutoClosureExpr::getUnwrappedCurryThunkExpr() const {
2136+
ApplyExpr *AutoClosureExpr::getUnwrappedCurryThunkImpl() const {
21372137
auto maybeUnwrapOpenExistential = [](Expr *expr) {
21382138
if (auto *openExistential = dyn_cast<OpenExistentialExpr>(expr)) {
21392139
expr = openExistential->getSubExpr()->getSemanticsProvidingExpr();
@@ -2177,7 +2177,7 @@ Expr *AutoClosureExpr::getUnwrappedCurryThunkExpr() const {
21772177
body = maybeUnwrapConversions(body);
21782178

21792179
if (auto *outerCall = dyn_cast<ApplyExpr>(body)) {
2180-
return outerCall->getFn();
2180+
return outerCall;
21812181
}
21822182

21832183
assert(false && "Malformed curry thunk?");
@@ -2198,7 +2198,7 @@ Expr *AutoClosureExpr::getUnwrappedCurryThunkExpr() const {
21982198
if (auto *outerCall = dyn_cast<ApplyExpr>(innerBody)) {
21992199
auto outerFn = maybeUnwrapConversions(outerCall->getFn());
22002200
if (auto *innerCall = dyn_cast<ApplyExpr>(outerFn)) {
2201-
return innerCall->getFn();
2201+
return innerCall;
22022202
}
22032203
}
22042204
}
@@ -2211,6 +2211,20 @@ Expr *AutoClosureExpr::getUnwrappedCurryThunkExpr() const {
22112211
return nullptr;
22122212
}
22132213

2214+
Expr *AutoClosureExpr::getUnwrappedCurryThunkExpr() const {
2215+
ApplyExpr *ae = getUnwrappedCurryThunkImpl();
2216+
if (ae == nullptr)
2217+
return nullptr;
2218+
return ae->getFn();
2219+
}
2220+
2221+
ValueDecl *AutoClosureExpr::getUnwrappedCurryThunkCalledValue() const {
2222+
ApplyExpr *ae = getUnwrappedCurryThunkImpl();
2223+
if (ae == nullptr)
2224+
return nullptr;
2225+
return ae->getCalledValue(/*skipFunctionConversions=*/true);
2226+
}
2227+
22142228
FORWARD_SOURCE_LOCS_TO(UnresolvedPatternExpr, subPattern)
22152229

22162230
TypeExpr::TypeExpr(TypeRepr *Repr)

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 86 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
#include "llvm/ADT/SmallSet.h"
5252
#include "llvm/Support/CommandLine.h"
5353

54+
#include <unordered_map>
55+
5456
using namespace swift;
5557
using namespace swift::autodiff;
5658
using llvm::DenseMap;
@@ -84,6 +86,9 @@ class DifferentiationTransformer {
8486
/// Context necessary for performing the transformations.
8587
ADContext context;
8688

89+
/// Cache used in getUnwrappedCurryThunkFunction.
90+
std::unordered_map<AbstractFunctionDecl *, SILFunction *> afdToSILFn;
91+
8792
/// Promotes the given `differentiable_function` instruction to a valid
8893
/// `@differentiable` function-typed value.
8994
SILValue promoteToDifferentiableFunction(DifferentiableFunctionInst *inst,
@@ -96,6 +101,25 @@ class DifferentiationTransformer {
96101
SILBuilder &builder, SILLocation loc,
97102
DifferentiationInvoker invoker);
98103

104+
/// Emits a reference to a derivative function of `original`, differentiated
105+
/// with respect to a superset of `desiredIndices`. Returns the `SILValue` for
106+
/// the derivative function and the actual indices that the derivative
107+
/// function is with respect to.
108+
///
109+
/// Returns `None` on failure, signifying that a diagnostic has been emitted
110+
/// using `invoker`.
111+
std::optional<std::pair<SILValue, AutoDiffConfig>>
112+
emitDerivativeFunctionReference(
113+
SILBuilder &builder, const AutoDiffConfig &desiredConfig,
114+
AutoDiffDerivativeFunctionKind kind, SILValue original,
115+
DifferentiationInvoker invoker,
116+
SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc);
117+
118+
/// If the given function corresponds to AutoClosureExpr with either
119+
/// SingleCurryThunk or DoubleCurryThunk kind, get the SILFunction
120+
/// corresponding to the function being wrapped in the thunk.
121+
SILFunction *getUnwrappedCurryThunkFunction(SILFunction *originalFn);
122+
99123
public:
100124
/// Construct an `DifferentiationTransformer` for the given module.
101125
explicit DifferentiationTransformer(SILModuleTransform &transform)
@@ -453,21 +477,40 @@ static SILValue reapplyFunctionConversion(
453477
llvm_unreachable("Unhandled function conversion instruction");
454478
}
455479

456-
/// Emits a reference to a derivative function of `original`, differentiated
457-
/// with respect to a superset of `desiredIndices`. Returns the `SILValue` for
458-
/// the derivative function and the actual indices that the derivative function
459-
/// is with respect to.
460-
///
461-
/// Returns `None` on failure, signifying that a diagnostic has been emitted
462-
/// using `invoker`.
463-
static std::optional<std::pair<SILValue, AutoDiffConfig>>
464-
emitDerivativeFunctionReference(
465-
DifferentiationTransformer &transformer, SILBuilder &builder,
466-
const AutoDiffConfig &desiredConfig, AutoDiffDerivativeFunctionKind kind,
467-
SILValue original, DifferentiationInvoker invoker,
468-
SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) {
469-
ADContext &context = transformer.getContext();
480+
SILFunction *DifferentiationTransformer::getUnwrappedCurryThunkFunction(
481+
SILFunction *originalFn) {
482+
auto *abstractCE = originalFn->getDeclRef().getAbstractClosureExpr();
483+
if (abstractCE == nullptr)
484+
return nullptr;
485+
auto *autoCE = dyn_cast<AutoClosureExpr>(abstractCE);
486+
if (autoCE == nullptr)
487+
return nullptr;
488+
489+
auto *afd =
490+
cast<AbstractFunctionDecl>(autoCE->getUnwrappedCurryThunkCalledValue());
491+
492+
auto silFnIt = afdToSILFn.find(afd);
493+
if (silFnIt == afdToSILFn.end()) {
494+
assert(afdToSILFn.empty());
470495

496+
SILModule *module = getTransform().getModule();
497+
for (SILFunction &currentFunc : module->getFunctions())
498+
if (auto *currentAFD = currentFunc.getDeclRef().getAbstractFunctionDecl())
499+
afdToSILFn.emplace(currentAFD, &currentFunc);
500+
501+
silFnIt = afdToSILFn.find(afd);
502+
assert(silFnIt != afdToSILFn.end());
503+
}
504+
505+
return silFnIt->second;
506+
}
507+
508+
std::optional<std::pair<SILValue, AutoDiffConfig>>
509+
DifferentiationTransformer::emitDerivativeFunctionReference(
510+
SILBuilder &builder, const AutoDiffConfig &desiredConfig,
511+
AutoDiffDerivativeFunctionKind kind, SILValue original,
512+
DifferentiationInvoker invoker,
513+
SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) {
471514
// If `original` is itself an `DifferentiableFunctionExtractInst` whose kind
472515
// matches the given kind and desired differentiation parameter indices,
473516
// simply extract the derivative function of its function operand, retain the
@@ -610,26 +653,36 @@ emitDerivativeFunctionReference(
610653
DifferentiabilityKind::Reverse, desiredParameterIndices,
611654
desiredResultIndices, derivativeConstrainedGenSig, /*jvp*/ nullptr,
612655
/*vjp*/ nullptr, /*isSerialized*/ false);
613-
if (transformer.canonicalizeDifferentiabilityWitness(
614-
minimalWitness, invoker, IsNotSerialized))
656+
if (canonicalizeDifferentiabilityWitness(minimalWitness, invoker,
657+
IsNotSerialized))
615658
return std::nullopt;
616659
}
617660
assert(minimalWitness);
618-
if (original->getFunction()->isSerialized() &&
619-
!hasPublicVisibility(minimalWitness->getLinkage())) {
620-
enum { Inlinable = 0, DefaultArgument = 1 };
621-
unsigned fragileKind = Inlinable;
622-
// FIXME: This is not a very robust way of determining if the function is
623-
// a default argument. Also, we have not exhaustively listed all the kinds
624-
// of fragility.
625-
if (original->getFunction()->getLinkage() == SILLinkage::PublicNonABI)
626-
fragileKind = DefaultArgument;
627-
context.emitNondifferentiabilityError(
628-
original, invoker, diag::autodiff_private_derivative_from_fragile,
629-
fragileKind,
630-
isa_and_nonnull<AbstractClosureExpr>(
631-
originalFRI->getLoc().getAsASTNode<Expr>()));
632-
return std::nullopt;
661+
if (original->getFunction()->isSerialized()) {
662+
// When dealing with curry thunk, look at the function being wrapped
663+
// inside implicit closure. If it has public visibility, the corresponding
664+
// differentiability witness also has public visibility. It should be OK
665+
// for implicit wrapper closure and its witness to have private linkage.
666+
SILFunction *unwrappedFn = getUnwrappedCurryThunkFunction(originalFn);
667+
bool isWitnessPublic =
668+
unwrappedFn == nullptr
669+
? hasPublicVisibility(minimalWitness->getLinkage())
670+
: hasPublicVisibility(unwrappedFn->getLinkage());
671+
if (!isWitnessPublic) {
672+
enum { Inlinable = 0, DefaultArgument = 1 };
673+
unsigned fragileKind = Inlinable;
674+
// FIXME: This is not a very robust way of determining if the function
675+
// is a default argument. Also, we have not exhaustively listed all the
676+
// kinds of fragility.
677+
if (original->getFunction()->getLinkage() == SILLinkage::PublicNonABI)
678+
fragileKind = DefaultArgument;
679+
context.emitNondifferentiabilityError(
680+
original, invoker, diag::autodiff_private_derivative_from_fragile,
681+
fragileKind,
682+
isa_and_nonnull<AbstractClosureExpr>(
683+
originalFRI->getLoc().getAsASTNode<Expr>()));
684+
return std::nullopt;
685+
}
633686
}
634687
// TODO(TF-482): Move generic requirement checking logic to
635688
// `getExactDifferentiabilityWitness` and
@@ -1121,8 +1174,8 @@ SILValue DifferentiationTransformer::promoteToDifferentiableFunction(
11211174
for (auto derivativeFnKind : {AutoDiffDerivativeFunctionKind::JVP,
11221175
AutoDiffDerivativeFunctionKind::VJP}) {
11231176
auto derivativeFnAndIndices = emitDerivativeFunctionReference(
1124-
*this, builder, desiredConfig, derivativeFnKind, origFnOperand,
1125-
invoker, newBuffersToDealloc);
1177+
builder, desiredConfig, derivativeFnKind, origFnOperand, invoker,
1178+
newBuffersToDealloc);
11261179
// Show an error at the operator, highlight the argument, and show a note
11271180
// at the definition site of the argument.
11281181
if (!derivativeFnAndIndices)

test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -772,32 +772,6 @@ public func fragileDifferentiable(_ x: Float) -> Float {
772772
implicitlyDifferentiableFromFragile(x)
773773
}
774774

775-
776-
// FIXME: Differentiable curry thunk RequirementMachine error (rdar://87429620, https://github.com/apple/swift/issues/54819).
777-
#if false
778-
// TF-1208: Test curry thunk differentiation regression.
779-
public struct Struct_54819<Scalar> {
780-
var x: Scalar
781-
}
782-
extension Struct_54819: Differentiable where Scalar: Differentiable {
783-
@differentiable(reverse)
784-
public static func id(x: Self) -> Self {
785-
return x
786-
}
787-
}
788-
@differentiable(reverse, wrt: x)
789-
public func f_54819<Scalar: Differentiable>(
790-
_ x: Struct_54819<Scalar>,
791-
// NOTE(TF-1208): This diagnostic is unexpected because `Struct_54819.id` is marked `@differentiable`.
792-
// xpected-error @+3 2 {{function is not differentiable}}
793-
// xpected-note @+2 {{differentiated functions in '@inlinable' functions must be marked '@differentiable' or have a public '@derivative'; this is not possible with a closure, make a top-level function instead}}
794-
// xpected-note @+1 {{opaque non-'@differentiable' function is not differentiable}}
795-
reduction: @differentiable(reverse) (Struct_54819<Scalar>) -> Struct_54819<Scalar> = Struct_54819.id
796-
) -> Struct_54819<Scalar> {
797-
reduction(x)
798-
}
799-
#endif
800-
801775
//===----------------------------------------------------------------------===//
802776
// Coroutines (SIL function yields, `begin_apply`) (not yet supported)
803777
//===----------------------------------------------------------------------===//
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
// RUN: %target-swift-frontend -emit-sil -verify -primary-file %s -o /dev/null
2+
3+
import _Differentiation
4+
5+
/// Minimal reproducer for both single and double curry thunk
6+
7+
@inlinable
8+
func caller<Thing: Differentiable & FloatingPoint>(
9+
of f: @differentiable(reverse) (_: Thing) -> Thing
10+
) -> Int where Thing.TangentVector == Thing {
11+
return 42
12+
}
13+
14+
public struct Struct<Thing: Differentiable & FloatingPoint>: Differentiable where Thing.TangentVector == Thing {
15+
@inlinable
16+
static func foo_single() -> Int {
17+
return caller(of: callee_single) // No error expected
18+
}
19+
20+
@inlinable
21+
@differentiable(reverse)
22+
static func callee_single(input: Thing) -> Thing {
23+
return input
24+
}
25+
26+
@inlinable
27+
func foo_double() -> Int {
28+
return caller(of: callee_double) // No error expected
29+
}
30+
31+
@inlinable
32+
@differentiable(reverse)
33+
func callee_double(input: Thing) -> Thing {
34+
return input
35+
}
36+
}
37+
38+
/// Reproducer from issue https://github.com/swiftlang/swift/issues/75776
39+
40+
public struct Solution2<Thing: Differentiable & FloatingPoint>: Differentiable where Thing.TangentVector == Thing {
41+
@inlinable
42+
public static func optimization() -> Thing {
43+
var initial = Thing.zero
44+
let (_, delta) = valueWithGradient(at: initial, of: simulationWithLoss) // No error expected
45+
initial.move(by: delta)
46+
return initial
47+
}
48+
49+
@inlinable
50+
@differentiable(reverse)
51+
static func simulationWithLoss(input: Thing) -> Thing {
52+
return input // implementation
53+
}
54+
}
55+
56+
/// Reproducer from https://github.com/swiftlang/swift/issues/54819
57+
58+
public struct TF_688_Struct<Scalar> {
59+
var x: Scalar
60+
}
61+
extension TF_688_Struct: Differentiable where Scalar: Differentiable {
62+
@differentiable(reverse)
63+
public static func id(x: Self) -> Self {
64+
return x
65+
}
66+
}
67+
@differentiable(reverse, wrt: x)
68+
public func TF_688<Scalar: Differentiable>(
69+
_ x: TF_688_Struct<Scalar>,
70+
reduction: @differentiable(reverse) (TF_688_Struct<Scalar>) -> TF_688_Struct<Scalar> = TF_688_Struct.id // No error expected
71+
) -> TF_688_Struct<Scalar> {
72+
reduction(x)
73+
}

test/AutoDiff/compiler_crashers/rdar87429620-differentiable-curry-thunk-reqmachine.swift renamed to test/AutoDiff/compiler_crashers_fixed/rdar87429620-differentiable-curry-thunk-reqmachine.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
// RUN: %target-swift-frontend -emit-sil -verify %s
2-
// XFAIL: *
32

43
// rdar://87429620
54
// https://github.com/apple/swift/issues/54819

0 commit comments

Comments
 (0)