Skip to content

[mlir][Transforms] Support 1:N mappings in ConversionValueMapping #116524

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ class BoxprocTypeRewriter : public mlir::TypeConverter {
addConversion([&](TypeDescType ty) {
return TypeDescType::get(convertType(ty.getOfTy()));
});
addArgumentMaterialization(materializeProcedure);
addSourceMaterialization(materializeProcedure);
addTargetMaterialization(materializeProcedure);
}
Expand Down
35 changes: 5 additions & 30 deletions mlir/docs/DialectConversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -242,19 +242,6 @@ cannot. These materializations are used by the conversion framework to ensure
type safety during the conversion process. There are several types of
materializations depending on the situation.

* Argument Materialization

- An argument materialization is used when converting the type of a block
argument during a [signature conversion](#region-signature-conversion).
The new block argument types are specified in a `SignatureConversion`
object. An original block argument can be converted into multiple
block arguments, which is not supported everywhere in the dialect
conversion. (E.g., adaptors support only a single replacement value for
each original value.) Therefore, an argument materialization is used to
convert potentially multiple new block arguments back into a single SSA
value. An argument materialization is also used when replacing an op
result with multiple values.

* Source Materialization

- A source materialization is used when a value was replaced with a value
Expand Down Expand Up @@ -343,17 +330,6 @@ class TypeConverter {
/// Materialization functions must be provided when a type conversion may
/// persist after the conversion has finished.

/// This method registers a materialization that will be called when
/// converting (potentially multiple) block arguments that were the result of
/// a signature conversion of a single block argument, to a single SSA value
/// with the old argument type.
template <typename FnT,
typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
void addArgumentMaterialization(FnT &&callback) {
argumentMaterializations.emplace_back(
wrapMaterialization<T>(std::forward<FnT>(callback)));
}

/// This method registers a materialization that will be called when
/// converting a replacement value back to its original source type.
/// This is used when some uses of the original value persist beyond the main
Expand Down Expand Up @@ -406,12 +382,11 @@ done explicitly via a conversion pattern.
To convert the types of block arguments within a Region, a custom hook on the
`ConversionPatternRewriter` must be invoked; `convertRegionTypes`. This hook
uses a provided type converter to apply type conversions to all blocks of a
given region. As noted above, the conversions performed by this method use the
argument materialization hook on the `TypeConverter`. This hook also takes an
optional `TypeConverter::SignatureConversion` parameter that applies a custom
conversion to the entry block of the region. The types of the entry block
arguments are often tied semantically to the operation, e.g.,
`func::FuncOp`, `AffineForOp`, etc.
given region. This hook also takes an optional
`TypeConverter::SignatureConversion` parameter that applies a custom conversion
to the entry block of the region. The types of the entry block arguments are
often tied semantically to the operation, e.g., `func::FuncOp`, `AffineForOp`,
etc.

To convert the signature of just one given block, the
`applySignatureConversion` hook can be used.
Expand Down
18 changes: 7 additions & 11 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ class TypeConverter {
/// converting (potentially multiple) block arguments that were the result of
/// a signature conversion of a single block argument, to a single SSA value
/// with the old block argument type.
///
/// Note: Argument materializations are used only with the 1:N dialect
/// conversion driver. The 1:N dialect conversion driver will be removed soon
/// and so will be argument materializations.
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<1>>
void addArgumentMaterialization(FnT &&callback) {
Expand Down Expand Up @@ -880,15 +884,7 @@ class ConversionPatternRewriter final : public PatternRewriter {
void replaceOp(Operation *op, Operation *newOp) override;

/// Replace the given operation with the new value ranges. The number of op
/// results and value ranges must match. If an original SSA value is replaced
/// by multiple SSA values (i.e., a value range has more than 1 element), the
/// conversion driver will insert an argument materialization to convert the
/// N SSA values back into 1 SSA value of the original type. The given
/// operation is erased.
///
/// Note: The argument materialization is a workaround until we have full 1:N
/// support in the dialect conversion. (It is going to disappear from both
/// `replaceOpWithMultiple` and `applySignatureConversion`.)
/// results and value ranges must match. The given operation is erased.
void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);

/// PatternRewriter hook for erasing a dead operation. The uses of this
Expand Down Expand Up @@ -1285,8 +1281,8 @@ struct ConversionConfig {
// represented at the moment.
RewriterBase::Listener *listener = nullptr;

/// If set to "true", the dialect conversion attempts to build source/target/
/// argument materializations through the type converter API in lieu of
/// If set to "true", the dialect conversion attempts to build source/target
/// materializations through the type converter API in lieu of
/// "builtin.unrealized_conversion_cast ops". The conversion process fails if
/// at least one materialization could not be built.
///
Expand Down
16 changes: 3 additions & 13 deletions mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ static Value unrankedMemRefMaterialization(OpBuilder &builder,
UnrankedMemRefType resultType,
ValueRange inputs, Location loc,
const LLVMTypeConverter &converter) {
// An argument materialization must return a value of type
// A source materialization must return a value of type
// `resultType`, so insert a cast from the memref descriptor type
// (!llvm.struct) to the original memref type.
Value packed =
Expand All @@ -101,7 +101,7 @@ static Value rankedMemRefMaterialization(OpBuilder &builder,
MemRefType resultType,
ValueRange inputs, Location loc,
const LLVMTypeConverter &converter) {
// An argument materialization must return a value of type `resultType`,
// A source materialization must return a value of type `resultType`,
// so insert a cast from the memref descriptor type (!llvm.struct) to the
// original memref type.
Value packed =
Expand Down Expand Up @@ -234,19 +234,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
.getResult(0);
});

// Argument materializations convert from the new block argument types
// Source materializations convert from the new block argument types
// (multiple SSA values that make up a memref descriptor) back to the
// original block argument type.
addArgumentMaterialization([&](OpBuilder &builder,
UnrankedMemRefType resultType,
ValueRange inputs, Location loc) {
return unrankedMemRefMaterialization(builder, resultType, inputs, loc,
*this);
});
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
ValueRange inputs, Location loc) {
return rankedMemRefMaterialization(builder, resultType, inputs, loc, *this);
});
addSourceMaterialization([&](OpBuilder &builder,
UnrankedMemRefType resultType, ValueRange inputs,
Location loc) {
Expand Down
1 change: 0 additions & 1 deletion mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ void mlir::populateEmitCSizeTTypeConversions(TypeConverter &converter) {

converter.addSourceMaterialization(materializeAsUnrealizedCast);
converter.addTargetMaterialization(materializeAsUnrealizedCast);
converter.addArgumentMaterialization(materializeAsUnrealizedCast);
}

/// Get an unsigned integer or size data type corresponding to \p ty.
Expand Down
1 change: 0 additions & 1 deletion mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ class DetensorizeTypeConverter : public TypeConverter {
});

addSourceMaterialization(sourceMaterializationCallback);
addArgumentMaterialization(sourceMaterializationCallback);
}
};

Expand Down
1 change: 0 additions & 1 deletion mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ class QuantizedTypeConverter : public TypeConverter {
addConversion(convertQuantizedType);
addConversion(convertTensorType);

addArgumentMaterialization(materializeConversion);
addSourceMaterialization(materializeConversion);
addTargetMaterialization(materializeConversion);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,6 @@ SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {

// Required by scf.for 1:N type conversion.
addSourceMaterialization(materializeTuple);

// Required as a workaround until we have full 1:N support.
addArgumentMaterialization(materializeTuple);
}

//===----------------------------------------------------------------------===//
Expand Down
1 change: 0 additions & 1 deletion mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,6 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(

return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
};
typeConverter.addArgumentMaterialization(materializeCast);
typeConverter.addSourceMaterialization(materializeCast);
typeConverter.addTargetMaterialization(materializeCast);
target.markUnknownOpDynamicallyLegal(
Expand Down
Loading
Loading