Skip to content

Commit 971f187

Browse files
committed
Correctly consume vjp value
1 parent 29b04b9 commit 971f187

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

lib/SILOptimizer/Differentiation/VJPCloner.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,9 +438,11 @@ class VJPCloner::Implementation final
438438
LLVM_DEBUG(getADDebugStream() << "VJP-transforming:\n" << *eai << '\n');
439439

440440
FullApplySite fai(token->getDefiningInstruction());
441-
auto vjpResult = getBuilder().createEndApply(loc, token, fai.getType());
441+
auto vjpResult = builder.createEndApply(loc, token, fai.getType());
442442
LLVM_DEBUG(getADDebugStream() << "Created end_apply\n" << *vjpResult);
443443

444+
builder.emitDestroyValueOperation(loc, fai.getCallee());
445+
444446
// Checkpoint the pullback.
445447
SmallVector<SILValue, 8> vjpDirectResults;
446448
extractAllElements(vjpResult, getBuilder(), vjpDirectResults);
@@ -603,7 +605,7 @@ class VJPCloner::Implementation final
603605
auto *vjpCall = getBuilder().createBeginApply(loc, vjpValue, SubstitutionMap(),
604606
vjpArgs, bai->getApplyOptions());
605607
LLVM_DEBUG(getADDebugStream() << "Applied vjp function\n" << *vjpCall);
606-
builder.emitDestroyValueOperation(loc, vjpValue);
608+
// Note that vjpValue is destroyed after end_apply
607609

608610
// Store all the results (yields and token) to the value map.
609611
assert(bai->getNumResults() == vjpCall->getNumResults());

0 commit comments

Comments
 (0)