File tree 2 files changed +31
-2
lines changed
lib/Dialect/Vector/Transforms
2 files changed +31
-2
lines changed Original file line number Diff line number Diff line change @@ -912,6 +912,15 @@ struct ReorderElementwiseOpsOnBroadcast final
912
912
return failure ();
913
913
if (!OpTrait::hasElementwiseMappableTraits (op))
914
914
return failure ();
915
+ if (op->getNumOperands () == 0 ||
916
+ op->getResults ()[0 ].getType () != op->getOperand (0 ).getType ()) {
917
+ return failure ();
918
+ }
919
+ // Avoid operations that only accept vector types, since broadcast
920
+ // source might be scalar types.
921
+ if (isa<vector::FMAOp>(op)) {
922
+ return failure ();
923
+ }
915
924
916
925
// Get the type of the lhs operand
917
926
auto *lhsBcastOrSplat = op->getOperand (0 ).getDefiningOp ();
@@ -1447,8 +1456,8 @@ void mlir::vector::
1447
1456
1448
1457
void mlir::vector::populateSinkVectorBroadcastPatterns (
1449
1458
RewritePatternSet &patterns, PatternBenefit benefit) {
1450
- patterns.add <ReorderElementwiseOpsOnBroadcast>(patterns. getContext (),
1451
- benefit);
1459
+ patterns.add <ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnBroadcast>(
1460
+ patterns. getContext (), benefit);
1452
1461
}
1453
1462
1454
1463
// ===----------------------------------------------------------------------===//
Original file line number Diff line number Diff line change @@ -105,3 +105,23 @@ func.func @broadcast_not_elementwise() -> vector<2x2xf32> {
105
105
106
106
return %mm1 : vector <2 x2 xf32 >
107
107
}
108
+
109
+ // CHECK-LABEL: func.func @dont_sink_cmp(
110
+ // CHECK: %[[BROADCAST:.+]] = vector.broadcast
111
+ // CHECK: %[[RETURN:.+]] = arith.cmpf uno, %[[BROADCAST]], %[[BROADCAST]]
112
+ // CHECK: return %[[RETURN]]
113
+ func.func @dont_sink_cmp (%arg0 : f32 , %arg1 : vector <1 xf32 >) -> vector <1 xi1 > {
114
+ %0 = vector.broadcast %arg0 : f32 to vector <1 xf32 >
115
+ %1 = arith.cmpf uno , %0 , %0 : vector <1 xf32 >
116
+ return %1 : vector <1 xi1 >
117
+ }
118
+
119
+ // CHECK-LABEL: func.func @dont_sink_fma(
120
+ // CHECK: %[[BROADCAST:.+]] = vector.broadcast
121
+ // CHECK: %[[RESULT:.+]] = vector.fma %[[BROADCAST]]
122
+ // CHECK: return %[[RESULT]]
123
+ func.func @dont_sink_fma (%arg0 : f32 ) -> vector <1 xf32 > {
124
+ %0 = vector.broadcast %arg0 : f32 to vector <1 xf32 >
125
+ %1 = vector.fma %0 , %0 , %0 : vector <1 xf32 >
126
+ return %1 : vector <1 xf32 >
127
+ }
You can’t perform that action at this time.
0 commit comments