Skip to content

Commit 3999b8b

Browse files
authored
Merge pull request #30912 from dan-zheng/fix-synthesized-file-unit
Fix SynthesizedFileUnit serialization and TBDGen issues.
2 parents 83e176b + c142fcf commit 3999b8b

File tree

12 files changed

+132
-23
lines changed

12 files changed

+132
-23
lines changed

include/swift/AST/SourceFile.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414
#define SWIFT_AST_SOURCEFILE_H
1515

1616
#include "swift/AST/FileUnit.h"
17+
#include "swift/AST/SynthesizedFileUnit.h"
1718
#include "swift/Basic/Debug.h"
1819

1920
namespace swift {
2021

2122
class PersistentParserState;
23+
class SynthesizedFileUnit;
2224

2325
/// A file containing Swift source code.
2426
///
@@ -135,6 +137,9 @@ class SourceFile final : public FileUnit {
135137
/// same module.
136138
mutable Identifier PrivateDiscriminator;
137139

140+
/// A synthesized file corresponding to this file, created on-demand.
141+
SynthesizedFileUnit *SynthesizedFile = nullptr;
142+
138143
/// The root TypeRefinementContext for this SourceFile.
139144
///
140145
/// This is set during type checking.
@@ -447,6 +452,11 @@ class SourceFile final : public FileUnit {
447452
Identifier getPrivateDiscriminator() const { return PrivateDiscriminator; }
448453
Optional<BasicDeclLocs> getBasicLocsForDecl(const Decl *D) const override;
449454

455+
/// Returns the synthesized file for this source file, if it exists.
456+
SynthesizedFileUnit *getSynthesizedFile() const { return SynthesizedFile; };
457+
458+
SynthesizedFileUnit &getOrCreateSynthesizedFile();
459+
450460
virtual bool walk(ASTWalker &walker) override;
451461

452462
ReferencedNameTracker *getLegacyReferencedNameTracker() {

include/swift/AST/SynthesizedFileUnit.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,15 @@
1818

1919
namespace swift {
2020

21-
/// A container for synthesized module-level declarations.
21+
class SourceFile;
22+
23+
/// A container for synthesized declarations, attached to a `SourceFile`.
24+
///
25+
/// Currently, only module-level synthesized declarations are supported.
2226
class SynthesizedFileUnit final : public FileUnit {
27+
/// The parent source file.
28+
SourceFile &SF;
29+
2330
/// Synthesized top level declarations.
2431
TinyPtrVector<ValueDecl *> TopLevelDecls;
2532

@@ -29,9 +36,12 @@ class SynthesizedFileUnit final : public FileUnit {
2936
mutable Identifier PrivateDiscriminator;
3037

3138
public:
32-
SynthesizedFileUnit(ModuleDecl &M);
39+
SynthesizedFileUnit(SourceFile &SF);
3340
~SynthesizedFileUnit() = default;
3441

42+
/// Returns the parent source file.
43+
SourceFile &getSourceFile() const { return SF; }
44+
3545
/// Add a synthesized top-level declaration.
3646
void addTopLevelDecl(ValueDecl *D) { TopLevelDecls.push_back(D); }
3747

include/swift/SILOptimizer/Utils/Differentiation/ADContext.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,6 @@ class ADContext {
6767
/// Shared pass manager.
6868
SILPassManager &passManager;
6969

70-
/// A synthesized file unit.
71-
SynthesizedFileUnit *synthesizedFile = nullptr;
72-
7370
/// The worklist (stack) of `differentiable_function` instructions to be
7471
/// processed.
7572
llvm::SmallVector<DifferentiableFunctionInst *, 32>
@@ -126,9 +123,10 @@ class ADContext {
126123
SILPassManager &getPassManager() const { return passManager; }
127124
Lowering::TypeConverter &getTypeConverter() { return module.Types; }
128125

129-
/// Get or create a synthesized file for adding generated linear map structs
130-
/// and branching trace enums. Used by `LinearMapInfo`.
131-
SynthesizedFileUnit &getOrCreateSynthesizedFile();
126+
/// Get or create the synthesized file for the given `SILFunction`.
127+
/// Used by `LinearMapInfo` for adding generated linear map struct and
128+
/// branching trace enum declarations.
129+
SynthesizedFileUnit &getOrCreateSynthesizedFile(SILFunction *original);
132130

133131
/// Returns true if the `differentiable_function` instruction worklist is
134132
/// empty.

lib/AST/Module.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2569,7 +2569,6 @@ ASTScope &SourceFile::getScope() {
25692569
return *Scope.get();
25702570
}
25712571

2572-
25732572
Identifier
25742573
SourceFile::getDiscriminatorForPrivateValue(const ValueDecl *D) const {
25752574
assert(D->getDeclContext()->getModuleScopeContext() == this);
@@ -2611,6 +2610,14 @@ SourceFile::getDiscriminatorForPrivateValue(const ValueDecl *D) const {
26112610
return PrivateDiscriminator;
26122611
}
26132612

2613+
SynthesizedFileUnit &SourceFile::getOrCreateSynthesizedFile() {
2614+
if (SynthesizedFile)
2615+
return *SynthesizedFile;
2616+
SynthesizedFile = new (getASTContext()) SynthesizedFileUnit(*this);
2617+
getParentModule()->addFile(*SynthesizedFile);
2618+
return *SynthesizedFile;
2619+
}
2620+
26142621
TypeRefinementContext *SourceFile::getTypeRefinementContext() {
26152622
return TRC;
26162623
}
@@ -2674,9 +2681,9 @@ SourceFile::lookupOpaqueResultType(StringRef MangledName) {
26742681
// SynthesizedFileUnit Implementation
26752682
//===----------------------------------------------------------------------===//
26762683

2677-
SynthesizedFileUnit::SynthesizedFileUnit(ModuleDecl &M)
2678-
: FileUnit(FileUnitKind::Synthesized, M) {
2679-
M.getASTContext().addDestructorCleanup(*this);
2684+
SynthesizedFileUnit::SynthesizedFileUnit(SourceFile &SF)
2685+
: FileUnit(FileUnitKind::Synthesized, *SF.getParentModule()), SF(SF) {
2686+
SF.getASTContext().addDestructorCleanup(*this);
26802687
}
26812688

26822689
Identifier

lib/SILOptimizer/Utils/Differentiation/ADContext.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,22 @@ ADContext::ADContext(SILModuleTransform &transform)
5858
: transform(transform), module(*transform.getModule()),
5959
passManager(*transform.getPassManager()) {}
6060

61-
SynthesizedFileUnit &ADContext::getOrCreateSynthesizedFile() {
62-
if (synthesizedFile)
63-
return *synthesizedFile;
64-
auto *moduleDecl = module.getSwiftModule();
65-
auto &ctx = moduleDecl->getASTContext();
66-
synthesizedFile = new (ctx) SynthesizedFileUnit(*moduleDecl);
67-
moduleDecl->addFile(*synthesizedFile);
68-
return *synthesizedFile;
61+
/// Get the source file for the given `SILFunction`.
62+
static SourceFile &getSourceFile(SILFunction *f) {
63+
if (f->hasLocation())
64+
if (auto *declContext = f->getLocation().getAsDeclContext())
65+
if (auto *parentSourceFile = declContext->getParentSourceFile())
66+
return *parentSourceFile;
67+
for (auto *file : f->getModule().getSwiftModule()->getFiles())
68+
if (auto *sourceFile = dyn_cast<SourceFile>(file))
69+
return *sourceFile;
70+
llvm_unreachable("Could not resolve SourceFile from SILFunction");
71+
}
72+
73+
SynthesizedFileUnit &
74+
ADContext::getOrCreateSynthesizedFile(SILFunction *original) {
75+
auto &SF = getSourceFile(original);
76+
return SF.getOrCreateSynthesizedFile();
6977
}
7078

7179
FuncDecl *ADContext::getPlusDecl() const {

lib/SILOptimizer/Utils/Differentiation/LinearMapInfo.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ LinearMapInfo::LinearMapInfo(ADContext &context, AutoDiffLinearMapKind kind,
5858
const DifferentiableActivityInfo &activityInfo)
5959
: kind(kind), original(original), derivative(derivative),
6060
activityInfo(activityInfo), indices(indices),
61-
synthesizedFile(context.getOrCreateSynthesizedFile()),
61+
synthesizedFile(context.getOrCreateSynthesizedFile(original)),
6262
typeConverter(context.getTypeConverter()) {
6363
generateDifferentiationDataStructures(context, derivative);
6464
}

lib/Serialization/Serialization.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "swift/AST/ProtocolConformance.h"
3333
#include "swift/AST/RawComment.h"
3434
#include "swift/AST/SourceFile.h"
35+
#include "swift/AST/SynthesizedFileUnit.h"
3536
#include "swift/AST/TypeCheckRequests.h"
3637
#include "swift/AST/TypeVisitor.h"
3738
#include "swift/Basic/Dwarf.h"
@@ -59,8 +60,8 @@
5960
#include "llvm/Support/MemoryBuffer.h"
6061
#include "llvm/Support/OnDiskHashTable.h"
6162
#include "llvm/Support/Path.h"
62-
#include "llvm/Support/raw_ostream.h"
6363
#include "llvm/Support/SmallVectorMemoryBuffer.h"
64+
#include "llvm/Support/raw_ostream.h"
6465

6566
#include <vector>
6667

@@ -1903,7 +1904,7 @@ bool Serializer::isDeclXRef(const Decl *D) const {
19031904
const DeclContext *topLevel = D->getDeclContext()->getModuleScopeContext();
19041905
if (topLevel->getParentModule() != M)
19051906
return true;
1906-
if (!SF || topLevel == SF)
1907+
if (!SF || topLevel == SF || topLevel == SF->getSynthesizedFile())
19071908
return false;
19081909
// Special-case for SIL generic parameter decls, which don't have a real
19091910
// DeclContext.
@@ -4976,6 +4977,8 @@ void Serializer::writeAST(ModuleOrSourceFile DC) {
49764977
SmallVector<const FileUnit *, 1> Scratch;
49774978
if (SF) {
49784979
Scratch.push_back(SF);
4980+
if (auto *synthesizedFile = SF->getSynthesizedFile())
4981+
Scratch.push_back(synthesizedFile);
49794982
files = llvm::makeArrayRef(Scratch);
49804983
} else {
49814984
files = M->getFiles();

lib/TBDGen/TBDGen.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "swift/AST/Module.h"
2424
#include "swift/AST/ParameterList.h"
2525
#include "swift/AST/PropertyWrappers.h"
26+
#include "swift/AST/SynthesizedFileUnit.h"
2627
#include "swift/AST/TBDGenRequests.h"
2728
#include "swift/Basic/LLVM.h"
2829
#include "swift/ClangImporter/ClangImporter.h"
@@ -1153,6 +1154,10 @@ GenerateTBDRequest::evaluate(Evaluator &evaluator,
11531154
if (auto *singleFile = desc.getSingleFile()) {
11541155
assert(M == singleFile->getParentModule() && "mismatched file and module");
11551156
visitFile(singleFile);
1157+
// Visit synthesized file, if it exists.
1158+
if (auto *SF = dyn_cast<SourceFile>(singleFile))
1159+
if (auto *synthesizedFile = SF->getSynthesizedFile())
1160+
visitFile(synthesizedFile);
11561161
} else {
11571162
llvm::SmallVector<ModuleDecl*, 4> Modules;
11581163
Modules.push_back(M);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import _Differentiation
2+
3+
@inlinable
4+
@differentiable(where T: Differentiable)
5+
public func identity<T>(_ x: T) -> T { x }
6+
7+
public func foo<T: Differentiable>(_ f: @differentiable (T) -> T = identity) -> T {
8+
fatalError()
9+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// RUN: %empty-directory(%t)
2+
// RUN: %target-build-swift -emit-module -module-name tf1202 -emit-module-path %t/tf1202.swiftmodule %S/Inputs/tf1202-differentiability-witness-dead-function-elimination.swift
3+
// RUN: %target-build-swift -I%t -emit-module -O %s
4+
5+
// TF-1202: test bug where DeadFunctionElimination eliminated the
6+
// SILFunction for `func identity<T>` even though a differentiability witness for it
7+
// exists. This causes deserialization of this module to crash when
8+
// trying to deserialize the differentiability witness because it can't find
9+
// the original function `func identity<T>`.
10+
11+
// TF-1239: Test `SynthesizedFileUnit` serialization.
12+
13+
import tf1202
14+
15+
func callit() -> Float {
16+
return foo()
17+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import _Differentiation
2+
3+
@differentiable
4+
public func defaultArgument(_ x: Float) -> Float {
5+
return x
6+
}
7+
8+
@differentiable
9+
public func applyArgument(
10+
_ x: Float,
11+
_ f: @differentiable (Float) -> Float = defaultArgument
12+
) -> Float {
13+
return f(x)
14+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// RUN: %empty-directory(%t)
2+
// RUN: %target-build-swift -working-directory %t -parse-as-library -emit-module -module-name cross_module_differentiation_other -emit-module-path %t/cross_module_differentiation_other.swiftmodule -emit-library -static %S/Inputs/cross_module_differentiation_other.swift
3+
// RUN: %target-build-swift -I%t -L%t %s -o %t/a.out -lcross_module_differentiation_other
4+
// RUN: %target-run %t/a.out
5+
// REQUIRES: executable_test
6+
7+
// TF-1025: Test differentiability witness linkage for `PublicNonABI` original functions.
8+
// TF-1239: Test `SynthesizedFileUnit` TBDGen.
9+
10+
import cross_module_differentiation_other
11+
import _Differentiation
12+
import StdlibUnittest
13+
14+
var CrossModuleTests = TestSuite("E2ECrossModule")
15+
16+
CrossModuleTests.test("differentiable function default argument") {
17+
let actualGrad = gradient(at: 0) { applyArgument($0) }
18+
let expectedGrad: Float = 1
19+
expectEqual(actualGrad, expectedGrad)
20+
}
21+
22+
CrossModuleTests.test("differentiable function specified default argument") {
23+
let actualGrad = gradient(at: 0) { applyArgument($0, { $0 }) }
24+
let expectedGrad: Float = 1
25+
expectEqual(actualGrad, expectedGrad)
26+
}
27+
28+
runAllTests()

0 commit comments

Comments
 (0)