@@ -608,51 +608,75 @@ class AdjointGenerator
608
608
I.getOpcode () == CastInst::CastOps::PtrToInt)
609
609
return ;
610
610
611
- if (Mode == DerivativeMode::ReverseModePrimal)
611
+ switch (Mode) {
612
+ case DerivativeMode::ReverseModePrimal: {
612
613
return ;
614
+ }
615
+ case DerivativeMode::ReverseModeGradient:
616
+ case DerivativeMode::ReverseModeCombined: {
617
+ Value *orig_op0 = I.getOperand (0 );
618
+ Value *op0 = gutils->getNewFromOriginal (orig_op0);
613
619
614
- Value *orig_op0 = I. getOperand ( 0 );
615
- Value *op0 = gutils-> getNewFromOriginal (orig_op0 );
620
+ IRBuilder<> Builder2 (I. getParent () );
621
+ getReverseBuilder (Builder2 );
616
622
617
- IRBuilder<> Builder2 (I. getParent ());
618
- getReverseBuilder ( Builder2);
623
+ if (!gutils-> isConstantValue (orig_op0)) {
624
+ Value *dif = diffe (&I, Builder2);
619
625
620
- if (!gutils->isConstantValue (orig_op0)) {
621
- Value *dif = diffe (&I, Builder2);
622
-
623
- size_t size = 1 ;
624
- if (orig_op0->getType ()->isSized ())
625
- size = (gutils->newFunc ->getParent ()->getDataLayout ().getTypeSizeInBits (
626
- orig_op0->getType ()) +
627
- 7 ) /
628
- 8 ;
629
- Type *FT = TR.addingType (size, orig_op0);
630
- if (!FT) {
631
- llvm::errs () << " " << *gutils->oldFunc << " \n " ;
632
- TR.dump ();
633
- llvm::errs () << " " << *orig_op0 << " \n " ;
626
+ size_t size = 1 ;
627
+ if (orig_op0->getType ()->isSized ())
628
+ size =
629
+ (gutils->newFunc ->getParent ()->getDataLayout ().getTypeSizeInBits (
630
+ orig_op0->getType ()) +
631
+ 7 ) /
632
+ 8 ;
633
+ Type *FT = TR.addingType (size, orig_op0);
634
+ if (!FT) {
635
+ llvm::errs () << " " << *gutils->oldFunc << " \n " ;
636
+ TR.dump ();
637
+ llvm::errs () << " " << *orig_op0 << " \n " ;
638
+ }
639
+ assert (FT);
640
+ if (I.getOpcode () == CastInst::CastOps::FPTrunc ||
641
+ I.getOpcode () == CastInst::CastOps::FPExt) {
642
+ addToDiffe (orig_op0, Builder2.CreateFPCast (dif, op0->getType ()),
643
+ Builder2, FT);
644
+ } else if (I.getOpcode () == CastInst::CastOps::BitCast) {
645
+ addToDiffe (orig_op0, Builder2.CreateBitCast (dif, op0->getType ()),
646
+ Builder2, FT);
647
+ } else if (I.getOpcode () == CastInst::CastOps::Trunc) {
648
+ // TODO CHECK THIS
649
+ auto trunced = Builder2.CreateZExt (dif, op0->getType ());
650
+ addToDiffe (orig_op0, trunced, Builder2, FT);
651
+ } else {
652
+ TR.dump ();
653
+ llvm::errs () << *I.getParent ()->getParent () << " \n "
654
+ << *I.getParent () << " \n " ;
655
+ llvm::errs () << " cannot handle above cast " << I << " \n " ;
656
+ report_fatal_error (" unknown instruction" );
657
+ }
634
658
}
635
- assert (FT);
636
- if (I.getOpcode () == CastInst::CastOps::FPTrunc ||
637
- I.getOpcode () == CastInst::CastOps::FPExt) {
638
- addToDiffe (orig_op0, Builder2.CreateFPCast (dif, op0->getType ()),
639
- Builder2, FT);
640
- } else if (I.getOpcode () == CastInst::CastOps::BitCast) {
641
- addToDiffe (orig_op0, Builder2.CreateBitCast (dif, op0->getType ()),
642
- Builder2, FT);
643
- } else if (I.getOpcode () == CastInst::CastOps::Trunc) {
644
- // TODO CHECK THIS
645
- auto trunced = Builder2.CreateZExt (dif, op0->getType ());
646
- addToDiffe (orig_op0, trunced, Builder2, FT);
659
+ setDiffe (&I, Constant::getNullValue (I.getType ()), Builder2);
660
+
661
+ break ;
662
+ }
663
+ case DerivativeMode::ForwardMode: {
664
+ Value *orig_op0 = I.getOperand (0 );
665
+
666
+ IRBuilder<> Builder2 (&I);
667
+ getForwardBuilder (Builder2);
668
+
669
+ if (!gutils->isConstantValue (orig_op0)) {
670
+ Value *dif = diffe (orig_op0, Builder2);
671
+ setDiffe (&I, Builder2.CreateCast (I.getOpcode (), dif, I.getType ()),
672
+ Builder2);
647
673
} else {
648
- TR.dump ();
649
- llvm::errs () << *I.getParent ()->getParent () << " \n "
650
- << *I.getParent () << " \n " ;
651
- llvm::errs () << " cannot handle above cast " << I << " \n " ;
652
- report_fatal_error (" unknown instruction" );
674
+ setDiffe (&I, Constant::getNullValue (I.getType ()), Builder2);
653
675
}
676
+
677
+ break ;
678
+ }
654
679
}
655
- setDiffe (&I, Constant::getNullValue (I.getType ()), Builder2);
656
680
}
657
681
658
682
void visitSelectInst (llvm::SelectInst &SI) {
0 commit comments