@@ -358,22 +358,30 @@ GenericSignature autodiff::getConstrainedDerivativeGenericSignature(
358
358
// Require differentiability results to conform to `Differentiable`.
359
359
SmallVector<SILResultInfo, 2 > originalResults;
360
360
getSemanticResults (originalFnTy, diffParamIndices, originalResults);
361
+ unsigned firstSemanticParamResultIdx = originalFnTy->getNumResults ();
362
+ unsigned firstYieldResultIndex = originalFnTy->getNumResults () +
363
+ originalFnTy->getNumAutoDiffSemanticResultsParameters ();
361
364
for (unsigned resultIdx : diffResultIndices->getIndices ()) {
362
365
// Handle formal original result.
363
- if (resultIdx < originalFnTy-> getNumResults () ) {
366
+ if (resultIdx < firstSemanticParamResultIdx ) {
364
367
auto resultType = originalResults[resultIdx].getInterfaceType ();
365
368
addRequirement (resultType);
366
- continue ;
369
+ } else if (resultIdx < firstYieldResultIndex) {
370
+ // Handle original semantic result parameters.
371
+ auto resultParamIndex = resultIdx - originalFnTy->getNumResults ();
372
+ auto resultParamIt = std::next (
373
+ originalFnTy->getAutoDiffSemanticResultsParameters ().begin (),
374
+ resultParamIndex);
375
+ auto paramIndex =
376
+ std::distance (originalFnTy->getParameters ().begin (), &*resultParamIt);
377
+ addRequirement (originalFnTy->getParameters ()[paramIndex].getInterfaceType ());
378
+ } else {
379
+ // Handle formal original yields.
380
+ assert (originalFnTy->isCoroutine ());
381
+ assert (originalFnTy->getCoroutineKind () == SILCoroutineKind::YieldOnce);
382
+ auto yieldResultIndex = resultIdx - firstYieldResultIndex;
383
+ addRequirement (originalFnTy->getYields ()[yieldResultIndex].getInterfaceType ());
367
384
}
368
- // Handle original semantic result parameters.
369
- // FIXME: Constraint generic yields when we will start supporting them
370
- auto resultParamIndex = resultIdx - originalFnTy->getNumResults ();
371
- auto resultParamIt = std::next (
372
- originalFnTy->getAutoDiffSemanticResultsParameters ().begin (),
373
- resultParamIndex);
374
- auto paramIndex =
375
- std::distance (originalFnTy->getParameters ().begin (), &*resultParamIt);
376
- addRequirement (originalFnTy->getParameters ()[paramIndex].getInterfaceType ());
377
385
}
378
386
379
387
return buildGenericSignature (ctx, derivativeGenSig,
0 commit comments