Skip to content

Commit af9dafe

Browse files
authored
[mlir][spirv] Fix coop matrix store (#65709)
- Fix operand/attribute order - Use ODS for parsing/printing - Allow for stride to be any integer type
1 parent cc2b09b commit af9dafe

File tree

3 files changed

+42
-63
lines changed

3 files changed

+42
-63
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -182,21 +182,32 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
182182
inactive.
183183

184184
``` {.ebnf}
185-
coop-matrix-store-op ::= `spirv.KHR.CooperativeMatrixStore `
186-
ssa-use `, ` ssa-use `, `
187-
ssa-use `, ` cooperative-matrix-layout `, `
188-
(`[` memory-operand `]`)? `:`
189-
pointer-type `,` coop-matrix-type
185+
coop-matrix-store-op ::= `spirv.KHR.CooperativeMatrixStore`
186+
ssa-use `,` ssa-use `,`
187+
ssa-use `,` `<` cooperative-matrix-layout `>`
188+
(`,` `<` memory-operand `>`)? `:`
189+
pointer-type `,` coop-matrix-type `,` stride-type
190190
```
191191

192+
TODO: In the SPIR-V spec, `stride` is an optional argument. We should also
193+
support this optionality in the SPIR-V dialect.
194+
192195
#### Example:
193196

194197
```
195-
spirv.KHR.CooperativeMatrixStore %ptr, %obj, %stride :
196-
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
198+
spirv.KHR.CooperativeMatrixStore %ptr, %obj, %stride, <RowMajor> :
199+
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
200+
201+
spirv.KHR.CooperativeMatrixStore %ptr, %obj, %stride, <ColumnMajor>, <Volatile> :
202+
!spirv.ptr<f32, StorageBuffer>, !spirv.coopmatrix<8x8xf32, Subgroup, MatrixAcc>, i64
197203
```
198204
}];
199205

206+
let assemblyFormat = [{
207+
$pointer `,` $object `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? attr-dict `:`
208+
type(operands)
209+
}];
210+
200211
let availability = [
201212
MinVersion<SPIRV_V_1_6>,
202213
MaxVersion<SPIRV_V_1_6>,
@@ -207,8 +218,8 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
207218
let arguments = (ins
208219
SPIRV_AnyPtr:$pointer,
209220
SPIRV_AnyCooperativeMatrix:$object,
210-
SPIRV_Integer:$stride,
211221
SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout,
222+
SPIRV_Integer:$stride,
212223
OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand
213224
);
214225

mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -63,49 +63,6 @@ LogicalResult KHRCooperativeMatrixLoadOp::verify() {
6363
// spirv.KHR.CooperativeMatrixStore
6464
//===----------------------------------------------------------------------===//
6565

66-
ParseResult KHRCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
67-
OperationState &result) {
68-
std::array<OpAsmParser::UnresolvedOperand, 3> operandInfo = {};
69-
for (auto &op : operandInfo) {
70-
if (parser.parseOperand(op) || parser.parseComma())
71-
return failure();
72-
}
73-
74-
CooperativeMatrixLayoutKHR layout;
75-
if (parseEnumKeywordAttr<CooperativeMatrixLayoutKHRAttr>(
76-
layout, parser, result, kKhrCooperativeMatrixLayoutAttrName)) {
77-
return failure();
78-
}
79-
80-
if (parseMemoryAccessAttributes(parser, result, kMemoryOperandAttrName))
81-
return failure();
82-
83-
Type ptrType;
84-
Type objectType;
85-
if (parser.parseColon() || parser.parseType(ptrType) || parser.parseComma() ||
86-
parser.parseType(objectType)) {
87-
return failure();
88-
}
89-
90-
Type strideType = parser.getBuilder().getIntegerType(32);
91-
if (parser.resolveOperands(operandInfo, {ptrType, objectType, strideType},
92-
parser.getNameLoc(), result.operands)) {
93-
return failure();
94-
}
95-
96-
return success();
97-
}
98-
99-
void KHRCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
100-
printer << " " << getPointer() << ", " << getObject() << ", " << getStride()
101-
<< ", " << getMatrixLayout();
102-
103-
// Print optional memory operand attribute.
104-
if (auto memOperand = getMemoryOperand())
105-
printer << " [\"" << *memOperand << "\"]";
106-
printer << " : " << getPointer().getType() << ", " << getObject().getType();
107-
}
108-
10966
LogicalResult KHRCooperativeMatrixStoreOp::verify() {
11067
return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
11168
getObject().getType());

mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -69,21 +69,32 @@ spirv.func @cooperative_matrix_load_stride_i16(%ptr : !spirv.ptr<i32, StorageBuf
6969
// CHECK-LABEL: @cooperative_matrix_store
7070
spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
7171
%m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
72-
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, RowMajor :
73-
// CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>
74-
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, RowMajor :
75-
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>
72+
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, <RowMajor> :
73+
// CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
74+
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor> :
75+
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
7676
spirv.Return
7777
}
7878

7979
// CHECK-LABEL: @cooperative_matrix_store_memoperand
8080
spirv.func @cooperative_matrix_store_memoperand(%ptr : !spirv.ptr<i32, StorageBuffer>,
8181
%m : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
8282
%stride : i32) "None" {
83-
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, ColumnMajor ["Volatile"] :
84-
// CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
85-
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, ColumnMajor ["Volatile"] :
86-
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
83+
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, <ColumnMajor>, <Volatile> :
84+
// CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i32
85+
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <ColumnMajor>, <Volatile> :
86+
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i32
87+
spirv.Return
88+
}
89+
90+
// CHECK-LABEL: @cooperative_matrix_store_stride_i16
91+
spirv.func @cooperative_matrix_store_stride_i16(%ptr : !spirv.ptr<i32, StorageBuffer>,
92+
%m : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
93+
%stride : i16) "None" {
94+
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, <ColumnMajor> :
95+
// CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i16
96+
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <ColumnMajor> :
97+
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i16
8798
spirv.Return
8899
}
89100

@@ -137,9 +148,9 @@ spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr<i32, Storage
137148

138149
spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
139150
%m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
140-
// expected-error @+1 {{expected valid keyword}}
151+
// expected-error @+1 {{expected '<'}}
141152
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, :
142-
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>
153+
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
143154
spirv.Return
144155
}
145156

@@ -148,8 +159,8 @@ spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr<i32, Storage
148159
spirv.func @cooperative_matrix_store_bad_object_type(%ptr : !spirv.ptr<i32, StorageBuffer>,
149160
%stride : i32) "None" {
150161
// expected-error @+1 {{op operand #1 must be any SPIR-V cooperative matrix type}}
151-
spirv.KHR.CooperativeMatrixStore %ptr, %stride, %stride, RowMajor :
152-
!spirv.ptr<i32, StorageBuffer>, i32
162+
spirv.KHR.CooperativeMatrixStore %ptr, %stride, %stride, <RowMajor> :
163+
!spirv.ptr<i32, StorageBuffer>, i32, i32
153164
spirv.Return
154165
}
155166

0 commit comments

Comments
 (0)