diff --git a/include/swift/AST/SourceFile.h b/include/swift/AST/SourceFile.h index 9a0cce71f24d5..5ad7038a7661d 100644 --- a/include/swift/AST/SourceFile.h +++ b/include/swift/AST/SourceFile.h @@ -14,11 +14,13 @@ #define SWIFT_AST_SOURCEFILE_H #include "swift/AST/FileUnit.h" +#include "swift/AST/SynthesizedFileUnit.h" #include "swift/Basic/Debug.h" namespace swift { class PersistentParserState; +class SynthesizedFileUnit; /// A file containing Swift source code. /// @@ -135,6 +137,9 @@ class SourceFile final : public FileUnit { /// same module. mutable Identifier PrivateDiscriminator; + /// A synthesized file corresponding to this file, created on-demand. + SynthesizedFileUnit *SynthesizedFile = nullptr; + /// The root TypeRefinementContext for this SourceFile. /// /// This is set during type checking. @@ -447,6 +452,11 @@ class SourceFile final : public FileUnit { Identifier getPrivateDiscriminator() const { return PrivateDiscriminator; } Optional getBasicLocsForDecl(const Decl *D) const override; + /// Returns the synthesized file for this source file, if it exists. + SynthesizedFileUnit *getSynthesizedFile() const { return SynthesizedFile; }; + + SynthesizedFileUnit &getOrCreateSynthesizedFile(); + virtual bool walk(ASTWalker &walker) override; ReferencedNameTracker *getLegacyReferencedNameTracker() { diff --git a/include/swift/AST/SynthesizedFileUnit.h b/include/swift/AST/SynthesizedFileUnit.h index 747dc04d1f322..88e524a4419dd 100644 --- a/include/swift/AST/SynthesizedFileUnit.h +++ b/include/swift/AST/SynthesizedFileUnit.h @@ -18,8 +18,15 @@ namespace swift { -/// A container for synthesized module-level declarations. +class SourceFile; + +/// A container for synthesized declarations, attached to a `SourceFile`. +/// +/// Currently, only module-level synthesized declarations are supported. class SynthesizedFileUnit final : public FileUnit { + /// The parent source file. + SourceFile &SF; + /// Synthesized top level declarations. TinyPtrVector TopLevelDecls; @@ -29,9 +36,12 @@ class SynthesizedFileUnit final : public FileUnit { mutable Identifier PrivateDiscriminator; public: - SynthesizedFileUnit(ModuleDecl &M); + SynthesizedFileUnit(SourceFile &SF); ~SynthesizedFileUnit() = default; + /// Returns the parent source file. + SourceFile &getSourceFile() const { return SF; } + /// Add a synthesized top-level declaration. void addTopLevelDecl(ValueDecl *D) { TopLevelDecls.push_back(D); } diff --git a/include/swift/SILOptimizer/Utils/Differentiation/ADContext.h b/include/swift/SILOptimizer/Utils/Differentiation/ADContext.h index b2f84a4e02b52..96e7b97cf9674 100644 --- a/include/swift/SILOptimizer/Utils/Differentiation/ADContext.h +++ b/include/swift/SILOptimizer/Utils/Differentiation/ADContext.h @@ -67,9 +67,6 @@ class ADContext { /// Shared pass manager. SILPassManager &passManager; - /// A synthesized file unit. - SynthesizedFileUnit *synthesizedFile = nullptr; - /// The worklist (stack) of `differentiable_function` instructions to be /// processed. llvm::SmallVector @@ -126,9 +123,10 @@ class ADContext { SILPassManager &getPassManager() const { return passManager; } Lowering::TypeConverter &getTypeConverter() { return module.Types; } - /// Get or create a synthesized file for adding generated linear map structs - /// and branching trace enums. Used by `LinearMapInfo`. - SynthesizedFileUnit &getOrCreateSynthesizedFile(); + /// Get or create the synthesized file for the given `SILFunction`. + /// Used by `LinearMapInfo` for adding generated linear map struct and + /// branching trace enum declarations. + SynthesizedFileUnit &getOrCreateSynthesizedFile(SILFunction *original); /// Returns true if the `differentiable_function` instruction worklist is /// empty. diff --git a/lib/AST/Module.cpp b/lib/AST/Module.cpp index a3bf9c37aa769..a2616d0a285db 100644 --- a/lib/AST/Module.cpp +++ b/lib/AST/Module.cpp @@ -2569,7 +2569,6 @@ ASTScope &SourceFile::getScope() { return *Scope.get(); } - Identifier SourceFile::getDiscriminatorForPrivateValue(const ValueDecl *D) const { assert(D->getDeclContext()->getModuleScopeContext() == this); @@ -2611,6 +2610,14 @@ SourceFile::getDiscriminatorForPrivateValue(const ValueDecl *D) const { return PrivateDiscriminator; } +SynthesizedFileUnit &SourceFile::getOrCreateSynthesizedFile() { + if (SynthesizedFile) + return *SynthesizedFile; + SynthesizedFile = new (getASTContext()) SynthesizedFileUnit(*this); + getParentModule()->addFile(*SynthesizedFile); + return *SynthesizedFile; +} + TypeRefinementContext *SourceFile::getTypeRefinementContext() { return TRC; } @@ -2674,9 +2681,9 @@ SourceFile::lookupOpaqueResultType(StringRef MangledName) { // SynthesizedFileUnit Implementation //===----------------------------------------------------------------------===// -SynthesizedFileUnit::SynthesizedFileUnit(ModuleDecl &M) - : FileUnit(FileUnitKind::Synthesized, M) { - M.getASTContext().addDestructorCleanup(*this); +SynthesizedFileUnit::SynthesizedFileUnit(SourceFile &SF) + : FileUnit(FileUnitKind::Synthesized, *SF.getParentModule()), SF(SF) { + SF.getASTContext().addDestructorCleanup(*this); } Identifier diff --git a/lib/SILOptimizer/Utils/Differentiation/ADContext.cpp b/lib/SILOptimizer/Utils/Differentiation/ADContext.cpp index ad82e1520d19f..cc1b616ecd947 100644 --- a/lib/SILOptimizer/Utils/Differentiation/ADContext.cpp +++ b/lib/SILOptimizer/Utils/Differentiation/ADContext.cpp @@ -58,14 +58,22 @@ ADContext::ADContext(SILModuleTransform &transform) : transform(transform), module(*transform.getModule()), passManager(*transform.getPassManager()) {} -SynthesizedFileUnit &ADContext::getOrCreateSynthesizedFile() { - if (synthesizedFile) - return *synthesizedFile; - auto *moduleDecl = module.getSwiftModule(); - auto &ctx = moduleDecl->getASTContext(); - synthesizedFile = new (ctx) SynthesizedFileUnit(*moduleDecl); - moduleDecl->addFile(*synthesizedFile); - return *synthesizedFile; +/// Get the source file for the given `SILFunction`. +static SourceFile &getSourceFile(SILFunction *f) { + if (f->hasLocation()) + if (auto *declContext = f->getLocation().getAsDeclContext()) + if (auto *parentSourceFile = declContext->getParentSourceFile()) + return *parentSourceFile; + for (auto *file : f->getModule().getSwiftModule()->getFiles()) + if (auto *sourceFile = dyn_cast(file)) + return *sourceFile; + llvm_unreachable("Could not resolve SourceFile from SILFunction"); +} + +SynthesizedFileUnit & +ADContext::getOrCreateSynthesizedFile(SILFunction *original) { + auto &SF = getSourceFile(original); + return SF.getOrCreateSynthesizedFile(); } FuncDecl *ADContext::getPlusDecl() const { diff --git a/lib/SILOptimizer/Utils/Differentiation/LinearMapInfo.cpp b/lib/SILOptimizer/Utils/Differentiation/LinearMapInfo.cpp index c34ba90de8167..053aa3924f69d 100644 --- a/lib/SILOptimizer/Utils/Differentiation/LinearMapInfo.cpp +++ b/lib/SILOptimizer/Utils/Differentiation/LinearMapInfo.cpp @@ -58,7 +58,7 @@ LinearMapInfo::LinearMapInfo(ADContext &context, AutoDiffLinearMapKind kind, const DifferentiableActivityInfo &activityInfo) : kind(kind), original(original), derivative(derivative), activityInfo(activityInfo), indices(indices), - synthesizedFile(context.getOrCreateSynthesizedFile()), + synthesizedFile(context.getOrCreateSynthesizedFile(original)), typeConverter(context.getTypeConverter()) { generateDifferentiationDataStructures(context, derivative); } diff --git a/lib/Serialization/Serialization.cpp b/lib/Serialization/Serialization.cpp index 3d08ac479e320..4ea8cb3e7858a 100644 --- a/lib/Serialization/Serialization.cpp +++ b/lib/Serialization/Serialization.cpp @@ -32,6 +32,7 @@ #include "swift/AST/ProtocolConformance.h" #include "swift/AST/RawComment.h" #include "swift/AST/SourceFile.h" +#include "swift/AST/SynthesizedFileUnit.h" #include "swift/AST/TypeCheckRequests.h" #include "swift/AST/TypeVisitor.h" #include "swift/Basic/Dwarf.h" @@ -59,8 +60,8 @@ #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/OnDiskHashTable.h" #include "llvm/Support/Path.h" -#include "llvm/Support/raw_ostream.h" #include "llvm/Support/SmallVectorMemoryBuffer.h" +#include "llvm/Support/raw_ostream.h" #include @@ -1903,7 +1904,7 @@ bool Serializer::isDeclXRef(const Decl *D) const { const DeclContext *topLevel = D->getDeclContext()->getModuleScopeContext(); if (topLevel->getParentModule() != M) return true; - if (!SF || topLevel == SF) + if (!SF || topLevel == SF || topLevel == SF->getSynthesizedFile()) return false; // Special-case for SIL generic parameter decls, which don't have a real // DeclContext. @@ -4976,6 +4977,8 @@ void Serializer::writeAST(ModuleOrSourceFile DC) { SmallVector Scratch; if (SF) { Scratch.push_back(SF); + if (auto *synthesizedFile = SF->getSynthesizedFile()) + Scratch.push_back(synthesizedFile); files = llvm::makeArrayRef(Scratch); } else { files = M->getFiles(); diff --git a/lib/TBDGen/TBDGen.cpp b/lib/TBDGen/TBDGen.cpp index 5f0b42360f9ad..8e2577c31c93a 100644 --- a/lib/TBDGen/TBDGen.cpp +++ b/lib/TBDGen/TBDGen.cpp @@ -23,6 +23,7 @@ #include "swift/AST/Module.h" #include "swift/AST/ParameterList.h" #include "swift/AST/PropertyWrappers.h" +#include "swift/AST/SynthesizedFileUnit.h" #include "swift/AST/TBDGenRequests.h" #include "swift/Basic/LLVM.h" #include "swift/ClangImporter/ClangImporter.h" @@ -1153,6 +1154,10 @@ GenerateTBDRequest::evaluate(Evaluator &evaluator, if (auto *singleFile = desc.getSingleFile()) { assert(M == singleFile->getParentModule() && "mismatched file and module"); visitFile(singleFile); + // Visit synthesized file, if it exists. + if (auto *SF = dyn_cast(singleFile)) + if (auto *synthesizedFile = SF->getSynthesizedFile()) + visitFile(synthesizedFile); } else { llvm::SmallVector Modules; Modules.push_back(M); diff --git a/test/AutoDiff/compiler_crashers_fixed/Inputs/tf1202-differentiability-witness-dead-function-elimination.swift b/test/AutoDiff/compiler_crashers_fixed/Inputs/tf1202-differentiability-witness-dead-function-elimination.swift new file mode 100644 index 0000000000000..252b13e0a9cf6 --- /dev/null +++ b/test/AutoDiff/compiler_crashers_fixed/Inputs/tf1202-differentiability-witness-dead-function-elimination.swift @@ -0,0 +1,9 @@ +import _Differentiation + +@inlinable +@differentiable(where T: Differentiable) +public func identity(_ x: T) -> T { x } + +public func foo(_ f: @differentiable (T) -> T = identity) -> T { + fatalError() +} diff --git a/test/AutoDiff/compiler_crashers_fixed/tf1202-differentiability-witness-dead-function-elimination.swift b/test/AutoDiff/compiler_crashers_fixed/tf1202-differentiability-witness-dead-function-elimination.swift new file mode 100644 index 0000000000000..e0564a9adfe31 --- /dev/null +++ b/test/AutoDiff/compiler_crashers_fixed/tf1202-differentiability-witness-dead-function-elimination.swift @@ -0,0 +1,17 @@ +// RUN: %empty-directory(%t) +// RUN: %target-build-swift -emit-module -module-name tf1202 -emit-module-path %t/tf1202.swiftmodule %S/Inputs/tf1202-differentiability-witness-dead-function-elimination.swift +// RUN: %target-build-swift -I%t -emit-module -O %s + +// TF-1202: test bug where DeadFunctionElimination eliminated the +// SILFunction for `func identity` even though a differentiability witness for it +// exists. This causes deserialization of this module to crash when +// trying to deserialize the differentiability witness because it can't find +// the original function `func identity`. + +// TF-1239: Test `SynthesizedFileUnit` serialization. + +import tf1202 + +func callit() -> Float { + return foo() +} diff --git a/test/AutoDiff/validation-test/Inputs/cross_module_differentiation_other.swift b/test/AutoDiff/validation-test/Inputs/cross_module_differentiation_other.swift new file mode 100644 index 0000000000000..3243bf2dffe73 --- /dev/null +++ b/test/AutoDiff/validation-test/Inputs/cross_module_differentiation_other.swift @@ -0,0 +1,14 @@ +import _Differentiation + +@differentiable +public func defaultArgument(_ x: Float) -> Float { + return x +} + +@differentiable +public func applyArgument( + _ x: Float, + _ f: @differentiable (Float) -> Float = defaultArgument +) -> Float { + return f(x) +} diff --git a/test/AutoDiff/validation-test/cross_module_differentiation.swift b/test/AutoDiff/validation-test/cross_module_differentiation.swift new file mode 100644 index 0000000000000..4c191829148ee --- /dev/null +++ b/test/AutoDiff/validation-test/cross_module_differentiation.swift @@ -0,0 +1,28 @@ +// RUN: %empty-directory(%t) +// RUN: %target-build-swift -working-directory %t -parse-as-library -emit-module -module-name cross_module_differentiation_other -emit-module-path %t/cross_module_differentiation_other.swiftmodule -emit-library -static %S/Inputs/cross_module_differentiation_other.swift +// RUN: %target-build-swift -I%t -L%t %s -o %t/a.out -lcross_module_differentiation_other +// RUN: %target-run %t/a.out +// REQUIRES: executable_test + +// TF-1025: Test differentiability witness linkage for `PublicNonABI` original functions. +// TF-1239: Test `SynthesizedFileUnit` TBDGen. + +import cross_module_differentiation_other +import _Differentiation +import StdlibUnittest + +var CrossModuleTests = TestSuite("E2ECrossModule") + +CrossModuleTests.test("differentiable function default argument") { + let actualGrad = gradient(at: 0) { applyArgument($0) } + let expectedGrad: Float = 1 + expectEqual(actualGrad, expectedGrad) +} + +CrossModuleTests.test("differentiable function specified default argument") { + let actualGrad = gradient(at: 0) { applyArgument($0, { $0 }) } + let expectedGrad: Float = 1 + expectEqual(actualGrad, expectedGrad) +} + +runAllTests()