Skip to content

Commit 10e61f3

Browse files
authored
Forward mode cast inst (#194)
* implemented forward mode visitCastInst
1 parent d816645 commit 10e61f3

File tree

3 files changed

+105
-37
lines changed

3 files changed

+105
-37
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 61 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -608,51 +608,75 @@ class AdjointGenerator
608608
I.getOpcode() == CastInst::CastOps::PtrToInt)
609609
return;
610610

611-
if (Mode == DerivativeMode::ReverseModePrimal)
611+
switch (Mode) {
612+
case DerivativeMode::ReverseModePrimal: {
612613
return;
614+
}
615+
case DerivativeMode::ReverseModeGradient:
616+
case DerivativeMode::ReverseModeCombined: {
617+
Value *orig_op0 = I.getOperand(0);
618+
Value *op0 = gutils->getNewFromOriginal(orig_op0);
613619

614-
Value *orig_op0 = I.getOperand(0);
615-
Value *op0 = gutils->getNewFromOriginal(orig_op0);
620+
IRBuilder<> Builder2(I.getParent());
621+
getReverseBuilder(Builder2);
616622

617-
IRBuilder<> Builder2(I.getParent());
618-
getReverseBuilder(Builder2);
623+
if (!gutils->isConstantValue(orig_op0)) {
624+
Value *dif = diffe(&I, Builder2);
619625

620-
if (!gutils->isConstantValue(orig_op0)) {
621-
Value *dif = diffe(&I, Builder2);
622-
623-
size_t size = 1;
624-
if (orig_op0->getType()->isSized())
625-
size = (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
626-
orig_op0->getType()) +
627-
7) /
628-
8;
629-
Type *FT = TR.addingType(size, orig_op0);
630-
if (!FT) {
631-
llvm::errs() << " " << *gutils->oldFunc << "\n";
632-
TR.dump();
633-
llvm::errs() << " " << *orig_op0 << "\n";
626+
size_t size = 1;
627+
if (orig_op0->getType()->isSized())
628+
size =
629+
(gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
630+
orig_op0->getType()) +
631+
7) /
632+
8;
633+
Type *FT = TR.addingType(size, orig_op0);
634+
if (!FT) {
635+
llvm::errs() << " " << *gutils->oldFunc << "\n";
636+
TR.dump();
637+
llvm::errs() << " " << *orig_op0 << "\n";
638+
}
639+
assert(FT);
640+
if (I.getOpcode() == CastInst::CastOps::FPTrunc ||
641+
I.getOpcode() == CastInst::CastOps::FPExt) {
642+
addToDiffe(orig_op0, Builder2.CreateFPCast(dif, op0->getType()),
643+
Builder2, FT);
644+
} else if (I.getOpcode() == CastInst::CastOps::BitCast) {
645+
addToDiffe(orig_op0, Builder2.CreateBitCast(dif, op0->getType()),
646+
Builder2, FT);
647+
} else if (I.getOpcode() == CastInst::CastOps::Trunc) {
648+
// TODO CHECK THIS
649+
auto trunced = Builder2.CreateZExt(dif, op0->getType());
650+
addToDiffe(orig_op0, trunced, Builder2, FT);
651+
} else {
652+
TR.dump();
653+
llvm::errs() << *I.getParent()->getParent() << "\n"
654+
<< *I.getParent() << "\n";
655+
llvm::errs() << "cannot handle above cast " << I << "\n";
656+
report_fatal_error("unknown instruction");
657+
}
634658
}
635-
assert(FT);
636-
if (I.getOpcode() == CastInst::CastOps::FPTrunc ||
637-
I.getOpcode() == CastInst::CastOps::FPExt) {
638-
addToDiffe(orig_op0, Builder2.CreateFPCast(dif, op0->getType()),
639-
Builder2, FT);
640-
} else if (I.getOpcode() == CastInst::CastOps::BitCast) {
641-
addToDiffe(orig_op0, Builder2.CreateBitCast(dif, op0->getType()),
642-
Builder2, FT);
643-
} else if (I.getOpcode() == CastInst::CastOps::Trunc) {
644-
// TODO CHECK THIS
645-
auto trunced = Builder2.CreateZExt(dif, op0->getType());
646-
addToDiffe(orig_op0, trunced, Builder2, FT);
659+
setDiffe(&I, Constant::getNullValue(I.getType()), Builder2);
660+
661+
break;
662+
}
663+
case DerivativeMode::ForwardMode: {
664+
Value *orig_op0 = I.getOperand(0);
665+
666+
IRBuilder<> Builder2(&I);
667+
getForwardBuilder(Builder2);
668+
669+
if (!gutils->isConstantValue(orig_op0)) {
670+
Value *dif = diffe(orig_op0, Builder2);
671+
setDiffe(&I, Builder2.CreateCast(I.getOpcode(), dif, I.getType()),
672+
Builder2);
647673
} else {
648-
TR.dump();
649-
llvm::errs() << *I.getParent()->getParent() << "\n"
650-
<< *I.getParent() << "\n";
651-
llvm::errs() << "cannot handle above cast " << I << "\n";
652-
report_fatal_error("unknown instruction");
674+
setDiffe(&I, Constant::getNullValue(I.getType()), Builder2);
653675
}
676+
677+
break;
678+
}
654679
}
655-
setDiffe(&I, Constant::getNullValue(I.getType()), Builder2);
656680
}
657681

658682
void visitSelectInst(llvm::SelectInst &SI) {
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -adce -instsimplify -S | FileCheck %s
2+
3+
define double @tester(double %x) {
4+
entry:
5+
%y = bitcast double %x to i64
6+
%z = bitcast i64 %y to double
7+
ret double %z
8+
}
9+
10+
define double @test_derivative(double %x) {
11+
entry:
12+
%0 = tail call double (double (double)*, ...) @__enzyme_fwddiff(double (double)* nonnull @tester, double %x, double 1.0)
13+
ret double %0
14+
}
15+
16+
declare double @__enzyme_fwddiff(double (double)*, ...)
17+
18+
; CHECK: define internal {{(dso_local )?}}{ double } @diffetester(double %x, double %"x'") {
19+
; CHECK-NEXT: entry:
20+
; CHECK-NEXT: %0 = insertvalue { double } undef, double %"x'", 0
21+
; CHECK-NEXT: ret { double } %0
22+
; CHECK-NEXT: }
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -adce -instsimplify -S | FileCheck %s
2+
3+
define double @tester(float %x) {
4+
entry:
5+
%y = fpext float %x to double
6+
ret double %y
7+
}
8+
9+
define double @test_derivative(float %x) {
10+
entry:
11+
%0 = tail call double (double (float)*, ...) @__enzyme_fwddiff(double (float)* nonnull @tester, float %x, float 1.0)
12+
ret double %0
13+
}
14+
15+
declare double @__enzyme_fwddiff(double (float)*, ...)
16+
17+
; CHECK: define internal {{(dso_local )?}}{ double } @diffetester(float %x, float %"x'") {
18+
; CHECK-NEXT: entry:
19+
; CHECK-NEXT: %0 = fpext float %"x'" to double
20+
; CHECK-NEXT: %1 = insertvalue { double } undef, double %0, 0
21+
; CHECK-NEXT: ret { double } %1
22+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)