@@ -1166,45 +1166,76 @@ class AdjointGenerator
1166
1166
eraseIfUnused(IEI);
1167
1167
if (gutils->isConstantInstruction(&IEI))
1168
1168
return;
1169
- if (Mode == DerivativeMode::ReverseModePrimal)
1170
- return;
1171
1169
1172
- IRBuilder<> Builder2(IEI.getParent());
1173
- getReverseBuilder(Builder2);
1170
+ switch (Mode) {
1171
+ case DerivativeMode::ForwardMode: {
1172
+ IRBuilder<> Builder2(&IEI);
1173
+ getForwardBuilder(Builder2);
1174
1174
1175
- Value *dif1 = diffe(&IEI, Builder2);
1175
+ Value *orig_vector = IEI.getOperand(0);
1176
+ Value *orig_inserted = IEI.getOperand(1);
1177
+ Value *orig_index = IEI.getOperand(2);
1176
1178
1177
- Value *orig_op0 = IEI.getOperand(0);
1178
- Value *orig_op1 = IEI.getOperand(1);
1179
- Value *op1 = gutils->getNewFromOriginal(orig_op1);
1180
- Value *op2 = gutils->getNewFromOriginal(IEI.getOperand(2));
1179
+ Value *diff_inserted = gutils->isConstantValue(orig_inserted)
1180
+ ? ConstantFP::get(orig_inserted->getType(), 0)
1181
+ : diffe(orig_inserted, Builder2);
1181
1182
1182
- size_t size0 = 1;
1183
- if (orig_op0->getType()->isSized())
1184
- size0 = (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
1185
- orig_op0->getType()) +
1186
- 7) /
1187
- 8;
1188
- size_t size1 = 1;
1189
- if (orig_op1->getType()->isSized())
1190
- size1 = (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
1191
- orig_op1->getType()) +
1192
- 7) /
1193
- 8;
1183
+ Value *prediff =
1184
+ gutils->isConstantValue(orig_vector)
1185
+ ? diffe(orig_vector, Builder2)
1186
+ : ConstantVector::getNullValue(orig_vector->getType());
1187
+ auto dindex = Builder2.CreateInsertElement(
1188
+ prediff, diff_inserted, gutils->getNewFromOriginal(orig_index));
1189
+ setDiffe(&IEI, dindex, Builder2);
1194
1190
1195
- if (!gutils->isConstantValue(orig_op0))
1196
- addToDiffe(orig_op0,
1197
- Builder2.CreateInsertElement(
1198
- dif1, Constant::getNullValue(op1->getType()),
1199
- lookup(op2, Builder2)),
1200
- Builder2, TR.addingType(size0, orig_op0) );
1191
+ return;
1192
+ }
1193
+ case DerivativeMode::ReverseModeGradient:
1194
+ case DerivativeMode::ReverseModeCombined: {
1195
+ IRBuilder<> Builder2(IEI.getParent());
1196
+ getReverseBuilder(Builder2 );
1201
1197
1202
- if (!gutils->isConstantValue(orig_op1))
1203
- addToDiffe(orig_op1,
1204
- Builder2.CreateExtractElement(dif1, lookup(op2, Builder2)),
1205
- Builder2, TR.addingType(size1, orig_op1));
1198
+ Value *dif1 = diffe(&IEI, Builder2);
1199
+
1200
+ Value *orig_op0 = IEI.getOperand(0);
1201
+ Value *orig_op1 = IEI.getOperand(1);
1202
+ Value *op1 = gutils->getNewFromOriginal(orig_op1);
1203
+ Value *op2 = gutils->getNewFromOriginal(IEI.getOperand(2));
1206
1204
1207
- setDiffe(&IEI, Constant::getNullValue(IEI.getType()), Builder2);
1205
+ size_t size0 = 1;
1206
+ if (orig_op0->getType()->isSized())
1207
+ size0 =
1208
+ (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
1209
+ orig_op0->getType()) +
1210
+ 7) /
1211
+ 8;
1212
+ size_t size1 = 1;
1213
+ if (orig_op1->getType()->isSized())
1214
+ size1 =
1215
+ (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
1216
+ orig_op1->getType()) +
1217
+ 7) /
1218
+ 8;
1219
+
1220
+ if (!gutils->isConstantValue(orig_op0))
1221
+ addToDiffe(orig_op0,
1222
+ Builder2.CreateInsertElement(
1223
+ dif1, Constant::getNullValue(op1->getType()),
1224
+ lookup(op2, Builder2)),
1225
+ Builder2, TR.addingType(size0, orig_op0));
1226
+
1227
+ if (!gutils->isConstantValue(orig_op1))
1228
+ addToDiffe(orig_op1,
1229
+ Builder2.CreateExtractElement(dif1, lookup(op2, Builder2)),
1230
+ Builder2, TR.addingType(size1, orig_op1));
1231
+
1232
+ setDiffe(&IEI, Constant::getNullValue(IEI.getType()), Builder2);
1233
+ return;
1234
+ }
1235
+ case DerivativeMode::ReverseModePrimal: {
1236
+ return;
1237
+ }
1238
+ }
1208
1239
}
1209
1240
1210
1241
void visitShuffleVectorInst(llvm::ShuffleVectorInst &SVI) {
0 commit comments