@@ -438,9 +438,11 @@ class VJPCloner::Implementation final
438
438
LLVM_DEBUG (getADDebugStream () << " VJP-transforming:\n " << *eai << ' \n ' );
439
439
440
440
FullApplySite fai (token->getDefiningInstruction ());
441
- auto vjpResult = getBuilder () .createEndApply (loc, token, fai.getType ());
441
+ auto vjpResult = builder .createEndApply (loc, token, fai.getType ());
442
442
LLVM_DEBUG (getADDebugStream () << " Created end_apply\n " << *vjpResult);
443
443
444
+ builder.emitDestroyValueOperation (loc, fai.getCallee ());
445
+
444
446
// Checkpoint the pullback.
445
447
SmallVector<SILValue, 8 > vjpDirectResults;
446
448
extractAllElements (vjpResult, getBuilder (), vjpDirectResults);
@@ -603,7 +605,7 @@ class VJPCloner::Implementation final
603
605
auto *vjpCall = getBuilder ().createBeginApply (loc, vjpValue, SubstitutionMap (),
604
606
vjpArgs, bai->getApplyOptions ());
605
607
LLVM_DEBUG (getADDebugStream () << " Applied vjp function\n " << *vjpCall);
606
- builder. emitDestroyValueOperation (loc, vjpValue);
608
+ // Note that vjpValue is destroyed after end_apply
607
609
608
610
// Store all the results (yields and token) to the value map.
609
611
assert (bai->getNumResults () == vjpCall->getNumResults ());
0 commit comments