@@ -676,7 +676,8 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
676
676
UnresolvedMaterializationRewrite (ConversionPatternRewriterImpl &rewriterImpl,
677
677
UnrealizedConversionCastOp op,
678
678
const TypeConverter *converter,
679
- MaterializationKind kind, Type originalType);
679
+ MaterializationKind kind, Type originalType,
680
+ Value mappedValue);
680
681
681
682
static bool classof (const IRRewrite *rewrite) {
682
683
return rewrite->getKind () == Kind::UnresolvedMaterialization;
@@ -710,6 +711,10 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
710
711
// / The original type of the SSA value. Only used for target
711
712
// / materializations.
712
713
Type originalType;
714
+
715
+ // / The value in the conversion value mapping that is being replaced by the
716
+ // / results of this unresolved materialization.
717
+ Value mappedValue;
713
718
};
714
719
} // namespace
715
720
@@ -814,10 +819,14 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
814
819
815
820
// / Build an unresolved materialization operation given an output type and set
816
821
// / of input operands.
822
+ // /
823
+ // / If `valueToMap` is set to a non-null Value, then that value is mapped to
824
+ // / the result of the unresolved materialization in the conversion value
825
+ // / mapping.
817
826
Value buildUnresolvedMaterialization (MaterializationKind kind,
818
827
OpBuilder::InsertPoint ip, Location loc,
819
- ValueRange inputs, Type outputType ,
820
- Type originalType,
828
+ Value valueToMap, ValueRange inputs ,
829
+ Type outputType, Type originalType,
821
830
const TypeConverter *converter);
822
831
823
832
// / Build an N:1 materialization for the given original value that was
@@ -1068,19 +1077,19 @@ void CreateOperationRewrite::rollback() {
1068
1077
1069
1078
UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite (
1070
1079
ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
1071
- const TypeConverter *converter, MaterializationKind kind, Type originalType)
1080
+ const TypeConverter *converter, MaterializationKind kind, Type originalType,
1081
+ Value mappedValue)
1072
1082
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
1073
- converterAndKind(converter, kind), originalType(originalType) {
1083
+ converterAndKind(converter, kind), originalType(originalType),
1084
+ mappedValue(mappedValue) {
1074
1085
assert ((!originalType || kind == MaterializationKind::Target) &&
1075
1086
" original type is valid only for target materializations" );
1076
1087
rewriterImpl.unresolvedMaterializations [op] = this ;
1077
1088
}
1078
1089
1079
1090
void UnresolvedMaterializationRewrite::rollback () {
1080
- if (getMaterializationKind () == MaterializationKind::Target) {
1081
- for (Value input : op->getOperands ())
1082
- rewriterImpl.mapping .erase (input);
1083
- }
1091
+ if (mappedValue)
1092
+ rewriterImpl.mapping .erase (mappedValue);
1084
1093
rewriterImpl.unresolvedMaterializations .erase (getOperation ());
1085
1094
op->erase ();
1086
1095
}
@@ -1176,10 +1185,9 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
1176
1185
// source materialization was created yet.
1177
1186
Value castValue = buildUnresolvedMaterialization (
1178
1187
MaterializationKind::Target, computeInsertPoint (newOperand),
1179
- operandLoc,
1180
- /* inputs=*/ newOperand, /* outputType=*/ desiredType,
1181
- /* originalType=*/ origType, currentTypeConverter);
1182
- mapping.map (newOperand, castValue);
1188
+ operandLoc, /* valueToMap=*/ newOperand, /* inputs=*/ newOperand,
1189
+ /* outputType=*/ desiredType, /* originalType=*/ origType,
1190
+ currentTypeConverter);
1183
1191
newOperand = castValue;
1184
1192
}
1185
1193
remapped.push_back (newOperand);
@@ -1293,12 +1301,11 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
1293
1301
if (!inputMap) {
1294
1302
// This block argument was dropped and no replacement value was provided.
1295
1303
// Materialize a replacement value "out of thin air".
1296
- Value repl = buildUnresolvedMaterialization (
1304
+ buildUnresolvedMaterialization (
1297
1305
MaterializationKind::Source,
1298
1306
OpBuilder::InsertPoint (newBlock, newBlock->begin ()), origArg.getLoc (),
1299
- /* inputs=*/ ValueRange (),
1307
+ /* valueToMap= */ origArg, /* inputs=*/ ValueRange (),
1300
1308
/* outputType=*/ origArgType, /* originalType=*/ Type (), converter);
1301
- mapping.map (origArg, repl);
1302
1309
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1303
1310
continue ;
1304
1311
}
@@ -1342,23 +1349,28 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
1342
1349
// / of input operands.
1343
1350
Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization (
1344
1351
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1345
- ValueRange inputs, Type outputType, Type originalType,
1352
+ Value valueToMap, ValueRange inputs, Type outputType, Type originalType,
1346
1353
const TypeConverter *converter) {
1347
1354
assert ((!originalType || kind == MaterializationKind::Target) &&
1348
1355
" original type is valid only for target materializations" );
1349
1356
1350
1357
// Avoid materializing an unnecessary cast.
1351
- if (inputs.size () == 1 && inputs.front ().getType () == outputType)
1358
+ if (inputs.size () == 1 && inputs.front ().getType () == outputType) {
1359
+ if (valueToMap)
1360
+ mapping.map (valueToMap, inputs.front ());
1352
1361
return inputs.front ();
1362
+ }
1353
1363
1354
1364
// Create an unresolved materialization. We use a new OpBuilder to avoid
1355
1365
// tracking the materialization like we do for other operations.
1356
1366
OpBuilder builder (outputType.getContext ());
1357
1367
builder.setInsertionPoint (ip.getBlock (), ip.getPoint ());
1358
1368
auto convertOp =
1359
1369
builder.create <UnrealizedConversionCastOp>(loc, outputType, inputs);
1370
+ if (valueToMap)
1371
+ mapping.map (valueToMap, convertOp.getResult (0 ));
1360
1372
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
1361
- originalType);
1373
+ originalType, valueToMap );
1362
1374
return convertOp.getResult (0 );
1363
1375
}
1364
1376
@@ -1367,11 +1379,10 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
1367
1379
Value originalValue, const TypeConverter *converter) {
1368
1380
// Insert argument materialization back to the original type.
1369
1381
Type originalType = originalValue.getType ();
1370
- Value argMat =
1371
- buildUnresolvedMaterialization (MaterializationKind::Argument, ip, loc,
1372
- /* inputs=*/ replacements, originalType,
1373
- /* originalType=*/ Type (), converter);
1374
- mapping.map (originalValue, argMat);
1382
+ Value argMat = buildUnresolvedMaterialization (
1383
+ MaterializationKind::Argument, ip, loc, /* valueToMap=*/ originalValue,
1384
+ /* inputs=*/ replacements, originalType, /* originalType=*/ Type (),
1385
+ converter);
1375
1386
1376
1387
// Insert target materialization to the legalized type.
1377
1388
Type legalOutputType;
@@ -1387,11 +1398,11 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
1387
1398
legalOutputType = replacements[0 ].getType ();
1388
1399
}
1389
1400
if (legalOutputType && legalOutputType != originalType) {
1390
- Value targetMat = buildUnresolvedMaterialization (
1391
- MaterializationKind::Target, computeInsertPoint (argMat), loc,
1392
- /* inputs =*/ argMat, /* outputType =*/ legalOutputType ,
1393
- /* originalType =*/ originalType, converter);
1394
- mapping. map (argMat, targetMat );
1401
+ buildUnresolvedMaterialization (MaterializationKind::Target,
1402
+ computeInsertPoint (argMat), loc,
1403
+ /* valueToMap =*/ argMat, /* inputs =*/ argMat ,
1404
+ /* outputType =*/ legalOutputType,
1405
+ /* originalType= */ originalType, converter );
1395
1406
}
1396
1407
}
1397
1408
@@ -1425,9 +1436,8 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
1425
1436
}
1426
1437
Value castValue = buildUnresolvedMaterialization (
1427
1438
MaterializationKind::Source, computeInsertPoint (repl), value.getLoc (),
1428
- /* inputs=*/ repl, /* outputType=*/ value.getType (),
1439
+ /* valueToMap= */ value, /* inputs=*/ repl, /* outputType=*/ value.getType (),
1429
1440
/* originalType=*/ Type (), converter);
1430
- mapping.map (value, castValue);
1431
1441
return castValue;
1432
1442
}
1433
1443
@@ -1480,7 +1490,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
1480
1490
// Materialize a replacement value "out of thin air".
1481
1491
Value sourceMat = buildUnresolvedMaterialization (
1482
1492
MaterializationKind::Source, computeInsertPoint (result),
1483
- result.getLoc (), /* inputs=*/ ValueRange (),
1493
+ result.getLoc (), /* valueToMap= */ Value (), /* inputs=*/ ValueRange (),
1484
1494
/* outputType=*/ result.getType (), /* originalType=*/ Type (),
1485
1495
currentTypeConverter);
1486
1496
repl.push_back (sourceMat);
0 commit comments