Skip to content

Commit df31fd8

Browse files
[mlir] Fix use-after-return in #117513 (#120968)
Fix a use-after-return in #117513. Free-standing lambdas should not be defined inside of the `LLVMTypeConverter` constructor because they go out of scope.
1 parent 11676da commit df31fd8

File tree

2 files changed

+123
-107
lines changed

2 files changed

+123
-107
lines changed

mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,41 @@ class LLVMTypeConverter : public TypeConverter {
161161
/// Check if a memref type can be converted to a bare pointer.
162162
static bool canConvertToBarePtr(BaseMemRefType type);
163163

164+
/// Convert a memref type into a list of LLVM IR types that will form the
165+
/// memref descriptor. If `unpackAggregates` is true the `sizes` and `strides`
166+
/// arrays in the descriptors are unpacked to individual index-typed elements,
167+
/// else they are kept as rank-sized arrays of index type. In particular,
168+
/// the list will contain:
169+
/// - two pointers to the memref element type, followed by
170+
/// - an index-typed offset, followed by
171+
/// - (if unpackAggregates = true)
172+
/// - one index-typed size per dimension of the memref, followed by
173+
/// - one index-typed stride per dimension of the memref.
174+
/// - (if unpackArrregates = false)
175+
/// - one rank-sized array of index-type for the size of each dimension
176+
/// - one rank-sized array of index-type for the stride of each dimension
177+
///
178+
/// For example, memref<?x?xf32> is converted to the following list:
179+
/// - `!llvm<"float*">` (allocated pointer),
180+
/// - `!llvm<"float*">` (aligned pointer),
181+
/// - `i64` (offset),
182+
/// - `i64`, `i64` (sizes),
183+
/// - `i64`, `i64` (strides).
184+
/// These types can be recomposed to a memref descriptor struct.
185+
SmallVector<Type, 5> getMemRefDescriptorFields(MemRefType type,
186+
bool unpackAggregates) const;
187+
188+
/// Convert an unranked memref type into a list of non-aggregate LLVM IR types
189+
/// that will form the unranked memref descriptor. In particular, this list
190+
/// contains:
191+
/// - an integer rank, followed by
192+
/// - a pointer to the memref descriptor struct.
193+
/// For example, memref<*xf32> is converted to the following list:
194+
/// i64 (rank)
195+
/// !llvm<"i8*"> (type-erased pointer).
196+
/// These types can be recomposed to a unranked memref descriptor struct.
197+
SmallVector<Type, 2> getUnrankedMemRefDescriptorFields() const;
198+
164199
protected:
165200
/// Pointer to the LLVM dialect.
166201
LLVM::LLVMDialect *llvmDialect;
@@ -213,41 +248,6 @@ class LLVMTypeConverter : public TypeConverter {
213248
/// Convert a memref type into an LLVM type that captures the relevant data.
214249
Type convertMemRefType(MemRefType type) const;
215250

216-
/// Convert a memref type into a list of LLVM IR types that will form the
217-
/// memref descriptor. If `unpackAggregates` is true the `sizes` and `strides`
218-
/// arrays in the descriptors are unpacked to individual index-typed elements,
219-
/// else they are kept as rank-sized arrays of index type. In particular,
220-
/// the list will contain:
221-
/// - two pointers to the memref element type, followed by
222-
/// - an index-typed offset, followed by
223-
/// - (if unpackAggregates = true)
224-
/// - one index-typed size per dimension of the memref, followed by
225-
/// - one index-typed stride per dimension of the memref.
226-
/// - (if unpackArrregates = false)
227-
/// - one rank-sized array of index-type for the size of each dimension
228-
/// - one rank-sized array of index-type for the stride of each dimension
229-
///
230-
/// For example, memref<?x?xf32> is converted to the following list:
231-
/// - `!llvm<"float*">` (allocated pointer),
232-
/// - `!llvm<"float*">` (aligned pointer),
233-
/// - `i64` (offset),
234-
/// - `i64`, `i64` (sizes),
235-
/// - `i64`, `i64` (strides).
236-
/// These types can be recomposed to a memref descriptor struct.
237-
SmallVector<Type, 5> getMemRefDescriptorFields(MemRefType type,
238-
bool unpackAggregates) const;
239-
240-
/// Convert an unranked memref type into a list of non-aggregate LLVM IR types
241-
/// that will form the unranked memref descriptor. In particular, this list
242-
/// contains:
243-
/// - an integer rank, followed by
244-
/// - a pointer to the memref descriptor struct.
245-
/// For example, memref<*xf32> is converted to the following list:
246-
/// i64 (rank)
247-
/// !llvm<"i8*"> (type-erased pointer).
248-
/// These types can be recomposed to a unranked memref descriptor struct.
249-
SmallVector<Type, 2> getUnrankedMemRefDescriptorFields() const;
250-
251251
/// Convert an unranked memref type to an LLVM type that captures the
252252
/// runtime rank and a pointer to the static ranked memref desc
253253
Type convertUnrankedMemRefType(UnrankedMemRefType type) const;

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 88 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,74 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
4444
const DataLayoutAnalysis *analysis)
4545
: LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {}
4646

47+
/// Helper function that checks if the given value range is a bare pointer.
48+
static bool isBarePointer(ValueRange values) {
49+
return values.size() == 1 &&
50+
isa<LLVM::LLVMPointerType>(values.front().getType());
51+
};
52+
53+
/// Pack SSA values into an unranked memref descriptor struct.
54+
static Value packUnrankedMemRefDesc(OpBuilder &builder,
55+
UnrankedMemRefType resultType,
56+
ValueRange inputs, Location loc,
57+
const LLVMTypeConverter &converter) {
58+
// Note: Bare pointers are not supported for unranked memrefs because a
59+
// memref descriptor cannot be built just from a bare pointer.
60+
if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields())
61+
return Value();
62+
return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType,
63+
inputs);
64+
}
65+
66+
/// Pack SSA values into a ranked memref descriptor struct.
67+
static Value packRankedMemRefDesc(OpBuilder &builder, MemRefType resultType,
68+
ValueRange inputs, Location loc,
69+
const LLVMTypeConverter &converter) {
70+
assert(resultType && "expected non-null result type");
71+
if (isBarePointer(inputs))
72+
return MemRefDescriptor::fromStaticShape(builder, loc, converter,
73+
resultType, inputs[0]);
74+
if (TypeRange(inputs) ==
75+
converter.getMemRefDescriptorFields(resultType,
76+
/*unpackAggregates=*/true))
77+
return MemRefDescriptor::pack(builder, loc, converter, resultType, inputs);
78+
// The inputs are neither a bare pointer nor an unpacked memref descriptor.
79+
// This materialization function cannot be used.
80+
return Value();
81+
}
82+
83+
/// MemRef descriptor elements -> UnrankedMemRefType
84+
static Value unrankedMemRefMaterialization(OpBuilder &builder,
85+
UnrankedMemRefType resultType,
86+
ValueRange inputs, Location loc,
87+
const LLVMTypeConverter &converter) {
88+
// An argument materialization must return a value of type
89+
// `resultType`, so insert a cast from the memref descriptor type
90+
// (!llvm.struct) to the original memref type.
91+
Value packed =
92+
packUnrankedMemRefDesc(builder, resultType, inputs, loc, converter);
93+
if (!packed)
94+
return Value();
95+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
96+
.getResult(0);
97+
};
98+
99+
/// MemRef descriptor elements -> MemRefType
100+
static Value rankedMemRefMaterialization(OpBuilder &builder,
101+
MemRefType resultType,
102+
ValueRange inputs, Location loc,
103+
const LLVMTypeConverter &converter) {
104+
// An argument materialization must return a value of type `resultType`,
105+
// so insert a cast from the memref descriptor type (!llvm.struct) to the
106+
// original memref type.
107+
Value packed =
108+
packRankedMemRefDesc(builder, resultType, inputs, loc, converter);
109+
if (!packed)
110+
return Value();
111+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
112+
.getResult(0);
113+
}
114+
47115
/// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
48116
LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
49117
const LowerToLLVMOptions &options,
@@ -166,81 +234,29 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
166234
.getResult(0);
167235
});
168236

169-
// Helper function that checks if the given value range is a bare pointer.
170-
auto isBarePointer = [](ValueRange values) {
171-
return values.size() == 1 &&
172-
isa<LLVM::LLVMPointerType>(values.front().getType());
173-
};
174-
175-
// TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
176-
// must be passed explicitly.
177-
auto packUnrankedMemRefDesc =
178-
[&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
179-
Location loc, LLVMTypeConverter &converter) -> Value {
180-
// Note: Bare pointers are not supported for unranked memrefs because a
181-
// memref descriptor cannot be built just from a bare pointer.
182-
if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields())
183-
return Value();
184-
return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType,
185-
inputs);
186-
};
187-
188-
// MemRef descriptor elements -> UnrankedMemRefType
189-
auto unrakedMemRefMaterialization = [&](OpBuilder &builder,
190-
UnrankedMemRefType resultType,
191-
ValueRange inputs, Location loc) {
192-
// An argument materialization must return a value of type
193-
// `resultType`, so insert a cast from the memref descriptor type
194-
// (!llvm.struct) to the original memref type.
195-
Value packed =
196-
packUnrankedMemRefDesc(builder, resultType, inputs, loc, *this);
197-
if (!packed)
198-
return Value();
199-
return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
200-
.getResult(0);
201-
};
202-
203-
// TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
204-
// must be passed explicitly.
205-
auto packRankedMemRefDesc = [&](OpBuilder &builder, MemRefType resultType,
206-
ValueRange inputs, Location loc,
207-
LLVMTypeConverter &converter) -> Value {
208-
assert(resultType && "expected non-null result type");
209-
if (isBarePointer(inputs))
210-
return MemRefDescriptor::fromStaticShape(builder, loc, converter,
211-
resultType, inputs[0]);
212-
if (TypeRange(inputs) ==
213-
converter.getMemRefDescriptorFields(resultType,
214-
/*unpackAggregates=*/true))
215-
return MemRefDescriptor::pack(builder, loc, converter, resultType,
216-
inputs);
217-
// The inputs are neither a bare pointer nor an unpacked memref descriptor.
218-
// This materialization function cannot be used.
219-
return Value();
220-
};
221-
222-
// MemRef descriptor elements -> MemRefType
223-
auto rankedMemRefMaterialization = [&](OpBuilder &builder,
224-
MemRefType resultType,
225-
ValueRange inputs, Location loc) {
226-
// An argument materialization must return a value of type `resultType`,
227-
// so insert a cast from the memref descriptor type (!llvm.struct) to the
228-
// original memref type.
229-
Value packed =
230-
packRankedMemRefDesc(builder, resultType, inputs, loc, *this);
231-
if (!packed)
232-
return Value();
233-
return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
234-
.getResult(0);
235-
};
236-
237237
// Argument materializations convert from the new block argument types
238238
// (multiple SSA values that make up a memref descriptor) back to the
239239
// original block argument type.
240-
addArgumentMaterialization(unrakedMemRefMaterialization);
241-
addArgumentMaterialization(rankedMemRefMaterialization);
242-
addSourceMaterialization(unrakedMemRefMaterialization);
243-
addSourceMaterialization(rankedMemRefMaterialization);
240+
addArgumentMaterialization([&](OpBuilder &builder,
241+
UnrankedMemRefType resultType,
242+
ValueRange inputs, Location loc) {
243+
return unrankedMemRefMaterialization(builder, resultType, inputs, loc,
244+
*this);
245+
});
246+
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
247+
ValueRange inputs, Location loc) {
248+
return rankedMemRefMaterialization(builder, resultType, inputs, loc, *this);
249+
});
250+
addSourceMaterialization([&](OpBuilder &builder,
251+
UnrankedMemRefType resultType, ValueRange inputs,
252+
Location loc) {
253+
return unrankedMemRefMaterialization(builder, resultType, inputs, loc,
254+
*this);
255+
});
256+
addSourceMaterialization([&](OpBuilder &builder, MemRefType resultType,
257+
ValueRange inputs, Location loc) {
258+
return rankedMemRefMaterialization(builder, resultType, inputs, loc, *this);
259+
});
244260

245261
// Bare pointer -> Packed MemRef descriptor
246262
addTargetMaterialization([&](OpBuilder &builder, Type resultType,

0 commit comments

Comments
 (0)