Skip to content

Commit 9c312f1

Browse files
committed
Remove redundant patterns and refactor code
1 parent 8915f1e commit 9c312f1

File tree

5 files changed

+85
-159
lines changed

5 files changed

+85
-159
lines changed

mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
2020
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
2121
#include "mlir/Transforms/DialectConversion.h"
22+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2223
#include "mlir/Transforms/OneToNTypeConversion.h"
2324
#include "llvm/ADT/SmallSet.h"
25+
#include "llvm/Support/LogicalResult.h"
2426

2527
namespace mlir {
2628

@@ -202,6 +204,12 @@ SmallVector<int64_t> getNativeVectorShapeImpl(vector::TransposeOp op);
202204
// For general ops.
203205
std::optional<SmallVector<int64_t>> getNativeVectorShape(Operation *op);
204206

207+
// Unroll vectors in function signatures to native size.
208+
LogicalResult unrollVectorsInSignatures(Operation *op);
209+
210+
// Unroll vectors in function bodies to native size.
211+
LogicalResult unrollVectorsInFuncBodies(Operation *op);
212+
205213
} // namespace spirv
206214
} // namespace mlir
207215

mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp

Lines changed: 7 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -44,90 +44,16 @@ struct ConvertToSPIRVPass final
4444
using ConvertToSPIRVPassBase::ConvertToSPIRVPassBase;
4545

4646
void runOnOperation() override {
47-
MLIRContext *context = &getContext();
4847
Operation *op = getOperation();
48+
MLIRContext *context = &getContext();
4949

50-
if (runSignatureConversion) {
51-
// Unroll vectors in function signatures to native vector size.
52-
RewritePatternSet patterns(context);
53-
populateFuncOpVectorRewritePatterns(patterns);
54-
populateReturnOpVectorRewritePatterns(patterns);
55-
GreedyRewriteConfig config;
56-
config.strictMode = GreedyRewriteStrictness::ExistingOps;
57-
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
58-
return signalPassFailure();
59-
}
60-
61-
if (runVectorUnrolling) {
62-
// Fold transpose ops if possible as we cannot unroll it later.
63-
{
64-
RewritePatternSet patterns(context);
65-
vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
66-
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) {
67-
return signalPassFailure();
68-
}
69-
}
70-
71-
// Unroll vectors to native vector size.
72-
{
73-
RewritePatternSet patterns(context);
74-
auto options = vector::UnrollVectorOptions().setNativeShapeFn(
75-
[=](auto op) { return mlir::spirv::getNativeVectorShape(op); });
76-
populateVectorUnrollPatterns(patterns, options);
77-
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
78-
return signalPassFailure();
79-
}
80-
81-
// Convert transpose ops into extract and insert pairs, in preparation
82-
// of further transformations to canonicalize/cancel.
83-
{
84-
RewritePatternSet patterns(context);
85-
auto options =
86-
vector::VectorTransformsOptions().setVectorTransposeLowering(
87-
vector::VectorTransposeLowering::EltWise);
88-
vector::populateVectorTransposeLoweringPatterns(patterns, options);
89-
vector::populateVectorShapeCastLoweringPatterns(patterns);
90-
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) {
91-
return signalPassFailure();
92-
}
93-
}
94-
95-
// Run canonicalization to cast away leading size-1 dimensions.
96-
{
97-
RewritePatternSet patterns(context);
98-
99-
// Pull in casting way leading one dims to allow cancelling some
100-
// read/write ops.
101-
vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
102-
vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
103-
104-
// Decompose different rank insert_strided_slice and n-D
105-
// extract_slided_slice.
106-
vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
107-
patterns);
108-
vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
109-
110-
// Trimming leading unit dims may generate broadcast/shape_cast ops.
111-
// Clean them up.
112-
vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
113-
vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
114-
115-
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
116-
return signalPassFailure();
117-
}
50+
// Unroll vectors in function signatures to native size.
51+
if (runSignatureConversion && failed(spirv::unrollVectorsInSignatures(op)))
52+
return signalPassFailure();
11853

119-
// Run all sorts of canonicalization patterns to clean up again.
120-
{
121-
RewritePatternSet patterns(context);
122-
vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
123-
vector::InsertOp::getCanonicalizationPatterns(patterns, context);
124-
vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
125-
vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
126-
vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
127-
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
128-
return signalPassFailure();
129-
}
130-
}
54+
// Unroll vectors in function bodies to native size.
55+
if (runVectorUnrolling && failed(spirv::unrollVectorsInFuncBodies(op)))
56+
return signalPassFailure();
13157

13258
spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
13359
std::unique_ptr<ConversionTarget> target =

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,21 @@
2020
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
2121
#include "mlir/Dialect/Utils/IndexingUtils.h"
2222
#include "mlir/Dialect/Vector/IR/VectorOps.h"
23+
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
2324
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
2425
#include "mlir/IR/BuiltinTypes.h"
2526
#include "mlir/IR/Operation.h"
2627
#include "mlir/IR/PatternMatch.h"
28+
#include "mlir/Pass/Pass.h"
2729
#include "mlir/Support/LLVM.h"
2830
#include "mlir/Transforms/DialectConversion.h"
31+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2932
#include "mlir/Transforms/OneToNTypeConversion.h"
3033
#include "llvm/ADT/STLExtras.h"
3134
#include "llvm/ADT/SmallVector.h"
3235
#include "llvm/ADT/StringExtras.h"
3336
#include "llvm/Support/Debug.h"
37+
#include "llvm/Support/LogicalResult.h"
3438
#include "llvm/Support/MathExtras.h"
3539

3640
#include <functional>
@@ -1330,6 +1334,68 @@ mlir::spirv::getNativeVectorShape(Operation *op) {
13301334
.Default([](Operation *) { return std::nullopt; });
13311335
}
13321336

1337+
LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) {
1338+
MLIRContext *context = op->getContext();
1339+
RewritePatternSet patterns(context);
1340+
populateFuncOpVectorRewritePatterns(patterns);
1341+
populateReturnOpVectorRewritePatterns(patterns);
1342+
GreedyRewriteConfig config;
1343+
config.strictMode = GreedyRewriteStrictness::ExistingOps;
1344+
return applyPatternsAndFoldGreedily(op, std::move(patterns), config);
1345+
}
1346+
1347+
LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
1348+
MLIRContext *context = op->getContext();
1349+
1350+
// Unroll vectors in function bodies to native vector size.
1351+
{
1352+
RewritePatternSet patterns(context);
1353+
auto options = vector::UnrollVectorOptions().setNativeShapeFn(
1354+
[](auto op) { return mlir::spirv::getNativeVectorShape(op); });
1355+
populateVectorUnrollPatterns(patterns, options);
1356+
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
1357+
return llvm::failure();
1358+
}
1359+
1360+
// Convert transpose ops into extract and insert pairs, in preparation of
1361+
// further transformations to canonicalize/cancel.
1362+
{
1363+
RewritePatternSet patterns(context);
1364+
auto options = vector::VectorTransformsOptions().setVectorTransposeLowering(
1365+
vector::VectorTransposeLowering::EltWise);
1366+
vector::populateVectorTransposeLoweringPatterns(patterns, options);
1367+
vector::populateVectorShapeCastLoweringPatterns(patterns);
1368+
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
1369+
return llvm::failure();
1370+
}
1371+
1372+
// Run canonicalization to cast away leading size-1 dimensions.
1373+
{
1374+
RewritePatternSet patterns(context);
1375+
1376+
// We need to pull in casting way leading one dims.
1377+
vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
1378+
vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
1379+
vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
1380+
1381+
// Decompose different rank insert_strided_slice and n-D
1382+
// extract_slided_slice.
1383+
vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
1384+
patterns);
1385+
vector::InsertOp::getCanonicalizationPatterns(patterns, context);
1386+
vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
1387+
1388+
// Trimming leading unit dims may generate broadcast/shape_cast ops. Clean
1389+
// them up.
1390+
vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
1391+
vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
1392+
1393+
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
1394+
return llvm::failure();
1395+
}
1396+
return llvm::success();
1397+
}
1398+
13331399
//===----------------------------------------------------------------------===//
13341400
// SPIR-V TypeConverter
13351401
//===----------------------------------------------------------------------===//

mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVFuncSignatureConversion.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,8 @@ struct TestSPIRVFuncSignatureConversion final
3737
}
3838

3939
void runOnOperation() override {
40-
RewritePatternSet patterns(&getContext());
41-
populateFuncOpVectorRewritePatterns(patterns);
42-
populateReturnOpVectorRewritePatterns(patterns);
43-
GreedyRewriteConfig config;
44-
config.strictMode = GreedyRewriteStrictness::ExistingOps;
45-
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
46-
config);
40+
Operation *op = getOperation();
41+
(void)spirv::unrollVectorsInSignatures(op);
4742
}
4843
};
4944

mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVVectorUnrolling.cpp

Lines changed: 2 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -34,78 +34,9 @@ struct TestSPIRVVectorUnrolling final
3434
}
3535

3636
void runOnOperation() override {
37-
MLIRContext *context = &getContext();
3837
Operation *op = getOperation();
39-
40-
// Unroll vectors in function signatures to native vector size.
41-
{
42-
RewritePatternSet patterns(context);
43-
populateFuncOpVectorRewritePatterns(patterns);
44-
populateReturnOpVectorRewritePatterns(patterns);
45-
GreedyRewriteConfig config;
46-
config.strictMode = GreedyRewriteStrictness::ExistingOps;
47-
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
48-
return signalPassFailure();
49-
}
50-
51-
// Unroll vectors to native vector size.
52-
{
53-
RewritePatternSet patterns(context);
54-
auto options = vector::UnrollVectorOptions().setNativeShapeFn(
55-
[=](auto op) { return mlir::spirv::getNativeVectorShape(op); });
56-
populateVectorUnrollPatterns(patterns, options);
57-
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
58-
return signalPassFailure();
59-
}
60-
61-
// Convert transpose ops into extract and insert pairs, in preparation of
62-
// further transformations to canonicalize/cancel.
63-
{
64-
RewritePatternSet patterns(context);
65-
auto options =
66-
vector::VectorTransformsOptions().setVectorTransposeLowering(
67-
vector::VectorTransposeLowering::EltWise);
68-
vector::populateVectorTransposeLoweringPatterns(patterns, options);
69-
vector::populateVectorShapeCastLoweringPatterns(patterns);
70-
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) {
71-
return signalPassFailure();
72-
}
73-
}
74-
75-
// Run canonicalization to cast away leading size-1 dimensions.
76-
{
77-
RewritePatternSet patterns(context);
78-
79-
// We need to pull in casting way leading one dims.
80-
vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
81-
vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
82-
83-
// Decompose different rank insert_strided_slice and n-D
84-
// extract_slided_slice.
85-
vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
86-
patterns);
87-
vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
88-
89-
// Trimming leading unit dims may generate broadcast/shape_cast ops. Clean
90-
// them up.
91-
vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
92-
vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
93-
94-
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
95-
return signalPassFailure();
96-
}
97-
98-
// Run all sorts of canonicalization patterns to clean up again.
99-
{
100-
RewritePatternSet patterns(context);
101-
vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
102-
vector::InsertOp::getCanonicalizationPatterns(patterns, context);
103-
vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
104-
vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
105-
vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
106-
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
107-
return signalPassFailure();
108-
}
38+
(void)spirv::unrollVectorsInSignatures(op);
39+
(void)spirv::unrollVectorsInFuncBodies(op);
10940
}
11041
};
11142

0 commit comments

Comments
 (0)