@@ -3882,17 +3882,48 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
3882
3882
return getGenericSpeculatabilityImpl (cast<LinalgOp>(getOperation ()));
3883
3883
}
3884
3884
3885
+ static FailureOr<SmallVector<SmallVector<int64_t >>>
3886
+ getAffineResultPositions (ArrayAttr maps) {
3887
+ SmallVector<SmallVector<int64_t >> positions;
3888
+ for (auto map : maps) {
3889
+ AffineMapAttr attr = dyn_cast<AffineMapAttr>(map);
3890
+ if (!attr)
3891
+ return failure ();
3892
+ SmallVector<int64_t > pos;
3893
+ for (auto result : attr.getAffineMap ().getResults ()) {
3894
+ auto dim = dyn_cast<AffineDimExpr>(result);
3895
+ if (!dim)
3896
+ return failure ();
3897
+ pos.push_back (dim.getPosition ());
3898
+ }
3899
+ positions.push_back (pos);
3900
+ }
3901
+ return positions;
3902
+ }
3903
+
3885
3904
SmallVector<AffineMap> MatmulTransposeAOp::getAffineMaps (OpBuilder &builder) {
3886
3905
AffineExpr d0, d1, d2;
3887
- auto context = builder.getContext ();
3906
+ MLIRContext * context = builder.getContext ();
3888
3907
bindDims (context, d0, d1, d2);
3889
3908
AffineMap mapLHS = AffineMap::get (3 , 0 , {d2, d0}, context);
3890
3909
AffineMap mapRHS = AffineMap::get (3 , 0 , {d2, d1}, context);
3891
3910
AffineMap mapOut = AffineMap::get (3 , 0 , {d0, d1}, context);
3892
- SmallVector<AffineMap> affineMaps{mapLHS, mapRHS, mapOut};
3893
- return affineMaps;
3911
+ return {mapLHS, mapRHS, mapOut};
3894
3912
}
3895
3913
3914
+ bool MatmulTransposeAOp::isExpectedAffineMaps (Attribute attr) {
3915
+ ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3916
+ if (!maps)
3917
+ return false ;
3918
+ if (maps.size () != 3 )
3919
+ return false ;
3920
+ auto positions = getAffineResultPositions (maps);
3921
+ if (failed (positions))
3922
+ return false ;
3923
+ return (*positions)[0 ] == SmallVector<int64_t >{2 , 0 } &&
3924
+ (*positions)[1 ] == SmallVector<int64_t >{2 , 1 } &&
3925
+ (*positions)[2 ] == SmallVector<int64_t >{0 , 1 };
3926
+ }
3896
3927
void linalg::MatmulTransposeAOp::build (OpBuilder &builder,
3897
3928
OperationState &result,
3898
3929
ValueRange inputs, ValueRange outputs,
@@ -3922,18 +3953,32 @@ void linalg::MatmulTransposeAOp::build(OpBuilder &builder,
3922
3953
}
3923
3954
3924
3955
bool MatmulTransposeAOp::classof (Operation *op) {
3925
- return dyn_cast_or_null<linalg::MatmulOp>(op);
3956
+ return dyn_cast_or_null<linalg::MatmulOp>(op) &&
3957
+ MatmulTransposeAOp::isExpectedAffineMaps (op->getAttr (" indexing_maps" ));
3926
3958
}
3927
3959
3928
3960
SmallVector<AffineMap> MatmulTransposeBOp::getAffineMaps (OpBuilder &builder) {
3929
3961
AffineExpr d0, d1, d2;
3930
- auto context = builder.getContext ();
3962
+ MLIRContext * context = builder.getContext ();
3931
3963
bindDims (context, d0, d1, d2);
3932
3964
AffineMap mapLHS = AffineMap::get (3 , 0 , {d0, d2}, context);
3933
3965
AffineMap mapRHS = AffineMap::get (3 , 0 , {d1, d2}, context);
3934
3966
AffineMap mapOut = AffineMap::get (3 , 0 , {d0, d1}, context);
3935
- SmallVector<AffineMap> affineMaps{mapLHS, mapRHS, mapOut};
3936
- return affineMaps;
3967
+ return {mapLHS, mapRHS, mapOut};
3968
+ }
3969
+
3970
+ bool MatmulTransposeBOp::isExpectedAffineMaps (Attribute attr) {
3971
+ ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3972
+ if (!maps)
3973
+ return false ;
3974
+ if (maps.size () != 3 )
3975
+ return false ;
3976
+ auto positions = getAffineResultPositions (maps);
3977
+ if (failed (positions))
3978
+ return false ;
3979
+ return (*positions)[0 ] == SmallVector<int64_t >{0 , 2 } &&
3980
+ (*positions)[1 ] == SmallVector<int64_t >{1 , 2 } &&
3981
+ (*positions)[2 ] == SmallVector<int64_t >{0 , 1 };
3937
3982
}
3938
3983
3939
3984
void linalg::MatmulTransposeBOp::build (OpBuilder &builder,
@@ -3965,19 +4010,33 @@ void linalg::MatmulTransposeBOp::build(OpBuilder &builder,
3965
4010
}
3966
4011
3967
4012
bool MatmulTransposeBOp::classof (Operation *op) {
3968
- return dyn_cast_or_null<linalg::MatmulOp>(op);
4013
+ return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4014
+ MatmulTransposeBOp::isExpectedAffineMaps (op->getAttr (" indexing_maps" ));
3969
4015
}
3970
4016
3971
4017
SmallVector<AffineMap>
3972
4018
BatchMatmulTransposeAOp::getAffineMaps (OpBuilder &builder) {
3973
4019
AffineExpr d0, d1, d2, d3;
3974
- auto context = builder.getContext ();
4020
+ MLIRContext * context = builder.getContext ();
3975
4021
bindDims (context, d0, d1, d2, d3);
3976
4022
AffineMap mapLHS = AffineMap::get (4 , 0 , {d0, d3, d1}, context);
3977
4023
AffineMap mapRHS = AffineMap::get (4 , 0 , {d0, d3, d2}, context);
3978
4024
AffineMap mapOut = AffineMap::get (4 , 0 , {d0, d1, d2}, context);
3979
- SmallVector<AffineMap> affineMaps{mapLHS, mapRHS, mapOut};
3980
- return affineMaps;
4025
+ return {mapLHS, mapRHS, mapOut};
4026
+ }
4027
+
4028
+ bool BatchMatmulTransposeAOp::isExpectedAffineMaps (Attribute attr) {
4029
+ ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4030
+ if (!maps)
4031
+ return false ;
4032
+ if (maps.size () != 3 )
4033
+ return false ;
4034
+ auto positions = getAffineResultPositions (maps);
4035
+ if (failed (positions))
4036
+ return false ;
4037
+ return (*positions)[0 ] == SmallVector<int64_t >{0 , 3 , 1 } &&
4038
+ (*positions)[1 ] == SmallVector<int64_t >{0 , 3 , 2 } &&
4039
+ (*positions)[2 ] == SmallVector<int64_t >{0 , 1 , 2 };
3981
4040
}
3982
4041
3983
4042
void linalg::BatchMatmulTransposeAOp::build (
@@ -4005,19 +4064,34 @@ void linalg::BatchMatmulTransposeAOp::build(
4005
4064
}
4006
4065
4007
4066
bool BatchMatmulTransposeAOp::classof (Operation *op) {
4008
- return dyn_cast_or_null<linalg::BatchMatmulOp>(op);
4067
+ return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4068
+ BatchMatmulTransposeAOp::isExpectedAffineMaps (
4069
+ op->getAttr (" indexing_maps" ));
4009
4070
}
4010
4071
4011
4072
SmallVector<AffineMap>
4012
4073
BatchMatmulTransposeBOp::getAffineMaps (OpBuilder &builder) {
4013
4074
AffineExpr d0, d1, d2, d3;
4014
- auto context = builder.getContext ();
4075
+ MLIRContext * context = builder.getContext ();
4015
4076
bindDims (context, d0, d1, d2, d3);
4016
4077
AffineMap mapLHS = AffineMap::get (4 , 0 , {d0, d1, d3}, context);
4017
4078
AffineMap mapRHS = AffineMap::get (4 , 0 , {d0, d2, d3}, context);
4018
4079
AffineMap mapOut = AffineMap::get (4 , 0 , {d0, d1, d2}, context);
4019
- SmallVector<AffineMap> affineMaps{mapLHS, mapRHS, mapOut};
4020
- return affineMaps;
4080
+ return {mapLHS, mapRHS, mapOut};
4081
+ }
4082
+
4083
+ bool BatchMatmulTransposeBOp::isExpectedAffineMaps (Attribute attr) {
4084
+ ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4085
+ if (!maps)
4086
+ return false ;
4087
+ if (maps.size () != 3 )
4088
+ return false ;
4089
+ auto positions = getAffineResultPositions (maps);
4090
+ if (failed (positions))
4091
+ return false ;
4092
+ return (*positions)[0 ] == SmallVector<int64_t >{0 , 1 , 3 } &&
4093
+ (*positions)[1 ] == SmallVector<int64_t >{0 , 2 , 3 } &&
4094
+ (*positions)[2 ] == SmallVector<int64_t >{0 , 1 , 2 };
4021
4095
}
4022
4096
4023
4097
void linalg::BatchMatmulTransposeBOp::build (
@@ -4045,7 +4119,9 @@ void linalg::BatchMatmulTransposeBOp::build(
4045
4119
}
4046
4120
4047
4121
bool BatchMatmulTransposeBOp::classof (Operation *op) {
4048
- return dyn_cast_or_null<linalg::BatchMatmulOp>(op);
4122
+ return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4123
+ BatchMatmulTransposeBOp::isExpectedAffineMaps (
4124
+ op->getAttr (" indexing_maps" ));
4049
4125
}
4050
4126
4051
4127
// ===----------------------------------------------------------------------===//
0 commit comments