Skip to content

Commit ffa5f81

Browse files
committed
[AutoDiff] Support custom derivatives for @_alwaysEmitIntoClient functions
Fixes #54445
1 parent e8fe5c3 commit ffa5f81

30 files changed

+467
-31
lines changed

include/swift/AST/DiagnosticsSema.def

+3
Original file line numberDiff line numberDiff line change
@@ -4299,6 +4299,9 @@ NOTE(derivative_attr_fix_access,none,
42994299
"mark the derivative function as "
43004300
"'%select{private|fileprivate|internal|package|@usableFromInline|@usableFromInline}0' "
43014301
"to match the original function", (AccessLevel))
4302+
ERROR(derivative_attr_always_emit_into_client_mismatch,none,
4303+
"either both or none of derivative and original function must have "
4304+
"@alwaysEmitIntoClient attribute", ())
43024305
ERROR(derivative_attr_static_method_mismatch_original,none,
43034306
"unexpected derivative function declaration; "
43044307
"%0 requires the derivative function %1 to be %select{an instance|a 'static'}2 method",

lib/SIL/IR/Linker.cpp

+17-3
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,23 @@ void SILLinkerVisitor::maybeAddFunctionToWorklist(
159159
// HiddenExternal linkage when they are declarations, then they
160160
// become Shared after the body has been deserialized.
161161
// So try deserializing HiddenExternal functions too.
162-
if (linkage == SILLinkage::HiddenExternal)
163-
return deserializeAndPushToWorklist(F);
164-
162+
if (linkage == SILLinkage::HiddenExternal) {
163+
deserializeAndPushToWorklist(F);
164+
if (!F->markedAsAlwaysEmitIntoClient())
165+
return;
166+
// For @_alwaysEmitIntoClient functions, we need to lookup its
167+
// differentiability witness and, if present, ask SILLoader to obtain its
168+
// definition. Otherwise, a linker error would occur due to undefined
169+
// reference to these symbols.
170+
for (SILDifferentiabilityWitness *witness :
171+
F->getModule().lookUpDifferentiabilityWitnessesForFunction(
172+
F->getName())) {
173+
F->getModule().getSILLoader()->lookupDifferentiabilityWitness(
174+
witness->getKey());
175+
}
176+
return;
177+
}
178+
165179
// Update the linkage of the function in case it's different in the serialized
166180
// SIL than derived from the AST. This can be the case with cross-module-
167181
// optimizations.

lib/SILGen/SILGen.cpp

+9-4
Original file line numberDiff line numberDiff line change
@@ -1400,14 +1400,19 @@ void SILGenModule::emitDifferentiabilityWitness(
14001400
auto *diffWitness = M.lookUpDifferentiabilityWitness(key);
14011401
if (!diffWitness) {
14021402
// Differentiability witnesses have the same linkage as the original
1403-
// function, stripping external.
1404-
auto linkage = stripExternalFromLinkage(originalFunction->getLinkage());
1403+
// function, stripping external. For @_alwaysEmitIntoClient original
1404+
// functions, force PublicNonABI linkage of the differentiability witness so
1405+
// we can serialize it (the original function itself might be HiddenExternal
1406+
// in this case if we only have declaration without definition).
1407+
auto linkage =
1408+
originalFunction->markedAsAlwaysEmitIntoClient()
1409+
? SILLinkage::PublicNonABI
1410+
: stripExternalFromLinkage(originalFunction->getLinkage());
14051411
diffWitness = SILDifferentiabilityWitness::createDefinition(
14061412
M, linkage, originalFunction, diffKind, silConfig.parameterIndices,
14071413
silConfig.resultIndices, config.derivativeGenericSignature,
14081414
/*jvp*/ nullptr, /*vjp*/ nullptr,
1409-
/*isSerialized*/ hasPublicVisibility(originalFunction->getLinkage()),
1410-
attr);
1415+
/*isSerialized*/ hasPublicVisibility(linkage), attr);
14111416
}
14121417

14131418
// Set derivative function in differentiability witness.

lib/SILGen/SILGenPoly.cpp

+8-2
Original file line numberDiff line numberDiff line change
@@ -6281,8 +6281,14 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
62816281
auto loc = customDerivativeFn->getLocation();
62826282
SILGenFunctionBuilder fb(*this);
62836283
// Derivative thunks have the same linkage as the original function, stripping
6284-
// external.
6285-
auto linkage = stripExternalFromLinkage(originalFn->getLinkage());
6284+
// external. For @_alwaysEmitIntoClient original functions, force PublicNonABI
6285+
// linkage of derivative thunks so we can serialize them (the original
6286+
// function itself might be HiddenExternal in this case if we only have
6287+
// declaration without definition).
6288+
auto linkage = originalFn->markedAsAlwaysEmitIntoClient()
6289+
? SILLinkage::PublicNonABI
6290+
: stripExternalFromLinkage(originalFn->getLinkage());
6291+
62866292
auto *thunk = fb.getOrCreateFunction(
62876293
loc, name, linkage, thunkFnTy, IsBare, IsNotTransparent,
62886294
customDerivativeFn->getSerializedKind(),

lib/SILOptimizer/Differentiation/Common.cpp

+8-3
Original file line numberDiff line numberDiff line change
@@ -538,9 +538,14 @@ SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness(
538538
"definitions with explicit differentiable attributes");
539539

540540
return SILDifferentiabilityWitness::createDeclaration(
541-
module, SILLinkage::PublicExternal, original, kind,
542-
minimalConfig->parameterIndices, minimalConfig->resultIndices,
543-
minimalConfig->derivativeGenericSignature);
541+
module,
542+
// Witness for @_alwaysEmitIntoClient original function must be emitted,
543+
// otherwise a linker error would occur due to undefined reference to the
544+
// witness symbol.
545+
original->markedAsAlwaysEmitIntoClient() ? SILLinkage::PublicNonABI
546+
: SILLinkage::PublicExternal,
547+
original, kind, minimalConfig->parameterIndices,
548+
minimalConfig->resultIndices, minimalConfig->derivativeGenericSignature);
544549
}
545550

546551
} // end namespace autodiff

lib/SILOptimizer/Mandatory/Differentiation.cpp

+8-4
Original file line numberDiff line numberDiff line change
@@ -911,10 +911,14 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
911911

912912
// We can generate empty JVP / VJP for functions available externally. These
913913
// functions have the same linkage as the original ones sans `external`
914-
// flag. Important exception here hidden_external functions as they are
915-
// serializable but corresponding hidden ones would be not and the SIL
916-
// verifier will fail. Patch `serializeFunctions` for this case.
917-
if (orig->getLinkage() == SILLinkage::HiddenExternal)
914+
// flag. Important exception here hidden_external non-@_alwaysEmitIntoClient
915+
// functions as they are serializable but corresponding hidden ones would be
916+
// not and the SIL verifier will fail. Patch `serializeFunctions` for this
917+
// case. For @_alwaysEmitIntoClient original functions (which might be
918+
// HiddenExternal if we only have declaration without definition), we want
919+
// derivatives to be serialized and do not patch `serializeFunctions`.
920+
if (orig->getLinkage() == SILLinkage::HiddenExternal &&
921+
!orig->markedAsAlwaysEmitIntoClient())
918922
serializeFunctions = IsNotSerialized;
919923

920924
// If the JVP doesn't exist, need to synthesize it.

lib/Sema/TypeCheckAttr.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -6825,6 +6825,13 @@ static bool typeCheckDerivativeAttr(DerivativeAttr *attr) {
68256825
return true;
68266826
}
68276827

6828+
if (originalAFD->getAttrs().hasAttribute<AlwaysEmitIntoClientAttr>() !=
6829+
derivative->getAttrs().hasAttribute<AlwaysEmitIntoClientAttr>()) {
6830+
diags.diagnose(derivative->getLoc(),
6831+
diag::derivative_attr_always_emit_into_client_mismatch);
6832+
return true;
6833+
}
6834+
68286835
// Get the resolved differentiability parameter indices.
68296836
auto *resolvedDiffParamIndices = attr->getParameterIndices();
68306837

stdlib/public/Differentiation/SIMDDifferentiation.swift.gyb

+2-4
Original file line numberDiff line numberDiff line change
@@ -405,9 +405,6 @@ where
405405
}
406406
}
407407

408-
// FIXME(TF-1103): Derivative registration does not yet support
409-
// `@_alwaysEmitIntoClient` original functions like `SIMD.sum()`.
410-
/*
411408
extension SIMD
412409
where
413410
Self: Differentiable,
@@ -417,6 +414,7 @@ where
417414
TangentVector == Self
418415
{
419416
@inlinable
417+
@_alwaysEmitIntoClient
420418
@derivative(of: sum)
421419
func _vjpSum() -> (
422420
value: Scalar, pullback: (Scalar.TangentVector) -> TangentVector
@@ -425,14 +423,14 @@ where
425423
}
426424

427425
@inlinable
426+
@_alwaysEmitIntoClient
428427
@derivative(of: sum)
429428
func _jvpSum() -> (
430429
value: Scalar, differential: (TangentVector) -> Scalar.TangentVector
431430
) {
432431
return (sum(), { v in Scalar.TangentVector(v.sum()) })
433432
}
434433
}
435-
*/
436434

437435
extension SIMD
438436
where

test/AutoDiff/SILGen/nil_coalescing.swift

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
// RUN: %target-swift-frontend -Xllvm -sil-print-types -emit-sil -verify %s | %FileCheck %s
1+
/// Note: -primary-file prevents non_abi->shared linkage change in `removeSerializedFlagFromAllFunctions`
2+
// RUN: %target-swift-frontend -Xllvm -sil-print-types -emit-sil -verify -primary-file %s | %FileCheck %s
23

34
import _Differentiation
45

5-
// CHECK: sil @test_nil_coalescing
6+
// CHECK: sil non_abi @test_nil_coalescing
67
// CHECK: bb0(%{{.*}} : $*T, %[[ARG_OPT:.*]] : $*Optional<T>, %[[ARG_PB:.*]] :
78
// CHECK: $@noescape @callee_guaranteed @substituted <τ_0_0> () -> (@out τ_0_0, @error any Error) for <T>):
89
// CHECK: %[[ALLOC_OPT:.*]] = alloc_stack [lexical] $Optional<T>
@@ -15,7 +16,7 @@ import _Differentiation
1516
//
1617
@_silgen_name("test_nil_coalescing")
1718
@derivative(of: ??)
18-
@usableFromInline
19+
@_alwaysEmitIntoClient
1920
func nilCoalescing<T: Differentiable>(optional: T?, defaultValue: @autoclosure () throws -> T)
2021
rethrows -> (value: T, pullback: (T.TangentVector) -> Optional<T>.TangentVector)
2122
{

test/AutoDiff/Sema/derivative_attr_type_checking.swift

+18
Original file line numberDiff line numberDiff line change
@@ -1062,6 +1062,15 @@ func _internal_original_inlinable_derivative(_ x: Float) -> (value: Float, pullb
10621062
fatalError()
10631063
}
10641064

1065+
func internal_original_alwaysemitintoclient_derivative_error(_ x: Float) -> Float { x }
1066+
@_alwaysEmitIntoClient
1067+
@derivative(of: internal_original_alwaysemitintoclient_derivative_error)
1068+
// expected-error @+1 {{either both or none of derivative and original function must have @alwaysEmitIntoClient attribute}}
1069+
func _internal_original_alwaysemitintoclient_derivative_error(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
1070+
fatalError()
1071+
}
1072+
1073+
@_alwaysEmitIntoClient
10651074
func internal_original_alwaysemitintoclient_derivative(_ x: Float) -> Float { x }
10661075
@_alwaysEmitIntoClient
10671076
@derivative(of: internal_original_alwaysemitintoclient_derivative)
@@ -1084,6 +1093,15 @@ package func _package_original_inlinable_derivative(_ x: Float) -> (value: Float
10841093
fatalError()
10851094
}
10861095

1096+
@_alwaysEmitIntoClient
1097+
package func package_original_alwaysemitintoclient_derivative_error(_ x: Float) -> Float { x }
1098+
@derivative(of: package_original_alwaysemitintoclient_derivative_error)
1099+
// expected-error @+1 {{either both or none of derivative and original function must have @alwaysEmitIntoClient attribute}}
1100+
package func _package_original_alwaysemitintoclient_derivative_error(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
1101+
fatalError()
1102+
}
1103+
1104+
@_alwaysEmitIntoClient
10871105
package func package_original_alwaysemitintoclient_derivative(_ x: Float) -> Float { x }
10881106
@_alwaysEmitIntoClient
10891107
@derivative(of: package_original_alwaysemitintoclient_derivative)

test/AutoDiff/stdlib/simd.swift

-8
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@ SIMDTests.test("init(repeating:)") {
1919
expectEqual(8, pb1(g))
2020
}
2121

22-
// FIXME(TF-1103): Derivative registration does not yet support
23-
// `@_alwaysEmitIntoClient` original functions.
24-
/*
2522
SIMDTests.test("Sum") {
2623
let a = SIMD4<Float>(1, 2, 3, 4)
2724

@@ -32,7 +29,6 @@ SIMDTests.test("Sum") {
3229
expectEqual(10, val1)
3330
expectEqual(SIMD4<Float>(3, 3, 3, 3), pb1(3))
3431
}
35-
*/
3632

3733
SIMDTests.test("Identity") {
3834
let a = SIMD4<Float>(1, 2, 3, 4)
@@ -289,9 +285,6 @@ SIMDTests.test("Generics") {
289285
expectEqual(SIMD3<Double>(5, 10, 15), val4)
290286
expectEqual((SIMD3<Double>(5, 5, 5), 6), pb4(g))
291287

292-
// FIXME(TF-1103): Derivative registration does not yet support
293-
// `@_alwaysEmitIntoClient` original functions like `SIMD.sum()`.
294-
/*
295288
func testSum<Scalar, SIMDType: SIMD>(x: SIMDType) -> Scalar
296289
where SIMDType.Scalar == Scalar,
297290
SIMDType : Differentiable,
@@ -304,7 +297,6 @@ SIMDTests.test("Generics") {
304297
let (val5, pb5) = valueWithPullback(at: a, of: simd3Sum)
305298
expectEqual(6, val5)
306299
expectEqual(SIMD3<Double>(7, 7, 7), pb5(7))
307-
*/
308300
}
309301

310302
runAllTests()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
@_alwaysEmitIntoClient
2+
public func f(_ x: Float) -> Float {
3+
x
4+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import _Differentiation
2+
3+
@derivative(of: f)
4+
@_alwaysEmitIntoClient
5+
public func df(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
6+
(x, { 42 * $0 })
7+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
@_alwaysEmitIntoClient
2+
public func f(_ x: Float) -> Float {
3+
x
4+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import MultiModule1
2+
import _Differentiation
3+
4+
@derivative(of: f)
5+
@_alwaysEmitIntoClient
6+
public func df(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
7+
(x, { 42 * $0 })
8+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import _Differentiation
2+
3+
public protocol Protocol {
4+
var x : Float {get set}
5+
init()
6+
}
7+
8+
extension Protocol {
9+
public init(_ val: Float) {
10+
self.init()
11+
x = val
12+
}
13+
14+
@_alwaysEmitIntoClient
15+
public func sum() -> Float { x }
16+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import MultiModuleProtocol1
2+
import _Differentiation
3+
4+
extension Protocol where Self: Differentiable, Self.TangentVector == Self {
5+
@_alwaysEmitIntoClient
6+
@derivative(of: sum)
7+
public func _vjpSum() -> (
8+
value: Float, pullback: (Float) -> Self.TangentVector
9+
) {
10+
(value: self.x, pullback: { Self.TangentVector(42 * $0) })
11+
}
12+
13+
@_alwaysEmitIntoClient
14+
@derivative(of: sum)
15+
public func _jvpSum() -> (
16+
value: Float, differential: (Self.TangentVector) -> Float
17+
) {
18+
(value: self.x, differential: { 42 * $0.x })
19+
}
20+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import MultiModuleProtocol1
2+
import MultiModuleProtocol2
3+
import _Differentiation
4+
5+
public struct Struct : Protocol {
6+
private var _x : Float
7+
public var x : Float {
8+
get { _x }
9+
set { _x = newValue }
10+
}
11+
public init() { _x = 0 }
12+
}
13+
14+
extension Struct : AdditiveArithmetic {
15+
public static func +(lhs: Self, rhs: Self) -> Self { Self(lhs.x + rhs.x) }
16+
public static func -(lhs: Self, rhs: Self) -> Self { Self(lhs.x - rhs.x) }
17+
public static func +=(a: inout Self, b: Self) { a.x = a.x + b.x }
18+
public static func -=(a: inout Self, b: Self) { a.x = a.x - b.x }
19+
public static var zero: Self { Self(0) }
20+
}
21+
22+
extension Struct : Differentiable {
23+
public typealias TangentVector = Self
24+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
public struct Struct {
2+
public var x : Float
3+
public typealias TangentVector = Self
4+
public init() { x = 0 }
5+
}
6+
7+
extension Struct {
8+
public init(_ val: Float) {
9+
self.init()
10+
x = val
11+
}
12+
13+
@_alwaysEmitIntoClient
14+
public func sum() -> Float { x }
15+
}
16+
17+
extension Struct : AdditiveArithmetic {
18+
public static func +(lhs: Self, rhs: Self) -> Self { Self(lhs.x + rhs.x) }
19+
public static func -(lhs: Self, rhs: Self) -> Self { Self(lhs.x - rhs.x) }
20+
public static func +=(a: inout Self, b: Self) { a.x = a.x + b.x }
21+
public static func -=(a: inout Self, b: Self) { a.x = a.x - b.x }
22+
public static var zero: Self { Self(0) }
23+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import MultiModuleStruct1
2+
import _Differentiation
3+
4+
extension Struct : Differentiable {
5+
@_alwaysEmitIntoClient
6+
@derivative(of: sum)
7+
public func _vjpSum() -> (
8+
value: Float, pullback: (Float) -> Self.TangentVector
9+
) {
10+
(value: self.x, pullback: { Self.TangentVector(42 * $0) })
11+
}
12+
13+
@_alwaysEmitIntoClient
14+
@derivative(of: sum)
15+
public func _jvpSum() -> (
16+
value: Float, differential: (Self.TangentVector) -> Float
17+
) {
18+
(value: self.x, differential: { 42 * $0.x })
19+
}
20+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import MultiModuleStruct1
2+
import _Differentiation
3+
4+
extension Struct : Differentiable {
5+
@_alwaysEmitIntoClient
6+
@derivative(of: sum)
7+
public func _vjpSum() -> (
8+
value: Float, pullback: (Float) -> Self.TangentVector
9+
) {
10+
(value: self.x, pullback: { Self.TangentVector(42 * $0) })
11+
}
12+
}

0 commit comments

Comments
 (0)