Skip to content

Commit 44db7a9

Browse files
[mlir][Transforms] Merge 1:1 and 1:N type converters
1 parent cb46662 commit 44db7a9

File tree

7 files changed

+93
-98
lines changed

7 files changed

+93
-98
lines changed

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ std::unique_ptr<Pass> createLowerForeachToSCFPass();
150150
//===----------------------------------------------------------------------===//
151151

152152
/// Type converter for iter_space and iterator.
153-
struct SparseIterationTypeConverter : public OneToNTypeConverter {
153+
struct SparseIterationTypeConverter : public TypeConverter {
154154
SparseIterationTypeConverter();
155155
};
156156

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,9 @@ class TypeConverter {
173173
/// conversion has finished.
174174
///
175175
/// Note: Target materializations may optionally accept an additional Type
176-
/// parameter, which is the original type of the SSA value.
176+
/// parameter, which is the original type of the SSA value. Furthermore, `T`
177+
/// can be a TypeRange; in that case, the function must return a
178+
/// SmallVector<Value>.
177179

178180
/// This method registers a materialization that will be called when
179181
/// converting (potentially multiple) block arguments that were the result of
@@ -210,6 +212,9 @@ class TypeConverter {
210212
/// will be invoked with: outputType = "t3", inputs = "v2",
211213
// originalType = "t1". Note that the original type "t1" cannot be recovered
212214
/// from just "t3" and "v2"; that's why the originalType parameter exists.
215+
///
216+
/// Note: During a 1:N conversion, the result types can be a TypeRange. In
217+
/// that case the materialization produces a SmallVector<Value>.
213218
template <typename FnT, typename T = typename llvm::function_traits<
214219
std::decay_t<FnT>>::template arg_t<1>>
215220
void addTargetMaterialization(FnT &&callback) {
@@ -316,6 +321,11 @@ class TypeConverter {
316321
Value materializeTargetConversion(OpBuilder &builder, Location loc,
317322
Type resultType, ValueRange inputs,
318323
Type originalType = {}) const;
324+
SmallVector<Value> materializeTargetConversion(OpBuilder &builder,
325+
Location loc,
326+
TypeRange resultType,
327+
ValueRange inputs,
328+
Type originalType = {}) const;
319329

320330
/// Convert an attribute present `attr` from within the type `type` using
321331
/// the registered conversion functions. If no applicable conversion has been
@@ -340,9 +350,9 @@ class TypeConverter {
340350

341351
/// The signature of the callback used to materialize a target conversion.
342352
///
343-
/// Arguments: builder, result type, inputs, location, original type
344-
using TargetMaterializationCallbackFn =
345-
std::function<Value(OpBuilder &, Type, ValueRange, Location, Type)>;
353+
/// Arguments: builder, result types, inputs, location, original type
354+
using TargetMaterializationCallbackFn = std::function<SmallVector<Value>(
355+
OpBuilder &, TypeRange, ValueRange, Location, Type)>;
346356

347357
/// The signature of the callback used to convert a type attribute.
348358
using TypeAttributeConversionCallbackFn =
@@ -409,32 +419,50 @@ class TypeConverter {
409419
/// callback.
410420
///
411421
/// With callback of form:
412-
/// `Value(OpBuilder &, T, ValueRange, Location, Type)`
422+
/// - Value(OpBuilder &, T, ValueRange, Location, Type)
423+
/// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location, Type)
413424
template <typename T, typename FnT>
414425
std::enable_if_t<
415426
std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
416427
TargetMaterializationCallbackFn>
417428
wrapTargetMaterialization(FnT &&callback) const {
418429
return [callback = std::forward<FnT>(callback)](
419-
OpBuilder &builder, Type resultType, ValueRange inputs,
420-
Location loc, Type originalType) -> Value {
421-
if (T derivedType = dyn_cast<T>(resultType))
422-
return callback(builder, derivedType, inputs, loc, originalType);
423-
return Value();
430+
OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
431+
Location loc, Type originalType) -> SmallVector<Value> {
432+
SmallVector<Value> result;
433+
if constexpr (std::is_same<T, TypeRange>::value) {
434+
// This is a 1:N target materialization. Return the produces values
435+
// directly.
436+
result = callback(builder, resultTypes, inputs, loc, originalType);
437+
} else {
438+
// This is a 1:1 target materialization. Invoke it only if the result
439+
// type class of the callback matches the requested result type.
440+
if (T derivedType = dyn_cast<T>(resultTypes.front())) {
441+
// 1:1 materializations produce single values, but we store 1:N
442+
// target materialization functions in the type converter. Wrap the
443+
// result value in a SmallVector<Value>.
444+
std::optional<Value> val =
445+
callback(builder, derivedType, inputs, loc, originalType);
446+
if (val.has_value() && *val)
447+
result.push_back(*val);
448+
}
449+
}
450+
return result;
424451
};
425452
}
426453
/// With callback of form:
427-
/// `Value(OpBuilder &, T, ValueRange, Location)`
454+
/// - Value(OpBuilder &, T, ValueRange, Location)
455+
/// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location)
428456
template <typename T, typename FnT>
429457
std::enable_if_t<
430458
std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
431459
TargetMaterializationCallbackFn>
432460
wrapTargetMaterialization(FnT &&callback) const {
433461
return wrapTargetMaterialization<T>(
434462
[callback = std::forward<FnT>(callback)](
435-
OpBuilder &builder, T resultType, ValueRange inputs, Location loc,
436-
Type originalType) -> Value {
437-
return callback(builder, resultType, inputs, loc);
463+
OpBuilder &builder, T resultTypes, ValueRange inputs, Location loc,
464+
Type originalType) {
465+
return callback(builder, resultTypes, inputs, loc);
438466
});
439467
}
440468

mlir/include/mlir/Transforms/OneToNTypeConversion.h

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -33,49 +33,6 @@
3333

3434
namespace mlir {
3535

36-
/// Extends `TypeConverter` with 1:N target materializations. Such
37-
/// materializations have to provide the "reverse" of 1:N type conversions,
38-
/// i.e., they need to materialize N values with target types into one value
39-
/// with a source type (which isn't possible in the base class currently).
40-
class OneToNTypeConverter : public TypeConverter {
41-
public:
42-
/// Callback that expresses user-provided materialization logic from the given
43-
/// value to N values of the given types. This is useful for expressing target
44-
/// materializations for 1:N type conversions, which materialize one value in
45-
/// a source type as N values in target types.
46-
using OneToNMaterializationCallbackFn =
47-
std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange,
48-
Value, Location)>;
49-
50-
/// Creates the mapping of the given range of original types to target types
51-
/// of the conversion and stores that mapping in the given (signature)
52-
/// conversion. This function simply calls
53-
/// `TypeConverter::convertSignatureArgs` and exists here with a different
54-
/// name to reflect the broader semantic.
55-
LogicalResult computeTypeMapping(TypeRange types,
56-
SignatureConversion &result) const {
57-
return convertSignatureArgs(types, result);
58-
}
59-
60-
/// Applies one of the user-provided 1:N target materializations. If several
61-
/// exists, they are tried out in the reverse order in which they have been
62-
/// added until the first one succeeds. If none succeeds, the functions
63-
/// returns `std::nullopt`.
64-
std::optional<SmallVector<Value>>
65-
materializeTargetConversion(OpBuilder &builder, Location loc,
66-
TypeRange resultTypes, Value input) const;
67-
68-
/// Adds a 1:N target materialization to the converter. Such materializations
69-
/// build IR that converts N values with target types into 1 value of the
70-
/// source type.
71-
void addTargetMaterialization(OneToNMaterializationCallbackFn &&callback) {
72-
oneToNTargetMaterializations.emplace_back(std::move(callback));
73-
}
74-
75-
private:
76-
SmallVector<OneToNMaterializationCallbackFn> oneToNTargetMaterializations;
77-
};
78-
7936
/// Stores a 1:N mapping of types and provides several useful accessors. This
8037
/// class extends `SignatureConversion`, which already supports 1:N type
8138
/// mappings but lacks some accessors into the mapping as well as access to the
@@ -295,7 +252,7 @@ class OneToNOpConversionPattern : public OneToNConversionPattern {
295252
/// not fail if some ops or types remain unconverted (i.e., the conversion is
296253
/// only "partial").
297254
LogicalResult
298-
applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
255+
applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
299256
const FrozenRewritePatternSet &patterns);
300257

301258
/// Add a pattern to the given pattern list to convert the signature of a

mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -921,7 +921,7 @@ struct VectorLegalizationPass
921921
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
922922
void runOnOperation() override {
923923
auto *context = &getContext();
924-
OneToNTypeConverter converter;
924+
TypeConverter converter;
925925
RewritePatternSet patterns(context);
926926
converter.addConversion([](Type type) { return type; });
927927
converter.addConversion(

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2831,11 +2831,27 @@ Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
28312831
Location loc, Type resultType,
28322832
ValueRange inputs,
28332833
Type originalType) const {
2834+
SmallVector<Value> result = materializeTargetConversion(
2835+
builder, loc, TypeRange(resultType), inputs, originalType);
2836+
if (result.empty())
2837+
return nullptr;
2838+
assert(result.size() == 1 && "requested 1:1 materialization, but callback "
2839+
"produced 1:N materialization");
2840+
return result.front();
2841+
}
2842+
2843+
SmallVector<Value> TypeConverter::materializeTargetConversion(
2844+
OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
2845+
Type originalType) const {
28342846
for (const TargetMaterializationCallbackFn &fn :
2835-
llvm::reverse(targetMaterializations))
2836-
if (Value result = fn(builder, resultType, inputs, loc, originalType))
2837-
return result;
2838-
return nullptr;
2847+
llvm::reverse(targetMaterializations)) {
2848+
SmallVector<Value> result =
2849+
fn(builder, resultTypes, inputs, loc, originalType);
2850+
if (result.empty())
2851+
continue;
2852+
return result;
2853+
}
2854+
return {};
28392855
}
28402856

28412857
std::optional<TypeConverter::SignatureConversion>

mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp

Lines changed: 14 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,6 @@
1717
using namespace llvm;
1818
using namespace mlir;
1919

20-
std::optional<SmallVector<Value>>
21-
OneToNTypeConverter::materializeTargetConversion(OpBuilder &builder,
22-
Location loc,
23-
TypeRange resultTypes,
24-
Value input) const {
25-
for (const OneToNMaterializationCallbackFn &fn :
26-
llvm::reverse(oneToNTargetMaterializations)) {
27-
if (std::optional<SmallVector<Value>> result =
28-
fn(builder, resultTypes, input, loc))
29-
return *result;
30-
}
31-
return std::nullopt;
32-
}
33-
3420
TypeRange OneToNTypeMapping::getConvertedTypes(unsigned originalTypeNo) const {
3521
TypeRange convertedTypes = getConvertedTypes();
3622
if (auto mapping = getInputMapping(originalTypeNo))
@@ -268,20 +254,20 @@ Block *OneToNPatternRewriter::applySignatureConversion(
268254
LogicalResult
269255
OneToNConversionPattern::matchAndRewrite(Operation *op,
270256
PatternRewriter &rewriter) const {
271-
auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
257+
auto *typeConverter = getTypeConverter();
272258

273259
// Construct conversion mapping for results.
274260
Operation::result_type_range originalResultTypes = op->getResultTypes();
275261
OneToNTypeMapping resultMapping(originalResultTypes);
276-
if (failed(typeConverter->computeTypeMapping(originalResultTypes,
277-
resultMapping)))
262+
if (failed(typeConverter->convertSignatureArgs(originalResultTypes,
263+
resultMapping)))
278264
return failure();
279265

280266
// Construct conversion mapping for operands.
281267
Operation::operand_type_range originalOperandTypes = op->getOperandTypes();
282268
OneToNTypeMapping operandMapping(originalOperandTypes);
283-
if (failed(typeConverter->computeTypeMapping(originalOperandTypes,
284-
operandMapping)))
269+
if (failed(typeConverter->convertSignatureArgs(originalOperandTypes,
270+
operandMapping)))
285271
return failure();
286272

287273
// Cast operands to target types.
@@ -318,7 +304,7 @@ namespace mlir {
318304
// inserted by this pass are annotated with a string attribute that also
319305
// documents which kind of the cast (source, argument, or target).
320306
LogicalResult
321-
applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
307+
applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
322308
const FrozenRewritePatternSet &patterns) {
323309
#ifndef NDEBUG
324310
// Remember existing unrealized casts. This data structure is only used in
@@ -370,15 +356,13 @@ applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
370356
// Target materialization.
371357
assert(!areOperandTypesLegal && areResultsTypesLegal &&
372358
operands.size() == 1 && "found unexpected target cast");
373-
std::optional<SmallVector<Value>> maybeResults =
374-
typeConverter.materializeTargetConversion(
375-
rewriter, castOp->getLoc(), resultTypes, operands.front());
376-
if (!maybeResults) {
359+
materializedResults = typeConverter.materializeTargetConversion(
360+
rewriter, castOp->getLoc(), resultTypes, operands.front());
361+
if (materializedResults.empty()) {
377362
emitError(castOp->getLoc())
378363
<< "failed to create target materialization";
379364
return failure();
380365
}
381-
materializedResults = maybeResults.value();
382366
} else {
383367
// Source and argument materializations.
384368
assert(areOperandTypesLegal && !areResultsTypesLegal &&
@@ -427,18 +411,18 @@ class FunctionOpInterfaceSignatureConversion : public OneToNConversionPattern {
427411
const OneToNTypeMapping &resultMapping,
428412
ValueRange convertedOperands) const override {
429413
auto funcOp = cast<FunctionOpInterface>(op);
430-
auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
414+
auto *typeConverter = getTypeConverter();
431415

432416
// Construct mapping for function arguments.
433417
OneToNTypeMapping argumentMapping(funcOp.getArgumentTypes());
434-
if (failed(typeConverter->computeTypeMapping(funcOp.getArgumentTypes(),
435-
argumentMapping)))
418+
if (failed(typeConverter->convertSignatureArgs(funcOp.getArgumentTypes(),
419+
argumentMapping)))
436420
return failure();
437421

438422
// Construct mapping for function results.
439423
OneToNTypeMapping funcResultMapping(funcOp.getResultTypes());
440-
if (failed(typeConverter->computeTypeMapping(funcOp.getResultTypes(),
441-
funcResultMapping)))
424+
if (failed(typeConverter->convertSignatureArgs(funcOp.getResultTypes(),
425+
funcResultMapping)))
442426
return failure();
443427

444428
// Nothing to do if the op doesn't have any non-identity conversions for its

mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,14 @@ populateDecomposeTuplesTestPatterns(const TypeConverter &typeConverter,
147147
///
148148
/// This function has been copied (with small adaptions) from
149149
/// TestDecomposeCallGraphTypes.cpp.
150-
static std::optional<SmallVector<Value>>
151-
buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input,
152-
Location loc) {
150+
static SmallVector<Value> buildGetTupleElementOps(OpBuilder &builder,
151+
TypeRange resultTypes,
152+
ValueRange inputs,
153+
Location loc) {
154+
if (inputs.size() != 1)
155+
return {};
156+
Value input = inputs.front();
157+
153158
TupleType inputType = dyn_cast<TupleType>(input.getType());
154159
if (!inputType)
155160
return {};
@@ -222,7 +227,7 @@ void TestOneToNTypeConversionPass::runOnOperation() {
222227
auto *context = &getContext();
223228

224229
// Assemble type converter.
225-
OneToNTypeConverter typeConverter;
230+
TypeConverter typeConverter;
226231

227232
typeConverter.addConversion([](Type type) { return type; });
228233
typeConverter.addConversion(
@@ -234,6 +239,11 @@ void TestOneToNTypeConversionPass::runOnOperation() {
234239
typeConverter.addArgumentMaterialization(buildMakeTupleOp);
235240
typeConverter.addSourceMaterialization(buildMakeTupleOp);
236241
typeConverter.addTargetMaterialization(buildGetTupleElementOps);
242+
// Test the other target materialization variant that takes the original type
243+
// as additional argument. This materialization function always fails.
244+
typeConverter.addTargetMaterialization(
245+
[](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
246+
Location loc, Type originalType) -> SmallVector<Value> { return {}; });
237247

238248
// Assemble patterns.
239249
RewritePatternSet patterns(context);

0 commit comments

Comments
 (0)