Skip to content

Commit 90ec31a

Browse files
committed
[AutoDiff] Compute derivative types using requirements from archetypes.
Resolves rdar://84213107 and partially resolves rdar://82549134.
1 parent 2ccf3a2 commit 90ec31a

File tree

6 files changed

+67
-18
lines changed

6 files changed

+67
-18
lines changed

include/swift/AST/Types.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -4653,7 +4653,9 @@ class SILFunctionType final
46534653
AutoDiffDerivativeFunctionKind kind, Lowering::TypeConverter &TC,
46544654
LookupConformanceFn lookupConformance,
46554655
CanGenericSignature derivativeFunctionGenericSignature = nullptr,
4656-
bool isReabstractionThunk = false);
4656+
bool isReabstractionThunk = false,
4657+
CanType origTypeOfAbstraction = CanType());
4658+
46574659

46584660
/// Returns the type of the transpose function for the given parameter
46594661
/// indices, transpose function generic signature (optional), and other

lib/SIL/IR/SILFunctionType.cpp

+28-9
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,8 @@ getSemanticResults(SILFunctionType *functionType, IndexSubset *parameterIndices,
362362
}
363363

364364
static CanGenericSignature buildDifferentiableGenericSignature(CanGenericSignature sig,
365-
CanType tanType) {
365+
CanType tanType,
366+
CanType origTypeOfAbstraction) {
366367
if (!sig)
367368
return sig;
368369

@@ -390,6 +391,20 @@ static CanGenericSignature buildDifferentiableGenericSignature(CanGenericSignatu
390391
}
391392
}
392393

394+
if (origTypeOfAbstraction) {
395+
(void) origTypeOfAbstraction.findIf([&](Type t) -> bool {
396+
if (auto *at = t->getAs<ArchetypeType>()) {
397+
types.insert(at->getInterfaceType()->getCanonicalType());
398+
for (auto *proto : at->getConformsTo()) {
399+
reqs.push_back(Requirement(RequirementKind::Conformance,
400+
at->getInterfaceType(),
401+
proto->getDeclaredInterfaceType()));
402+
}
403+
}
404+
return false;
405+
});
406+
}
407+
393408
return evaluateOrDefault(
394409
ctx.evaluator,
395410
AbstractGenericSignatureRequest{sig.getPointer(), {}, reqs},
@@ -427,14 +442,15 @@ static CanType getAutoDiffTangentTypeForLinearMap(
427442
static CanSILFunctionType getAutoDiffDifferentialType(
428443
SILFunctionType *originalFnTy, IndexSubset *parameterIndices,
429444
IndexSubset *resultIndices, LookupConformanceFn lookupConformance,
445+
CanType origTypeOfAbstraction,
430446
TypeConverter &TC) {
431447
// Given the tangent type and the corresponding original parameter's
432448
// convention, returns the tangent parameter's convention.
433449
auto getTangentParameterConvention =
434450
[&](CanType tanType,
435451
ParameterConvention origParamConv) -> ParameterConvention {
436452
auto sig = buildDifferentiableGenericSignature(
437-
originalFnTy->getSubstGenericSignature(), tanType);
453+
originalFnTy->getSubstGenericSignature(), tanType, origTypeOfAbstraction);
438454

439455
tanType = tanType->getCanonicalType(sig);
440456
AbstractionPattern pattern(sig, tanType);
@@ -462,7 +478,7 @@ static CanSILFunctionType getAutoDiffDifferentialType(
462478
[&](CanType tanType,
463479
ResultConvention origResConv) -> ResultConvention {
464480
auto sig = buildDifferentiableGenericSignature(
465-
originalFnTy->getSubstGenericSignature(), tanType);
481+
originalFnTy->getSubstGenericSignature(), tanType, origTypeOfAbstraction);
466482

467483
tanType = tanType->getCanonicalType(sig);
468484
AbstractionPattern pattern(sig, tanType);
@@ -565,7 +581,7 @@ static CanSILFunctionType getAutoDiffDifferentialType(
565581
static CanSILFunctionType getAutoDiffPullbackType(
566582
SILFunctionType *originalFnTy, IndexSubset *parameterIndices,
567583
IndexSubset *resultIndices, LookupConformanceFn lookupConformance,
568-
TypeConverter &TC) {
584+
CanType origTypeOfAbstraction, TypeConverter &TC) {
569585
auto &ctx = originalFnTy->getASTContext();
570586
SmallVector<GenericTypeParamType *, 4> substGenericParams;
571587
SmallVector<Requirement, 4> substRequirements;
@@ -582,7 +598,7 @@ static CanSILFunctionType getAutoDiffPullbackType(
582598
[&](CanType tanType,
583599
ResultConvention origResConv) -> ParameterConvention {
584600
auto sig = buildDifferentiableGenericSignature(
585-
originalFnTy->getSubstGenericSignature(), tanType);
601+
originalFnTy->getSubstGenericSignature(), tanType, origTypeOfAbstraction);
586602

587603
tanType = tanType->getCanonicalType(sig);
588604
AbstractionPattern pattern(sig, tanType);
@@ -613,7 +629,7 @@ static CanSILFunctionType getAutoDiffPullbackType(
613629
[&](CanType tanType,
614630
ParameterConvention origParamConv) -> ResultConvention {
615631
auto sig = buildDifferentiableGenericSignature(
616-
originalFnTy->getSubstGenericSignature(), tanType);
632+
originalFnTy->getSubstGenericSignature(), tanType, origTypeOfAbstraction);
617633

618634
tanType = tanType->getCanonicalType(sig);
619635
AbstractionPattern pattern(sig, tanType);
@@ -780,7 +796,8 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
780796
AutoDiffDerivativeFunctionKind kind, TypeConverter &TC,
781797
LookupConformanceFn lookupConformance,
782798
CanGenericSignature derivativeFnInvocationGenSig,
783-
bool isReabstractionThunk) {
799+
bool isReabstractionThunk,
800+
CanType origTypeOfAbstraction) {
784801
assert(parameterIndices);
785802
assert(!parameterIndices->isEmpty() && "Parameter indices must not be empty");
786803
assert(resultIndices);
@@ -810,12 +827,14 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
810827
case AutoDiffDerivativeFunctionKind::JVP:
811828
closureType =
812829
getAutoDiffDifferentialType(constrainedOriginalFnTy, parameterIndices,
813-
resultIndices, lookupConformance, TC);
830+
resultIndices, lookupConformance,
831+
origTypeOfAbstraction, TC);
814832
break;
815833
case AutoDiffDerivativeFunctionKind::VJP:
816834
closureType =
817835
getAutoDiffPullbackType(constrainedOriginalFnTy, parameterIndices,
818-
resultIndices, lookupConformance, TC);
836+
resultIndices, lookupConformance,
837+
origTypeOfAbstraction, TC);
819838
break;
820839
}
821840
// Compute the derivative function parameters.

lib/SIL/IR/TypeLowering.cpp

+9-5
Original file line numberDiff line numberDiff line change
@@ -331,19 +331,23 @@ namespace {
331331
CanSILFunctionType type, AbstractionPattern origType) {
332332
auto &M = TC.M;
333333
auto origTy = type->getWithoutDifferentiability();
334-
// Pass the `AbstractionPattern` generic signature to
335-
// `SILFunctionType:getAutoDiffDerivativeFunctionType` for correct type
336-
// lowering.
334+
// Pass the original type of abstraction pattern to
335+
// `SILFunctionType:getAutoDiffDerivativeFunctionType` to get the
336+
// necessary generic requirements.
337+
auto origTypeOfAbstraction =
338+
origType.hasGenericSignature() ? origType.getType() : CanType();
337339
auto jvpTy = origTy->getAutoDiffDerivativeFunctionType(
338340
type->getDifferentiabilityParameterIndices(),
339341
type->getDifferentiabilityResultIndices(),
340342
AutoDiffDerivativeFunctionKind::JVP, TC,
341-
LookUpConformanceInModule(&M), CanGenericSignature());
343+
LookUpConformanceInModule(&M), CanGenericSignature(),
344+
false, origTypeOfAbstraction);
342345
auto vjpTy = origTy->getAutoDiffDerivativeFunctionType(
343346
type->getDifferentiabilityParameterIndices(),
344347
type->getDifferentiabilityResultIndices(),
345348
AutoDiffDerivativeFunctionKind::VJP, TC,
346-
LookUpConformanceInModule(&M), CanGenericSignature());
349+
LookUpConformanceInModule(&M), CanGenericSignature(),
350+
false, origTypeOfAbstraction);
347351
RecursiveProperties props;
348352
props.addSubobject(classifyType(origType, origTy, TC, Expansion));
349353
props.addSubobject(classifyType(origType, jvpTy, TC, Expansion));

test/AutoDiff/SILOptimizer/semantic_member_accessors_sil.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-frontend -emit-sil -Xllvm -sil-print-after=differentiation %s -module-name null -o /dev/null -requirement-machine=off 2>&1 | %FileCheck %s
1+
// RUN: %target-swift-frontend -emit-sil -Xllvm -sil-print-after=differentiation %s -module-name null -o /dev/null 2>&1 | %FileCheck %s
22

33
// Test differentiation of semantic member accessors:
44
// - Stored property accessors.

test/AutoDiff/compiler_crashers/pr32302-autodiff-generictypeparamdecl-has-incorrect-depth.swift

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// RUN: %empty-directory(%t)
2-
// RUN: not --crash %target-build-swift -emit-module -module-name pr32302 -emit-module-path %t/pr32302.swiftmodule -swift-version 5 -c %S/pr32302-autodiff-generictypeparamdecl-has-incorrect-depth.swift -Xfrontend -requirement-machine=off
2+
// RUN: not --crash %target-build-swift -emit-module -module-name pr32302 -emit-module-path %t/pr32302.swiftmodule -swift-version 5 -c %S/pr32302-autodiff-generictypeparamdecl-has-incorrect-depth.swift
33
// XFAIL: *
44

55
// pr32302 / pr32343 / pr38745 : reproduce assert with _Differentiation where
@@ -28,7 +28,7 @@ extension Differentiable {
2828
// GenericTypeParamDecl has incorrect depth
2929
// Please submit a bug report (https://swift.org/contributing/#reporting-bugs) and include the project and the crash backtrace.
3030
// Stack dump:
31-
// 0. Program arguments: /work/software/swift-stocktoolchain/build/ds/swift-linux-x86_64/bin/swift-frontend -frontend -merge-modules -emit-module /tmp/pr32302-autodiff-generictypeparamdecl-has-incorrect-depth-acc95c.swiftmodule -parse-as-library -disable-diagnostic-passes -disable-sil-perf-optzns -target x86_64-unknown-linux-gnu -warn-on-potentially-unavailable-enum-case -disable-objc-interop -module-cache-path /work/software/swift-stocktoolchain/build/ds/swift-linux-x86_64/swift-test-results/x86_64-unknown-linux-gnu/clang-module-cache -swift-version 5 -define-availability "SwiftStdlib 5.5:macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0" -requirement-machine=off -emit-module-doc-path /work/software/swift-stocktoolchain/build/ds/swift-linux-x86_64/test-linux-x86_64/AutoDiff/compiler_crashers/Output/pr32302-autodiff-generictypeparamdecl-has-incorrect-depth.swift.tmp/pr32302.swiftdoc -emit-module-source-info-path /work/software/swift-stocktoolchain/build/ds/swift-linux-x86_64/test-linux-x86_64/AutoDiff/compiler_crashers/Output/pr32302-autodiff-generictypeparamdecl-has-incorrect-depth.swift.tmp/pr32302.swiftsourceinfo -module-name pr32302 -o /work/software/swift-stocktoolchain/build/ds/swift-linux-x86_64/test-linux-x86_64/AutoDiff/compiler_crashers/Output/pr32302-autodiff-generictypeparamdecl-has-incorrect-depth.swift.tmp/pr32302.swiftmodule
31+
// 0. Program arguments: /work/software/swift-stocktoolchain/build/ds/swift-linux-x86_64/bin/swift-frontend -frontend -merge-modules -emit-module /tmp/pr32302-autodiff-generictypeparamdecl-has-incorrect-depth-acc95c.swiftmodule -parse-as-library -disable-diagnostic-passes -disable-sil-perf-optzns -target x86_64-unknown-linux-gnu -warn-on-potentially-unavailable-enum-case -disable-objc-interop -module-cache-path /work/software/swift-stocktoolchain/build/ds/swift-linux-x86_64/swift-test-results/x86_64-unknown-linux-gnu/clang-module-cache -swift-version 5 -define-availability "SwiftStdlib 5.5:macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0" -emit-module-doc-path /work/software/swift-stocktoolchain/build/ds/swift-linux-x86_64/test-linux-x86_64/AutoDiff/compiler_crashers/Output/pr32302-autodiff-generictypeparamdecl-has-incorrect-depth.swift.tmp/pr32302.swiftdoc -emit-module-source-info-path /work/software/swift-stocktoolchain/build/ds/swift-linux-x86_64/test-linux-x86_64/AutoDiff/compiler_crashers/Output/pr32302-autodiff-generictypeparamdecl-has-incorrect-depth.swift.tmp/pr32302.swiftsourceinfo -module-name pr32302 -o /work/software/swift-stocktoolchain/build/ds/swift-linux-x86_64/test-linux-x86_64/AutoDiff/compiler_crashers/Output/pr32302-autodiff-generictypeparamdecl-has-incorrect-depth.swift.tmp/pr32302.swiftmodule
3232
// 1. Swift version 5.6-dev (LLVM ba0b85f590c1ba2, Swift 319b3e64aaeb252)
3333
// 2. Compiling with the current language version
3434
// 3. While verifying GenericTypeParamDecl 'τ_1_0' (in module 'pr32302')
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// RUN: %target-build-swift %s
2+
3+
import _Differentiation
4+
5+
public protocol Layer {
6+
associatedtype Input: Differentiable
7+
associatedtype Output: Differentiable
8+
func callAsFunction(_ input: Input) -> Output
9+
}
10+
11+
public class Function<Input: Differentiable, Output: Differentiable>: Layer {
12+
public typealias Body = @differentiable(reverse) (Input) -> Output
13+
14+
@noDerivative public let body: Body
15+
16+
public init(_ body: @escaping Body) {
17+
self.body = body
18+
}
19+
20+
@differentiable(reverse)
21+
public func callAsFunction(_ input: Input) -> Output {
22+
body(input)
23+
}
24+
}

0 commit comments

Comments
 (0)