@@ -362,7 +362,8 @@ getSemanticResults(SILFunctionType *functionType, IndexSubset *parameterIndices,
362
362
}
363
363
364
364
static CanGenericSignature buildDifferentiableGenericSignature (CanGenericSignature sig,
365
- CanType tanType) {
365
+ CanType tanType,
366
+ CanType origTypeOfAbstraction) {
366
367
if (!sig)
367
368
return sig;
368
369
@@ -390,6 +391,20 @@ static CanGenericSignature buildDifferentiableGenericSignature(CanGenericSignatu
390
391
}
391
392
}
392
393
394
+ if (origTypeOfAbstraction) {
395
+ (void ) origTypeOfAbstraction.findIf ([&](Type t) -> bool {
396
+ if (auto *at = t->getAs <ArchetypeType>()) {
397
+ types.insert (at->getInterfaceType ()->getCanonicalType ());
398
+ for (auto *proto : at->getConformsTo ()) {
399
+ reqs.push_back (Requirement (RequirementKind::Conformance,
400
+ at->getInterfaceType (),
401
+ proto->getDeclaredInterfaceType ()));
402
+ }
403
+ }
404
+ return false ;
405
+ });
406
+ }
407
+
393
408
return evaluateOrDefault (
394
409
ctx.evaluator ,
395
410
AbstractGenericSignatureRequest{sig.getPointer (), {}, reqs},
@@ -427,14 +442,15 @@ static CanType getAutoDiffTangentTypeForLinearMap(
427
442
static CanSILFunctionType getAutoDiffDifferentialType (
428
443
SILFunctionType *originalFnTy, IndexSubset *parameterIndices,
429
444
IndexSubset *resultIndices, LookupConformanceFn lookupConformance,
445
+ CanType origTypeOfAbstraction,
430
446
TypeConverter &TC) {
431
447
// Given the tangent type and the corresponding original parameter's
432
448
// convention, returns the tangent parameter's convention.
433
449
auto getTangentParameterConvention =
434
450
[&](CanType tanType,
435
451
ParameterConvention origParamConv) -> ParameterConvention {
436
452
auto sig = buildDifferentiableGenericSignature (
437
- originalFnTy->getSubstGenericSignature (), tanType);
453
+ originalFnTy->getSubstGenericSignature (), tanType, origTypeOfAbstraction );
438
454
439
455
tanType = tanType->getCanonicalType (sig);
440
456
AbstractionPattern pattern (sig, tanType);
@@ -462,7 +478,7 @@ static CanSILFunctionType getAutoDiffDifferentialType(
462
478
[&](CanType tanType,
463
479
ResultConvention origResConv) -> ResultConvention {
464
480
auto sig = buildDifferentiableGenericSignature (
465
- originalFnTy->getSubstGenericSignature (), tanType);
481
+ originalFnTy->getSubstGenericSignature (), tanType, origTypeOfAbstraction );
466
482
467
483
tanType = tanType->getCanonicalType (sig);
468
484
AbstractionPattern pattern (sig, tanType);
@@ -565,7 +581,7 @@ static CanSILFunctionType getAutoDiffDifferentialType(
565
581
static CanSILFunctionType getAutoDiffPullbackType (
566
582
SILFunctionType *originalFnTy, IndexSubset *parameterIndices,
567
583
IndexSubset *resultIndices, LookupConformanceFn lookupConformance,
568
- TypeConverter &TC) {
584
+ CanType origTypeOfAbstraction, TypeConverter &TC) {
569
585
auto &ctx = originalFnTy->getASTContext ();
570
586
SmallVector<GenericTypeParamType *, 4 > substGenericParams;
571
587
SmallVector<Requirement, 4 > substRequirements;
@@ -582,7 +598,7 @@ static CanSILFunctionType getAutoDiffPullbackType(
582
598
[&](CanType tanType,
583
599
ResultConvention origResConv) -> ParameterConvention {
584
600
auto sig = buildDifferentiableGenericSignature (
585
- originalFnTy->getSubstGenericSignature (), tanType);
601
+ originalFnTy->getSubstGenericSignature (), tanType, origTypeOfAbstraction );
586
602
587
603
tanType = tanType->getCanonicalType (sig);
588
604
AbstractionPattern pattern (sig, tanType);
@@ -613,7 +629,7 @@ static CanSILFunctionType getAutoDiffPullbackType(
613
629
[&](CanType tanType,
614
630
ParameterConvention origParamConv) -> ResultConvention {
615
631
auto sig = buildDifferentiableGenericSignature (
616
- originalFnTy->getSubstGenericSignature (), tanType);
632
+ originalFnTy->getSubstGenericSignature (), tanType, origTypeOfAbstraction );
617
633
618
634
tanType = tanType->getCanonicalType (sig);
619
635
AbstractionPattern pattern (sig, tanType);
@@ -780,7 +796,8 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
780
796
AutoDiffDerivativeFunctionKind kind, TypeConverter &TC,
781
797
LookupConformanceFn lookupConformance,
782
798
CanGenericSignature derivativeFnInvocationGenSig,
783
- bool isReabstractionThunk) {
799
+ bool isReabstractionThunk,
800
+ CanType origTypeOfAbstraction) {
784
801
assert (parameterIndices);
785
802
assert (!parameterIndices->isEmpty () && " Parameter indices must not be empty" );
786
803
assert (resultIndices);
@@ -810,12 +827,14 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
810
827
case AutoDiffDerivativeFunctionKind::JVP:
811
828
closureType =
812
829
getAutoDiffDifferentialType (constrainedOriginalFnTy, parameterIndices,
813
- resultIndices, lookupConformance, TC);
830
+ resultIndices, lookupConformance,
831
+ origTypeOfAbstraction, TC);
814
832
break ;
815
833
case AutoDiffDerivativeFunctionKind::VJP:
816
834
closureType =
817
835
getAutoDiffPullbackType (constrainedOriginalFnTy, parameterIndices,
818
- resultIndices, lookupConformance, TC);
836
+ resultIndices, lookupConformance,
837
+ origTypeOfAbstraction, TC);
819
838
break ;
820
839
}
821
840
// Compute the derivative function parameters.
0 commit comments