-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][Transforms] Merge 1:1 and 1:N type converters #113032
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Transforms] Merge 1:1 and 1:N type converters #113032
Conversation
@llvm/pr-subscribers-mlir-sme @llvm/pr-subscribers-mlir-sparse Author: Matthias Springer (matthias-springer) ChangesThe 1:N type converter derived from the 1:1 type converter and extends it with 1:N target materializations. This commit merges the two type converters and stores 1:N target materializations in the 1:1 type converter. This is in preparation of merging the 1:1 and 1:N dialect conversion infrastructures. 1:1 target materializations (producing a single The 1:N type converter is removed. Note for LLVM integration: If you are using the Depends on #113031. Full diff: https://github.com/llvm/llvm-project/pull/113032.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 6ccbc40bdd6034..2e9c297f20182a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -150,7 +150,7 @@ std::unique_ptr<Pass> createLowerForeachToSCFPass();
//===----------------------------------------------------------------------===//
/// Type converter for iter_space and iterator.
-struct SparseIterationTypeConverter : public OneToNTypeConverter {
+struct SparseIterationTypeConverter : public TypeConverter {
SparseIterationTypeConverter();
};
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 5ff36160dd6162..eb7da67c1bb995 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -173,7 +173,9 @@ class TypeConverter {
/// conversion has finished.
///
/// Note: Target materializations may optionally accept an additional Type
- /// parameter, which is the original type of the SSA value.
+ /// parameter, which is the original type of the SSA value. Furthermore `T`
+ /// can be a TypeRange; in that case, the function must return a
+ /// SmallVector<Value>.
/// This method registers a materialization that will be called when
/// converting (potentially multiple) block arguments that were the result of
@@ -210,6 +212,9 @@ class TypeConverter {
/// will be invoked with: outputType = "t3", inputs = "v2",
// originalType = "t1". Note that the original type "t1" cannot be recovered
/// from just "t3" and "v2"; that's why the originalType parameter exists.
+ ///
+ /// Note: During a 1:N conversion, the result types can be a TypeRange. In
+ /// that case the materialization produces a SmallVector<Value>.
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<1>>
void addTargetMaterialization(FnT &&callback) {
@@ -316,6 +321,11 @@ class TypeConverter {
Value materializeTargetConversion(OpBuilder &builder, Location loc,
Type resultType, ValueRange inputs,
Type originalType = {}) const;
+ SmallVector<Value> materializeTargetConversion(OpBuilder &builder,
+ Location loc,
+ TypeRange resultType,
+ ValueRange inputs,
+ Type originalType = {}) const;
/// Convert an attribute present `attr` from within the type `type` using
/// the registered conversion functions. If no applicable conversion has been
@@ -341,8 +351,8 @@ class TypeConverter {
/// The signature of the callback used to materialize a target conversion.
///
/// Arguments: builder, result type, inputs, location, original type
- using TargetMaterializationCallbackFn =
- std::function<Value(OpBuilder &, Type, ValueRange, Location, Type)>;
+ using TargetMaterializationCallbackFn = std::function<SmallVector<Value>(
+ OpBuilder &, TypeRange, ValueRange, Location, Type)>;
/// The signature of the callback used to convert a type attribute.
using TypeAttributeConversionCallbackFn =
@@ -409,22 +419,40 @@ class TypeConverter {
/// callback.
///
/// With callback of form:
- /// `Value(OpBuilder &, T, ValueRange, Location, Type)`
+ /// - Value(OpBuilder &, T, ValueRange, Location, Type)
+ /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location, Type)
template <typename T, typename FnT>
std::enable_if_t<
std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
TargetMaterializationCallbackFn>
wrapTargetMaterialization(FnT &&callback) const {
return [callback = std::forward<FnT>(callback)](
- OpBuilder &builder, Type resultType, ValueRange inputs,
- Location loc, Type originalType) -> Value {
- if (T derivedType = dyn_cast<T>(resultType))
- return callback(builder, derivedType, inputs, loc, originalType);
- return Value();
+ OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+ Location loc, Type originalType) -> SmallVector<Value> {
+ SmallVector<Value> result;
+ if constexpr (std::is_same<T, TypeRange>::value) {
+ // This is a 1:N target materialization. Return the produces values
+ // directly.
+ result = callback(builder, resultTypes, inputs, loc, originalType);
+ } else {
+ // This is a 1:1 target materialization. Invoke it only if the result
+ // type class of the callback matches the requested result type.
+ if (T derivedType = dyn_cast<T>(resultTypes.front())) {
+ // 1:1 materializations produce single values, but we store 1:N
+ // target materialization functions in the type converter. Wrap the
+ // result value in a SmallVector<Value>.
+ std::optional<Value> val =
+ callback(builder, derivedType, inputs, loc, originalType);
+ if (val.has_value() && *val)
+ result.push_back(*val);
+ }
+ }
+ return result;
};
}
/// With callback of form:
- /// `Value(OpBuilder &, T, ValueRange, Location)`
+ /// - Value(OpBuilder &, T, ValueRange, Location)
+ /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location)
template <typename T, typename FnT>
std::enable_if_t<
std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
@@ -432,9 +460,9 @@ class TypeConverter {
wrapTargetMaterialization(FnT &&callback) const {
return wrapTargetMaterialization<T>(
[callback = std::forward<FnT>(callback)](
- OpBuilder &builder, T resultType, ValueRange inputs, Location loc,
- Type originalType) -> Value {
- return callback(builder, resultType, inputs, loc);
+ OpBuilder &builder, T resultTypes, ValueRange inputs, Location loc,
+ Type originalType) {
+ return callback(builder, resultTypes, inputs, loc);
});
}
diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
index c59a3a52f028f3..7b4dd65cbff7b2 100644
--- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h
+++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
@@ -33,49 +33,6 @@
namespace mlir {
-/// Extends `TypeConverter` with 1:N target materializations. Such
-/// materializations have to provide the "reverse" of 1:N type conversions,
-/// i.e., they need to materialize N values with target types into one value
-/// with a source type (which isn't possible in the base class currently).
-class OneToNTypeConverter : public TypeConverter {
-public:
- /// Callback that expresses user-provided materialization logic from the given
- /// value to N values of the given types. This is useful for expressing target
- /// materializations for 1:N type conversions, which materialize one value in
- /// a source type as N values in target types.
- using OneToNMaterializationCallbackFn =
- std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange,
- Value, Location)>;
-
- /// Creates the mapping of the given range of original types to target types
- /// of the conversion and stores that mapping in the given (signature)
- /// conversion. This function simply calls
- /// `TypeConverter::convertSignatureArgs` and exists here with a different
- /// name to reflect the broader semantic.
- LogicalResult computeTypeMapping(TypeRange types,
- SignatureConversion &result) const {
- return convertSignatureArgs(types, result);
- }
-
- /// Applies one of the user-provided 1:N target materializations. If several
- /// exists, they are tried out in the reverse order in which they have been
- /// added until the first one succeeds. If none succeeds, the functions
- /// returns `std::nullopt`.
- std::optional<SmallVector<Value>>
- materializeTargetConversion(OpBuilder &builder, Location loc,
- TypeRange resultTypes, Value input) const;
-
- /// Adds a 1:N target materialization to the converter. Such materializations
- /// build IR that converts N values with target types into 1 value of the
- /// source type.
- void addTargetMaterialization(OneToNMaterializationCallbackFn &&callback) {
- oneToNTargetMaterializations.emplace_back(std::move(callback));
- }
-
-private:
- SmallVector<OneToNMaterializationCallbackFn> oneToNTargetMaterializations;
-};
-
/// Stores a 1:N mapping of types and provides several useful accessors. This
/// class extends `SignatureConversion`, which already supports 1:N type
/// mappings but lacks some accessors into the mapping as well as access to the
@@ -295,7 +252,7 @@ class OneToNOpConversionPattern : public OneToNConversionPattern {
/// not fail if some ops or types remain unconverted (i.e., the conversion is
/// only "partial").
LogicalResult
-applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
+applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
const FrozenRewritePatternSet &patterns);
/// Add a pattern to the given pattern list to convert the signature of a
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 4968c4fc463d04..e908a536e6fb27 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -921,7 +921,7 @@ struct VectorLegalizationPass
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
void runOnOperation() override {
auto *context = &getContext();
- OneToNTypeConverter converter;
+ TypeConverter converter;
RewritePatternSet patterns(context);
converter.addConversion([](Type type) { return type; });
converter.addConversion(
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 3cfcaa965f3546..bf969e74e8bfe0 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2831,11 +2831,27 @@ Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
Location loc, Type resultType,
ValueRange inputs,
Type originalType) const {
+ SmallVector<Value> result = materializeTargetConversion(
+ builder, loc, TypeRange(resultType), inputs, originalType);
+ if (result.empty())
+ return nullptr;
+ assert(result.size() == 1 && "requested 1:1 materialization, but callback "
+ "produced 1:N materialization");
+ return result.front();
+}
+
+SmallVector<Value> TypeConverter::materializeTargetConversion(
+ OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
+ Type originalType) const {
for (const TargetMaterializationCallbackFn &fn :
- llvm::reverse(targetMaterializations))
- if (Value result = fn(builder, resultType, inputs, loc, originalType))
- return result;
- return nullptr;
+ llvm::reverse(targetMaterializations)) {
+ SmallVector<Value> result =
+ fn(builder, resultTypes, inputs, loc, originalType);
+ if (result.empty())
+ continue;
+ return result;
+ }
+ return {};
}
std::optional<TypeConverter::SignatureConversion>
diff --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
index 19e29d48623e04..c208716891ef1f 100644
--- a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
+++ b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
@@ -17,20 +17,6 @@
using namespace llvm;
using namespace mlir;
-std::optional<SmallVector<Value>>
-OneToNTypeConverter::materializeTargetConversion(OpBuilder &builder,
- Location loc,
- TypeRange resultTypes,
- Value input) const {
- for (const OneToNMaterializationCallbackFn &fn :
- llvm::reverse(oneToNTargetMaterializations)) {
- if (std::optional<SmallVector<Value>> result =
- fn(builder, resultTypes, input, loc))
- return *result;
- }
- return std::nullopt;
-}
-
TypeRange OneToNTypeMapping::getConvertedTypes(unsigned originalTypeNo) const {
TypeRange convertedTypes = getConvertedTypes();
if (auto mapping = getInputMapping(originalTypeNo))
@@ -268,20 +254,20 @@ Block *OneToNPatternRewriter::applySignatureConversion(
LogicalResult
OneToNConversionPattern::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
- auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
+ auto *typeConverter = getTypeConverter();
// Construct conversion mapping for results.
Operation::result_type_range originalResultTypes = op->getResultTypes();
OneToNTypeMapping resultMapping(originalResultTypes);
- if (failed(typeConverter->computeTypeMapping(originalResultTypes,
- resultMapping)))
+ if (failed(typeConverter->convertSignatureArgs(originalResultTypes,
+ resultMapping)))
return failure();
// Construct conversion mapping for operands.
Operation::operand_type_range originalOperandTypes = op->getOperandTypes();
OneToNTypeMapping operandMapping(originalOperandTypes);
- if (failed(typeConverter->computeTypeMapping(originalOperandTypes,
- operandMapping)))
+ if (failed(typeConverter->convertSignatureArgs(originalOperandTypes,
+ operandMapping)))
return failure();
// Cast operands to target types.
@@ -318,7 +304,7 @@ namespace mlir {
// inserted by this pass are annotated with a string attribute that also
// documents which kind of the cast (source, argument, or target).
LogicalResult
-applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
+applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
const FrozenRewritePatternSet &patterns) {
#ifndef NDEBUG
// Remember existing unrealized casts. This data structure is only used in
@@ -370,15 +356,13 @@ applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
// Target materialization.
assert(!areOperandTypesLegal && areResultsTypesLegal &&
operands.size() == 1 && "found unexpected target cast");
- std::optional<SmallVector<Value>> maybeResults =
- typeConverter.materializeTargetConversion(
- rewriter, castOp->getLoc(), resultTypes, operands.front());
- if (!maybeResults) {
+ materializedResults = typeConverter.materializeTargetConversion(
+ rewriter, castOp->getLoc(), resultTypes, operands.front());
+ if (materializedResults.empty()) {
emitError(castOp->getLoc())
<< "failed to create target materialization";
return failure();
}
- materializedResults = maybeResults.value();
} else {
// Source and argument materializations.
assert(areOperandTypesLegal && !areResultsTypesLegal &&
@@ -427,18 +411,18 @@ class FunctionOpInterfaceSignatureConversion : public OneToNConversionPattern {
const OneToNTypeMapping &resultMapping,
ValueRange convertedOperands) const override {
auto funcOp = cast<FunctionOpInterface>(op);
- auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
+ auto *typeConverter = getTypeConverter();
// Construct mapping for function arguments.
OneToNTypeMapping argumentMapping(funcOp.getArgumentTypes());
- if (failed(typeConverter->computeTypeMapping(funcOp.getArgumentTypes(),
- argumentMapping)))
+ if (failed(typeConverter->convertSignatureArgs(funcOp.getArgumentTypes(),
+ argumentMapping)))
return failure();
// Construct mapping for function results.
OneToNTypeMapping funcResultMapping(funcOp.getResultTypes());
- if (failed(typeConverter->computeTypeMapping(funcOp.getResultTypes(),
- funcResultMapping)))
+ if (failed(typeConverter->convertSignatureArgs(funcOp.getResultTypes(),
+ funcResultMapping)))
return failure();
// Nothing to do if the op doesn't have any non-identity conversions for its
diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
index 5c03ac12d1e58c..b18dfd8bb22cb1 100644
--- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
+++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
@@ -147,9 +147,14 @@ populateDecomposeTuplesTestPatterns(const TypeConverter &typeConverter,
///
/// This function has been copied (with small adaptions) from
/// TestDecomposeCallGraphTypes.cpp.
-static std::optional<SmallVector<Value>>
-buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input,
- Location loc) {
+static SmallVector<Value> buildGetTupleElementOps(OpBuilder &builder,
+ TypeRange resultTypes,
+ ValueRange inputs,
+ Location loc) {
+ if (inputs.size() != 1)
+ return {};
+ Value input = inputs.front();
+
TupleType inputType = dyn_cast<TupleType>(input.getType());
if (!inputType)
return {};
@@ -222,7 +227,7 @@ void TestOneToNTypeConversionPass::runOnOperation() {
auto *context = &getContext();
// Assemble type converter.
- OneToNTypeConverter typeConverter;
+ TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
typeConverter.addConversion(
@@ -234,6 +239,11 @@ void TestOneToNTypeConversionPass::runOnOperation() {
typeConverter.addArgumentMaterialization(buildMakeTupleOp);
typeConverter.addSourceMaterialization(buildMakeTupleOp);
typeConverter.addTargetMaterialization(buildGetTupleElementOps);
+ // Test the other target materialization variant that takes the original type
+ // as additional argument. This materialization function always fails.
+ typeConverter.addTargetMaterialization(
+ [](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+ Location loc, Type originalType) -> SmallVector<Value> { return {}; });
// Assemble patterns.
RewritePatternSet patterns(context);
|
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesThe 1:N type converter derived from the 1:1 type converter and extends it with 1:N target materializations. This commit merges the two type converters and stores 1:N target materializations in the 1:1 type converter. This is in preparation of merging the 1:1 and 1:N dialect conversion infrastructures. 1:1 target materializations (producing a single The 1:N type converter is removed. Note for LLVM integration: If you are using the Depends on #113031. Full diff: https://github.com/llvm/llvm-project/pull/113032.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 6ccbc40bdd6034..2e9c297f20182a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -150,7 +150,7 @@ std::unique_ptr<Pass> createLowerForeachToSCFPass();
//===----------------------------------------------------------------------===//
/// Type converter for iter_space and iterator.
-struct SparseIterationTypeConverter : public OneToNTypeConverter {
+struct SparseIterationTypeConverter : public TypeConverter {
SparseIterationTypeConverter();
};
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 5ff36160dd6162..eb7da67c1bb995 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -173,7 +173,9 @@ class TypeConverter {
/// conversion has finished.
///
/// Note: Target materializations may optionally accept an additional Type
- /// parameter, which is the original type of the SSA value.
+ /// parameter, which is the original type of the SSA value. Furthermore `T`
+ /// can be a TypeRange; in that case, the function must return a
+ /// SmallVector<Value>.
/// This method registers a materialization that will be called when
/// converting (potentially multiple) block arguments that were the result of
@@ -210,6 +212,9 @@ class TypeConverter {
/// will be invoked with: outputType = "t3", inputs = "v2",
// originalType = "t1". Note that the original type "t1" cannot be recovered
/// from just "t3" and "v2"; that's why the originalType parameter exists.
+ ///
+ /// Note: During a 1:N conversion, the result types can be a TypeRange. In
+ /// that case the materialization produces a SmallVector<Value>.
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<1>>
void addTargetMaterialization(FnT &&callback) {
@@ -316,6 +321,11 @@ class TypeConverter {
Value materializeTargetConversion(OpBuilder &builder, Location loc,
Type resultType, ValueRange inputs,
Type originalType = {}) const;
+ SmallVector<Value> materializeTargetConversion(OpBuilder &builder,
+ Location loc,
+ TypeRange resultType,
+ ValueRange inputs,
+ Type originalType = {}) const;
/// Convert an attribute present `attr` from within the type `type` using
/// the registered conversion functions. If no applicable conversion has been
@@ -341,8 +351,8 @@ class TypeConverter {
/// The signature of the callback used to materialize a target conversion.
///
/// Arguments: builder, result type, inputs, location, original type
- using TargetMaterializationCallbackFn =
- std::function<Value(OpBuilder &, Type, ValueRange, Location, Type)>;
+ using TargetMaterializationCallbackFn = std::function<SmallVector<Value>(
+ OpBuilder &, TypeRange, ValueRange, Location, Type)>;
/// The signature of the callback used to convert a type attribute.
using TypeAttributeConversionCallbackFn =
@@ -409,22 +419,40 @@ class TypeConverter {
/// callback.
///
/// With callback of form:
- /// `Value(OpBuilder &, T, ValueRange, Location, Type)`
+ /// - Value(OpBuilder &, T, ValueRange, Location, Type)
+ /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location, Type)
template <typename T, typename FnT>
std::enable_if_t<
std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
TargetMaterializationCallbackFn>
wrapTargetMaterialization(FnT &&callback) const {
return [callback = std::forward<FnT>(callback)](
- OpBuilder &builder, Type resultType, ValueRange inputs,
- Location loc, Type originalType) -> Value {
- if (T derivedType = dyn_cast<T>(resultType))
- return callback(builder, derivedType, inputs, loc, originalType);
- return Value();
+ OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+ Location loc, Type originalType) -> SmallVector<Value> {
+ SmallVector<Value> result;
+ if constexpr (std::is_same<T, TypeRange>::value) {
+ // This is a 1:N target materialization. Return the produces values
+ // directly.
+ result = callback(builder, resultTypes, inputs, loc, originalType);
+ } else {
+ // This is a 1:1 target materialization. Invoke it only if the result
+ // type class of the callback matches the requested result type.
+ if (T derivedType = dyn_cast<T>(resultTypes.front())) {
+ // 1:1 materializations produce single values, but we store 1:N
+ // target materialization functions in the type converter. Wrap the
+ // result value in a SmallVector<Value>.
+ std::optional<Value> val =
+ callback(builder, derivedType, inputs, loc, originalType);
+ if (val.has_value() && *val)
+ result.push_back(*val);
+ }
+ }
+ return result;
};
}
/// With callback of form:
- /// `Value(OpBuilder &, T, ValueRange, Location)`
+ /// - Value(OpBuilder &, T, ValueRange, Location)
+ /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location)
template <typename T, typename FnT>
std::enable_if_t<
std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
@@ -432,9 +460,9 @@ class TypeConverter {
wrapTargetMaterialization(FnT &&callback) const {
return wrapTargetMaterialization<T>(
[callback = std::forward<FnT>(callback)](
- OpBuilder &builder, T resultType, ValueRange inputs, Location loc,
- Type originalType) -> Value {
- return callback(builder, resultType, inputs, loc);
+ OpBuilder &builder, T resultTypes, ValueRange inputs, Location loc,
+ Type originalType) {
+ return callback(builder, resultTypes, inputs, loc);
});
}
diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
index c59a3a52f028f3..7b4dd65cbff7b2 100644
--- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h
+++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
@@ -33,49 +33,6 @@
namespace mlir {
-/// Extends `TypeConverter` with 1:N target materializations. Such
-/// materializations have to provide the "reverse" of 1:N type conversions,
-/// i.e., they need to materialize N values with target types into one value
-/// with a source type (which isn't possible in the base class currently).
-class OneToNTypeConverter : public TypeConverter {
-public:
- /// Callback that expresses user-provided materialization logic from the given
- /// value to N values of the given types. This is useful for expressing target
- /// materializations for 1:N type conversions, which materialize one value in
- /// a source type as N values in target types.
- using OneToNMaterializationCallbackFn =
- std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange,
- Value, Location)>;
-
- /// Creates the mapping of the given range of original types to target types
- /// of the conversion and stores that mapping in the given (signature)
- /// conversion. This function simply calls
- /// `TypeConverter::convertSignatureArgs` and exists here with a different
- /// name to reflect the broader semantic.
- LogicalResult computeTypeMapping(TypeRange types,
- SignatureConversion &result) const {
- return convertSignatureArgs(types, result);
- }
-
- /// Applies one of the user-provided 1:N target materializations. If several
- /// exists, they are tried out in the reverse order in which they have been
- /// added until the first one succeeds. If none succeeds, the functions
- /// returns `std::nullopt`.
- std::optional<SmallVector<Value>>
- materializeTargetConversion(OpBuilder &builder, Location loc,
- TypeRange resultTypes, Value input) const;
-
- /// Adds a 1:N target materialization to the converter. Such materializations
- /// build IR that converts N values with target types into 1 value of the
- /// source type.
- void addTargetMaterialization(OneToNMaterializationCallbackFn &&callback) {
- oneToNTargetMaterializations.emplace_back(std::move(callback));
- }
-
-private:
- SmallVector<OneToNMaterializationCallbackFn> oneToNTargetMaterializations;
-};
-
/// Stores a 1:N mapping of types and provides several useful accessors. This
/// class extends `SignatureConversion`, which already supports 1:N type
/// mappings but lacks some accessors into the mapping as well as access to the
@@ -295,7 +252,7 @@ class OneToNOpConversionPattern : public OneToNConversionPattern {
/// not fail if some ops or types remain unconverted (i.e., the conversion is
/// only "partial").
LogicalResult
-applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
+applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
const FrozenRewritePatternSet &patterns);
/// Add a pattern to the given pattern list to convert the signature of a
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 4968c4fc463d04..e908a536e6fb27 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -921,7 +921,7 @@ struct VectorLegalizationPass
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
void runOnOperation() override {
auto *context = &getContext();
- OneToNTypeConverter converter;
+ TypeConverter converter;
RewritePatternSet patterns(context);
converter.addConversion([](Type type) { return type; });
converter.addConversion(
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 3cfcaa965f3546..bf969e74e8bfe0 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2831,11 +2831,27 @@ Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
Location loc, Type resultType,
ValueRange inputs,
Type originalType) const {
+ SmallVector<Value> result = materializeTargetConversion(
+ builder, loc, TypeRange(resultType), inputs, originalType);
+ if (result.empty())
+ return nullptr;
+ assert(result.size() == 1 && "requested 1:1 materialization, but callback "
+ "produced 1:N materialization");
+ return result.front();
+}
+
+SmallVector<Value> TypeConverter::materializeTargetConversion(
+ OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
+ Type originalType) const {
for (const TargetMaterializationCallbackFn &fn :
- llvm::reverse(targetMaterializations))
- if (Value result = fn(builder, resultType, inputs, loc, originalType))
- return result;
- return nullptr;
+ llvm::reverse(targetMaterializations)) {
+ SmallVector<Value> result =
+ fn(builder, resultTypes, inputs, loc, originalType);
+ if (result.empty())
+ continue;
+ return result;
+ }
+ return {};
}
std::optional<TypeConverter::SignatureConversion>
diff --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
index 19e29d48623e04..c208716891ef1f 100644
--- a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
+++ b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
@@ -17,20 +17,6 @@
using namespace llvm;
using namespace mlir;
-std::optional<SmallVector<Value>>
-OneToNTypeConverter::materializeTargetConversion(OpBuilder &builder,
- Location loc,
- TypeRange resultTypes,
- Value input) const {
- for (const OneToNMaterializationCallbackFn &fn :
- llvm::reverse(oneToNTargetMaterializations)) {
- if (std::optional<SmallVector<Value>> result =
- fn(builder, resultTypes, input, loc))
- return *result;
- }
- return std::nullopt;
-}
-
TypeRange OneToNTypeMapping::getConvertedTypes(unsigned originalTypeNo) const {
TypeRange convertedTypes = getConvertedTypes();
if (auto mapping = getInputMapping(originalTypeNo))
@@ -268,20 +254,20 @@ Block *OneToNPatternRewriter::applySignatureConversion(
LogicalResult
OneToNConversionPattern::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
- auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
+ auto *typeConverter = getTypeConverter();
// Construct conversion mapping for results.
Operation::result_type_range originalResultTypes = op->getResultTypes();
OneToNTypeMapping resultMapping(originalResultTypes);
- if (failed(typeConverter->computeTypeMapping(originalResultTypes,
- resultMapping)))
+ if (failed(typeConverter->convertSignatureArgs(originalResultTypes,
+ resultMapping)))
return failure();
// Construct conversion mapping for operands.
Operation::operand_type_range originalOperandTypes = op->getOperandTypes();
OneToNTypeMapping operandMapping(originalOperandTypes);
- if (failed(typeConverter->computeTypeMapping(originalOperandTypes,
- operandMapping)))
+ if (failed(typeConverter->convertSignatureArgs(originalOperandTypes,
+ operandMapping)))
return failure();
// Cast operands to target types.
@@ -318,7 +304,7 @@ namespace mlir {
// inserted by this pass are annotated with a string attribute that also
// documents which kind of the cast (source, argument, or target).
LogicalResult
-applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
+applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
const FrozenRewritePatternSet &patterns) {
#ifndef NDEBUG
// Remember existing unrealized casts. This data structure is only used in
@@ -370,15 +356,13 @@ applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
// Target materialization.
assert(!areOperandTypesLegal && areResultsTypesLegal &&
operands.size() == 1 && "found unexpected target cast");
- std::optional<SmallVector<Value>> maybeResults =
- typeConverter.materializeTargetConversion(
- rewriter, castOp->getLoc(), resultTypes, operands.front());
- if (!maybeResults) {
+ materializedResults = typeConverter.materializeTargetConversion(
+ rewriter, castOp->getLoc(), resultTypes, operands.front());
+ if (materializedResults.empty()) {
emitError(castOp->getLoc())
<< "failed to create target materialization";
return failure();
}
- materializedResults = maybeResults.value();
} else {
// Source and argument materializations.
assert(areOperandTypesLegal && !areResultsTypesLegal &&
@@ -427,18 +411,18 @@ class FunctionOpInterfaceSignatureConversion : public OneToNConversionPattern {
const OneToNTypeMapping &resultMapping,
ValueRange convertedOperands) const override {
auto funcOp = cast<FunctionOpInterface>(op);
- auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
+ auto *typeConverter = getTypeConverter();
// Construct mapping for function arguments.
OneToNTypeMapping argumentMapping(funcOp.getArgumentTypes());
- if (failed(typeConverter->computeTypeMapping(funcOp.getArgumentTypes(),
- argumentMapping)))
+ if (failed(typeConverter->convertSignatureArgs(funcOp.getArgumentTypes(),
+ argumentMapping)))
return failure();
// Construct mapping for function results.
OneToNTypeMapping funcResultMapping(funcOp.getResultTypes());
- if (failed(typeConverter->computeTypeMapping(funcOp.getResultTypes(),
- funcResultMapping)))
+ if (failed(typeConverter->convertSignatureArgs(funcOp.getResultTypes(),
+ funcResultMapping)))
return failure();
// Nothing to do if the op doesn't have any non-identity conversions for its
diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
index 5c03ac12d1e58c..b18dfd8bb22cb1 100644
--- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
+++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
@@ -147,9 +147,14 @@ populateDecomposeTuplesTestPatterns(const TypeConverter &typeConverter,
///
/// This function has been copied (with small adaptions) from
/// TestDecomposeCallGraphTypes.cpp.
-static std::optional<SmallVector<Value>>
-buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input,
- Location loc) {
+static SmallVector<Value> buildGetTupleElementOps(OpBuilder &builder,
+ TypeRange resultTypes,
+ ValueRange inputs,
+ Location loc) {
+ if (inputs.size() != 1)
+ return {};
+ Value input = inputs.front();
+
TupleType inputType = dyn_cast<TupleType>(input.getType());
if (!inputType)
return {};
@@ -222,7 +227,7 @@ void TestOneToNTypeConversionPass::runOnOperation() {
auto *context = &getContext();
// Assemble type converter.
- OneToNTypeConverter typeConverter;
+ TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
typeConverter.addConversion(
@@ -234,6 +239,11 @@ void TestOneToNTypeConversionPass::runOnOperation() {
typeConverter.addArgumentMaterialization(buildMakeTupleOp);
typeConverter.addSourceMaterialization(buildMakeTupleOp);
typeConverter.addTargetMaterialization(buildGetTupleElementOps);
+ // Test the other target materialization variant that takes the original type
+ // as additional argument. This materialization function always fails.
+ typeConverter.addTargetMaterialization(
+ [](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+ Location loc, Type originalType) -> SmallVector<Value> { return {}; });
// Assemble patterns.
RewritePatternSet patterns(context);
|
@llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesThe 1:N type converter derived from the 1:1 type converter and extends it with 1:N target materializations. This commit merges the two type converters and stores 1:N target materializations in the 1:1 type converter. This is in preparation of merging the 1:1 and 1:N dialect conversion infrastructures. 1:1 target materializations (producing a single The 1:N type converter is removed. Note for LLVM integration: If you are using the Depends on #113031. Full diff: https://github.com/llvm/llvm-project/pull/113032.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 6ccbc40bdd6034..2e9c297f20182a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -150,7 +150,7 @@ std::unique_ptr<Pass> createLowerForeachToSCFPass();
//===----------------------------------------------------------------------===//
/// Type converter for iter_space and iterator.
-struct SparseIterationTypeConverter : public OneToNTypeConverter {
+struct SparseIterationTypeConverter : public TypeConverter {
SparseIterationTypeConverter();
};
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 5ff36160dd6162..eb7da67c1bb995 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -173,7 +173,9 @@ class TypeConverter {
/// conversion has finished.
///
/// Note: Target materializations may optionally accept an additional Type
- /// parameter, which is the original type of the SSA value.
+ /// parameter, which is the original type of the SSA value. Furthermore `T`
+ /// can be a TypeRange; in that case, the function must return a
+ /// SmallVector<Value>.
/// This method registers a materialization that will be called when
/// converting (potentially multiple) block arguments that were the result of
@@ -210,6 +212,9 @@ class TypeConverter {
/// will be invoked with: outputType = "t3", inputs = "v2",
// originalType = "t1". Note that the original type "t1" cannot be recovered
/// from just "t3" and "v2"; that's why the originalType parameter exists.
+ ///
+ /// Note: During a 1:N conversion, the result types can be a TypeRange. In
+ /// that case the materialization produces a SmallVector<Value>.
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<1>>
void addTargetMaterialization(FnT &&callback) {
@@ -316,6 +321,11 @@ class TypeConverter {
Value materializeTargetConversion(OpBuilder &builder, Location loc,
Type resultType, ValueRange inputs,
Type originalType = {}) const;
+ SmallVector<Value> materializeTargetConversion(OpBuilder &builder,
+ Location loc,
+ TypeRange resultType,
+ ValueRange inputs,
+ Type originalType = {}) const;
/// Convert an attribute present `attr` from within the type `type` using
/// the registered conversion functions. If no applicable conversion has been
@@ -341,8 +351,8 @@ class TypeConverter {
/// The signature of the callback used to materialize a target conversion.
///
/// Arguments: builder, result type, inputs, location, original type
- using TargetMaterializationCallbackFn =
- std::function<Value(OpBuilder &, Type, ValueRange, Location, Type)>;
+ using TargetMaterializationCallbackFn = std::function<SmallVector<Value>(
+ OpBuilder &, TypeRange, ValueRange, Location, Type)>;
/// The signature of the callback used to convert a type attribute.
using TypeAttributeConversionCallbackFn =
@@ -409,22 +419,40 @@ class TypeConverter {
/// callback.
///
/// With callback of form:
- /// `Value(OpBuilder &, T, ValueRange, Location, Type)`
+ /// - Value(OpBuilder &, T, ValueRange, Location, Type)
+ /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location, Type)
template <typename T, typename FnT>
std::enable_if_t<
std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
TargetMaterializationCallbackFn>
wrapTargetMaterialization(FnT &&callback) const {
return [callback = std::forward<FnT>(callback)](
- OpBuilder &builder, Type resultType, ValueRange inputs,
- Location loc, Type originalType) -> Value {
- if (T derivedType = dyn_cast<T>(resultType))
- return callback(builder, derivedType, inputs, loc, originalType);
- return Value();
+ OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+ Location loc, Type originalType) -> SmallVector<Value> {
+ SmallVector<Value> result;
+ if constexpr (std::is_same<T, TypeRange>::value) {
+ // This is a 1:N target materialization. Return the produces values
+ // directly.
+ result = callback(builder, resultTypes, inputs, loc, originalType);
+ } else {
+ // This is a 1:1 target materialization. Invoke it only if the result
+ // type class of the callback matches the requested result type.
+ if (T derivedType = dyn_cast<T>(resultTypes.front())) {
+ // 1:1 materializations produce single values, but we store 1:N
+ // target materialization functions in the type converter. Wrap the
+ // result value in a SmallVector<Value>.
+ std::optional<Value> val =
+ callback(builder, derivedType, inputs, loc, originalType);
+ if (val.has_value() && *val)
+ result.push_back(*val);
+ }
+ }
+ return result;
};
}
/// With callback of form:
- /// `Value(OpBuilder &, T, ValueRange, Location)`
+ /// - Value(OpBuilder &, T, ValueRange, Location)
+ /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location)
template <typename T, typename FnT>
std::enable_if_t<
std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
@@ -432,9 +460,9 @@ class TypeConverter {
wrapTargetMaterialization(FnT &&callback) const {
return wrapTargetMaterialization<T>(
[callback = std::forward<FnT>(callback)](
- OpBuilder &builder, T resultType, ValueRange inputs, Location loc,
- Type originalType) -> Value {
- return callback(builder, resultType, inputs, loc);
+ OpBuilder &builder, T resultTypes, ValueRange inputs, Location loc,
+ Type originalType) {
+ return callback(builder, resultTypes, inputs, loc);
});
}
diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
index c59a3a52f028f3..7b4dd65cbff7b2 100644
--- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h
+++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
@@ -33,49 +33,6 @@
namespace mlir {
-/// Extends `TypeConverter` with 1:N target materializations. Such
-/// materializations have to provide the "reverse" of 1:N type conversions,
-/// i.e., they need to materialize N values with target types into one value
-/// with a source type (which isn't possible in the base class currently).
-class OneToNTypeConverter : public TypeConverter {
-public:
- /// Callback that expresses user-provided materialization logic from the given
- /// value to N values of the given types. This is useful for expressing target
- /// materializations for 1:N type conversions, which materialize one value in
- /// a source type as N values in target types.
- using OneToNMaterializationCallbackFn =
- std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange,
- Value, Location)>;
-
- /// Creates the mapping of the given range of original types to target types
- /// of the conversion and stores that mapping in the given (signature)
- /// conversion. This function simply calls
- /// `TypeConverter::convertSignatureArgs` and exists here with a different
- /// name to reflect the broader semantic.
- LogicalResult computeTypeMapping(TypeRange types,
- SignatureConversion &result) const {
- return convertSignatureArgs(types, result);
- }
-
- /// Applies one of the user-provided 1:N target materializations. If several
- /// exists, they are tried out in the reverse order in which they have been
- /// added until the first one succeeds. If none succeeds, the functions
- /// returns `std::nullopt`.
- std::optional<SmallVector<Value>>
- materializeTargetConversion(OpBuilder &builder, Location loc,
- TypeRange resultTypes, Value input) const;
-
- /// Adds a 1:N target materialization to the converter. Such materializations
- /// build IR that converts N values with target types into 1 value of the
- /// source type.
- void addTargetMaterialization(OneToNMaterializationCallbackFn &&callback) {
- oneToNTargetMaterializations.emplace_back(std::move(callback));
- }
-
-private:
- SmallVector<OneToNMaterializationCallbackFn> oneToNTargetMaterializations;
-};
-
/// Stores a 1:N mapping of types and provides several useful accessors. This
/// class extends `SignatureConversion`, which already supports 1:N type
/// mappings but lacks some accessors into the mapping as well as access to the
@@ -295,7 +252,7 @@ class OneToNOpConversionPattern : public OneToNConversionPattern {
/// not fail if some ops or types remain unconverted (i.e., the conversion is
/// only "partial").
LogicalResult
-applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
+applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
const FrozenRewritePatternSet &patterns);
/// Add a pattern to the given pattern list to convert the signature of a
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 4968c4fc463d04..e908a536e6fb27 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -921,7 +921,7 @@ struct VectorLegalizationPass
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
void runOnOperation() override {
auto *context = &getContext();
- OneToNTypeConverter converter;
+ TypeConverter converter;
RewritePatternSet patterns(context);
converter.addConversion([](Type type) { return type; });
converter.addConversion(
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 3cfcaa965f3546..bf969e74e8bfe0 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2831,11 +2831,27 @@ Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
Location loc, Type resultType,
ValueRange inputs,
Type originalType) const {
+ SmallVector<Value> result = materializeTargetConversion(
+ builder, loc, TypeRange(resultType), inputs, originalType);
+ if (result.empty())
+ return nullptr;
+ assert(result.size() == 1 && "requested 1:1 materialization, but callback "
+ "produced 1:N materialization");
+ return result.front();
+}
+
+SmallVector<Value> TypeConverter::materializeTargetConversion(
+ OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
+ Type originalType) const {
for (const TargetMaterializationCallbackFn &fn :
- llvm::reverse(targetMaterializations))
- if (Value result = fn(builder, resultType, inputs, loc, originalType))
- return result;
- return nullptr;
+ llvm::reverse(targetMaterializations)) {
+ SmallVector<Value> result =
+ fn(builder, resultTypes, inputs, loc, originalType);
+ if (result.empty())
+ continue;
+ return result;
+ }
+ return {};
}
std::optional<TypeConverter::SignatureConversion>
diff --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
index 19e29d48623e04..c208716891ef1f 100644
--- a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
+++ b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
@@ -17,20 +17,6 @@
using namespace llvm;
using namespace mlir;
-std::optional<SmallVector<Value>>
-OneToNTypeConverter::materializeTargetConversion(OpBuilder &builder,
- Location loc,
- TypeRange resultTypes,
- Value input) const {
- for (const OneToNMaterializationCallbackFn &fn :
- llvm::reverse(oneToNTargetMaterializations)) {
- if (std::optional<SmallVector<Value>> result =
- fn(builder, resultTypes, input, loc))
- return *result;
- }
- return std::nullopt;
-}
-
TypeRange OneToNTypeMapping::getConvertedTypes(unsigned originalTypeNo) const {
TypeRange convertedTypes = getConvertedTypes();
if (auto mapping = getInputMapping(originalTypeNo))
@@ -268,20 +254,20 @@ Block *OneToNPatternRewriter::applySignatureConversion(
LogicalResult
OneToNConversionPattern::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
- auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
+ auto *typeConverter = getTypeConverter();
// Construct conversion mapping for results.
Operation::result_type_range originalResultTypes = op->getResultTypes();
OneToNTypeMapping resultMapping(originalResultTypes);
- if (failed(typeConverter->computeTypeMapping(originalResultTypes,
- resultMapping)))
+ if (failed(typeConverter->convertSignatureArgs(originalResultTypes,
+ resultMapping)))
return failure();
// Construct conversion mapping for operands.
Operation::operand_type_range originalOperandTypes = op->getOperandTypes();
OneToNTypeMapping operandMapping(originalOperandTypes);
- if (failed(typeConverter->computeTypeMapping(originalOperandTypes,
- operandMapping)))
+ if (failed(typeConverter->convertSignatureArgs(originalOperandTypes,
+ operandMapping)))
return failure();
// Cast operands to target types.
@@ -318,7 +304,7 @@ namespace mlir {
// inserted by this pass are annotated with a string attribute that also
// documents which kind of the cast (source, argument, or target).
LogicalResult
-applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
+applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
const FrozenRewritePatternSet &patterns) {
#ifndef NDEBUG
// Remember existing unrealized casts. This data structure is only used in
@@ -370,15 +356,13 @@ applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
// Target materialization.
assert(!areOperandTypesLegal && areResultsTypesLegal &&
operands.size() == 1 && "found unexpected target cast");
- std::optional<SmallVector<Value>> maybeResults =
- typeConverter.materializeTargetConversion(
- rewriter, castOp->getLoc(), resultTypes, operands.front());
- if (!maybeResults) {
+ materializedResults = typeConverter.materializeTargetConversion(
+ rewriter, castOp->getLoc(), resultTypes, operands.front());
+ if (materializedResults.empty()) {
emitError(castOp->getLoc())
<< "failed to create target materialization";
return failure();
}
- materializedResults = maybeResults.value();
} else {
// Source and argument materializations.
assert(areOperandTypesLegal && !areResultsTypesLegal &&
@@ -427,18 +411,18 @@ class FunctionOpInterfaceSignatureConversion : public OneToNConversionPattern {
const OneToNTypeMapping &resultMapping,
ValueRange convertedOperands) const override {
auto funcOp = cast<FunctionOpInterface>(op);
- auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
+ auto *typeConverter = getTypeConverter();
// Construct mapping for function arguments.
OneToNTypeMapping argumentMapping(funcOp.getArgumentTypes());
- if (failed(typeConverter->computeTypeMapping(funcOp.getArgumentTypes(),
- argumentMapping)))
+ if (failed(typeConverter->convertSignatureArgs(funcOp.getArgumentTypes(),
+ argumentMapping)))
return failure();
// Construct mapping for function results.
OneToNTypeMapping funcResultMapping(funcOp.getResultTypes());
- if (failed(typeConverter->computeTypeMapping(funcOp.getResultTypes(),
- funcResultMapping)))
+ if (failed(typeConverter->convertSignatureArgs(funcOp.getResultTypes(),
+ funcResultMapping)))
return failure();
// Nothing to do if the op doesn't have any non-identity conversions for its
diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
index 5c03ac12d1e58c..b18dfd8bb22cb1 100644
--- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
+++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
@@ -147,9 +147,14 @@ populateDecomposeTuplesTestPatterns(const TypeConverter &typeConverter,
///
/// This function has been copied (with small adaptions) from
/// TestDecomposeCallGraphTypes.cpp.
-static std::optional<SmallVector<Value>>
-buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input,
- Location loc) {
+static SmallVector<Value> buildGetTupleElementOps(OpBuilder &builder,
+ TypeRange resultTypes,
+ ValueRange inputs,
+ Location loc) {
+ if (inputs.size() != 1)
+ return {};
+ Value input = inputs.front();
+
TupleType inputType = dyn_cast<TupleType>(input.getType());
if (!inputType)
return {};
@@ -222,7 +227,7 @@ void TestOneToNTypeConversionPass::runOnOperation() {
auto *context = &getContext();
// Assemble type converter.
- OneToNTypeConverter typeConverter;
+ TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
typeConverter.addConversion(
@@ -234,6 +239,11 @@ void TestOneToNTypeConversionPass::runOnOperation() {
typeConverter.addArgumentMaterialization(buildMakeTupleOp);
typeConverter.addSourceMaterialization(buildMakeTupleOp);
typeConverter.addTargetMaterialization(buildGetTupleElementOps);
+ // Test the other target materialization variant that takes the original type
+ // as additional argument. This materialization function always fails.
+ typeConverter.addTargetMaterialization(
+ [](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+ Location loc, Type originalType) -> SmallVector<Value> { return {}; });
// Assemble patterns.
RewritePatternSet patterns(context);
|
8fc8147
to
7ac8aa2
Compare
2fe3e2d
to
59a9bbb
Compare
216d522
to
44db7a9
Compare
Co-authored-by: Markus Böck <[email protected]>
Co-authored-by: Markus Böck <[email protected]>
eaff670
to
43edbab
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/5467 Here is the relevant piece of the build log for the reference
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Nice progress!
The `ValueDecomposer` in `DecomposeCallGraphTypes` was a workaround around missing 1:N support in the dialect conversion. Since #113032,tThe dialect conversion infrastructure supports 1:N type conversions and 1:N target materializations. The `ValueDecomposer` class is no longer needed. (However, target materializations must still be inserted manually, until we fully merge the 1:1 and 1:N drivers.)
The `ValueDecomposer` in `DecomposeCallGraphTypes` was a workaround around missing 1:N support in the dialect conversion. Since #113032,tThe dialect conversion infrastructure supports 1:N type conversions and 1:N target materializations. The `ValueDecomposer` class is no longer needed. (However, target materializations must still be inserted manually, until we fully merge the 1:1 and 1:N drivers.)
…114192) The `ValueDecomposer` in `DecomposeCallGraphTypes` was a workaround around missing 1:N support in the dialect conversion. Since #113032, the dialect conversion infrastructure supports 1:N type conversions and 1:N target materializations. The `ValueDecomposer` class is no longer needed. (However, target materializations must still be inserted manually, until we fully merge the 1:1 and 1:N drivers.) Note for LLVM integration: Register 1:N target materializations on the type converter instead of "decompose value conversions" on the `ValueDecomposer`.
…lvm#114192) The `ValueDecomposer` in `DecomposeCallGraphTypes` was a workaround around missing 1:N support in the dialect conversion. Since llvm#113032, the dialect conversion infrastructure supports 1:N type conversions and 1:N target materializations. The `ValueDecomposer` class is no longer needed. (However, target materializations must still be inserted manually, until we fully merge the 1:1 and 1:N drivers.) Note for LLVM integration: Register 1:N target materializations on the type converter instead of "decompose value conversions" on the `ValueDecomposer`.
The 1:N type converter derived from the 1:1 type converter and extends it with 1:N target materializations. This commit merges the two type converters and stores 1:N target materializations in the 1:1 type converter. This is in preparation of merging the 1:1 and 1:N dialect conversion infrastructures. 1:1 target materializations (producing a single `Value`) will remain valid. An additional API is added to the type converter to register 1:N target materializations (producing a `SmallVector<Value>`). Internally, all target materializations are stored as 1:N materializations. The 1:N type converter is removed. Note for LLVM integration: If you are using the `OneToNTypeConverter`, simply switch all occurrences to `TypeConverter`. --------- Co-authored-by: Markus Böck <[email protected]>
…lvm#114192) The `ValueDecomposer` in `DecomposeCallGraphTypes` was a workaround around missing 1:N support in the dialect conversion. Since llvm#113032, the dialect conversion infrastructure supports 1:N type conversions and 1:N target materializations. The `ValueDecomposer` class is no longer needed. (However, target materializations must still be inserted manually, until we fully merge the 1:1 and 1:N drivers.) Note for LLVM integration: Register 1:N target materializations on the type converter instead of "decompose value conversions" on the `ValueDecomposer`.
The 1:N type converter derived from the 1:1 type converter and extends it with 1:N target materializations. This commit merges the two type converters and stores 1:N target materializations in the 1:1 type converter. This is in preparation of merging the 1:1 and 1:N dialect conversion infrastructures.
1:1 target materializations (producing a single
Value
) will remain valid. An additional API is added to the type converter to register 1:N target materializations (producing aSmallVector<Value>
). Internally, all target materializations are stored as 1:N materializations.The 1:N type converter is removed.
Note for LLVM integration: If you are using the
OneToNTypeConverter
, simply switch all occurrences toTypeConverter
.