@@ -104,4 +104,212 @@ def Linalgx_SigmoidOp : LinalgxStructuredBase_Op<"sigmoid",
104
104
}];
105
105
}
106
106
107
+ def Linalgx_Mm2DVnniOp
108
+ : LinalgxStructuredBase_Op<"mm2d_vnni", [AttrSizedOperandSegments]> {
109
+ let summary = "Transposed matmul with 2d input and vnni packed weights";
110
+ let description = [{
111
+ Supported format: A[M, K] * B[N0, K0, k, n, v] -> C[M, N], with:
112
+ N = N0 * n
113
+ K = K0 * k * v; v = (2, 4)
114
+ }];
115
+ let arguments = (ins
116
+ Variadic<TensorOrMemref>:$inputs,
117
+ Variadic<TensorOrMemref>:$outputs);
118
+ let results = (outs Variadic<TensorOrMemref>:$results);
119
+ let regions = (region AnyRegion:$region);
120
+
121
+ let skipDefaultBuilders = 1;
122
+ let builders = [
123
+ OpBuilder<
124
+ (ins
125
+ "TypeRange":$resultTensorTypes,
126
+ "ValueRange":$inputs,
127
+ "ValueRange":$outputs,
128
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
129
+ [{
130
+ buildStructuredOp($_builder, $_state, resultTensorTypes,
131
+ inputs, outputs, attributes, Mm2DVnniOp::getRegionBuilder());
132
+ }]>
133
+ ];
134
+
135
+ let hasCustomAssemblyFormat = 1;
136
+ let hasFolder = 1;
137
+ let hasVerifier = 1;
138
+
139
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
140
+ // Declare functions necessary for LinalgStructuredInterface.
141
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
142
+ ArrayAttr getIndexingMaps();
143
+ static unsigned getNumRegionArgs() { return 3; }
144
+ std::string getLibraryCallName() {
145
+ return "op_has_no_registered_library_name";
146
+ }
147
+
148
+ // Implement functions necessary for DestinationStyleOpInterface.
149
+ MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
150
+
151
+ static void regionBuilder(ImplicitLocOpBuilder &b,
152
+ Block &block, ArrayRef<NamedAttribute> attrs);
153
+ static std::function<void(ImplicitLocOpBuilder &,
154
+ Block &, ArrayRef<NamedAttribute>)>
155
+ getRegionBuilder() {
156
+ return regionBuilder;
157
+ }
158
+ }];
159
+ }
160
+
161
+ def Linalgx_Mm4DVnniOp
162
+ : LinalgxStructuredBase_Op<"mm4d_vnni", [AttrSizedOperandSegments]> {
163
+ let summary = "Transposed matmul with 4d blocking input and vnni packed weights";
164
+ let description = [{
165
+ Supported format: A[M, K, m, k] * B[N, K, k0, n, v] -> C[M, N, m, n], with:
166
+ k = k0 * v; v = (2, 4)
167
+ }];
168
+ let arguments = (ins
169
+ Variadic<TensorOrMemref>:$inputs,
170
+ Variadic<TensorOrMemref>:$outputs);
171
+ let results = (outs Variadic<TensorOrMemref>:$results);
172
+ let regions = (region AnyRegion:$region);
173
+
174
+ let skipDefaultBuilders = 1;
175
+ let builders = [
176
+ OpBuilder<
177
+ (ins
178
+ "TypeRange":$resultTensorTypes,
179
+ "ValueRange":$inputs,
180
+ "ValueRange":$outputs,
181
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
182
+ [{
183
+ buildStructuredOp($_builder, $_state, resultTensorTypes,
184
+ inputs, outputs, attributes, Mm4DVnniOp::getRegionBuilder());
185
+ }]>
186
+ ];
187
+
188
+ let hasCustomAssemblyFormat = 1;
189
+ let hasFolder = 1;
190
+ let hasVerifier = 1;
191
+
192
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
193
+ // Declare functions necessary for LinalgStructuredInterface.
194
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
195
+ ArrayAttr getIndexingMaps();
196
+ static unsigned getNumRegionArgs() { return 3; }
197
+ std::string getLibraryCallName() {
198
+ return "op_has_no_registered_library_name";
199
+ }
200
+
201
+ // Implement functions necessary for DestinationStyleOpInterface.
202
+ MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
203
+
204
+ static void regionBuilder(ImplicitLocOpBuilder &b,
205
+ Block &block, ArrayRef<NamedAttribute> attrs);
206
+ static std::function<void(ImplicitLocOpBuilder &,
207
+ Block &, ArrayRef<NamedAttribute>)>
208
+ getRegionBuilder() {
209
+ return regionBuilder;
210
+ }
211
+ }];
212
+ }
213
+
214
+ def Linalgx_BatchReduceMatmulVnniOp
215
+ : LinalgxStructuredBase_Op<"batch_reduce_matmul_vnni", [AttrSizedOperandSegments]> {
216
+ let summary = "Batch reduced matmul with 3d batch input and vnni packed weights";
217
+ let description = [{
218
+ Supported format: A[B, M, K] * B[B, k, N, v] -> C[M, N], with:
219
+ K = k * v; v = (2, 4)
220
+ }];
221
+ let arguments = (ins
222
+ Variadic<TensorOrMemref>:$inputs,
223
+ Variadic<TensorOrMemref>:$outputs);
224
+ let results = (outs Variadic<TensorOrMemref>:$results);
225
+ let regions = (region AnyRegion:$region);
226
+
227
+ let skipDefaultBuilders = 1;
228
+ let builders = [
229
+ OpBuilder<
230
+ (ins
231
+ "TypeRange":$resultTensorTypes,
232
+ "ValueRange":$inputs,
233
+ "ValueRange":$outputs,
234
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
235
+ [{
236
+ buildStructuredOp($_builder, $_state, resultTensorTypes,
237
+ inputs, outputs, attributes, BatchReduceMatmulVnniOp::getRegionBuilder());
238
+ }]>
239
+ ];
240
+
241
+ let hasCustomAssemblyFormat = 1;
242
+ let hasFolder = 1;
243
+ let hasVerifier = 1;
244
+
245
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
246
+ // Declare functions necessary for LinalgStructuredInterface.
247
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
248
+ ArrayAttr getIndexingMaps();
249
+ static unsigned getNumRegionArgs() { return 3; }
250
+ std::string getLibraryCallName() {
251
+ return "op_has_no_registered_library_name";
252
+ }
253
+
254
+ // Implement functions necessary for DestinationStyleOpInterface.
255
+ MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
256
+
257
+ static void regionBuilder(ImplicitLocOpBuilder &b,
258
+ Block &block, ArrayRef<NamedAttribute> attrs);
259
+ static std::function<void(ImplicitLocOpBuilder &,
260
+ Block &, ArrayRef<NamedAttribute>)>
261
+ getRegionBuilder() {
262
+ return regionBuilder;
263
+ }
264
+ }];
265
+ }
266
+
267
+ def Linalgx_MultiBatchMatmulOp : LinalgxStructuredBase_Op<"multi_batch_matmul",
268
+ [AttrSizedOperandSegments, LinalgContractionOpInterface]> {
269
+ let summary = "Batch matmul with variable batch dims";
270
+ let arguments = (ins
271
+ Variadic<TensorOrMemref>:$inputs,
272
+ Variadic<TensorOrMemref>:$outputs);
273
+ let results = (outs Variadic<TensorOrMemref>:$results);
274
+ let regions = (region AnyRegion:$region);
275
+
276
+ let skipDefaultBuilders = 1;
277
+ let builders = [
278
+ OpBuilder<
279
+ (ins
280
+ "TypeRange":$resultTensorTypes,
281
+ "ValueRange":$inputs,
282
+ "ValueRange":$outputs,
283
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
284
+ [{
285
+ buildStructuredOp($_builder, $_state, resultTensorTypes,
286
+ inputs, outputs, attributes, MultiBatchMatmulOp::getRegionBuilder());
287
+ }]>
288
+ ];
289
+
290
+ let hasCustomAssemblyFormat = 1;
291
+ let hasFolder = 1;
292
+
293
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
294
+ // Declare functions necessary for LinalgStructuredInterface.
295
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
296
+ ArrayAttr getIndexingMaps();
297
+ static unsigned getNumRegionArgs() { return 3; }
298
+ std::string getLibraryCallName() {
299
+ return "op_has_no_registered_library_name";
300
+ }
301
+
302
+ // Implement functions necessary for DestinationStyleOpInterface.
303
+ MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
304
+
305
+ static void regionBuilder(ImplicitLocOpBuilder &b,
306
+ Block &block, ArrayRef<NamedAttribute> attrs);
307
+ static std::function<void(ImplicitLocOpBuilder &,
308
+ Block &, ArrayRef<NamedAttribute>)>
309
+ getRegionBuilder() {
310
+ return regionBuilder;
311
+ }
312
+ }];
313
+ }
314
+
107
315
#endif // LINALGX_STRUCTURED_OPS
0 commit comments