Skip to content

Commit 4a805c4

Browse files
committed
Take into account generic types of yields while adding Differentiability requirements
1 parent 971f187 commit 4a805c4

File tree

1 file changed

+19
-11
lines changed

1 file changed

+19
-11
lines changed

lib/AST/AutoDiff.cpp

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -358,22 +358,30 @@ GenericSignature autodiff::getConstrainedDerivativeGenericSignature(
358358
// Require differentiability results to conform to `Differentiable`.
359359
SmallVector<SILResultInfo, 2> originalResults;
360360
getSemanticResults(originalFnTy, diffParamIndices, originalResults);
361+
unsigned firstSemanticParamResultIdx = originalFnTy->getNumResults();
362+
unsigned firstYieldResultIndex = originalFnTy->getNumResults() +
363+
originalFnTy->getNumAutoDiffSemanticResultsParameters();
361364
for (unsigned resultIdx : diffResultIndices->getIndices()) {
362365
// Handle formal original result.
363-
if (resultIdx < originalFnTy->getNumResults()) {
366+
if (resultIdx < firstSemanticParamResultIdx) {
364367
auto resultType = originalResults[resultIdx].getInterfaceType();
365368
addRequirement(resultType);
366-
continue;
369+
} else if (resultIdx < firstYieldResultIndex) {
370+
// Handle original semantic result parameters.
371+
auto resultParamIndex = resultIdx - originalFnTy->getNumResults();
372+
auto resultParamIt = std::next(
373+
originalFnTy->getAutoDiffSemanticResultsParameters().begin(),
374+
resultParamIndex);
375+
auto paramIndex =
376+
std::distance(originalFnTy->getParameters().begin(), &*resultParamIt);
377+
addRequirement(originalFnTy->getParameters()[paramIndex].getInterfaceType());
378+
} else {
379+
// Handle formal original yields.
380+
assert(originalFnTy->isCoroutine());
381+
assert(originalFnTy->getCoroutineKind() == SILCoroutineKind::YieldOnce);
382+
auto yieldResultIndex = resultIdx - firstYieldResultIndex;
383+
addRequirement(originalFnTy->getYields()[yieldResultIndex].getInterfaceType());
367384
}
368-
// Handle original semantic result parameters.
369-
// FIXME: Constraint generic yields when we will start supporting them
370-
auto resultParamIndex = resultIdx - originalFnTy->getNumResults();
371-
auto resultParamIt = std::next(
372-
originalFnTy->getAutoDiffSemanticResultsParameters().begin(),
373-
resultParamIndex);
374-
auto paramIndex =
375-
std::distance(originalFnTy->getParameters().begin(), &*resultParamIt);
376-
addRequirement(originalFnTy->getParameters()[paramIndex].getInterfaceType());
377385
}
378386

379387
return buildGenericSignature(ctx, derivativeGenSig,

0 commit comments

Comments
 (0)