Skip to content

Commit 6bb9def

Browse files
committed
[AutoDiff] Support custom derivatives for @_alwaysEmitIntoClient functions
Fixes #54445
1 parent e8fe5c3 commit 6bb9def

File tree

30 files changed

+467
-31
lines changed

30 files changed

+467
-31
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4299,6 +4299,9 @@ NOTE(derivative_attr_fix_access,none,
42994299
"mark the derivative function as "
43004300
"'%select{private|fileprivate|internal|package|@usableFromInline|@usableFromInline}0' "
43014301
"to match the original function", (AccessLevel))
4302+
ERROR(derivative_attr_always_emit_into_client_mismatch,none,
4303+
"either both or none of derivative and original function must have "
4304+
"@alwaysEmitIntoClient attribute", ())
43024305
ERROR(derivative_attr_static_method_mismatch_original,none,
43034306
"unexpected derivative function declaration; "
43044307
"%0 requires the derivative function %1 to be %select{an instance|a 'static'}2 method",

lib/SIL/IR/Linker.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,23 @@ void SILLinkerVisitor::maybeAddFunctionToWorklist(
159159
// HiddenExternal linkage when they are declarations, then they
160160
// become Shared after the body has been deserialized.
161161
// So try deserializing HiddenExternal functions too.
162-
if (linkage == SILLinkage::HiddenExternal)
163-
return deserializeAndPushToWorklist(F);
164-
162+
if (linkage == SILLinkage::HiddenExternal) {
163+
deserializeAndPushToWorklist(F);
164+
if (!F->markedAsAlwaysEmitIntoClient())
165+
return;
166+
// For @_alwaysEmitIntoClient functions, we need to lookup its
167+
// differentiability witness and, if present, ask SILLoader to obtain its
168+
// definition. Otherwise, a linker error would occur due to undefined
169+
// reference to these symbols.
170+
for (SILDifferentiabilityWitness *witness :
171+
F->getModule().lookUpDifferentiabilityWitnessesForFunction(
172+
F->getName())) {
173+
F->getModule().getSILLoader()->lookupDifferentiabilityWitness(
174+
witness->getKey());
175+
}
176+
return;
177+
}
178+
165179
// Update the linkage of the function in case it's different in the serialized
166180
// SIL than derived from the AST. This can be the case with cross-module-
167181
// optimizations.

lib/SILGen/SILGen.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1400,14 +1400,19 @@ void SILGenModule::emitDifferentiabilityWitness(
14001400
auto *diffWitness = M.lookUpDifferentiabilityWitness(key);
14011401
if (!diffWitness) {
14021402
// Differentiability witnesses have the same linkage as the original
1403-
// function, stripping external.
1404-
auto linkage = stripExternalFromLinkage(originalFunction->getLinkage());
1403+
// function, stripping external. For @_alwaysEmitIntoClient original
1404+
// functions, force PublicNonABI linkage of the differentiability witness so
1405+
// we can serialize it (the original function itself might be HiddenExternal
1406+
// in this case if we only have declaration without definition).
1407+
auto linkage =
1408+
originalFunction->markedAsAlwaysEmitIntoClient()
1409+
? SILLinkage::PublicNonABI
1410+
: stripExternalFromLinkage(originalFunction->getLinkage());
14051411
diffWitness = SILDifferentiabilityWitness::createDefinition(
14061412
M, linkage, originalFunction, diffKind, silConfig.parameterIndices,
14071413
silConfig.resultIndices, config.derivativeGenericSignature,
14081414
/*jvp*/ nullptr, /*vjp*/ nullptr,
1409-
/*isSerialized*/ hasPublicVisibility(originalFunction->getLinkage()),
1410-
attr);
1415+
/*isSerialized*/ hasPublicVisibility(linkage), attr);
14111416
}
14121417

14131418
// Set derivative function in differentiability witness.

lib/SILGen/SILGenPoly.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6281,8 +6281,14 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
62816281
auto loc = customDerivativeFn->getLocation();
62826282
SILGenFunctionBuilder fb(*this);
62836283
// Derivative thunks have the same linkage as the original function, stripping
6284-
// external.
6285-
auto linkage = stripExternalFromLinkage(originalFn->getLinkage());
6284+
// external. For @_alwaysEmitIntoClient original functions, force PublicNonABI
6285+
// linkage of derivative thunks so we can serialize them (the original
6286+
// function itself might be HiddenExternal in this case if we only have
6287+
// declaration without definition).
6288+
auto linkage = originalFn->markedAsAlwaysEmitIntoClient()
6289+
? SILLinkage::PublicNonABI
6290+
: stripExternalFromLinkage(originalFn->getLinkage());
6291+
62866292
auto *thunk = fb.getOrCreateFunction(
62876293
loc, name, linkage, thunkFnTy, IsBare, IsNotTransparent,
62886294
customDerivativeFn->getSerializedKind(),

lib/SILOptimizer/Differentiation/Common.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -538,9 +538,14 @@ SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness(
538538
"definitions with explicit differentiable attributes");
539539

540540
return SILDifferentiabilityWitness::createDeclaration(
541-
module, SILLinkage::PublicExternal, original, kind,
542-
minimalConfig->parameterIndices, minimalConfig->resultIndices,
543-
minimalConfig->derivativeGenericSignature);
541+
module,
542+
// Witness for @_alwaysEmitIntoClient original function must be emitted,
543+
// otherwise a linker error would occur due to undefined reference to the
544+
// witness symbol.
545+
original->markedAsAlwaysEmitIntoClient() ? SILLinkage::PublicNonABI
546+
: SILLinkage::PublicExternal,
547+
original, kind, minimalConfig->parameterIndices,
548+
minimalConfig->resultIndices, minimalConfig->derivativeGenericSignature);
544549
}
545550

546551
} // end namespace autodiff

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -911,10 +911,14 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
911911

912912
// We can generate empty JVP / VJP for functions available externally. These
913913
// functions have the same linkage as the original ones sans `external`
914-
// flag. Important exception here hidden_external functions as they are
915-
// serializable but corresponding hidden ones would be not and the SIL
916-
// verifier will fail. Patch `serializeFunctions` for this case.
917-
if (orig->getLinkage() == SILLinkage::HiddenExternal)
914+
// flag. Important exception here hidden_external non-@_alwaysEmitIntoClient
915+
// functions as they are serializable but corresponding hidden ones would be
916+
// not and the SIL verifier will fail. Patch `serializeFunctions` for this
917+
// case. For @_alwaysEmitIntoClient original functions (which might be
918+
// HiddenExternal if we only have declaration without definition), we want
919+
// derivatives to be serialized and do not patch `serializeFunctions`.
920+
if (orig->getLinkage() == SILLinkage::HiddenExternal &&
921+
!orig->markedAsAlwaysEmitIntoClient())
918922
serializeFunctions = IsNotSerialized;
919923

920924
// If the JVP doesn't exist, need to synthesize it.

lib/Sema/TypeCheckAttr.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6825,6 +6825,13 @@ static bool typeCheckDerivativeAttr(DerivativeAttr *attr) {
68256825
return true;
68266826
}
68276827

6828+
if (originalAFD->getAttrs().hasAttribute<AlwaysEmitIntoClientAttr>() !=
6829+
derivative->getAttrs().hasAttribute<AlwaysEmitIntoClientAttr>()) {
6830+
diags.diagnose(derivative->getLoc(),
6831+
diag::derivative_attr_always_emit_into_client_mismatch);
6832+
return true;
6833+
}
6834+
68286835
// Get the resolved differentiability parameter indices.
68296836
auto *resolvedDiffParamIndices = attr->getParameterIndices();
68306837

stdlib/public/Differentiation/SIMDDifferentiation.swift.gyb

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -405,9 +405,6 @@ where
405405
}
406406
}
407407

408-
// FIXME(TF-1103): Derivative registration does not yet support
409-
// `@_alwaysEmitIntoClient` original functions like `SIMD.sum()`.
410-
/*
411408
extension SIMD
412409
where
413410
Self: Differentiable,
@@ -417,6 +414,7 @@ where
417414
TangentVector == Self
418415
{
419416
@inlinable
417+
@_alwaysEmitIntoClient
420418
@derivative(of: sum)
421419
func _vjpSum() -> (
422420
value: Scalar, pullback: (Scalar.TangentVector) -> TangentVector
@@ -425,14 +423,14 @@ where
425423
}
426424

427425
@inlinable
426+
@_alwaysEmitIntoClient
428427
@derivative(of: sum)
429428
func _jvpSum() -> (
430429
value: Scalar, differential: (TangentVector) -> Scalar.TangentVector
431430
) {
432431
return (sum(), { v in Scalar.TangentVector(v.sum()) })
433432
}
434433
}
435-
*/
436434

437435
extension SIMD
438436
where

test/AutoDiff/SILGen/nil_coalescing.swift

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
// RUN: %target-swift-frontend -Xllvm -sil-print-types -emit-sil -verify %s | %FileCheck %s
1+
/// Note: -primary-file prevents non_abi->shared linkage change in `removeSerializedFlagFromAllFunctions`
2+
// RUN: %target-swift-frontend -Xllvm -sil-print-types -emit-sil -verify -primary-file %s | %FileCheck %s
23

34
import _Differentiation
45

5-
// CHECK: sil @test_nil_coalescing
6+
// CHECK: sil non_abi @test_nil_coalescing
67
// CHECK: bb0(%{{.*}} : $*T, %[[ARG_OPT:.*]] : $*Optional<T>, %[[ARG_PB:.*]] :
78
// CHECK: $@noescape @callee_guaranteed @substituted <τ_0_0> () -> (@out τ_0_0, @error any Error) for <T>):
89
// CHECK: %[[ALLOC_OPT:.*]] = alloc_stack [lexical] $Optional<T>
@@ -15,7 +16,7 @@ import _Differentiation
1516
//
1617
@_silgen_name("test_nil_coalescing")
1718
@derivative(of: ??)
18-
@usableFromInline
19+
@_alwaysEmitIntoClient
1920
func nilCoalescing<T: Differentiable>(optional: T?, defaultValue: @autoclosure () throws -> T)
2021
rethrows -> (value: T, pullback: (T.TangentVector) -> Optional<T>.TangentVector)
2122
{

test/AutoDiff/Sema/derivative_attr_type_checking.swift

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,6 +1062,15 @@ func _internal_original_inlinable_derivative(_ x: Float) -> (value: Float, pullb
10621062
fatalError()
10631063
}
10641064

1065+
func internal_original_alwaysemitintoclient_derivative_error(_ x: Float) -> Float { x }
1066+
@_alwaysEmitIntoClient
1067+
@derivative(of: internal_original_alwaysemitintoclient_derivative_error)
1068+
// expected-error @+1 {{either both or none of derivative and original function must have @alwaysEmitIntoClient attribute}}
1069+
func _internal_original_alwaysemitintoclient_derivative_error(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
1070+
fatalError()
1071+
}
1072+
1073+
@_alwaysEmitIntoClient
10651074
func internal_original_alwaysemitintoclient_derivative(_ x: Float) -> Float { x }
10661075
@_alwaysEmitIntoClient
10671076
@derivative(of: internal_original_alwaysemitintoclient_derivative)
@@ -1084,6 +1093,15 @@ package func _package_original_inlinable_derivative(_ x: Float) -> (value: Float
10841093
fatalError()
10851094
}
10861095

1096+
@_alwaysEmitIntoClient
1097+
package func package_original_alwaysemitintoclient_derivative_error(_ x: Float) -> Float { x }
1098+
@derivative(of: package_original_alwaysemitintoclient_derivative_error)
1099+
// expected-error @+1 {{either both or none of derivative and original function must have @alwaysEmitIntoClient attribute}}
1100+
package func _package_original_alwaysemitintoclient_derivative_error(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
1101+
fatalError()
1102+
}
1103+
1104+
@_alwaysEmitIntoClient
10871105
package func package_original_alwaysemitintoclient_derivative(_ x: Float) -> Float { x }
10881106
@_alwaysEmitIntoClient
10891107
@derivative(of: package_original_alwaysemitintoclient_derivative)

0 commit comments

Comments
 (0)