diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 0fb3ff1f3c744..009617413a906 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -4299,6 +4299,9 @@ NOTE(derivative_attr_fix_access,none, "mark the derivative function as " "'%select{private|fileprivate|internal|package|@usableFromInline|@usableFromInline}0' " "to match the original function", (AccessLevel)) +ERROR(derivative_attr_always_emit_into_client_mismatch,none, + "either both or none of derivative and original function must have " + "@alwaysEmitIntoClient attribute", ()) ERROR(derivative_attr_static_method_mismatch_original,none, "unexpected derivative function declaration; " "%0 requires the derivative function %1 to be %select{an instance|a 'static'}2 method", diff --git a/lib/SIL/IR/Linker.cpp b/lib/SIL/IR/Linker.cpp index 7496d49d9ef50..8b181dfd5cf30 100644 --- a/lib/SIL/IR/Linker.cpp +++ b/lib/SIL/IR/Linker.cpp @@ -159,9 +159,23 @@ void SILLinkerVisitor::maybeAddFunctionToWorklist( // HiddenExternal linkage when they are declarations, then they // become Shared after the body has been deserialized. // So try deserializing HiddenExternal functions too. - if (linkage == SILLinkage::HiddenExternal) - return deserializeAndPushToWorklist(F); - + if (linkage == SILLinkage::HiddenExternal) { + deserializeAndPushToWorklist(F); + if (!F->markedAsAlwaysEmitIntoClient()) + return; + // For @_alwaysEmitIntoClient functions, we need to lookup its + // differentiability witness and, if present, ask SILLoader to obtain its + // definition. Otherwise, a linker error would occur due to undefined + // reference to these symbols. + for (SILDifferentiabilityWitness *witness : + F->getModule().lookUpDifferentiabilityWitnessesForFunction( + F->getName())) { + F->getModule().getSILLoader()->lookupDifferentiabilityWitness( + witness->getKey()); + } + return; + } + // Update the linkage of the function in case it's different in the serialized // SIL than derived from the AST. This can be the case with cross-module- // optimizations. diff --git a/lib/SILGen/SILGen.cpp b/lib/SILGen/SILGen.cpp index 51e787fefebe7..ee73b37008275 100644 --- a/lib/SILGen/SILGen.cpp +++ b/lib/SILGen/SILGen.cpp @@ -1400,14 +1400,19 @@ void SILGenModule::emitDifferentiabilityWitness( auto *diffWitness = M.lookUpDifferentiabilityWitness(key); if (!diffWitness) { // Differentiability witnesses have the same linkage as the original - // function, stripping external. - auto linkage = stripExternalFromLinkage(originalFunction->getLinkage()); + // function, stripping external. For @_alwaysEmitIntoClient original + // functions, force PublicNonABI linkage of the differentiability witness so + // we can serialize it (the original function itself might be HiddenExternal + // in this case if we only have declaration without definition). + auto linkage = + originalFunction->markedAsAlwaysEmitIntoClient() + ? SILLinkage::PublicNonABI + : stripExternalFromLinkage(originalFunction->getLinkage()); diffWitness = SILDifferentiabilityWitness::createDefinition( M, linkage, originalFunction, diffKind, silConfig.parameterIndices, silConfig.resultIndices, config.derivativeGenericSignature, /*jvp*/ nullptr, /*vjp*/ nullptr, - /*isSerialized*/ hasPublicVisibility(originalFunction->getLinkage()), - attr); + /*isSerialized*/ hasPublicVisibility(linkage), attr); } // Set derivative function in differentiability witness. diff --git a/lib/SILGen/SILGenPoly.cpp b/lib/SILGen/SILGenPoly.cpp index ca89af6d83ec1..3100c43668ace 100644 --- a/lib/SILGen/SILGenPoly.cpp +++ b/lib/SILGen/SILGenPoly.cpp @@ -6281,8 +6281,14 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk( auto loc = customDerivativeFn->getLocation(); SILGenFunctionBuilder fb(*this); // Derivative thunks have the same linkage as the original function, stripping - // external. - auto linkage = stripExternalFromLinkage(originalFn->getLinkage()); + // external. For @_alwaysEmitIntoClient original functions, force PublicNonABI + // linkage of derivative thunks so we can serialize them (the original + // function itself might be HiddenExternal in this case if we only have + // declaration without definition). + auto linkage = originalFn->markedAsAlwaysEmitIntoClient() + ? SILLinkage::PublicNonABI + : stripExternalFromLinkage(originalFn->getLinkage()); + auto *thunk = fb.getOrCreateFunction( loc, name, linkage, thunkFnTy, IsBare, IsNotTransparent, customDerivativeFn->getSerializedKind(), diff --git a/lib/SILOptimizer/Differentiation/Common.cpp b/lib/SILOptimizer/Differentiation/Common.cpp index ee92aad8894dc..fd4adfd979f0a 100644 --- a/lib/SILOptimizer/Differentiation/Common.cpp +++ b/lib/SILOptimizer/Differentiation/Common.cpp @@ -538,9 +538,14 @@ SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness( "definitions with explicit differentiable attributes"); return SILDifferentiabilityWitness::createDeclaration( - module, SILLinkage::PublicExternal, original, kind, - minimalConfig->parameterIndices, minimalConfig->resultIndices, - minimalConfig->derivativeGenericSignature); + module, + // Witness for @_alwaysEmitIntoClient original function must be emitted, + // otherwise a linker error would occur due to undefined reference to the + // witness symbol. + original->markedAsAlwaysEmitIntoClient() ? SILLinkage::PublicNonABI + : SILLinkage::PublicExternal, + original, kind, minimalConfig->parameterIndices, + minimalConfig->resultIndices, minimalConfig->derivativeGenericSignature); } } // end namespace autodiff diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 6276cc3a7c089..f7dda0a195146 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -911,10 +911,14 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness( // We can generate empty JVP / VJP for functions available externally. These // functions have the same linkage as the original ones sans `external` - // flag. Important exception here hidden_external functions as they are - // serializable but corresponding hidden ones would be not and the SIL - // verifier will fail. Patch `serializeFunctions` for this case. - if (orig->getLinkage() == SILLinkage::HiddenExternal) + // flag. Important exception here hidden_external non-@_alwaysEmitIntoClient + // functions as they are serializable but corresponding hidden ones would be + // not and the SIL verifier will fail. Patch `serializeFunctions` for this + // case. For @_alwaysEmitIntoClient original functions (which might be + // HiddenExternal if we only have declaration without definition), we want + // derivatives to be serialized and do not patch `serializeFunctions`. + if (orig->getLinkage() == SILLinkage::HiddenExternal && + !orig->markedAsAlwaysEmitIntoClient()) serializeFunctions = IsNotSerialized; // If the JVP doesn't exist, need to synthesize it. diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 41a01b99b8ac2..47c3f80a0ec82 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -6825,6 +6825,13 @@ static bool typeCheckDerivativeAttr(DerivativeAttr *attr) { return true; } + if (originalAFD->getAttrs().hasAttribute() != + derivative->getAttrs().hasAttribute()) { + diags.diagnose(derivative->getLoc(), + diag::derivative_attr_always_emit_into_client_mismatch); + return true; + } + // Get the resolved differentiability parameter indices. auto *resolvedDiffParamIndices = attr->getParameterIndices(); diff --git a/stdlib/public/Differentiation/SIMDDifferentiation.swift.gyb b/stdlib/public/Differentiation/SIMDDifferentiation.swift.gyb index a1e10887ee17f..904768cfd2588 100644 --- a/stdlib/public/Differentiation/SIMDDifferentiation.swift.gyb +++ b/stdlib/public/Differentiation/SIMDDifferentiation.swift.gyb @@ -405,9 +405,6 @@ where } } -// FIXME(TF-1103): Derivative registration does not yet support -// `@_alwaysEmitIntoClient` original functions like `SIMD.sum()`. -/* extension SIMD where Self: Differentiable, @@ -417,6 +414,7 @@ where TangentVector == Self { @inlinable + @_alwaysEmitIntoClient @derivative(of: sum) func _vjpSum() -> ( value: Scalar, pullback: (Scalar.TangentVector) -> TangentVector @@ -425,6 +423,7 @@ where } @inlinable + @_alwaysEmitIntoClient @derivative(of: sum) func _jvpSum() -> ( value: Scalar, differential: (TangentVector) -> Scalar.TangentVector @@ -432,7 +431,6 @@ where return (sum(), { v in Scalar.TangentVector(v.sum()) }) } } -*/ extension SIMD where diff --git a/test/AutoDiff/SILGen/nil_coalescing.swift b/test/AutoDiff/SILGen/nil_coalescing.swift index 027367e60d196..079d2d4747147 100644 --- a/test/AutoDiff/SILGen/nil_coalescing.swift +++ b/test/AutoDiff/SILGen/nil_coalescing.swift @@ -1,8 +1,9 @@ -// RUN: %target-swift-frontend -Xllvm -sil-print-types -emit-sil -verify %s | %FileCheck %s +/// Note: -primary-file prevents non_abi->shared linkage change in `removeSerializedFlagFromAllFunctions` +// RUN: %target-swift-frontend -Xllvm -sil-print-types -emit-sil -verify -primary-file %s | %FileCheck %s import _Differentiation -// CHECK: sil @test_nil_coalescing +// CHECK: sil non_abi @test_nil_coalescing // CHECK: bb0(%{{.*}} : $*T, %[[ARG_OPT:.*]] : $*Optional, %[[ARG_PB:.*]] : // CHECK: $@noescape @callee_guaranteed @substituted <τ_0_0> () -> (@out τ_0_0, @error any Error) for ): // CHECK: %[[ALLOC_OPT:.*]] = alloc_stack [lexical] $Optional @@ -15,7 +16,7 @@ import _Differentiation // @_silgen_name("test_nil_coalescing") @derivative(of: ??) -@usableFromInline +@_alwaysEmitIntoClient func nilCoalescing(optional: T?, defaultValue: @autoclosure () throws -> T) rethrows -> (value: T, pullback: (T.TangentVector) -> Optional.TangentVector) { diff --git a/test/AutoDiff/Sema/derivative_attr_type_checking.swift b/test/AutoDiff/Sema/derivative_attr_type_checking.swift index 9740ab1fc589c..4fdebeddd5276 100644 --- a/test/AutoDiff/Sema/derivative_attr_type_checking.swift +++ b/test/AutoDiff/Sema/derivative_attr_type_checking.swift @@ -1062,6 +1062,15 @@ func _internal_original_inlinable_derivative(_ x: Float) -> (value: Float, pullb fatalError() } +func internal_original_alwaysemitintoclient_derivative_error(_ x: Float) -> Float { x } +@_alwaysEmitIntoClient +@derivative(of: internal_original_alwaysemitintoclient_derivative_error) +// expected-error @+1 {{either both or none of derivative and original function must have @alwaysEmitIntoClient attribute}} +func _internal_original_alwaysemitintoclient_derivative_error(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { + fatalError() +} + +@_alwaysEmitIntoClient func internal_original_alwaysemitintoclient_derivative(_ x: Float) -> Float { x } @_alwaysEmitIntoClient @derivative(of: internal_original_alwaysemitintoclient_derivative) @@ -1084,6 +1093,15 @@ package func _package_original_inlinable_derivative(_ x: Float) -> (value: Float fatalError() } +@_alwaysEmitIntoClient +package func package_original_alwaysemitintoclient_derivative_error(_ x: Float) -> Float { x } +@derivative(of: package_original_alwaysemitintoclient_derivative_error) +// expected-error @+1 {{either both or none of derivative and original function must have @alwaysEmitIntoClient attribute}} +package func _package_original_alwaysemitintoclient_derivative_error(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { + fatalError() +} + +@_alwaysEmitIntoClient package func package_original_alwaysemitintoclient_derivative(_ x: Float) -> Float { x } @_alwaysEmitIntoClient @derivative(of: package_original_alwaysemitintoclient_derivative) diff --git a/test/AutoDiff/stdlib/simd.swift b/test/AutoDiff/stdlib/simd.swift index 04dc27e078c38..32046fa3a2072 100644 --- a/test/AutoDiff/stdlib/simd.swift +++ b/test/AutoDiff/stdlib/simd.swift @@ -19,9 +19,6 @@ SIMDTests.test("init(repeating:)") { expectEqual(8, pb1(g)) } -// FIXME(TF-1103): Derivative registration does not yet support -// `@_alwaysEmitIntoClient` original functions. -/* SIMDTests.test("Sum") { let a = SIMD4(1, 2, 3, 4) @@ -32,7 +29,6 @@ SIMDTests.test("Sum") { expectEqual(10, val1) expectEqual(SIMD4(3, 3, 3, 3), pb1(3)) } -*/ SIMDTests.test("Identity") { let a = SIMD4(1, 2, 3, 4) @@ -289,9 +285,6 @@ SIMDTests.test("Generics") { expectEqual(SIMD3(5, 10, 15), val4) expectEqual((SIMD3(5, 5, 5), 6), pb4(g)) - // FIXME(TF-1103): Derivative registration does not yet support - // `@_alwaysEmitIntoClient` original functions like `SIMD.sum()`. - /* func testSum(x: SIMDType) -> Scalar where SIMDType.Scalar == Scalar, SIMDType : Differentiable, @@ -304,7 +297,6 @@ SIMDTests.test("Generics") { let (val5, pb5) = valueWithPullback(at: a, of: simd3Sum) expectEqual(6, val5) expectEqual(SIMD3(7, 7, 7), pb5(7)) - */ } runAllTests() diff --git a/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiFileModule/file1.swift b/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiFileModule/file1.swift new file mode 100644 index 0000000000000..c808c0ef666ed --- /dev/null +++ b/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiFileModule/file1.swift @@ -0,0 +1,4 @@ +@_alwaysEmitIntoClient +public func f(_ x: Float) -> Float { + x +} diff --git a/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiFileModule/file2.swift b/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiFileModule/file2.swift new file mode 100644 index 0000000000000..1e047d39bdf82 --- /dev/null +++ b/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiFileModule/file2.swift @@ -0,0 +1,7 @@ +import _Differentiation + +@derivative(of: f) +@_alwaysEmitIntoClient +public func df(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { + (x, { 42 * $0 }) +} diff --git a/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiModule/file1.swift b/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiModule/file1.swift new file mode 100644 index 0000000000000..c808c0ef666ed --- /dev/null +++ b/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiModule/file1.swift @@ -0,0 +1,4 @@ +@_alwaysEmitIntoClient +public func f(_ x: Float) -> Float { + x +} diff --git a/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiModule/file2.swift b/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiModule/file2.swift new file mode 100644 index 0000000000000..8feda0e57f92c --- /dev/null +++ b/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiModule/file2.swift @@ -0,0 +1,8 @@ +import MultiModule1 +import _Differentiation + +@derivative(of: f) +@_alwaysEmitIntoClient +public func df(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { + (x, { 42 * $0 }) +} diff --git a/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiModuleProtocol/file1.swift b/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiModuleProtocol/file1.swift new file mode 100644 index 0000000000000..066e1bd2eeed1 --- /dev/null +++ b/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiModuleProtocol/file1.swift @@ -0,0 +1,16 @@ +import _Differentiation + +public protocol Protocol { + var x : Float {get set} + init() +} + +extension Protocol { + public init(_ val: Float) { + self.init() + x = val + } + + @_alwaysEmitIntoClient + public func sum() -> Float { x } +} diff --git a/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiModuleProtocol/file2.swift b/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiModuleProtocol/file2.swift new file mode 100644 index 0000000000000..e559da2345baa --- /dev/null +++ b/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiModuleProtocol/file2.swift @@ -0,0 +1,20 @@ +import MultiModuleProtocol1 +import _Differentiation + +extension Protocol where Self: Differentiable, Self.TangentVector == Self { + @_alwaysEmitIntoClient + @derivative(of: sum) + public func _vjpSum() -> ( + value: Float, pullback: (Float) -> Self.TangentVector + ) { + (value: self.x, pullback: { Self.TangentVector(42 * $0) }) + } + + @_alwaysEmitIntoClient + @derivative(of: sum) + public func _jvpSum() -> ( + value: Float, differential: (Self.TangentVector) -> Float + ) { + (value: self.x, differential: { 42 * $0.x }) + } +} diff --git a/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiModuleProtocol/file3.swift b/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiModuleProtocol/file3.swift new file mode 100644 index 0000000000000..4da6021fa69d0 --- /dev/null +++ b/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiModuleProtocol/file3.swift @@ -0,0 +1,24 @@ +import MultiModuleProtocol1 +import MultiModuleProtocol2 +import _Differentiation + +public struct Struct : Protocol { + private var _x : Float + public var x : Float { + get { _x } + set { _x = newValue } + } + public init() { _x = 0 } +} + +extension Struct : AdditiveArithmetic { + public static func +(lhs: Self, rhs: Self) -> Self { Self(lhs.x + rhs.x) } + public static func -(lhs: Self, rhs: Self) -> Self { Self(lhs.x - rhs.x) } + public static func +=(a: inout Self, b: Self) { a.x = a.x + b.x } + public static func -=(a: inout Self, b: Self) { a.x = a.x - b.x } + public static var zero: Self { Self(0) } +} + +extension Struct : Differentiable { + public typealias TangentVector = Self +} diff --git a/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiModuleStruct/file1.swift b/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiModuleStruct/file1.swift new file mode 100644 index 0000000000000..5f4fa36bebf53 --- /dev/null +++ b/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiModuleStruct/file1.swift @@ -0,0 +1,23 @@ +public struct Struct { + public var x : Float + public typealias TangentVector = Self + public init() { x = 0 } +} + +extension Struct { + public init(_ val: Float) { + self.init() + x = val + } + + @_alwaysEmitIntoClient + public func sum() -> Float { x } +} + +extension Struct : AdditiveArithmetic { + public static func +(lhs: Self, rhs: Self) -> Self { Self(lhs.x + rhs.x) } + public static func -(lhs: Self, rhs: Self) -> Self { Self(lhs.x - rhs.x) } + public static func +=(a: inout Self, b: Self) { a.x = a.x + b.x } + public static func -=(a: inout Self, b: Self) { a.x = a.x - b.x } + public static var zero: Self { Self(0) } +} diff --git a/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiModuleStruct/file2.swift b/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiModuleStruct/file2.swift new file mode 100644 index 0000000000000..05a63c26de0bc --- /dev/null +++ b/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiModuleStruct/file2.swift @@ -0,0 +1,20 @@ +import MultiModuleStruct1 +import _Differentiation + +extension Struct : Differentiable { + @_alwaysEmitIntoClient + @derivative(of: sum) + public func _vjpSum() -> ( + value: Float, pullback: (Float) -> Self.TangentVector + ) { + (value: self.x, pullback: { Self.TangentVector(42 * $0) }) + } + + @_alwaysEmitIntoClient + @derivative(of: sum) + public func _jvpSum() -> ( + value: Float, differential: (Self.TangentVector) -> Float + ) { + (value: self.x, differential: { 42 * $0.x }) + } +} diff --git a/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiModuleStruct/file2_no_jvp.swift b/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiModuleStruct/file2_no_jvp.swift new file mode 100644 index 0000000000000..d45a36a13752a --- /dev/null +++ b/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiModuleStruct/file2_no_jvp.swift @@ -0,0 +1,12 @@ +import MultiModuleStruct1 +import _Differentiation + +extension Struct : Differentiable { + @_alwaysEmitIntoClient + @derivative(of: sum) + public func _vjpSum() -> ( + value: Float, pullback: (Float) -> Self.TangentVector + ) { + (value: self.x, pullback: { Self.TangentVector(42 * $0) }) + } +} diff --git a/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiModuleStruct/file2_no_vjp.swift b/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiModuleStruct/file2_no_vjp.swift new file mode 100644 index 0000000000000..7aaf91d5e607a --- /dev/null +++ b/test/AutoDiff/validation-test/always_emit_into_client/Inputs/MultiModuleStruct/file2_no_vjp.swift @@ -0,0 +1,12 @@ +import MultiModuleStruct1 +import _Differentiation + +extension Struct : Differentiable { + @_alwaysEmitIntoClient + @derivative(of: sum) + public func _jvpSum() -> ( + value: Float, differential: (Self.TangentVector) -> Float + ) { + (value: self.x, differential: { 42 * $0.x }) + } +} diff --git a/test/AutoDiff/validation-test/always_emit_into_client/Inputs/SingleFileModule/file.swift b/test/AutoDiff/validation-test/always_emit_into_client/Inputs/SingleFileModule/file.swift new file mode 100644 index 0000000000000..6916329c88f03 --- /dev/null +++ b/test/AutoDiff/validation-test/always_emit_into_client/Inputs/SingleFileModule/file.swift @@ -0,0 +1,12 @@ +import _Differentiation + +@_alwaysEmitIntoClient +public func f(_ x: Float) -> Float { + x +} + +@derivative(of: f) +@_alwaysEmitIntoClient +public func df(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { + (x, { 42 * $0 }) +} diff --git a/test/AutoDiff/validation-test/always_emit_into_client/multi_file.swift b/test/AutoDiff/validation-test/always_emit_into_client/multi_file.swift new file mode 100644 index 0000000000000..4464c85f02fca --- /dev/null +++ b/test/AutoDiff/validation-test/always_emit_into_client/multi_file.swift @@ -0,0 +1,28 @@ +// REQUIRES: executable_test +// RUN: %empty-directory(%t) + +/// Note: we build just a module without a library since it would not contain any exported +/// symbols because all the functions in the module are marked as @_alwaysEmitIntoClient. +// RUN: %target-build-swift %S/Inputs/MultiFileModule/file1.swift %S/Inputs/MultiFileModule/file2.swift \ +// RUN: -emit-module -emit-module-path %t/MultiFileModule.swiftmodule -module-name MultiFileModule + +// RUN: %target-build-swift -I%t %s -o %t/a.out %target-rpath(%t) +// RUN: %target-run %t/a.out + +// RUN: %target-build-swift -I%t %s -emit-ir | %FileCheck %s + +import MultiFileModule +import StdlibUnittest +import _Differentiation + +var AlwaysEmitIntoClientTests = TestSuite("AlwaysEmitIntoClient") + +AlwaysEmitIntoClientTests.test("registration") { + expectEqual(42, gradient(at: 0, of: f)) + expectEqual(42, gradient(at: 1, of: f)) + expectEqual(42, gradient(at: 2, of: f)) +} + +runAllTests() + +// CHECK: @"15MultiFileModule1fyS2fFWJrSpSr" = weak_odr hidden {{()|local_unnamed_addr }}global { ptr, ptr } { ptr @"$s15MultiFileModule1fyS2fFTJfSpSr", ptr @"$s15MultiFileModule1fyS2fFTJrSpSr" } diff --git a/test/AutoDiff/validation-test/always_emit_into_client/multi_module.swift b/test/AutoDiff/validation-test/always_emit_into_client/multi_module.swift new file mode 100644 index 0000000000000..fa156c73f532d --- /dev/null +++ b/test/AutoDiff/validation-test/always_emit_into_client/multi_module.swift @@ -0,0 +1,31 @@ +// REQUIRES: executable_test +// RUN: %empty-directory(%t) + +/// Note: we build just modules without libraries since they would not contain any exported +/// symbols because all the functions in the module are marked as @_alwaysEmitIntoClient. +// RUN: %target-build-swift-dylib(%t/%target-library-name(MultiModule1)) %S/Inputs/MultiModule/file1.swift \ +// RUN: -emit-module -emit-module-path %t/MultiModule1.swiftmodule -module-name MultiModule1 +// RUN: %target-build-swift-dylib(%t/%target-library-name(MultiModule2)) %S/Inputs/MultiModule/file2.swift \ +// RUN: -emit-module -emit-module-path %t/MultiModule2.swiftmodule -module-name MultiModule2 -I%t %target-rpath(%t) + +// RUN: %target-build-swift -I%t %s -o %t/a.out %target-rpath(%t) +// RUN: %target-run %t/a.out + +// RUN: %target-build-swift -I%t %s -emit-ir | %FileCheck %s + +import MultiModule1 +import MultiModule2 +import StdlibUnittest +import _Differentiation + +var AlwaysEmitIntoClientTests = TestSuite("AlwaysEmitIntoClient") + +AlwaysEmitIntoClientTests.test("registration") { + expectEqual(42, gradient(at: 0, of: f)) + expectEqual(42, gradient(at: 1, of: f)) + expectEqual(42, gradient(at: 2, of: f)) +} + +runAllTests() + +// CHECK: @"12MultiModule11fyS2fFWJrSpSr" = weak_odr hidden {{()|local_unnamed_addr }}global { ptr, ptr } { ptr @"$s12MultiModule11fyS2fFTJfSpSr", ptr @"$s12MultiModule11fyS2fFTJrSpSr" } diff --git a/test/AutoDiff/validation-test/always_emit_into_client/multi_module_protocol.swift b/test/AutoDiff/validation-test/always_emit_into_client/multi_module_protocol.swift new file mode 100644 index 0000000000000..708a55f7fcc92 --- /dev/null +++ b/test/AutoDiff/validation-test/always_emit_into_client/multi_module_protocol.swift @@ -0,0 +1,44 @@ +// REQUIRES: executable_test +// RUN: %empty-directory(%t) + +// RUN: %target-build-swift-dylib(%t/%target-library-name(MultiModuleProtocol1)) %S/Inputs/MultiModuleProtocol/file1.swift \ +// RUN: -emit-module -emit-module-path %t/MultiModuleProtocol1.swiftmodule -module-name MultiModuleProtocol1 + +/// Note: we build just a module without a library since it would not contain any exported +/// symbols because all the functions in the module are marked as @_alwaysEmitIntoClient. +// RUN: %target-build-swift %S/Inputs/MultiModuleProtocol/file2.swift -emit-module -emit-module-path %t/MultiModuleProtocol2.swiftmodule \ +// RUN: -module-name MultiModuleProtocol2 -I%t -lMultiModuleProtocol1 %target-rpath(%t) + +// RUN: %target-build-swift-dylib(%t/%target-library-name(MultiModuleProtocol3)) %S/Inputs/MultiModuleProtocol/file3.swift \ +// RUN: -emit-module -emit-module-path %t/MultiModuleProtocol3.swiftmodule -module-name MultiModuleProtocol3 -I%t -L%t -lMultiModuleProtocol1 %target-rpath(%t) + +/// Note: we enable forward-mode differentiation to automatically generate JVP for `foo`. +/// It wraps `Protocol.sum` that has custom JVP defined in MultiModuleProtocol2, so we can test it. +// RUN: %target-build-swift -Xfrontend -enable-experimental-forward-mode-differentiation \ +// RUN: -I%t -L%t %s -lMultiModuleProtocol1 -lMultiModuleProtocol3 -o %t/a.out %target-rpath(%t) +// RUN: %target-run %t/a.out + +// RUN: %target-build-swift -I%t %s -emit-ir | %FileCheck %s + +import MultiModuleProtocol1 +import MultiModuleProtocol2 +import MultiModuleProtocol3 +import StdlibUnittest +import _Differentiation + +var AlwaysEmitIntoClientTests = TestSuite("AlwaysEmitIntoClient") + +AlwaysEmitIntoClientTests.test("registration") { + func foo(x: T) -> Float + where T: Differentiable, T.TangentVector == T { x.sum() } + expectEqual(Struct(42), pullback(at: Struct(0), of: foo)(1)) + expectEqual(Struct(42), pullback(at: Struct(1), of: foo)(1)) + expectEqual(Struct(42), pullback(at: Struct(2), of: foo)(1)) + expectEqual(42, differential(at: Struct(0), of: foo)(Struct(1))) + expectEqual(42, differential(at: Struct(1), of: foo)(Struct(1))) + expectEqual(42, differential(at: Struct(2), of: foo)(Struct(1))) +} + +runAllTests() + +// CHECK: @"20MultiModuleProtocol18ProtocolPAAE3sumSfyFAaBRz16_Differentiation14DifferentiableRz13TangentVectorAeFPQzRszlWJrSpSr" = weak_odr hidden {{()|local_unnamed_addr }}global { ptr, ptr } { ptr @"$s20MultiModuleProtocol18ProtocolPAAE3sumSfyFAaBRz16_Differentiation14DifferentiableRz13TangentVectorAeFPQzRszlTJfSpSr", ptr @"$s20MultiModuleProtocol18ProtocolPAAE3sumSfyFAaBRz16_Differentiation14DifferentiableRz13TangentVectorAeFPQzRszlTJrSpSr" } diff --git a/test/AutoDiff/validation-test/always_emit_into_client/multi_module_struct.swift b/test/AutoDiff/validation-test/always_emit_into_client/multi_module_struct.swift new file mode 100644 index 0000000000000..2125d40fa5554 --- /dev/null +++ b/test/AutoDiff/validation-test/always_emit_into_client/multi_module_struct.swift @@ -0,0 +1,36 @@ +// REQUIRES: executable_test +// RUN: %empty-directory(%t) + +// RUN: %target-build-swift-dylib(%t/%target-library-name(MultiModuleStruct1)) %S/Inputs/MultiModuleStruct/file1.swift \ +// RUN: -emit-module -emit-module-path %t/MultiModuleStruct1.swiftmodule -module-name MultiModuleStruct1 +// RUN: %target-build-swift-dylib(%t/%target-library-name(MultiModuleStruct2)) %S/Inputs/MultiModuleStruct/file2.swift \ +// RUN: -emit-module -emit-module-path %t/MultiModuleStruct2.swiftmodule -module-name MultiModuleStruct2 -I%t -L%t -lMultiModuleStruct1 %target-rpath(%t) + +/// Note: we enable forward-mode differentiation to automatically generate JVP for `foo`. +/// It wraps `Struct.sum` that has custom JVP defined in MultiModuleStruct2, so we can test it. +// RUN: %target-build-swift -Xfrontend -enable-experimental-forward-mode-differentiation \ +// RUN: -I%t -L%t %s -lMultiModuleStruct1 -lMultiModuleStruct2 -o %t/a.out %target-rpath(%t) +// RUN: %target-run %t/a.out + +// RUN: %target-build-swift -I%t %s -emit-ir | %FileCheck %s + +import MultiModuleStruct1 +import MultiModuleStruct2 +import StdlibUnittest +import _Differentiation + +var AlwaysEmitIntoClientTests = TestSuite("AlwaysEmitIntoClient") + +AlwaysEmitIntoClientTests.test("registration") { + func foo(x: Struct) -> Float { x.sum() } + expectEqual(Struct(42), pullback(at: Struct(0), of: foo)(1)) + expectEqual(Struct(42), pullback(at: Struct(1), of: foo)(1)) + expectEqual(Struct(42), pullback(at: Struct(2), of: foo)(1)) + expectEqual(42, differential(at: Struct(0), of: foo)(Struct(1))) + expectEqual(42, differential(at: Struct(1), of: foo)(Struct(1))) + expectEqual(42, differential(at: Struct(2), of: foo)(Struct(1))) +} + +runAllTests() + +// CHECK: @"18MultiModuleStruct16StructV3sumSfyFWJrSpSr" = weak_odr hidden {{()|local_unnamed_addr }}global { ptr, ptr } { ptr @"$s18MultiModuleStruct16StructV3sumSfyFTJfSpSr", ptr @"$s18MultiModuleStruct16StructV3sumSfyFTJrSpSr" } diff --git a/test/AutoDiff/validation-test/always_emit_into_client/multi_module_struct_no_jvp.swift b/test/AutoDiff/validation-test/always_emit_into_client/multi_module_struct_no_jvp.swift new file mode 100644 index 0000000000000..05e62e12c534f --- /dev/null +++ b/test/AutoDiff/validation-test/always_emit_into_client/multi_module_struct_no_jvp.swift @@ -0,0 +1,37 @@ +// REQUIRES: executable_test +// RUN: %empty-directory(%t) + +// RUN: %target-build-swift-dylib(%t/%target-library-name(MultiModuleStruct1)) %S/Inputs/MultiModuleStruct/file1.swift \ +// RUN: -emit-module -emit-module-path %t/MultiModuleStruct1.swiftmodule -module-name MultiModuleStruct1 +// RUN: %target-build-swift-dylib(%t/%target-library-name(MultiModuleStruct2NoJVP)) %S/Inputs/MultiModuleStruct/file2_no_jvp.swift \ +// RUN: -emit-module -emit-module-path %t/MultiModuleStruct2NoJVP.swiftmodule -module-name MultiModuleStruct2NoJVP -I%t -L%t -lMultiModuleStruct1 %target-rpath(%t) + +/// Note: we enable forward-mode differentiation to automatically generate JVP for `foo`. +/// It wraps `Struct.sum` that has custom JVP defined in MultiModuleStruct2, so we can test it. +// RUN: %target-build-swift -Xfrontend -enable-experimental-forward-mode-differentiation \ +// RUN: -I%t -L%t %s -lMultiModuleStruct1 -lMultiModuleStruct2NoJVP -o %t/a.out %target-rpath(%t) +// RUN: %target-run %t/a.out + +// RUN: %target-build-swift -I%t %s -emit-ir | %FileCheck %s + +import MultiModuleStruct1 +import MultiModuleStruct2NoJVP +import StdlibUnittest +import _Differentiation + +var AlwaysEmitIntoClientTests = TestSuite("AlwaysEmitIntoClient") + +AlwaysEmitIntoClientTests.test("registration") { + func foo(x: Struct) -> Float { x.sum() } + expectEqual(Struct(42), pullback(at: Struct(0), of: foo)(1)) + expectEqual(Struct(42), pullback(at: Struct(1), of: foo)(1)) + expectEqual(Struct(42), pullback(at: Struct(2), of: foo)(1)) + /// Custom JVP for Struct.sum is not provided, a JVP causing fatal error is emitted. + expectCrash{differential(at: Struct(0), of: foo)(Struct(1))} + expectCrash{differential(at: Struct(1), of: foo)(Struct(1))} + expectCrash{differential(at: Struct(2), of: foo)(Struct(1))} +} + +runAllTests() + +// CHECK: @"18MultiModuleStruct16StructV3sumSfyFWJrSpSr" = weak_odr hidden {{()|local_unnamed_addr }}global { ptr, ptr } { ptr @"$s18MultiModuleStruct16StructV3sumSfyFTJfSpSr", ptr @"$s18MultiModuleStruct16StructV3sumSfyFTJrSpSr" } diff --git a/test/AutoDiff/validation-test/always_emit_into_client/multi_module_struct_no_vjp.swift b/test/AutoDiff/validation-test/always_emit_into_client/multi_module_struct_no_vjp.swift new file mode 100644 index 0000000000000..a76ab04308b52 --- /dev/null +++ b/test/AutoDiff/validation-test/always_emit_into_client/multi_module_struct_no_vjp.swift @@ -0,0 +1,26 @@ +// REQUIRES: executable_test +// RUN: %empty-directory(%t) + +// RUN: %target-build-swift-dylib(%t/%target-library-name(MultiModuleStruct1)) %S/Inputs/MultiModuleStruct/file1.swift \ +// RUN: -emit-module -emit-module-path %t/MultiModuleStruct1.swiftmodule -module-name MultiModuleStruct1 +// RUN: not %target-build-swift-dylib(%t/%target-library-name(MultiModuleStruct2NoVJP)) %S/Inputs/MultiModuleStruct/file2_no_vjp.swift \ +// RUN: -emit-module -emit-module-path %t/MultiModuleStruct2NoVJP.swiftmodule -module-name MultiModuleStruct2NoVJP -I%t -L%t -lMultiModuleStruct1 %target-rpath(%t) 2>&1 | \ +// RUN: %FileCheck %s + +// CHECK: file2_no_vjp.swift:6:4: error: function is not differentiable + +import MultiModuleStruct1 +import MultiModuleStruct2 +import StdlibUnittest +import _Differentiation + +var AlwaysEmitIntoClientTests = TestSuite("AlwaysEmitIntoClient") + +AlwaysEmitIntoClientTests.test("registration") { + func foo(x: Struct) -> Float { x.sum() } + expectEqual(42, differential(at: Struct(0), of: foo)(Struct(1))) + expectEqual(42, differential(at: Struct(1), of: foo)(Struct(1))) + expectEqual(42, differential(at: Struct(2), of: foo)(Struct(1))) +} + +runAllTests() diff --git a/test/AutoDiff/validation-test/always_emit_into_client/single_file.swift b/test/AutoDiff/validation-test/always_emit_into_client/single_file.swift new file mode 100644 index 0000000000000..7d293232b5414 --- /dev/null +++ b/test/AutoDiff/validation-test/always_emit_into_client/single_file.swift @@ -0,0 +1,28 @@ +// REQUIRES: executable_test +// RUN: %empty-directory(%t) + +/// Note: we build just a module without a library since it would not contain any exported +/// symbols because all the functions in the module are marked as @_alwaysEmitIntoClient. +// RUN: %target-build-swift %S/Inputs/SingleFileModule/file.swift -emit-module \ +// RUN: -emit-module-path %t/SingleFileModule.swiftmodule -module-name SingleFileModule + +// RUN: %target-build-swift -I%t %s -o %t/a.out %target-rpath(%t) +// RUN: %target-run %t/a.out + +// RUN: %target-build-swift -I%t %s -emit-ir | %FileCheck %s + +import SingleFileModule +import StdlibUnittest +import _Differentiation + +var AlwaysEmitIntoClientTests = TestSuite("AlwaysEmitIntoClient") + +AlwaysEmitIntoClientTests.test("registration") { + expectEqual(42, gradient(at: 0, of: f)) + expectEqual(42, gradient(at: 1, of: f)) + expectEqual(42, gradient(at: 2, of: f)) +} + +runAllTests() + +// CHECK: @"16SingleFileModule1fyS2fFWJrSpSr" = weak_odr hidden {{()|local_unnamed_addr }}global { ptr, ptr } { ptr @"$s16SingleFileModule1fyS2fFTJfSpSr", ptr @"$s16SingleFileModule1fyS2fFTJrSpSr" }