Skip to content

Commit a1e138b

Browse files
authored
[AutoDiff] Implement cross-file lookup of derivatives (#58644)
Look-up for functions with @Derivative attributes defined in non-primary source files Fixes #55170
1 parent f387076 commit a1e138b

File tree

10 files changed

+75
-12
lines changed

10 files changed

+75
-12
lines changed

include/swift/AST/Attr.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1938,6 +1938,10 @@ class DerivativeAttr final
19381938
friend TrailingObjects;
19391939
friend class DerivativeAttrOriginalDeclRequest;
19401940

1941+
/// The declaration on which the `@derivative` attribute is declared.
1942+
/// May not be a valid declaration for `@derivative` attributes.
1943+
/// Resolved during parsing and deserialization.
1944+
Decl *OriginalDeclaration = nullptr;
19411945
/// The base type for the referenced original declaration. This field is
19421946
/// non-null only for parsed attributes that reference a qualified original
19431947
/// declaration. This field is not serialized; type-checking uses it to
@@ -1991,6 +1995,12 @@ class DerivativeAttr final
19911995
DeclNameRefWithLoc original,
19921996
IndexSubset *parameterIndices);
19931997

1998+
Decl *getOriginalDeclaration() const { return OriginalDeclaration; }
1999+
2000+
/// Sets the original declaration on which this attribute is declared.
2001+
/// Should only be used by parsing and deserialization.
2002+
void setOriginalDeclaration(Decl *originalDeclaration);
2003+
19942004
TypeRepr *getBaseTypeRepr() const { return BaseTypeRepr; }
19952005
DeclNameRefWithLoc getOriginalFunctionName() const {
19962006
return OriginalFunctionName;

include/swift/AST/Decl.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6265,8 +6265,11 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
62656265
DerivativeFunctionConfigurationList *DerivativeFunctionConfigs = nullptr;
62666266

62676267
public:
6268-
/// Get all derivative function configurations.
6269-
ArrayRef<AutoDiffConfig> getDerivativeFunctionConfigurations();
6268+
/// Get all derivative function configurations. If `lookInNonPrimarySources`
6269+
/// is true then lookup is done in non-primary sources as well. Note that
6270+
/// such lookup might end in cycles if done during sema stages.
6271+
ArrayRef<AutoDiffConfig>
6272+
getDerivativeFunctionConfigurations(bool lookInNonPrimarySources = true);
62706273

62716274
/// Add the given derivative function configuration.
62726275
void addDerivativeFunctionConfiguration(const AutoDiffConfig &config);

lib/AST/Attr.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2107,6 +2107,13 @@ void DerivativeAttr::setOriginalFunctionResolver(
21072107
ResolverContextData = resolverContextData;
21082108
}
21092109

2110+
void DerivativeAttr::setOriginalDeclaration(Decl *originalDeclaration) {
2111+
assert(originalDeclaration && "Original declaration must be non-null");
2112+
assert(!OriginalDeclaration &&
2113+
"Original declaration cannot have already been set");
2114+
OriginalDeclaration = originalDeclaration;
2115+
}
2116+
21102117
TransposeAttr::TransposeAttr(bool implicit, SourceLoc atLoc,
21112118
SourceRange baseRange, TypeRepr *baseTypeRepr,
21122119
DeclNameRefWithLoc originalName,

lib/AST/Decl.cpp

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "swift/AST/ASTContext.h"
2121
#include "swift/AST/ASTWalker.h"
2222
#include "swift/AST/ASTMangler.h"
23+
#include "swift/AST/Attr.h"
2324
#include "swift/AST/CaptureInfo.h"
2425
#include "swift/AST/DiagnosticEngine.h"
2526
#include "swift/AST/DiagnosticsSema.h"
@@ -8311,7 +8312,7 @@ void AbstractFunctionDecl::prepareDerivativeFunctionConfigurations() {
83118312
}
83128313

83138314
ArrayRef<AutoDiffConfig>
8314-
AbstractFunctionDecl::getDerivativeFunctionConfigurations() {
8315+
AbstractFunctionDecl::getDerivativeFunctionConfigurations(bool lookInNonPrimarySources) {
83158316
prepareDerivativeFunctionConfigurations();
83168317

83178318
// Resolve derivative function configurations from `@differentiable`
@@ -8334,6 +8335,37 @@ AbstractFunctionDecl::getDerivativeFunctionConfigurations() {
83348335
ctx.loadDerivativeFunctionConfigurations(this, previousGeneration,
83358336
*DerivativeFunctionConfigs);
83368337
}
8338+
8339+
class DerivativeFinder : public ASTWalker {
8340+
const AbstractFunctionDecl *AFD;
8341+
public:
8342+
DerivativeFinder(const AbstractFunctionDecl *afd) : AFD(afd) {}
8343+
8344+
bool walkToDeclPre(Decl *D) override {
8345+
if (auto *afd = dyn_cast<AbstractFunctionDecl>(D)) {
8346+
for (auto *derAttr : afd->getAttrs().getAttributes<DerivativeAttr>()) {
8347+
// Resolve derivative function configurations from `@derivative`
8348+
// attributes by type-checking them.
8349+
if (AFD->getName().matchesRef(
8350+
derAttr->getOriginalFunctionName().Name.getFullName())) {
8351+
(void)derAttr->getOriginalFunction(afd->getASTContext());
8352+
return false;
8353+
}
8354+
}
8355+
}
8356+
8357+
return true;
8358+
}
8359+
};
8360+
8361+
// Load derivative configurations from @derivative attributes defined in
8362+
// non-primary sources. Note that it might trigger lookup cycles if called
8363+
// from inside Sema stages.
8364+
if (lookInNonPrimarySources) {
8365+
DerivativeFinder finder(this);
8366+
getParent()->walkContext(finder);
8367+
}
8368+
83378369
return DerivativeFunctionConfigs->getArrayRef();
83388370
}
83398371

lib/Parse/ParseDecl.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4457,6 +4457,8 @@ setOriginalDeclarationForDifferentiableAttributes(DeclAttributes attrs,
44574457
Decl *D) {
44584458
for (auto *attr : attrs.getAttributes<DifferentiableAttr>())
44594459
const_cast<DifferentiableAttr *>(attr)->setOriginalDeclaration(D);
4460+
for (auto *attr : attrs.getAttributes<DerivativeAttr>())
4461+
const_cast<DerivativeAttr *>(attr)->setOriginalDeclaration(D);
44604462
}
44614463

44624464
/// Parse a single syntactic declaration and return a list of decl

lib/Sema/TypeCheckAttr.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4956,10 +4956,11 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
49564956
/// - Stores the attribute in `ASTContext::DerivativeAttrs`.
49574957
///
49584958
/// \returns true on error, false on success.
4959-
static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
4960-
DerivativeAttr *attr) {
4959+
static bool typeCheckDerivativeAttr(DerivativeAttr *attr) {
49614960
// Note: Implementation must be idempotent because it may be called multiple
49624961
// times for the same attribute.
4962+
Decl *D = attr->getOriginalDeclaration();
4963+
auto &Ctx = D->getASTContext();
49634964
auto &diags = Ctx.Diags;
49644965
// `@derivative` attribute requires experimental differentiable programming
49654966
// to be enabled.
@@ -5372,13 +5373,18 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
53725373
}
53735374

53745375
void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
5375-
if (typeCheckDerivativeAttr(Ctx, D, attr))
5376+
if (typeCheckDerivativeAttr(attr))
53765377
attr->setInvalid();
53775378
}
53785379

53795380
AbstractFunctionDecl *
53805381
DerivativeAttrOriginalDeclRequest::evaluate(Evaluator &evaluator,
53815382
DerivativeAttr *attr) const {
5383+
// Try to resolve the original function.
5384+
if (attr->isValid() && attr->OriginalFunction.isNull())
5385+
if (typeCheckDerivativeAttr(attr))
5386+
attr->setInvalid();
5387+
53825388
// If the typechecker has resolved the original function, return it.
53835389
if (auto *FD = attr->OriginalFunction.dyn_cast<AbstractFunctionDecl *>())
53845390
return FD;

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,8 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
379379
bool foundExactConfig = false;
380380
Optional<AutoDiffConfig> supersetConfig = None;
381381
for (auto witnessConfig :
382-
witnessAFD->getDerivativeFunctionConfigurations()) {
382+
witnessAFD->getDerivativeFunctionConfigurations(
383+
/*lookInNonPrimarySources*/ false)) {
383384
// All the witness's derivative generic requirements must be satisfied
384385
// by the requirement's derivative generic requirements OR by the
385386
// conditional conformance requirements.

lib/Serialization/Deserialization.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "ModuleFile.h"
1616
#include "ModuleFormat.h"
1717
#include "swift/AST/ASTContext.h"
18+
#include "swift/AST/Attr.h"
1819
#include "swift/AST/AutoDiff.h"
1920
#include "swift/AST/DiagnosticsSema.h"
2021
#include "swift/AST/Expr.h"
@@ -2590,6 +2591,10 @@ static void setOriginalDeclarationAndParameterIndicesInDifferentiableAttributes(
25902591
diffAttr->setOriginalDeclaration(decl);
25912592
diffAttr->setParameterIndices(diffAttrParamIndicesMap[diffAttr]);
25922593
}
2594+
for (auto *attr : tempAttrs.getAttributes<DerivativeAttr>()) {
2595+
auto *derAttr = const_cast<DerivativeAttr *>(attr);
2596+
derAttr->setOriginalDeclaration(decl);
2597+
}
25932598
}
25942599

25952600
Decl *ModuleFile::getDecl(DeclID DID) {

lib/Serialization/Serialization.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2780,7 +2780,7 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
27802780
auto abbrCode = S.DeclTypeAbbrCodes[DerivativeDeclAttrLayout::Code];
27812781
auto *attr = cast<DerivativeAttr>(DA);
27822782
auto &ctx = S.getASTContext();
2783-
assert(attr->getOriginalFunction(ctx) &&
2783+
assert(attr->getOriginalFunction(ctx) && attr->getOriginalDeclaration() &&
27842784
"`@derivative` attribute should have original declaration set "
27852785
"during construction or parsing");
27862786
auto origDeclNameRef = attr->getOriginalFunctionName();

test/AutoDiff/SILOptimizer/differentiation_diagnostics_cross_file.swift

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,11 @@ func crossFileDifferentiableAttr<T: Protocol>(
1313
}
1414

1515
// TF-1272: Test original function with registered derivatives in other files.
16-
// FIXME(TF-1272): Find a way to type-check `@derivative` attributes in other
17-
// files.
1816
@differentiable(reverse)
1917
func crossFileDerivativeAttr<T: Protocol>(
2018
_ input: T
2119
) -> T {
22-
// expected-error @+2 {{expression is not differentiable}}
23-
// expected-note @+1 {{cannot differentiate functions that have not been marked '@differentiable' and that are defined in other files}}
20+
// No error expected
2421
return input.identityDerivativeAttr()
2522
}
2623

0 commit comments

Comments
 (0)