Skip to content

Commit 8c961fa

Browse files
committed
[AutoDiff] Support custom derivatives for @_alwaysEmitIntoClient functions
TODO: - Multi file module (function and derivative in different files) - Single file module: fix crash if we have function with `@_alwaysEmitIntoClient` and its derivative w/o this attribute Fixes swiftlang#54445
1 parent 09d122a commit 8c961fa

File tree

4 files changed

+62
-6
lines changed

4 files changed

+62
-6
lines changed

lib/SIL/IR/Linker.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,21 @@ void SILLinkerVisitor::maybeAddFunctionToWorklist(
159159
// HiddenExternal linkage when they are declarations, then they
160160
// become Shared after the body has been deserialized.
161161
// So try deserializing HiddenExternal functions too.
162-
if (linkage == SILLinkage::HiddenExternal)
163-
return deserializeAndPushToWorklist(F);
164-
162+
if (linkage == SILLinkage::HiddenExternal) {
163+
deserializeAndPushToWorklist(F);
164+
if (F->getDeclRef().getLinkage(ForDefinition) != SILLinkage::PublicNonABI)
165+
return;
166+
for (SILDifferentiabilityWitness &witness :
167+
Mod.getDifferentiabilityWitnesses()) {
168+
if (witness.getOriginalFunction() != F)
169+
continue;
170+
Mod.getSILLoader()->lookupDifferentiabilityWitness(witness.getKey());
171+
deserializeAndPushToWorklist(witness.getJVP());
172+
deserializeAndPushToWorklist(witness.getVJP());
173+
}
174+
return;
175+
}
176+
165177
// Update the linkage of the function in case it's different in the serialized
166178
// SIL than derived from the AST. This can be the case with cross-module-
167179
// optimizations.

lib/SILOptimizer/Differentiation/Common.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -537,10 +537,15 @@ SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness(
537537
"SILGen should create differentiability witnesses for all function "
538538
"definitions with explicit differentiable attributes");
539539

540+
bool isOrigAlwaysEmitIntoClient =
541+
original->getDeclRef().getLinkage(ForDefinition) ==
542+
SILLinkage::PublicNonABI;
540543
return SILDifferentiabilityWitness::createDeclaration(
541-
module, SILLinkage::PublicExternal, original, kind,
542-
minimalConfig->parameterIndices, minimalConfig->resultIndices,
543-
minimalConfig->derivativeGenericSignature);
544+
module,
545+
isOrigAlwaysEmitIntoClient ? SILLinkage::PublicNonABI
546+
: SILLinkage::PublicExternal,
547+
original, kind, minimalConfig->parameterIndices,
548+
minimalConfig->resultIndices, minimalConfig->derivativeGenericSignature);
544549
}
545550

546551
} // end namespace autodiff
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import _Differentiation
2+
3+
@_alwaysEmitIntoClient
4+
public func f(_ x: Float) -> Float {
5+
x
6+
}
7+
8+
@derivative(of: f)
9+
@_alwaysEmitIntoClient
10+
public func df(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
11+
(x, { 42 * $0 })
12+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// REQUIRES: executable_test
2+
// RUN: %empty-directory(%t)
3+
4+
// RUN: %target-build-swift-dylib(%t/%target-library-name(SingleFileModule)) %S/Inputs/always_emit_into_client/SingleFileModule/file.swift \
5+
// RUN: -emit-module -emit-module-path %t/SingleFileModule.swiftmodule -module-name SingleFileModule
6+
// RUN: %target-build-swift -I%t -L%t %s -lm -lSingleFileModule -o %t/a.out %target-rpath(%t)
7+
// RUN: %target-run %t/a.out
8+
9+
// RUN: %target-build-swift -I%t %s -emit-ir | %FileCheck %s
10+
11+
import StdlibUnittest
12+
13+
import SingleFileModule
14+
15+
import _Differentiation
16+
17+
var AlwaysEmitIntoClientTests = TestSuite("AlwaysEmitIntoClient")
18+
19+
AlwaysEmitIntoClientTests.test("registration") {
20+
expectEqual(42, gradient(at: 0, of: f))
21+
expectEqual(42, gradient(at: 1, of: f))
22+
expectEqual(42, gradient(at: 2, of: f))
23+
}
24+
25+
runAllTests()
26+
27+
// CHECK: @"16SingleFileModule1fyS2fFWJrSpSr" = weak_odr hidden global { ptr, ptr } { ptr @"$s16SingleFileModule1fyS2fFTJfSpSr", ptr @"$s16SingleFileModule1fyS2fFTJrSpSr" }, align 8

0 commit comments

Comments
 (0)