Skip to content

Commit f18c3e4

Browse files
[mlir][Transforms] Dialect Conversion: Simplify materialization fn result type (#113031)
This commit simplifies the result type of materialization functions. Previously: `std::optional<Value>` Now: `Value` The previous implementation allowed 3 possible return values: - Non-null value: The materialization function produced a valid materialization. - `std::nullopt`: The materialization function failed, but another materialization can be attempted. - `Value()`: The materialization failed and so should the dialect conversion. (Previously: Dialect conversion can roll back.) This commit removes the last variant. It is not particularly useful because the dialect conversion will fail anyway if all other materialization functions produced `std::nullopt`. Furthermore, in contrast to type conversions, at least one materialization callback is expected to succeed. In case of a failing type conversion, the current dialect conversion can roll back and try a different pattern. This also used to be the case for materializations, but that functionality was removed with #107109: failed materializations can no longer trigger a rollback. (They can just make the entire dialect conversion fail without rollback.) With this in mind, it is even less useful to have an additional error state for materialization functions. This commit is in preparation of merging the 1:1 and 1:N type converters. Target materializations will have to return multiple values instead of a single one. With this commit, we can keep the API simple: `SmallVector<Value>` instead of `std::optional<SmallVector<Value>>`. Note for LLVM integration: All 1:1 materializations should return `Value` instead of `std::optional<Value>`. Instead of `std::nullopt` return `Value()`.
1 parent 8a9921f commit f18c3e4

File tree

16 files changed

+87
-99
lines changed

16 files changed

+87
-99
lines changed

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -163,14 +163,14 @@ class TypeConverter {
163163

164164
/// All of the following materializations require function objects that are
165165
/// convertible to the following form:
166-
/// `std::optional<Value>(OpBuilder &, T, ValueRange, Location)`,
166+
/// `Value(OpBuilder &, T, ValueRange, Location)`,
167167
/// where `T` is any subclass of `Type`. This function is responsible for
168168
/// creating an operation, using the OpBuilder and Location provided, that
169169
/// "casts" a range of values into a single value of the given type `T`. It
170-
/// must return a Value of the type `T` on success, an `std::nullopt` if
171-
/// it failed but other materialization can be attempted, and `nullptr` on
172-
/// unrecoverable failure. Materialization functions must be provided when a
173-
/// type conversion may persist after the conversion has finished.
170+
/// must return a Value of the type `T` on success and `nullptr` if
171+
/// it failed but other materialization should be attempted. Materialization
172+
/// functions must be provided when a type conversion may persist after the
173+
/// conversion has finished.
174174
///
175175
/// Note: Target materializations may optionally accept an additional Type
176176
/// parameter, which is the original type of the SSA value.
@@ -335,14 +335,14 @@ class TypeConverter {
335335
/// conversion.
336336
///
337337
/// Arguments: builder, result type, inputs, location
338-
using MaterializationCallbackFn = std::function<std::optional<Value>(
339-
OpBuilder &, Type, ValueRange, Location)>;
338+
using MaterializationCallbackFn =
339+
std::function<Value(OpBuilder &, Type, ValueRange, Location)>;
340340

341341
/// The signature of the callback used to materialize a target conversion.
342342
///
343343
/// Arguments: builder, result type, inputs, location, original type
344-
using TargetMaterializationCallbackFn = std::function<std::optional<Value>(
345-
OpBuilder &, Type, ValueRange, Location, Type)>;
344+
using TargetMaterializationCallbackFn =
345+
std::function<Value(OpBuilder &, Type, ValueRange, Location, Type)>;
346346

347347
/// The signature of the callback used to convert a type attribute.
348348
using TypeAttributeConversionCallbackFn =
@@ -396,10 +396,10 @@ class TypeConverter {
396396
MaterializationCallbackFn wrapMaterialization(FnT &&callback) const {
397397
return [callback = std::forward<FnT>(callback)](
398398
OpBuilder &builder, Type resultType, ValueRange inputs,
399-
Location loc) -> std::optional<Value> {
399+
Location loc) -> Value {
400400
if (T derivedType = dyn_cast<T>(resultType))
401401
return callback(builder, derivedType, inputs, loc);
402-
return std::nullopt;
402+
return Value();
403403
};
404404
}
405405

@@ -417,10 +417,10 @@ class TypeConverter {
417417
wrapTargetMaterialization(FnT &&callback) const {
418418
return [callback = std::forward<FnT>(callback)](
419419
OpBuilder &builder, Type resultType, ValueRange inputs,
420-
Location loc, Type originalType) -> std::optional<Value> {
420+
Location loc, Type originalType) -> Value {
421421
if (T derivedType = dyn_cast<T>(resultType))
422422
return callback(builder, derivedType, inputs, loc, originalType);
423-
return std::nullopt;
423+
return Value();
424424
};
425425
}
426426
/// With callback of form:
@@ -433,7 +433,7 @@ class TypeConverter {
433433
return wrapTargetMaterialization<T>(
434434
[callback = std::forward<FnT>(callback)](
435435
OpBuilder &builder, T resultType, ValueRange inputs, Location loc,
436-
Type originalType) -> std::optional<Value> {
436+
Type originalType) -> Value {
437437
return callback(builder, resultType, inputs, loc);
438438
});
439439
}

mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,9 +282,9 @@ class AsyncRuntimeTypeConverter : public TypeConverter {
282282
// Use UnrealizedConversionCast as the bridge so that we don't need to pull
283283
// in patterns for other dialects.
284284
auto addUnrealizedCast = [](OpBuilder &builder, Type type,
285-
ValueRange inputs, Location loc) {
285+
ValueRange inputs, Location loc) -> Value {
286286
auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
287-
return std::optional<Value>(cast.getResult(0));
287+
return cast.getResult(0);
288288
};
289289

290290
addSourceMaterialization(addUnrealizedCast);

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -158,36 +158,35 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
158158
// original block argument type. The dialect conversion framework will then
159159
// insert a target materialization from the original block argument type to
160160
// a legal type.
161-
addArgumentMaterialization(
162-
[&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
163-
Location loc) -> std::optional<Value> {
164-
if (inputs.size() == 1) {
165-
// Bare pointers are not supported for unranked memrefs because a
166-
// memref descriptor cannot be built just from a bare pointer.
167-
return std::nullopt;
168-
}
169-
Value desc = UnrankedMemRefDescriptor::pack(builder, loc, *this,
170-
resultType, inputs);
171-
// An argument materialization must return a value of type
172-
// `resultType`, so insert a cast from the memref descriptor type
173-
// (!llvm.struct) to the original memref type.
174-
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
175-
.getResult(0);
176-
});
161+
addArgumentMaterialization([&](OpBuilder &builder,
162+
UnrankedMemRefType resultType,
163+
ValueRange inputs, Location loc) {
164+
if (inputs.size() == 1) {
165+
// Bare pointers are not supported for unranked memrefs because a
166+
// memref descriptor cannot be built just from a bare pointer.
167+
return Value();
168+
}
169+
Value desc =
170+
UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
171+
// An argument materialization must return a value of type
172+
// `resultType`, so insert a cast from the memref descriptor type
173+
// (!llvm.struct) to the original memref type.
174+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
175+
.getResult(0);
176+
});
177177
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
178-
ValueRange inputs,
179-
Location loc) -> std::optional<Value> {
178+
ValueRange inputs, Location loc) {
180179
Value desc;
181180
if (inputs.size() == 1) {
182181
// This is a bare pointer. We allow bare pointers only for function entry
183182
// blocks.
184183
BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
185184
if (!barePtr)
186-
return std::nullopt;
185+
return Value();
187186
Block *block = barePtr.getOwner();
188187
if (!block->isEntryBlock() ||
189188
!isa<FunctionOpInterface>(block->getParentOp()))
190-
return std::nullopt;
189+
return Value();
191190
desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
192191
inputs[0]);
193192
} else {
@@ -202,19 +201,17 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
202201
// Add generic source and target materializations to handle cases where
203202
// non-LLVM types persist after an LLVM conversion.
204203
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
205-
ValueRange inputs,
206-
Location loc) -> std::optional<Value> {
204+
ValueRange inputs, Location loc) {
207205
if (inputs.size() != 1)
208-
return std::nullopt;
206+
return Value();
209207

210208
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
211209
.getResult(0);
212210
});
213211
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
214-
ValueRange inputs,
215-
Location loc) -> std::optional<Value> {
212+
ValueRange inputs, Location loc) {
216213
if (inputs.size() != 1)
217-
return std::nullopt;
214+
return Value();
218215

219216
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
220217
.getResult(0);

mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,10 @@ using namespace mlir;
1616

1717
namespace {
1818

19-
std::optional<Value> materializeAsUnrealizedCast(OpBuilder &builder,
20-
Type resultType,
21-
ValueRange inputs,
22-
Location loc) {
19+
Value materializeAsUnrealizedCast(OpBuilder &builder, Type resultType,
20+
ValueRange inputs, Location loc) {
2321
if (inputs.size() != 1)
24-
return std::nullopt;
22+
return Value();
2523

2624
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
2725
.getResult(0);

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -659,9 +659,9 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
659659
/// This function is meant to handle the **compute** side; so it does not
660660
/// involve storage classes in its logic. The storage side is expected to be
661661
/// handled by MemRef conversion logic.
662-
static std::optional<Value> castToSourceType(const spirv::TargetEnv &targetEnv,
663-
OpBuilder &builder, Type type,
664-
ValueRange inputs, Location loc) {
662+
static Value castToSourceType(const spirv::TargetEnv &targetEnv,
663+
OpBuilder &builder, Type type, ValueRange inputs,
664+
Location loc) {
665665
// We can only cast one value in SPIR-V.
666666
if (inputs.size() != 1) {
667667
auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
@@ -1459,7 +1459,7 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
14591459
addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs,
14601460
Location loc) {
14611461
auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
1462-
return std::optional<Value>(cast.getResult(0));
1462+
return cast.getResult(0);
14631463
});
14641464
}
14651465

mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,8 +425,7 @@ mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
425425
addConversion(convertIterSpaceType);
426426

427427
addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp,
428-
ValueRange inputs,
429-
Location loc) -> std::optional<Value> {
428+
ValueRange inputs, Location loc) -> Value {
430429
return builder
431430
.create<UnrealizedConversionCastOp>(loc, TypeRange(spTp), inputs)
432431
.getResult(0);

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,10 @@ SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
6060

6161
// Required by scf.for 1:N type conversion.
6262
addSourceMaterialization([](OpBuilder &builder, RankedTensorType tp,
63-
ValueRange inputs,
64-
Location loc) -> std::optional<Value> {
63+
ValueRange inputs, Location loc) -> Value {
6564
if (!getSparseTensorEncoding(tp))
6665
// Not a sparse tensor.
67-
return std::nullopt;
66+
return Value();
6867
// Sparsifier knows how to cancel out these casts.
6968
return genTuple(builder, loc, tp, inputs);
7069
});

mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -153,29 +153,29 @@ void transform::TypeConversionCastShapeDynamicDimsOp::
153153
converter.addSourceMaterialization([ignoreDynamicInfo](
154154
OpBuilder &builder, Type resultType,
155155
ValueRange inputs,
156-
Location loc) -> std::optional<Value> {
156+
Location loc) -> Value {
157157
if (inputs.size() != 1) {
158-
return std::nullopt;
158+
return Value();
159159
}
160160
Value input = inputs[0];
161161
if (!ignoreDynamicInfo &&
162162
!tensor::preservesStaticInformation(resultType, input.getType())) {
163-
return std::nullopt;
163+
return Value();
164164
}
165165
if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
166-
return std::nullopt;
166+
return Value();
167167
}
168168
return builder.create<tensor::CastOp>(loc, resultType, input).getResult();
169169
});
170170
converter.addTargetMaterialization([](OpBuilder &builder, Type resultType,
171171
ValueRange inputs,
172-
Location loc) -> std::optional<Value> {
172+
Location loc) -> Value {
173173
if (inputs.size() != 1) {
174-
return std::nullopt;
174+
return Value();
175175
}
176176
Value input = inputs[0];
177177
if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
178-
return std::nullopt;
178+
return Value();
179179
}
180180
return builder.create<tensor::CastOp>(loc, resultType, input).getResult();
181181
});

mlir/lib/Dialect/Tosa/Transforms/TosaTypeConverters.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,18 @@ void mlir::tosa::populateTosaTypeConversion(TypeConverter &converter) {
3333
});
3434
converter.addSourceMaterialization([&](OpBuilder &builder, Type resultType,
3535
ValueRange inputs,
36-
Location loc) -> std::optional<Value> {
36+
Location loc) -> Value {
3737
if (inputs.size() != 1)
38-
return std::nullopt;
38+
return Value();
3939

4040
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
4141
.getResult(0);
4242
});
4343
converter.addTargetMaterialization([&](OpBuilder &builder, Type resultType,
4444
ValueRange inputs,
45-
Location loc) -> std::optional<Value> {
45+
Location loc) -> Value {
4646
if (inputs.size() != 1)
47-
return std::nullopt;
47+
return Value();
4848

4949
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
5050
.getResult(0);

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2812,8 +2812,8 @@ Value TypeConverter::materializeArgumentConversion(OpBuilder &builder,
28122812
ValueRange inputs) const {
28132813
for (const MaterializationCallbackFn &fn :
28142814
llvm::reverse(argumentMaterializations))
2815-
if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
2816-
return *result;
2815+
if (Value result = fn(builder, resultType, inputs, loc))
2816+
return result;
28172817
return nullptr;
28182818
}
28192819

@@ -2822,8 +2822,8 @@ Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
28222822
ValueRange inputs) const {
28232823
for (const MaterializationCallbackFn &fn :
28242824
llvm::reverse(sourceMaterializations))
2825-
if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
2826-
return *result;
2825+
if (Value result = fn(builder, resultType, inputs, loc))
2826+
return result;
28272827
return nullptr;
28282828
}
28292829

@@ -2833,9 +2833,8 @@ Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
28332833
Type originalType) const {
28342834
for (const TargetMaterializationCallbackFn &fn :
28352835
llvm::reverse(targetMaterializations))
2836-
if (std::optional<Value> result =
2837-
fn(builder, resultType, inputs, loc, originalType))
2838-
return *result;
2836+
if (Value result = fn(builder, resultType, inputs, loc, originalType))
2837+
return result;
28392838
return nullptr;
28402839
}
28412840

mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,8 @@ buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input,
180180
///
181181
/// This function has been copied (with small adaptions) from
182182
/// TestDecomposeCallGraphTypes.cpp.
183-
static std::optional<Value> buildMakeTupleOp(OpBuilder &builder,
184-
TupleType resultType,
185-
ValueRange inputs, Location loc) {
183+
static Value buildMakeTupleOp(OpBuilder &builder, TupleType resultType,
184+
ValueRange inputs, Location loc) {
186185
// Build one value for each element at this nesting level.
187186
SmallVector<Value> elements;
188187
elements.reserve(resultType.getTypes().size());
@@ -201,13 +200,13 @@ static std::optional<Value> buildMakeTupleOp(OpBuilder &builder,
201200
inputIt += numNestedFlattenedTypes;
202201

203202
// Recurse on the values for the nested TupleType.
204-
std::optional<Value> res = buildMakeTupleOp(builder, nestedTupleType,
205-
nestedFlattenedelements, loc);
206-
if (!res.has_value())
207-
return {};
203+
Value res = buildMakeTupleOp(builder, nestedTupleType,
204+
nestedFlattenedelements, loc);
205+
if (!res)
206+
return Value();
208207

209208
// The tuple constructed by the conversion is the element value.
210-
elements.push_back(res.value());
209+
elements.push_back(res);
211210
} else {
212211
// Base case: take one input as is.
213212
elements.push_back(*inputIt++);

mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ struct TestEmulateWideIntPass
5959
// TODO: Consider extending `arith.bitcast` to support scalar-to-1D-vector
6060
// casts (and vice versa) and using it insted of `llvm.bitcast`.
6161
auto addBitcast = [](OpBuilder &builder, Type type, ValueRange inputs,
62-
Location loc) -> std::optional<Value> {
62+
Location loc) -> Value {
6363
auto cast = builder.create<LLVM::BitcastOp>(loc, type, inputs);
6464
return cast->getResult(0);
6565
};

mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,8 @@ static LogicalResult buildDecomposeTuple(OpBuilder &builder, Location loc,
4343
/// Creates a `test.make_tuple` op out of the given inputs building a tuple of
4444
/// type `resultType`. If that type is nested, each nested tuple is built
4545
/// recursively with another `test.make_tuple` op.
46-
static std::optional<Value> buildMakeTupleOp(OpBuilder &builder,
47-
TupleType resultType,
48-
ValueRange inputs, Location loc) {
46+
static Value buildMakeTupleOp(OpBuilder &builder, TupleType resultType,
47+
ValueRange inputs, Location loc) {
4948
// Build one value for each element at this nesting level.
5049
SmallVector<Value> elements;
5150
elements.reserve(resultType.getTypes().size());
@@ -64,13 +63,13 @@ static std::optional<Value> buildMakeTupleOp(OpBuilder &builder,
6463
inputIt += numNestedFlattenedTypes;
6564

6665
// Recurse on the values for the nested TupleType.
67-
std::optional<Value> res = buildMakeTupleOp(builder, nestedTupleType,
68-
nestedFlattenedelements, loc);
69-
if (!res.has_value())
70-
return {};
66+
Value res = buildMakeTupleOp(builder, nestedTupleType,
67+
nestedFlattenedelements, loc);
68+
if (!res)
69+
return Value();
7170

7271
// The tuple constructed by the conversion is the element value.
73-
elements.push_back(res.value());
72+
elements.push_back(res);
7473
} else {
7574
// Base case: take one input as is.
7675
elements.push_back(*inputIt++);

0 commit comments

Comments
 (0)