Skip to content

Commit bcd6f59

Browse files
committed
[MLIR][Linalg] Remove matmul_transpose variants
Removes the `(batch_)matmul_transpose_{a|b}` variants from OpDSL and replace it with `matmul affine_maps [...]` whenever appropriate. This is in line with the [plan](https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863), and can be done since llvm#104783 merged. See: https://discourse.llvm.org/t/deprecate-batch-matmul-transpose-a-b-linalg-operations/87245 Issues investigated: * pad transform tests that could use `matmul` instead, so change to that. * ArmSME test using transpose actually needed it, so changed to `matmul` + affine maps. Arm tests validated by @banach-space (thanks!!). Signed-off-by: hanhanW <[email protected]>
1 parent 77914c9 commit bcd6f59

File tree

21 files changed

+333
-954
lines changed

21 files changed

+333
-954
lines changed

mlir/include/mlir/Dialect/Linalg/IR/Linalg.h

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,4 +144,126 @@ std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr);
144144
#define GET_OP_CLASSES
145145
#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.h.inc"
146146

147+
namespace mlir::linalg {
148+
149+
/// Specialization of `linalg.matmul` op that has a transpose map on A
150+
class MatmulTransposeAOp : public MatmulOp {
151+
/// Create an affine map for a transpose-A matmul. Used only in the builders.
152+
static SmallVector<AffineMap> getAffineMaps(OpBuilder &builder);
153+
154+
public:
155+
using MatmulOp::MatmulOp;
156+
static ::mlir::TypeID resolveTypeID() { return TypeID::get<MatmulOp>(); }
157+
158+
/// Build a transpose A matmul.
159+
static void build(OpBuilder &builder, OperationState &result,
160+
ValueRange inputs, ValueRange outputs,
161+
ArrayRef<NamedAttribute> attributes = {});
162+
163+
/// Build a transpose A matmul with a specific result type.
164+
static void build(OpBuilder &builder, OperationState &result,
165+
TypeRange resultTensorTypes, ValueRange inputs,
166+
ValueRange outputs,
167+
ArrayRef<NamedAttribute> attributes = {});
168+
169+
/// Build a transpose A matmul with a specific result type and a cast type.
170+
static void build(OpBuilder &builder, OperationState &result,
171+
TypeRange resultTensorTypes, ValueRange inputs,
172+
ValueRange outputs, Attribute cast,
173+
ArrayRef<NamedAttribute> attributes = {});
174+
175+
static bool classof(Operation *op);
176+
};
177+
178+
/// Specialization of `linalg.matmul` op that has a transpose map on B
179+
class MatmulTransposeBOp : public MatmulOp {
180+
/// Create an affine map for a transpose-B matmul. Used only in the builders.
181+
static SmallVector<AffineMap> getAffineMaps(OpBuilder &builder);
182+
183+
public:
184+
using MatmulOp::MatmulOp;
185+
static ::mlir::TypeID resolveTypeID() { return TypeID::get<MatmulOp>(); }
186+
187+
/// Build a transpose B matmul.
188+
static void build(OpBuilder &builder, OperationState &result,
189+
ValueRange inputs, ValueRange outputs,
190+
ArrayRef<NamedAttribute> attributes = {});
191+
192+
/// Build a transpose B matmul with a specific result type.
193+
static void build(OpBuilder &builder, OperationState &result,
194+
TypeRange resultTensorTypes, ValueRange inputs,
195+
ValueRange outputs,
196+
ArrayRef<NamedAttribute> attributes = {});
197+
198+
/// Build a transpose B matmul with a specific result type and a cast type.
199+
static void build(OpBuilder &builder, OperationState &result,
200+
TypeRange resultTensorTypes, ValueRange inputs,
201+
ValueRange outputs, Attribute cast,
202+
ArrayRef<NamedAttribute> attributes = {});
203+
204+
static bool classof(Operation *op);
205+
};
206+
207+
/// Specialization of `linalg.batch_matmul` op that has a transpose map on A
208+
class BatchMatmulTransposeAOp : public BatchMatmulOp {
209+
/// Create an affine map for a transpose-A batch_matmul. Used only in the
210+
/// builders.
211+
static SmallVector<AffineMap> getAffineMaps(OpBuilder &builder);
212+
213+
public:
214+
using BatchMatmulOp::BatchMatmulOp;
215+
static ::mlir::TypeID resolveTypeID() { return TypeID::get<BatchMatmulOp>(); }
216+
217+
/// Build a transpose A matmul.
218+
static void build(OpBuilder &builder, OperationState &result,
219+
ValueRange inputs, ValueRange outputs,
220+
ArrayRef<NamedAttribute> attributes = {});
221+
222+
/// Build a transpose A matmul with a specific result type.
223+
static void build(OpBuilder &builder, OperationState &result,
224+
TypeRange resultTensorTypes, ValueRange inputs,
225+
ValueRange outputs,
226+
ArrayRef<NamedAttribute> attributes = {});
227+
228+
/// Build a transpose A matmul with a specific result type and a cast type.
229+
static void build(OpBuilder &builder, OperationState &result,
230+
TypeRange resultTensorTypes, ValueRange inputs,
231+
ValueRange outputs, Attribute cast,
232+
ArrayRef<NamedAttribute> attributes = {});
233+
234+
static bool classof(Operation *op);
235+
};
236+
237+
/// Specialization of `linalg.batch_matmul` op that has a transpose map on B
238+
class BatchMatmulTransposeBOp : public BatchMatmulOp {
239+
/// Create an affine map for a transpose-B batch_matmul. Used only in the
240+
/// builders.
241+
static SmallVector<AffineMap> getAffineMaps(OpBuilder &builder);
242+
243+
public:
244+
using BatchMatmulOp::BatchMatmulOp;
245+
static ::mlir::TypeID resolveTypeID() { return TypeID::get<BatchMatmulOp>(); }
246+
247+
/// Build a transpose A matmul.
248+
static void build(OpBuilder &builder, OperationState &result,
249+
ValueRange inputs, ValueRange outputs,
250+
ArrayRef<NamedAttribute> attributes = {});
251+
252+
/// Build a transpose A matmul with a specific result type.
253+
static void build(OpBuilder &builder, OperationState &result,
254+
TypeRange resultTensorTypes, ValueRange inputs,
255+
ValueRange outputs,
256+
ArrayRef<NamedAttribute> attributes = {});
257+
258+
/// Build a transpose A matmul with a specific result type and a cast type.
259+
static void build(OpBuilder &builder, OperationState &result,
260+
TypeRange resultTensorTypes, ValueRange inputs,
261+
ValueRange outputs, Attribute cast,
262+
ArrayRef<NamedAttribute> attributes = {});
263+
264+
static bool classof(Operation *op);
265+
};
266+
267+
} // namespace mlir::linalg
268+
147269
#endif // MLIR_DIALECT_LINALG_IR_LINALG_H

0 commit comments

Comments
 (0)