Skip to content

Commit cecc53c

Browse files
author
Longsheng Du
authored
[Dialect] [Linalgx] Add linalgx ops: 3 vnni matmuls and multi_batch_matmul (#89)
1 parent 47a5771 commit cecc53c

File tree

4 files changed

+887
-0
lines changed

4 files changed

+887
-0
lines changed

include/gc/Dialect/Linalgx/LinalgxStructuredOps.td

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,4 +104,212 @@ def Linalgx_SigmoidOp : LinalgxStructuredBase_Op<"sigmoid",
104104
}];
105105
}
106106

107+
def Linalgx_Mm2DVnniOp
108+
: LinalgxStructuredBase_Op<"mm2d_vnni", [AttrSizedOperandSegments]> {
109+
let summary = "Transposed matmul with 2d input and vnni packed weights";
110+
let description = [{
111+
Supported format: A[M, K] * B[N0, K0, k, n, v] -> C[M, N], with:
112+
N = N0 * n
113+
K = K0 * k * v; v = (2, 4)
114+
}];
115+
let arguments = (ins
116+
Variadic<TensorOrMemref>:$inputs,
117+
Variadic<TensorOrMemref>:$outputs);
118+
let results = (outs Variadic<TensorOrMemref>:$results);
119+
let regions = (region AnyRegion:$region);
120+
121+
let skipDefaultBuilders = 1;
122+
let builders = [
123+
OpBuilder<
124+
(ins
125+
"TypeRange":$resultTensorTypes,
126+
"ValueRange":$inputs,
127+
"ValueRange":$outputs,
128+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
129+
[{
130+
buildStructuredOp($_builder, $_state, resultTensorTypes,
131+
inputs, outputs, attributes, Mm2DVnniOp::getRegionBuilder());
132+
}]>
133+
];
134+
135+
let hasCustomAssemblyFormat = 1;
136+
let hasFolder = 1;
137+
let hasVerifier = 1;
138+
139+
let extraClassDeclaration = structuredOpsBaseDecls # [{
140+
// Declare functions necessary for LinalgStructuredInterface.
141+
SmallVector<utils::IteratorType> getIteratorTypesArray();
142+
ArrayAttr getIndexingMaps();
143+
static unsigned getNumRegionArgs() { return 3; }
144+
std::string getLibraryCallName() {
145+
return "op_has_no_registered_library_name";
146+
}
147+
148+
// Implement functions necessary for DestinationStyleOpInterface.
149+
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
150+
151+
static void regionBuilder(ImplicitLocOpBuilder &b,
152+
Block &block, ArrayRef<NamedAttribute> attrs);
153+
static std::function<void(ImplicitLocOpBuilder &,
154+
Block &, ArrayRef<NamedAttribute>)>
155+
getRegionBuilder() {
156+
return regionBuilder;
157+
}
158+
}];
159+
}
160+
161+
def Linalgx_Mm4DVnniOp
162+
: LinalgxStructuredBase_Op<"mm4d_vnni", [AttrSizedOperandSegments]> {
163+
let summary = "Transposed matmul with 4d blocking input and vnni packed weights";
164+
let description = [{
165+
Supported format: A[M, K, m, k] * B[N, K, k0, n, v] -> C[M, N, m, n], with:
166+
k = k0 * v; v = (2, 4)
167+
}];
168+
let arguments = (ins
169+
Variadic<TensorOrMemref>:$inputs,
170+
Variadic<TensorOrMemref>:$outputs);
171+
let results = (outs Variadic<TensorOrMemref>:$results);
172+
let regions = (region AnyRegion:$region);
173+
174+
let skipDefaultBuilders = 1;
175+
let builders = [
176+
OpBuilder<
177+
(ins
178+
"TypeRange":$resultTensorTypes,
179+
"ValueRange":$inputs,
180+
"ValueRange":$outputs,
181+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
182+
[{
183+
buildStructuredOp($_builder, $_state, resultTensorTypes,
184+
inputs, outputs, attributes, Mm4DVnniOp::getRegionBuilder());
185+
}]>
186+
];
187+
188+
let hasCustomAssemblyFormat = 1;
189+
let hasFolder = 1;
190+
let hasVerifier = 1;
191+
192+
let extraClassDeclaration = structuredOpsBaseDecls # [{
193+
// Declare functions necessary for LinalgStructuredInterface.
194+
SmallVector<utils::IteratorType> getIteratorTypesArray();
195+
ArrayAttr getIndexingMaps();
196+
static unsigned getNumRegionArgs() { return 3; }
197+
std::string getLibraryCallName() {
198+
return "op_has_no_registered_library_name";
199+
}
200+
201+
// Implement functions necessary for DestinationStyleOpInterface.
202+
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
203+
204+
static void regionBuilder(ImplicitLocOpBuilder &b,
205+
Block &block, ArrayRef<NamedAttribute> attrs);
206+
static std::function<void(ImplicitLocOpBuilder &,
207+
Block &, ArrayRef<NamedAttribute>)>
208+
getRegionBuilder() {
209+
return regionBuilder;
210+
}
211+
}];
212+
}
213+
214+
def Linalgx_BatchReduceMatmulVnniOp
215+
: LinalgxStructuredBase_Op<"batch_reduce_matmul_vnni", [AttrSizedOperandSegments]> {
216+
let summary = "Batch reduced matmul with 3d batch input and vnni packed weights";
217+
let description = [{
218+
Supported format: A[B, M, K] * B[B, k, N, v] -> C[M, N], with:
219+
K = k * v; v = (2, 4)
220+
}];
221+
let arguments = (ins
222+
Variadic<TensorOrMemref>:$inputs,
223+
Variadic<TensorOrMemref>:$outputs);
224+
let results = (outs Variadic<TensorOrMemref>:$results);
225+
let regions = (region AnyRegion:$region);
226+
227+
let skipDefaultBuilders = 1;
228+
let builders = [
229+
OpBuilder<
230+
(ins
231+
"TypeRange":$resultTensorTypes,
232+
"ValueRange":$inputs,
233+
"ValueRange":$outputs,
234+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
235+
[{
236+
buildStructuredOp($_builder, $_state, resultTensorTypes,
237+
inputs, outputs, attributes, BatchReduceMatmulVnniOp::getRegionBuilder());
238+
}]>
239+
];
240+
241+
let hasCustomAssemblyFormat = 1;
242+
let hasFolder = 1;
243+
let hasVerifier = 1;
244+
245+
let extraClassDeclaration = structuredOpsBaseDecls # [{
246+
// Declare functions necessary for LinalgStructuredInterface.
247+
SmallVector<utils::IteratorType> getIteratorTypesArray();
248+
ArrayAttr getIndexingMaps();
249+
static unsigned getNumRegionArgs() { return 3; }
250+
std::string getLibraryCallName() {
251+
return "op_has_no_registered_library_name";
252+
}
253+
254+
// Implement functions necessary for DestinationStyleOpInterface.
255+
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
256+
257+
static void regionBuilder(ImplicitLocOpBuilder &b,
258+
Block &block, ArrayRef<NamedAttribute> attrs);
259+
static std::function<void(ImplicitLocOpBuilder &,
260+
Block &, ArrayRef<NamedAttribute>)>
261+
getRegionBuilder() {
262+
return regionBuilder;
263+
}
264+
}];
265+
}
266+
267+
def Linalgx_MultiBatchMatmulOp : LinalgxStructuredBase_Op<"multi_batch_matmul",
268+
[AttrSizedOperandSegments, LinalgContractionOpInterface]> {
269+
let summary = "Batch matmul with variable batch dims";
270+
let arguments = (ins
271+
Variadic<TensorOrMemref>:$inputs,
272+
Variadic<TensorOrMemref>:$outputs);
273+
let results = (outs Variadic<TensorOrMemref>:$results);
274+
let regions = (region AnyRegion:$region);
275+
276+
let skipDefaultBuilders = 1;
277+
let builders = [
278+
OpBuilder<
279+
(ins
280+
"TypeRange":$resultTensorTypes,
281+
"ValueRange":$inputs,
282+
"ValueRange":$outputs,
283+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
284+
[{
285+
buildStructuredOp($_builder, $_state, resultTensorTypes,
286+
inputs, outputs, attributes, MultiBatchMatmulOp::getRegionBuilder());
287+
}]>
288+
];
289+
290+
let hasCustomAssemblyFormat = 1;
291+
let hasFolder = 1;
292+
293+
let extraClassDeclaration = structuredOpsBaseDecls # [{
294+
// Declare functions necessary for LinalgStructuredInterface.
295+
SmallVector<utils::IteratorType> getIteratorTypesArray();
296+
ArrayAttr getIndexingMaps();
297+
static unsigned getNumRegionArgs() { return 3; }
298+
std::string getLibraryCallName() {
299+
return "op_has_no_registered_library_name";
300+
}
301+
302+
// Implement functions necessary for DestinationStyleOpInterface.
303+
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
304+
305+
static void regionBuilder(ImplicitLocOpBuilder &b,
306+
Block &block, ArrayRef<NamedAttribute> attrs);
307+
static std::function<void(ImplicitLocOpBuilder &,
308+
Block &, ArrayRef<NamedAttribute>)>
309+
getRegionBuilder() {
310+
return regionBuilder;
311+
}
312+
}];
313+
}
314+
107315
#endif // LINALGX_STRUCTURED_OPS

0 commit comments

Comments
 (0)