Skip to content

[mlir][func] Replace ValueDecomposer with target materialization #114192

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,70 +23,10 @@

namespace mlir {

/// This class provides a hook that expands one Value into multiple Value's,
/// with a TypeConverter-inspired callback registration mechanism.
///
/// For folks that are familiar with the dialect conversion framework /
/// TypeConverter, this is effectively the inverse of a source/argument
/// materialization. A target materialization is not what we want here because
/// it always produces a single Value, but in this case the whole point is to
/// decompose a Value into multiple Value's.
///
/// The reason we need this inverse is easily understood by looking at what we
/// need to do for decomposing types for a return op. When converting a return
/// op, the dialect conversion framework will give the list of converted
/// operands, and will ensure that each converted operand, even if it expanded
/// into multiple types, is materialized as a single result. We then need to
/// undo that materialization to a single result, which we do with the
/// decomposeValue hooks registered on this object.
///
/// TODO: Eventually, the type conversion infra should have this hook built-in.
/// See
/// https://llvm.discourse.group/t/extending-type-conversion-infrastructure/779/2
class ValueDecomposer {
public:
/// This method tries to decompose a value of a certain type using provided
/// decompose callback functions. If it is unable to do so, the original value
/// is returned.
void decomposeValue(OpBuilder &, Location, Type, Value,
SmallVectorImpl<Value> &);

/// This method registers a callback function that will be called to decompose
/// a value of a certain type into 0, 1, or multiple values.
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<2>>
void addDecomposeValueConversion(FnT &&callback) {
decomposeValueConversions.emplace_back(
wrapDecomposeValueConversionCallback<T>(std::forward<FnT>(callback)));
}

private:
using DecomposeValueConversionCallFn =
std::function<std::optional<LogicalResult>(
OpBuilder &, Location, Type, Value, SmallVectorImpl<Value> &)>;

/// Generate a wrapper for the given decompose value conversion callback.
template <typename T, typename FnT>
DecomposeValueConversionCallFn
wrapDecomposeValueConversionCallback(FnT &&callback) {
return
[callback = std::forward<FnT>(callback)](
OpBuilder &builder, Location loc, Type type, Value value,
SmallVectorImpl<Value> &newValues) -> std::optional<LogicalResult> {
if (T derivedType = dyn_cast<T>(type))
return callback(builder, loc, derivedType, value, newValues);
return std::nullopt;
};
}

SmallVector<DecomposeValueConversionCallFn, 2> decomposeValueConversions;
};

/// Populates the patterns needed to drive the conversion process for
/// decomposing call graph types with the given `ValueDecomposer`.
/// decomposing call graph types with the given `TypeConverter`.
void populateDecomposeCallGraphTypesPatterns(MLIRContext *context,
const TypeConverter &typeConverter,
ValueDecomposer &decomposer,
RewritePatternSet &patterns);

} // namespace mlir
Expand Down
111 changes: 55 additions & 56 deletions mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,52 +14,48 @@ using namespace mlir;
using namespace mlir::func;

//===----------------------------------------------------------------------===//
// ValueDecomposer
// Helper functions
//===----------------------------------------------------------------------===//

void ValueDecomposer::decomposeValue(OpBuilder &builder, Location loc,
Type type, Value value,
SmallVectorImpl<Value> &results) {
for (auto &conversion : decomposeValueConversions)
if (conversion(builder, loc, type, value, results))
return;
results.push_back(value);
/// If the given value can be decomposed with the type converter, decompose it.
/// Otherwise, return the given value.
// TODO: Value decomposition should happen automatically through a 1:N adaptor.
// This function will disappear when the 1:1 and 1:N drivers are merged.
static SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc,
Value value,
const TypeConverter *converter) {
// Try to convert the given value's type. If that fails, just return the
// given value.
SmallVector<Type> convertedTypes;
if (failed(converter->convertType(value.getType(), convertedTypes)))
return {value};
if (convertedTypes.empty())
return {};

// If the given value's type is already legal, just return the given value.
TypeRange convertedTypeRange(convertedTypes);
if (convertedTypeRange == TypeRange(value.getType()))
return {value};

// Try to materialize a target conversion. If the materialization did not
// produce values of the requested type, the materialization failed. Just
// return the given value in that case.
SmallVector<Value> result = converter->materializeTargetConversion(
builder, loc, convertedTypeRange, value);
if (result.empty())
return {value};
return result;
}

//===----------------------------------------------------------------------===//
// DecomposeCallGraphTypesOpConversionPattern
//===----------------------------------------------------------------------===//

namespace {
/// Base OpConversionPattern class to make a ValueDecomposer available to
/// inherited patterns.
template <typename SourceOp>
class DecomposeCallGraphTypesOpConversionPattern
: public OpConversionPattern<SourceOp> {
public:
DecomposeCallGraphTypesOpConversionPattern(const TypeConverter &typeConverter,
MLIRContext *context,
ValueDecomposer &decomposer,
PatternBenefit benefit = 1)
: OpConversionPattern<SourceOp>(typeConverter, context, benefit),
decomposer(decomposer) {}

protected:
ValueDecomposer &decomposer;
};
} // namespace

//===----------------------------------------------------------------------===//
// DecomposeCallGraphTypesForFuncArgs
//===----------------------------------------------------------------------===//

namespace {
/// Expand function arguments according to the provided TypeConverter and
/// ValueDecomposer.
/// Expand function arguments according to the provided TypeConverter.
struct DecomposeCallGraphTypesForFuncArgs
: public DecomposeCallGraphTypesOpConversionPattern<func::FuncOp> {
using DecomposeCallGraphTypesOpConversionPattern::
DecomposeCallGraphTypesOpConversionPattern;
: public OpConversionPattern<func::FuncOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(func::FuncOp op, OpAdaptor adaptor,
Expand Down Expand Up @@ -100,19 +96,22 @@ struct DecomposeCallGraphTypesForFuncArgs
//===----------------------------------------------------------------------===//

namespace {
/// Expand return operands according to the provided TypeConverter and
/// ValueDecomposer.
/// Expand return operands according to the provided TypeConverter.
struct DecomposeCallGraphTypesForReturnOp
: public DecomposeCallGraphTypesOpConversionPattern<ReturnOp> {
using DecomposeCallGraphTypesOpConversionPattern::
DecomposeCallGraphTypesOpConversionPattern;
: public OpConversionPattern<ReturnOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
SmallVector<Value, 2> newOperands;
for (Value operand : adaptor.getOperands())
decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(),
operand, newOperands);
for (Value operand : adaptor.getOperands()) {
// TODO: We can directly take the values from the adaptor once this is a
// 1:N conversion pattern.
llvm::append_range(newOperands,
decomposeValue(rewriter, operand.getLoc(), operand,
getTypeConverter()));
}
rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
return success();
}
Expand All @@ -124,22 +123,23 @@ struct DecomposeCallGraphTypesForReturnOp
//===----------------------------------------------------------------------===//

namespace {
/// Expand call op operands and results according to the provided TypeConverter
/// and ValueDecomposer.
struct DecomposeCallGraphTypesForCallOp
: public DecomposeCallGraphTypesOpConversionPattern<CallOp> {
using DecomposeCallGraphTypesOpConversionPattern::
DecomposeCallGraphTypesOpConversionPattern;
/// Expand call op operands and results according to the provided TypeConverter.
struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(CallOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {

// Create the operands list of the new `CallOp`.
SmallVector<Value, 2> newOperands;
for (Value operand : adaptor.getOperands())
decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(),
operand, newOperands);
for (Value operand : adaptor.getOperands()) {
// TODO: We can directly take the values from the adaptor once this is a
// 1:N conversion pattern.
llvm::append_range(newOperands,
decomposeValue(rewriter, operand.getLoc(), operand,
getTypeConverter()));
}

// Create the new result types for the new `CallOp` and track the indices in
// the new call op's results that correspond to the old call op's results.
Expand Down Expand Up @@ -189,9 +189,8 @@ struct DecomposeCallGraphTypesForCallOp

void mlir::populateDecomposeCallGraphTypesPatterns(
MLIRContext *context, const TypeConverter &typeConverter,
ValueDecomposer &decomposer, RewritePatternSet &patterns) {
RewritePatternSet &patterns) {
patterns
.add<DecomposeCallGraphTypesForCallOp, DecomposeCallGraphTypesForFuncArgs,
DecomposeCallGraphTypesForReturnOp>(typeConverter, context,
decomposer);
DecomposeCallGraphTypesForReturnOp>(typeConverter, context);
}
60 changes: 37 additions & 23 deletions mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,40 @@ namespace {
/// given tuple value. If some tuple elements are, in turn, tuples, the elements
/// of those are extracted recursively such that the returned values have the
/// same types as `resultTypes.getFlattenedTypes()`.
static LogicalResult buildDecomposeTuple(OpBuilder &builder, Location loc,
TupleType resultType, Value value,
SmallVectorImpl<Value> &values) {
for (unsigned i = 0, e = resultType.size(); i < e; ++i) {
Type elementType = resultType.getType(i);
Value element = builder.create<test::GetTupleElementOp>(
loc, elementType, value, builder.getI32IntegerAttr(i));
if (auto nestedTupleType = dyn_cast<TupleType>(elementType)) {
// Recurse if the current element is also a tuple.
if (failed(buildDecomposeTuple(builder, loc, nestedTupleType, element,
values)))
return failure();
} else {
values.push_back(element);
static SmallVector<Value> buildDecomposeTuple(OpBuilder &builder,
TypeRange resultTypes,
ValueRange inputs, Location loc) {
// Skip materialization if the single input value is not a tuple.
if (inputs.size() != 1)
return {};
Value tuple = inputs.front();
auto tupleType = dyn_cast<TupleType>(tuple.getType());
if (!tupleType)
return {};
// Skip materialization if the flattened types do not match the requested
// result types.
SmallVector<Type> flattenedTypes;
tupleType.getFlattenedTypes(flattenedTypes);
if (TypeRange(resultTypes) != TypeRange(flattenedTypes))
return {};
// Recursively decompose the tuple.
SmallVector<Value> result;
std::function<void(Value)> decompose = [&](Value tuple) {
auto tupleType = dyn_cast<TupleType>(tuple.getType());
if (!tupleType) {
// This is not a tuple.
result.push_back(tuple);
return;
}
}
return success();
for (unsigned i = 0, e = tupleType.size(); i < e; ++i) {
Type elementType = tupleType.getType(i);
Value element = builder.create<test::GetTupleElementOp>(
loc, elementType, tuple, builder.getI32IntegerAttr(i));
decompose(element);
}
};
decompose(tuple);
return result;
}

/// Creates a `test.make_tuple` op out of the given inputs building a tuple of
Expand Down Expand Up @@ -82,8 +99,8 @@ static Value buildMakeTupleOp(OpBuilder &builder, TupleType resultType,

/// A pass for testing call graph type decomposition.
///
/// This instantiates the patterns with a TypeConverter and ValueDecomposer
/// that splits tuple types into their respective element types.
/// This instantiates the patterns with a TypeConverter that splits tuple types
/// into their respective element types.
/// For example, `tuple<T1, T2, T3> --> T1, T2, T3`.
struct TestDecomposeCallGraphTypes
: public PassWrapper<TestDecomposeCallGraphTypes, OperationPass<ModuleOp>> {
Expand Down Expand Up @@ -123,12 +140,9 @@ struct TestDecomposeCallGraphTypes
return success();
});
typeConverter.addArgumentMaterialization(buildMakeTupleOp);
typeConverter.addTargetMaterialization(buildDecomposeTuple);

ValueDecomposer decomposer;
decomposer.addDecomposeValueConversion(buildDecomposeTuple);

populateDecomposeCallGraphTypesPatterns(context, typeConverter, decomposer,
patterns);
populateDecomposeCallGraphTypesPatterns(context, typeConverter, patterns);

if (failed(applyPartialConversion(module, target, std::move(patterns))))
return signalPassFailure();
Expand Down
Loading