diff --git a/include/swift/AST/Attr.h b/include/swift/AST/Attr.h index 1a49d388c2267..6bfcd68a38d91 100644 --- a/include/swift/AST/Attr.h +++ b/include/swift/AST/Attr.h @@ -1938,6 +1938,10 @@ class DerivativeAttr final friend TrailingObjects; friend class DerivativeAttrOriginalDeclRequest; + /// The declaration on which the `@derivative` attribute is declared. + /// May not be a valid declaration for `@derivative` attributes. + /// Resolved during parsing and deserialization. + Decl *OriginalDeclaration = nullptr; /// The base type for the referenced original declaration. This field is /// non-null only for parsed attributes that reference a qualified original /// declaration. This field is not serialized; type-checking uses it to @@ -1991,6 +1995,12 @@ class DerivativeAttr final DeclNameRefWithLoc original, IndexSubset *parameterIndices); + Decl *getOriginalDeclaration() const { return OriginalDeclaration; } + + /// Sets the original declaration on which this attribute is declared. + /// Should only be used by parsing and deserialization. + void setOriginalDeclaration(Decl *originalDeclaration); + TypeRepr *getBaseTypeRepr() const { return BaseTypeRepr; } DeclNameRefWithLoc getOriginalFunctionName() const { return OriginalFunctionName; diff --git a/include/swift/AST/Decl.h b/include/swift/AST/Decl.h index e86fff132de50..573a640eb2164 100644 --- a/include/swift/AST/Decl.h +++ b/include/swift/AST/Decl.h @@ -6265,8 +6265,11 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl { DerivativeFunctionConfigurationList *DerivativeFunctionConfigs = nullptr; public: - /// Get all derivative function configurations. - ArrayRef getDerivativeFunctionConfigurations(); + /// Get all derivative function configurations. If `lookInNonPrimarySources` + /// is true then lookup is done in non-primary sources as well. Note that + /// such lookup might end in cycles if done during sema stages. + ArrayRef + getDerivativeFunctionConfigurations(bool lookInNonPrimarySources = true); /// Add the given derivative function configuration. void addDerivativeFunctionConfiguration(const AutoDiffConfig &config); diff --git a/lib/AST/Attr.cpp b/lib/AST/Attr.cpp index 427e94c4336fb..8b98bb33fdc4e 100644 --- a/lib/AST/Attr.cpp +++ b/lib/AST/Attr.cpp @@ -2107,6 +2107,13 @@ void DerivativeAttr::setOriginalFunctionResolver( ResolverContextData = resolverContextData; } +void DerivativeAttr::setOriginalDeclaration(Decl *originalDeclaration) { + assert(originalDeclaration && "Original declaration must be non-null"); + assert(!OriginalDeclaration && + "Original declaration cannot have already been set"); + OriginalDeclaration = originalDeclaration; +} + TransposeAttr::TransposeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange, TypeRepr *baseTypeRepr, DeclNameRefWithLoc originalName, diff --git a/lib/AST/Decl.cpp b/lib/AST/Decl.cpp index 81cd76ae1b5dc..2300941ae3b3a 100644 --- a/lib/AST/Decl.cpp +++ b/lib/AST/Decl.cpp @@ -20,6 +20,7 @@ #include "swift/AST/ASTContext.h" #include "swift/AST/ASTWalker.h" #include "swift/AST/ASTMangler.h" +#include "swift/AST/Attr.h" #include "swift/AST/CaptureInfo.h" #include "swift/AST/DiagnosticEngine.h" #include "swift/AST/DiagnosticsSema.h" @@ -8310,7 +8311,7 @@ void AbstractFunctionDecl::prepareDerivativeFunctionConfigurations() { } ArrayRef -AbstractFunctionDecl::getDerivativeFunctionConfigurations() { +AbstractFunctionDecl::getDerivativeFunctionConfigurations(bool lookInNonPrimarySources) { prepareDerivativeFunctionConfigurations(); // Resolve derivative function configurations from `@differentiable` @@ -8333,6 +8334,37 @@ AbstractFunctionDecl::getDerivativeFunctionConfigurations() { ctx.loadDerivativeFunctionConfigurations(this, previousGeneration, *DerivativeFunctionConfigs); } + + class DerivativeFinder : public ASTWalker { + const AbstractFunctionDecl *AFD; + public: + DerivativeFinder(const AbstractFunctionDecl *afd) : AFD(afd) {} + + bool walkToDeclPre(Decl *D) override { + if (auto *afd = dyn_cast(D)) { + for (auto *derAttr : afd->getAttrs().getAttributes()) { + // Resolve derivative function configurations from `@derivative` + // attributes by type-checking them. + if (AFD->getName().matchesRef( + derAttr->getOriginalFunctionName().Name.getFullName())) { + (void)derAttr->getOriginalFunction(afd->getASTContext()); + return false; + } + } + } + + return true; + } + }; + + // Load derivative configurations from @derivative attributes defined in + // non-primary sources. Note that it might trigger lookup cycles if called + // from inside Sema stages. + if (lookInNonPrimarySources) { + DerivativeFinder finder(this); + getParent()->walkContext(finder); + } + return DerivativeFunctionConfigs->getArrayRef(); } diff --git a/lib/Parse/ParseDecl.cpp b/lib/Parse/ParseDecl.cpp index 1b78f4e7ca892..547a06d9f9f5d 100644 --- a/lib/Parse/ParseDecl.cpp +++ b/lib/Parse/ParseDecl.cpp @@ -4455,6 +4455,8 @@ setOriginalDeclarationForDifferentiableAttributes(DeclAttributes attrs, Decl *D) { for (auto *attr : attrs.getAttributes()) const_cast(attr)->setOriginalDeclaration(D); + for (auto *attr : attrs.getAttributes()) + const_cast(attr)->setOriginalDeclaration(D); } /// Parse a single syntactic declaration and return a list of decl diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 109d911f1b7ef..1c252905c65ba 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -4949,10 +4949,11 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { /// - Stores the attribute in `ASTContext::DerivativeAttrs`. /// /// \returns true on error, false on success. -static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D, - DerivativeAttr *attr) { +static bool typeCheckDerivativeAttr(DerivativeAttr *attr) { // Note: Implementation must be idempotent because it may be called multiple // times for the same attribute. + Decl *D = attr->getOriginalDeclaration(); + auto &Ctx = D->getASTContext(); auto &diags = Ctx.Diags; // `@derivative` attribute requires experimental differentiable programming // to be enabled. @@ -5365,13 +5366,18 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D, } void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) { - if (typeCheckDerivativeAttr(Ctx, D, attr)) + if (typeCheckDerivativeAttr(attr)) attr->setInvalid(); } AbstractFunctionDecl * DerivativeAttrOriginalDeclRequest::evaluate(Evaluator &evaluator, DerivativeAttr *attr) const { + // Try to resolve the original function. + if (attr->isValid() && attr->OriginalFunction.isNull()) + if (typeCheckDerivativeAttr(attr)) + attr->setInvalid(); + // If the typechecker has resolved the original function, return it. if (auto *FD = attr->OriginalFunction.dyn_cast()) return FD; diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index bddbb877d9217..6c3071a8db904 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -379,7 +379,8 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req, bool foundExactConfig = false; Optional supersetConfig = None; for (auto witnessConfig : - witnessAFD->getDerivativeFunctionConfigurations()) { + witnessAFD->getDerivativeFunctionConfigurations( + /*lookInNonPrimarySources*/ false)) { // All the witness's derivative generic requirements must be satisfied // by the requirement's derivative generic requirements OR by the // conditional conformance requirements. diff --git a/lib/Serialization/Deserialization.cpp b/lib/Serialization/Deserialization.cpp index 9a6068e02478f..9d6c3dbf0de2c 100644 --- a/lib/Serialization/Deserialization.cpp +++ b/lib/Serialization/Deserialization.cpp @@ -15,6 +15,7 @@ #include "ModuleFile.h" #include "ModuleFormat.h" #include "swift/AST/ASTContext.h" +#include "swift/AST/Attr.h" #include "swift/AST/AutoDiff.h" #include "swift/AST/DiagnosticsSema.h" #include "swift/AST/Expr.h" @@ -2590,6 +2591,10 @@ static void setOriginalDeclarationAndParameterIndicesInDifferentiableAttributes( diffAttr->setOriginalDeclaration(decl); diffAttr->setParameterIndices(diffAttrParamIndicesMap[diffAttr]); } + for (auto *attr : tempAttrs.getAttributes()) { + auto *derAttr = const_cast(attr); + derAttr->setOriginalDeclaration(decl); + } } Decl *ModuleFile::getDecl(DeclID DID) { diff --git a/lib/Serialization/Serialization.cpp b/lib/Serialization/Serialization.cpp index c4ab6747b39e3..25aba2a6c18bc 100644 --- a/lib/Serialization/Serialization.cpp +++ b/lib/Serialization/Serialization.cpp @@ -2780,7 +2780,7 @@ class Serializer::DeclSerializer : public DeclVisitor { auto abbrCode = S.DeclTypeAbbrCodes[DerivativeDeclAttrLayout::Code]; auto *attr = cast(DA); auto &ctx = S.getASTContext(); - assert(attr->getOriginalFunction(ctx) && + assert(attr->getOriginalFunction(ctx) && attr->getOriginalDeclaration() && "`@derivative` attribute should have original declaration set " "during construction or parsing"); auto origDeclNameRef = attr->getOriginalFunctionName(); diff --git a/test/AutoDiff/SILOptimizer/differentiation_diagnostics_cross_file.swift b/test/AutoDiff/SILOptimizer/differentiation_diagnostics_cross_file.swift index 54e13f9c0d918..539e6f1c34d2e 100644 --- a/test/AutoDiff/SILOptimizer/differentiation_diagnostics_cross_file.swift +++ b/test/AutoDiff/SILOptimizer/differentiation_diagnostics_cross_file.swift @@ -13,14 +13,11 @@ func crossFileDifferentiableAttr( } // TF-1272: Test original function with registered derivatives in other files. -// FIXME(TF-1272): Find a way to type-check `@derivative` attributes in other -// files. @differentiable(reverse) func crossFileDerivativeAttr( _ input: T ) -> T { - // expected-error @+2 {{expression is not differentiable}} - // expected-note @+1 {{cannot differentiate functions that have not been marked '@differentiable' and that are defined in other files}} + // No error expected return input.identityDerivativeAttr() }