@@ -44,6 +44,74 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
44
44
const DataLayoutAnalysis *analysis)
45
45
: LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {}
46
46
47
+ // / Helper function that checks if the given value range is a bare pointer.
48
+ static bool isBarePointer (ValueRange values) {
49
+ return values.size () == 1 &&
50
+ isa<LLVM::LLVMPointerType>(values.front ().getType ());
51
+ };
52
+
53
+ // / Pack SSA values into an unranked memref descriptor struct.
54
+ static Value packUnrankedMemRefDesc (OpBuilder &builder,
55
+ UnrankedMemRefType resultType,
56
+ ValueRange inputs, Location loc,
57
+ const LLVMTypeConverter &converter) {
58
+ // Note: Bare pointers are not supported for unranked memrefs because a
59
+ // memref descriptor cannot be built just from a bare pointer.
60
+ if (TypeRange (inputs) != converter.getUnrankedMemRefDescriptorFields ())
61
+ return Value ();
62
+ return UnrankedMemRefDescriptor::pack (builder, loc, converter, resultType,
63
+ inputs);
64
+ }
65
+
66
+ // / Pack SSA values into a ranked memref descriptor struct.
67
+ static Value packRankedMemRefDesc (OpBuilder &builder, MemRefType resultType,
68
+ ValueRange inputs, Location loc,
69
+ const LLVMTypeConverter &converter) {
70
+ assert (resultType && " expected non-null result type" );
71
+ if (isBarePointer (inputs))
72
+ return MemRefDescriptor::fromStaticShape (builder, loc, converter,
73
+ resultType, inputs[0 ]);
74
+ if (TypeRange (inputs) ==
75
+ converter.getMemRefDescriptorFields (resultType,
76
+ /* unpackAggregates=*/ true ))
77
+ return MemRefDescriptor::pack (builder, loc, converter, resultType, inputs);
78
+ // The inputs are neither a bare pointer nor an unpacked memref descriptor.
79
+ // This materialization function cannot be used.
80
+ return Value ();
81
+ }
82
+
83
+ // / MemRef descriptor elements -> UnrankedMemRefType
84
+ static Value unrankedMemRefMaterialization (OpBuilder &builder,
85
+ UnrankedMemRefType resultType,
86
+ ValueRange inputs, Location loc,
87
+ const LLVMTypeConverter &converter) {
88
+ // An argument materialization must return a value of type
89
+ // `resultType`, so insert a cast from the memref descriptor type
90
+ // (!llvm.struct) to the original memref type.
91
+ Value packed =
92
+ packUnrankedMemRefDesc (builder, resultType, inputs, loc, converter);
93
+ if (!packed)
94
+ return Value ();
95
+ return builder.create <UnrealizedConversionCastOp>(loc, resultType, packed)
96
+ .getResult (0 );
97
+ };
98
+
99
+ // / MemRef descriptor elements -> MemRefType
100
+ static Value rankedMemRefMaterialization (OpBuilder &builder,
101
+ MemRefType resultType,
102
+ ValueRange inputs, Location loc,
103
+ const LLVMTypeConverter &converter) {
104
+ // An argument materialization must return a value of type `resultType`,
105
+ // so insert a cast from the memref descriptor type (!llvm.struct) to the
106
+ // original memref type.
107
+ Value packed =
108
+ packRankedMemRefDesc (builder, resultType, inputs, loc, converter);
109
+ if (!packed)
110
+ return Value ();
111
+ return builder.create <UnrealizedConversionCastOp>(loc, resultType, packed)
112
+ .getResult (0 );
113
+ }
114
+
47
115
// / Create an LLVMTypeConverter using custom LowerToLLVMOptions.
48
116
LLVMTypeConverter::LLVMTypeConverter (MLIRContext *ctx,
49
117
const LowerToLLVMOptions &options,
@@ -166,81 +234,29 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
166
234
.getResult (0 );
167
235
});
168
236
169
- // Helper function that checks if the given value range is a bare pointer.
170
- auto isBarePointer = [](ValueRange values) {
171
- return values.size () == 1 &&
172
- isa<LLVM::LLVMPointerType>(values.front ().getType ());
173
- };
174
-
175
- // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
176
- // must be passed explicitly.
177
- auto packUnrankedMemRefDesc =
178
- [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
179
- Location loc, LLVMTypeConverter &converter) -> Value {
180
- // Note: Bare pointers are not supported for unranked memrefs because a
181
- // memref descriptor cannot be built just from a bare pointer.
182
- if (TypeRange (inputs) != converter.getUnrankedMemRefDescriptorFields ())
183
- return Value ();
184
- return UnrankedMemRefDescriptor::pack (builder, loc, converter, resultType,
185
- inputs);
186
- };
187
-
188
- // MemRef descriptor elements -> UnrankedMemRefType
189
- auto unrakedMemRefMaterialization = [&](OpBuilder &builder,
190
- UnrankedMemRefType resultType,
191
- ValueRange inputs, Location loc) {
192
- // An argument materialization must return a value of type
193
- // `resultType`, so insert a cast from the memref descriptor type
194
- // (!llvm.struct) to the original memref type.
195
- Value packed =
196
- packUnrankedMemRefDesc (builder, resultType, inputs, loc, *this );
197
- if (!packed)
198
- return Value ();
199
- return builder.create <UnrealizedConversionCastOp>(loc, resultType, packed)
200
- .getResult (0 );
201
- };
202
-
203
- // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
204
- // must be passed explicitly.
205
- auto packRankedMemRefDesc = [&](OpBuilder &builder, MemRefType resultType,
206
- ValueRange inputs, Location loc,
207
- LLVMTypeConverter &converter) -> Value {
208
- assert (resultType && " expected non-null result type" );
209
- if (isBarePointer (inputs))
210
- return MemRefDescriptor::fromStaticShape (builder, loc, converter,
211
- resultType, inputs[0 ]);
212
- if (TypeRange (inputs) ==
213
- converter.getMemRefDescriptorFields (resultType,
214
- /* unpackAggregates=*/ true ))
215
- return MemRefDescriptor::pack (builder, loc, converter, resultType,
216
- inputs);
217
- // The inputs are neither a bare pointer nor an unpacked memref descriptor.
218
- // This materialization function cannot be used.
219
- return Value ();
220
- };
221
-
222
- // MemRef descriptor elements -> MemRefType
223
- auto rankedMemRefMaterialization = [&](OpBuilder &builder,
224
- MemRefType resultType,
225
- ValueRange inputs, Location loc) {
226
- // An argument materialization must return a value of type `resultType`,
227
- // so insert a cast from the memref descriptor type (!llvm.struct) to the
228
- // original memref type.
229
- Value packed =
230
- packRankedMemRefDesc (builder, resultType, inputs, loc, *this );
231
- if (!packed)
232
- return Value ();
233
- return builder.create <UnrealizedConversionCastOp>(loc, resultType, packed)
234
- .getResult (0 );
235
- };
236
-
237
237
// Argument materializations convert from the new block argument types
238
238
// (multiple SSA values that make up a memref descriptor) back to the
239
239
// original block argument type.
240
- addArgumentMaterialization (unrakedMemRefMaterialization);
241
- addArgumentMaterialization (rankedMemRefMaterialization);
242
- addSourceMaterialization (unrakedMemRefMaterialization);
243
- addSourceMaterialization (rankedMemRefMaterialization);
240
+ addArgumentMaterialization ([&](OpBuilder &builder,
241
+ UnrankedMemRefType resultType,
242
+ ValueRange inputs, Location loc) {
243
+ return unrankedMemRefMaterialization (builder, resultType, inputs, loc,
244
+ *this );
245
+ });
246
+ addArgumentMaterialization ([&](OpBuilder &builder, MemRefType resultType,
247
+ ValueRange inputs, Location loc) {
248
+ return rankedMemRefMaterialization (builder, resultType, inputs, loc, *this );
249
+ });
250
+ addSourceMaterialization ([&](OpBuilder &builder,
251
+ UnrankedMemRefType resultType, ValueRange inputs,
252
+ Location loc) {
253
+ return unrankedMemRefMaterialization (builder, resultType, inputs, loc,
254
+ *this );
255
+ });
256
+ addSourceMaterialization ([&](OpBuilder &builder, MemRefType resultType,
257
+ ValueRange inputs, Location loc) {
258
+ return rankedMemRefMaterialization (builder, resultType, inputs, loc, *this );
259
+ });
244
260
245
261
// Bare pointer -> Packed MemRef descriptor
246
262
addTargetMaterialization ([&](OpBuilder &builder, Type resultType,
0 commit comments