Skip to content

Commit 1d3bc1f

Browse files
committed
Check affine map on classof
Checking that the affine map is the correct one for the transpose variant to make sure classof works as expected. Also exposing a method to check the affine map in case one needs a check for `matmul && isTransposeA` to replace the old checks for `matmul_transpose_a` and others.
1 parent 080d77f commit 1d3bc1f

File tree

2 files changed

+104
-16
lines changed

2 files changed

+104
-16
lines changed

mlir/include/mlir/Dialect/Linalg/IR/Linalg.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@ class MatmulTransposeAOp : public MatmulOp {
172172
ValueRange outputs, Attribute cast,
173173
ArrayRef<NamedAttribute> attributes = {});
174174

175+
/// Checks if the affine map is the expected one for this operation
176+
static bool isExpectedAffineMaps(Attribute attr);
177+
175178
static bool classof(Operation *op);
176179
};
177180

@@ -201,6 +204,9 @@ class MatmulTransposeBOp : public MatmulOp {
201204
ValueRange outputs, Attribute cast,
202205
ArrayRef<NamedAttribute> attributes = {});
203206

207+
/// Checks if the affine map is the expected one for this operation
208+
static bool isExpectedAffineMaps(Attribute attr);
209+
204210
static bool classof(Operation *op);
205211
};
206212

@@ -231,6 +237,9 @@ class BatchMatmulTransposeAOp : public BatchMatmulOp {
231237
ValueRange outputs, Attribute cast,
232238
ArrayRef<NamedAttribute> attributes = {});
233239

240+
/// Checks if the affine map is the expected one for this operation
241+
static bool isExpectedAffineMaps(Attribute attr);
242+
234243
static bool classof(Operation *op);
235244
};
236245

@@ -261,6 +270,9 @@ class BatchMatmulTransposeBOp : public BatchMatmulOp {
261270
ValueRange outputs, Attribute cast,
262271
ArrayRef<NamedAttribute> attributes = {});
263272

273+
/// Checks if the affine map is the expected one for this operation
274+
static bool isExpectedAffineMaps(Attribute attr);
275+
264276
static bool classof(Operation *op);
265277
};
266278

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 92 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3882,17 +3882,48 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
38823882
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
38833883
}
38843884

3885+
static FailureOr<SmallVector<SmallVector<int64_t>>>
3886+
getAffineResultPositions(ArrayAttr maps) {
3887+
SmallVector<SmallVector<int64_t>> positions;
3888+
for (auto map : maps) {
3889+
AffineMapAttr attr = dyn_cast<AffineMapAttr>(map);
3890+
if (!attr)
3891+
return failure();
3892+
SmallVector<int64_t> pos;
3893+
for (auto result : attr.getAffineMap().getResults()) {
3894+
auto dim = dyn_cast<AffineDimExpr>(result);
3895+
if (!dim)
3896+
return failure();
3897+
pos.push_back(dim.getPosition());
3898+
}
3899+
positions.push_back(pos);
3900+
}
3901+
return positions;
3902+
}
3903+
38853904
SmallVector<AffineMap> MatmulTransposeAOp::getAffineMaps(OpBuilder &builder) {
38863905
AffineExpr d0, d1, d2;
3887-
auto context = builder.getContext();
3906+
MLIRContext *context = builder.getContext();
38883907
bindDims(context, d0, d1, d2);
38893908
AffineMap mapLHS = AffineMap::get(3, 0, {d2, d0}, context);
38903909
AffineMap mapRHS = AffineMap::get(3, 0, {d2, d1}, context);
38913910
AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
3892-
SmallVector<AffineMap> affineMaps{mapLHS, mapRHS, mapOut};
3893-
return affineMaps;
3911+
return {mapLHS, mapRHS, mapOut};
38943912
}
38953913

3914+
bool MatmulTransposeAOp::isExpectedAffineMaps(Attribute attr) {
3915+
ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3916+
if (!maps)
3917+
return false;
3918+
if (maps.size() != 3)
3919+
return false;
3920+
auto positions = getAffineResultPositions(maps);
3921+
if (failed(positions))
3922+
return false;
3923+
return (*positions)[0] == SmallVector<int64_t>{2, 0} &&
3924+
(*positions)[1] == SmallVector<int64_t>{2, 1} &&
3925+
(*positions)[2] == SmallVector<int64_t>{0, 1};
3926+
}
38963927
void linalg::MatmulTransposeAOp::build(OpBuilder &builder,
38973928
OperationState &result,
38983929
ValueRange inputs, ValueRange outputs,
@@ -3922,18 +3953,32 @@ void linalg::MatmulTransposeAOp::build(OpBuilder &builder,
39223953
}
39233954

39243955
bool MatmulTransposeAOp::classof(Operation *op) {
3925-
return dyn_cast_or_null<linalg::MatmulOp>(op);
3956+
return dyn_cast_or_null<linalg::MatmulOp>(op) &&
3957+
MatmulTransposeAOp::isExpectedAffineMaps(op->getAttr("indexing_maps"));
39263958
}
39273959

39283960
SmallVector<AffineMap> MatmulTransposeBOp::getAffineMaps(OpBuilder &builder) {
39293961
AffineExpr d0, d1, d2;
3930-
auto context = builder.getContext();
3962+
MLIRContext *context = builder.getContext();
39313963
bindDims(context, d0, d1, d2);
39323964
AffineMap mapLHS = AffineMap::get(3, 0, {d0, d2}, context);
39333965
AffineMap mapRHS = AffineMap::get(3, 0, {d1, d2}, context);
39343966
AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
3935-
SmallVector<AffineMap> affineMaps{mapLHS, mapRHS, mapOut};
3936-
return affineMaps;
3967+
return {mapLHS, mapRHS, mapOut};
3968+
}
3969+
3970+
bool MatmulTransposeBOp::isExpectedAffineMaps(Attribute attr) {
3971+
ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3972+
if (!maps)
3973+
return false;
3974+
if (maps.size() != 3)
3975+
return false;
3976+
auto positions = getAffineResultPositions(maps);
3977+
if (failed(positions))
3978+
return false;
3979+
return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
3980+
(*positions)[1] == SmallVector<int64_t>{1, 2} &&
3981+
(*positions)[2] == SmallVector<int64_t>{0, 1};
39373982
}
39383983

39393984
void linalg::MatmulTransposeBOp::build(OpBuilder &builder,
@@ -3965,19 +4010,33 @@ void linalg::MatmulTransposeBOp::build(OpBuilder &builder,
39654010
}
39664011

39674012
bool MatmulTransposeBOp::classof(Operation *op) {
3968-
return dyn_cast_or_null<linalg::MatmulOp>(op);
4013+
return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4014+
MatmulTransposeBOp::isExpectedAffineMaps(op->getAttr("indexing_maps"));
39694015
}
39704016

39714017
SmallVector<AffineMap>
39724018
BatchMatmulTransposeAOp::getAffineMaps(OpBuilder &builder) {
39734019
AffineExpr d0, d1, d2, d3;
3974-
auto context = builder.getContext();
4020+
MLIRContext *context = builder.getContext();
39754021
bindDims(context, d0, d1, d2, d3);
39764022
AffineMap mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context);
39774023
AffineMap mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context);
39784024
AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
3979-
SmallVector<AffineMap> affineMaps{mapLHS, mapRHS, mapOut};
3980-
return affineMaps;
4025+
return {mapLHS, mapRHS, mapOut};
4026+
}
4027+
4028+
bool BatchMatmulTransposeAOp::isExpectedAffineMaps(Attribute attr) {
4029+
ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4030+
if (!maps)
4031+
return false;
4032+
if (maps.size() != 3)
4033+
return false;
4034+
auto positions = getAffineResultPositions(maps);
4035+
if (failed(positions))
4036+
return false;
4037+
return (*positions)[0] == SmallVector<int64_t>{0, 3, 1} &&
4038+
(*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
4039+
(*positions)[2] == SmallVector<int64_t>{0, 1, 2};
39814040
}
39824041

39834042
void linalg::BatchMatmulTransposeAOp::build(
@@ -4005,19 +4064,34 @@ void linalg::BatchMatmulTransposeAOp::build(
40054064
}
40064065

40074066
bool BatchMatmulTransposeAOp::classof(Operation *op) {
4008-
return dyn_cast_or_null<linalg::BatchMatmulOp>(op);
4067+
return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4068+
BatchMatmulTransposeAOp::isExpectedAffineMaps(
4069+
op->getAttr("indexing_maps"));
40094070
}
40104071

40114072
SmallVector<AffineMap>
40124073
BatchMatmulTransposeBOp::getAffineMaps(OpBuilder &builder) {
40134074
AffineExpr d0, d1, d2, d3;
4014-
auto context = builder.getContext();
4075+
MLIRContext *context = builder.getContext();
40154076
bindDims(context, d0, d1, d2, d3);
40164077
AffineMap mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context);
40174078
AffineMap mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context);
40184079
AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
4019-
SmallVector<AffineMap> affineMaps{mapLHS, mapRHS, mapOut};
4020-
return affineMaps;
4080+
return {mapLHS, mapRHS, mapOut};
4081+
}
4082+
4083+
bool BatchMatmulTransposeBOp::isExpectedAffineMaps(Attribute attr) {
4084+
ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4085+
if (!maps)
4086+
return false;
4087+
if (maps.size() != 3)
4088+
return false;
4089+
auto positions = getAffineResultPositions(maps);
4090+
if (failed(positions))
4091+
return false;
4092+
return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
4093+
(*positions)[1] == SmallVector<int64_t>{0, 2, 3} &&
4094+
(*positions)[2] == SmallVector<int64_t>{0, 1, 2};
40214095
}
40224096

40234097
void linalg::BatchMatmulTransposeBOp::build(
@@ -4045,7 +4119,9 @@ void linalg::BatchMatmulTransposeBOp::build(
40454119
}
40464120

40474121
bool BatchMatmulTransposeBOp::classof(Operation *op) {
4048-
return dyn_cast_or_null<linalg::BatchMatmulOp>(op);
4122+
return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4123+
BatchMatmulTransposeBOp::isExpectedAffineMaps(
4124+
op->getAttr("indexing_maps"));
40494125
}
40504126

40514127
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)