51
51
#include " llvm/ADT/SmallSet.h"
52
52
#include " llvm/Support/CommandLine.h"
53
53
54
+ #include < unordered_map>
55
+
54
56
using namespace swift ;
55
57
using namespace swift ::autodiff;
56
58
using llvm::DenseMap;
@@ -84,6 +86,9 @@ class DifferentiationTransformer {
84
86
// / Context necessary for performing the transformations.
85
87
ADContext context;
86
88
89
+ // / Cache used in getUnwrappedCurryThunkFunction.
90
+ std::unordered_map<AbstractFunctionDecl *, SILFunction *> afdToSILFn;
91
+
87
92
// / Promotes the given `differentiable_function` instruction to a valid
88
93
// / `@differentiable` function-typed value.
89
94
SILValue promoteToDifferentiableFunction (DifferentiableFunctionInst *inst,
@@ -96,6 +101,25 @@ class DifferentiationTransformer {
96
101
SILBuilder &builder, SILLocation loc,
97
102
DifferentiationInvoker invoker);
98
103
104
+ // / Emits a reference to a derivative function of `original`, differentiated
105
+ // / with respect to a superset of `desiredIndices`. Returns the `SILValue` for
106
+ // / the derivative function and the actual indices that the derivative
107
+ // / function is with respect to.
108
+ // /
109
+ // / Returns `None` on failure, signifying that a diagnostic has been emitted
110
+ // / using `invoker`.
111
+ std::optional<std::pair<SILValue, AutoDiffConfig>>
112
+ emitDerivativeFunctionReference (
113
+ SILBuilder &builder, const AutoDiffConfig &desiredConfig,
114
+ AutoDiffDerivativeFunctionKind kind, SILValue original,
115
+ DifferentiationInvoker invoker,
116
+ SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc);
117
+
118
+ // / If the given function corresponds to AutoClosureExpr with either
119
+ // / SingleCurryThunk or DoubleCurryThunk kind, get the SILFunction
120
+ // / corresponding to the function being wrapped in the thunk.
121
+ SILFunction *getUnwrappedCurryThunkFunction (SILFunction *originalFn);
122
+
99
123
public:
100
124
// / Construct an `DifferentiationTransformer` for the given module.
101
125
explicit DifferentiationTransformer (SILModuleTransform &transform)
@@ -453,21 +477,40 @@ static SILValue reapplyFunctionConversion(
453
477
llvm_unreachable (" Unhandled function conversion instruction" );
454
478
}
455
479
456
- // / Emits a reference to a derivative function of `original`, differentiated
457
- // / with respect to a superset of `desiredIndices`. Returns the `SILValue` for
458
- // / the derivative function and the actual indices that the derivative function
459
- // / is with respect to.
460
- // /
461
- // / Returns `None` on failure, signifying that a diagnostic has been emitted
462
- // / using `invoker`.
463
- static std::optional<std::pair<SILValue, AutoDiffConfig>>
464
- emitDerivativeFunctionReference (
465
- DifferentiationTransformer &transformer, SILBuilder &builder,
466
- const AutoDiffConfig &desiredConfig, AutoDiffDerivativeFunctionKind kind,
467
- SILValue original, DifferentiationInvoker invoker,
468
- SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) {
469
- ADContext &context = transformer.getContext ();
480
+ SILFunction *DifferentiationTransformer::getUnwrappedCurryThunkFunction (
481
+ SILFunction *originalFn) {
482
+ auto *abstractCE = originalFn->getDeclRef ().getAbstractClosureExpr ();
483
+ if (abstractCE == nullptr )
484
+ return nullptr ;
485
+ auto *autoCE = dyn_cast<AutoClosureExpr>(abstractCE);
486
+ if (autoCE == nullptr )
487
+ return nullptr ;
488
+
489
+ auto *afd =
490
+ cast<AbstractFunctionDecl>(autoCE->getUnwrappedCurryThunkCalledValue ());
491
+
492
+ auto silFnIt = afdToSILFn.find (afd);
493
+ if (silFnIt == afdToSILFn.end ()) {
494
+ assert (afdToSILFn.empty ());
470
495
496
+ SILModule *module = getTransform ().getModule ();
497
+ for (SILFunction ¤tFunc : module->getFunctions ())
498
+ if (auto *currentAFD = currentFunc.getDeclRef ().getAbstractFunctionDecl ())
499
+ afdToSILFn.emplace (currentAFD, ¤tFunc);
500
+
501
+ silFnIt = afdToSILFn.find (afd);
502
+ assert (silFnIt != afdToSILFn.end ());
503
+ }
504
+
505
+ return silFnIt->second ;
506
+ }
507
+
508
+ std::optional<std::pair<SILValue, AutoDiffConfig>>
509
+ DifferentiationTransformer::emitDerivativeFunctionReference (
510
+ SILBuilder &builder, const AutoDiffConfig &desiredConfig,
511
+ AutoDiffDerivativeFunctionKind kind, SILValue original,
512
+ DifferentiationInvoker invoker,
513
+ SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) {
471
514
// If `original` is itself an `DifferentiableFunctionExtractInst` whose kind
472
515
// matches the given kind and desired differentiation parameter indices,
473
516
// simply extract the derivative function of its function operand, retain the
@@ -610,26 +653,36 @@ emitDerivativeFunctionReference(
610
653
DifferentiabilityKind::Reverse, desiredParameterIndices,
611
654
desiredResultIndices, derivativeConstrainedGenSig, /* jvp*/ nullptr ,
612
655
/* vjp*/ nullptr , /* isSerialized*/ false );
613
- if (transformer. canonicalizeDifferentiabilityWitness (
614
- minimalWitness, invoker, IsNotSerialized))
656
+ if (canonicalizeDifferentiabilityWitness (minimalWitness, invoker,
657
+ IsNotSerialized))
615
658
return std::nullopt;
616
659
}
617
660
assert (minimalWitness);
618
- if (original->getFunction ()->isSerialized () &&
619
- !hasPublicVisibility (minimalWitness->getLinkage ())) {
620
- enum { Inlinable = 0 , DefaultArgument = 1 };
621
- unsigned fragileKind = Inlinable;
622
- // FIXME: This is not a very robust way of determining if the function is
623
- // a default argument. Also, we have not exhaustively listed all the kinds
624
- // of fragility.
625
- if (original->getFunction ()->getLinkage () == SILLinkage::PublicNonABI)
626
- fragileKind = DefaultArgument;
627
- context.emitNondifferentiabilityError (
628
- original, invoker, diag::autodiff_private_derivative_from_fragile,
629
- fragileKind,
630
- isa_and_nonnull<AbstractClosureExpr>(
631
- originalFRI->getLoc ().getAsASTNode <Expr>()));
632
- return std::nullopt;
661
+ if (original->getFunction ()->isSerialized ()) {
662
+ // When dealing with curry thunk, look at the function being wrapped
663
+ // inside implicit closure. If it has public visibility, the corresponding
664
+ // differentiability witness also has public visibility. It should be OK
665
+ // for implicit wrapper closure and its witness to have private linkage.
666
+ SILFunction *unwrappedFn = getUnwrappedCurryThunkFunction (originalFn);
667
+ bool isWitnessPublic =
668
+ unwrappedFn == nullptr
669
+ ? hasPublicVisibility (minimalWitness->getLinkage ())
670
+ : hasPublicVisibility (unwrappedFn->getLinkage ());
671
+ if (!isWitnessPublic) {
672
+ enum { Inlinable = 0 , DefaultArgument = 1 };
673
+ unsigned fragileKind = Inlinable;
674
+ // FIXME: This is not a very robust way of determining if the function
675
+ // is a default argument. Also, we have not exhaustively listed all the
676
+ // kinds of fragility.
677
+ if (original->getFunction ()->getLinkage () == SILLinkage::PublicNonABI)
678
+ fragileKind = DefaultArgument;
679
+ context.emitNondifferentiabilityError (
680
+ original, invoker, diag::autodiff_private_derivative_from_fragile,
681
+ fragileKind,
682
+ isa_and_nonnull<AbstractClosureExpr>(
683
+ originalFRI->getLoc ().getAsASTNode <Expr>()));
684
+ return std::nullopt;
685
+ }
633
686
}
634
687
// TODO(TF-482): Move generic requirement checking logic to
635
688
// `getExactDifferentiabilityWitness` and
@@ -1121,8 +1174,8 @@ SILValue DifferentiationTransformer::promoteToDifferentiableFunction(
1121
1174
for (auto derivativeFnKind : {AutoDiffDerivativeFunctionKind::JVP,
1122
1175
AutoDiffDerivativeFunctionKind::VJP}) {
1123
1176
auto derivativeFnAndIndices = emitDerivativeFunctionReference (
1124
- * this , builder, desiredConfig, derivativeFnKind, origFnOperand,
1125
- invoker, newBuffersToDealloc);
1177
+ builder, desiredConfig, derivativeFnKind, origFnOperand, invoker ,
1178
+ newBuffersToDealloc);
1126
1179
// Show an error at the operator, highlight the argument, and show a note
1127
1180
// at the definition site of the argument.
1128
1181
if (!derivativeFnAndIndices)
0 commit comments