Skip to content

Commit bcfca8b

Browse files
committed
Ensure we are consuming pullback arguments on unwind path
1 parent 4a805c4 commit bcfca8b

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

lib/SILOptimizer/Differentiation/VJPCloner.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,33 @@ class VJPCloner::Implementation final
305305
joinElements(directResults, Builder, loc));
306306
}
307307

308+
void visitUnwindInst(UnwindInst *ui) {
309+
Builder.setCurrentDebugScope(getOpScope(ui->getDebugScope()));
310+
auto loc = ui->getLoc();
311+
auto *origExit = ui->getParent();
312+
313+
// Consume unused pullback values
314+
if (borrowedPullbackContextValue) {
315+
auto *pbTupleVal = buildPullbackValueTupleValue(ui);
316+
// Initialize the top-level subcontext buffer with the top-level pullback
317+
// tuple.
318+
auto addr = emitProjectTopLevelSubcontext(
319+
Builder, loc, borrowedPullbackContextValue, pbTupleVal->getType());
320+
Builder.createStore(
321+
loc, pbTupleVal, addr,
322+
pbTupleVal->getType().isTrivial(*pullback) ?
323+
StoreOwnershipQualifier::Trivial : StoreOwnershipQualifier::Init);
324+
325+
Builder.createEndBorrow(loc, borrowedPullbackContextValue);
326+
Builder.emitDestroyValueOperation(loc, pullbackContextValue);
327+
} else {
328+
for (SILValue val : getPullbackValues(origExit))
329+
Builder.emitDestroyValueOperation(loc, val);
330+
}
331+
332+
Builder.createUnwind(loc);
333+
}
334+
308335
void visitBranchInst(BranchInst *bi) {
309336
Builder.setCurrentDebugScope(getOpScope(bi->getDebugScope()));
310337
// Build pullback struct value for original block.

0 commit comments

Comments
 (0)