diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 748646e605827..b5bb2f42f2961 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -47,7 +47,10 @@ def ConvertToSPIRVPass : Pass<"convert-to-spirv"> { let options = [ Option<"runSignatureConversion", "run-signature-conversion", "bool", /*default=*/"true", - "Run function signature conversion to convert vector types"> + "Run function signature conversion to convert vector types">, + Option<"runVectorUnrolling", "run-vector-unrolling", "bool", + /*default=*/"true", + "Run vector unrolling to convert vector types in function bodies"> ]; } diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h index 9ad3d5fc85dd3..f54c93a09e727 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h @@ -19,8 +19,10 @@ #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/OneToNTypeConversion.h" #include "llvm/ADT/SmallSet.h" +#include "llvm/Support/LogicalResult.h" namespace mlir { @@ -189,6 +191,25 @@ Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder); +// Find the largest factor of size among {2,3,4} for the lowest dimension of +// the target shape. +int getComputeVectorSize(int64_t size); + +// GetNativeVectorShape implementation for reduction ops. +SmallVector getNativeVectorShapeImpl(vector::ReductionOp op); + +// GetNativeVectorShape implementation for transpose ops. +SmallVector getNativeVectorShapeImpl(vector::TransposeOp op); + +// For general ops. +std::optional> getNativeVectorShape(Operation *op); + +// Unroll vectors in function signatures to native size. +LogicalResult unrollVectorsInSignatures(Operation *op); + +// Unroll vectors in function bodies to native size. +LogicalResult unrollVectorsInFuncBodies(Operation *op); + } // namespace spirv } // namespace mlir diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp index 003a5feea9e9b..4694a147e1e94 100644 --- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp +++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp @@ -17,6 +17,8 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -42,19 +44,16 @@ struct ConvertToSPIRVPass final using ConvertToSPIRVPassBase::ConvertToSPIRVPassBase; void runOnOperation() override { - MLIRContext *context = &getContext(); Operation *op = getOperation(); + MLIRContext *context = &getContext(); - if (runSignatureConversion) { - // Unroll vectors in function signatures to native vector size. - RewritePatternSet patterns(context); - populateFuncOpVectorRewritePatterns(patterns); - populateReturnOpVectorRewritePatterns(patterns); - GreedyRewriteConfig config; - config.strictMode = GreedyRewriteStrictness::ExistingOps; - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config))) - return signalPassFailure(); - } + // Unroll vectors in function signatures to native size. + if (runSignatureConversion && failed(spirv::unrollVectorsInSignatures(op))) + return signalPassFailure(); + + // Unroll vectors in function bodies to native size. + if (runVectorUnrolling && failed(spirv::unrollVectorsInFuncBodies(op))) + return signalPassFailure(); spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op); std::unique_ptr target = diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index bf5044437fd09..d833ec9309baa 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -20,17 +20,21 @@ #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/OneToNTypeConversion.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/MathExtras.h" #include @@ -46,14 +50,6 @@ namespace { // Utility functions //===----------------------------------------------------------------------===// -static int getComputeVectorSize(int64_t size) { - for (int i : {4, 3, 2}) { - if (size % i == 0) - return i; - } - return 1; -} - static std::optional> getTargetShape(VectorType vecType) { LLVM_DEBUG(llvm::dbgs() << "Get target shape\n"); if (vecType.isScalable()) { @@ -62,8 +58,8 @@ static std::optional> getTargetShape(VectorType vecType) { return std::nullopt; } SmallVector unrollShape = llvm::to_vector<4>(vecType.getShape()); - std::optional> targetShape = - SmallVector(1, getComputeVectorSize(vecType.getShape().back())); + std::optional> targetShape = SmallVector( + 1, mlir::spirv::getComputeVectorSize(vecType.getShape().back())); if (!targetShape) { LLVM_DEBUG(llvm::dbgs() << "--no unrolling target shape defined\n"); return std::nullopt; @@ -1098,13 +1094,20 @@ struct ReturnOpVectorUnroll final : OpRewritePattern { // the original operand of illegal type. auto originalShape = llvm::to_vector_of(origVecType.getShape()); - SmallVector strides(targetShape->size(), 1); + SmallVector strides(originalShape.size(), 1); + SmallVector extractShape(originalShape.size(), 1); + extractShape.back() = targetShape->back(); SmallVector newTypes; Value returnValue = returnOp.getOperand(origResultNo); for (SmallVector offsets : StaticTileOffsetRange(originalShape, *targetShape)) { Value result = rewriter.create( - loc, returnValue, offsets, *targetShape, strides); + loc, returnValue, offsets, extractShape, strides); + if (originalShape.size() > 1) { + SmallVector extractIndices(originalShape.size() - 1, 0); + result = + rewriter.create(loc, result, extractIndices); + } newOperands.push_back(result); newTypes.push_back(unrolledType); } @@ -1285,6 +1288,118 @@ Value mlir::spirv::getElementPtr(const SPIRVTypeConverter &typeConverter, builder); } +//===----------------------------------------------------------------------===// +// Public functions for vector unrolling +//===----------------------------------------------------------------------===// + +int mlir::spirv::getComputeVectorSize(int64_t size) { + for (int i : {4, 3, 2}) { + if (size % i == 0) + return i; + } + return 1; +} + +SmallVector +mlir::spirv::getNativeVectorShapeImpl(vector::ReductionOp op) { + VectorType srcVectorType = op.getSourceVectorType(); + assert(srcVectorType.getRank() == 1); // Guaranteed by semantics + int64_t vectorSize = + mlir::spirv::getComputeVectorSize(srcVectorType.getDimSize(0)); + return {vectorSize}; +} + +SmallVector +mlir::spirv::getNativeVectorShapeImpl(vector::TransposeOp op) { + VectorType vectorType = op.getResultVectorType(); + SmallVector nativeSize(vectorType.getRank(), 1); + nativeSize.back() = + mlir::spirv::getComputeVectorSize(vectorType.getShape().back()); + return nativeSize; +} + +std::optional> +mlir::spirv::getNativeVectorShape(Operation *op) { + if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) { + if (auto vecType = dyn_cast(op->getResultTypes()[0])) { + SmallVector nativeSize(vecType.getRank(), 1); + nativeSize.back() = + mlir::spirv::getComputeVectorSize(vecType.getShape().back()); + return nativeSize; + } + } + + return TypeSwitch>>(op) + .Case( + [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); }) + .Default([](Operation *) { return std::nullopt; }); +} + +LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) { + MLIRContext *context = op->getContext(); + RewritePatternSet patterns(context); + populateFuncOpVectorRewritePatterns(patterns); + populateReturnOpVectorRewritePatterns(patterns); + // We only want to apply signature conversion once to the existing func ops. + // Without specifying strictMode, the greedy pattern rewriter will keep + // looking for newly created func ops. + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteStrictness::ExistingOps; + return applyPatternsAndFoldGreedily(op, std::move(patterns), config); +} + +LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) { + MLIRContext *context = op->getContext(); + + // Unroll vectors in function bodies to native vector size. + { + RewritePatternSet patterns(context); + auto options = vector::UnrollVectorOptions().setNativeShapeFn( + [](auto op) { return mlir::spirv::getNativeVectorShape(op); }); + populateVectorUnrollPatterns(patterns, options); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + return failure(); + } + + // Convert transpose ops into extract and insert pairs, in preparation of + // further transformations to canonicalize/cancel. + { + RewritePatternSet patterns(context); + auto options = vector::VectorTransformsOptions().setVectorTransposeLowering( + vector::VectorTransposeLowering::EltWise); + vector::populateVectorTransposeLoweringPatterns(patterns, options); + vector::populateVectorShapeCastLoweringPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + return failure(); + } + + // Run canonicalization to cast away leading size-1 dimensions. + { + RewritePatternSet patterns(context); + + // We need to pull in casting way leading one dims. + vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); + vector::ReductionOp::getCanonicalizationPatterns(patterns, context); + vector::TransposeOp::getCanonicalizationPatterns(patterns, context); + + // Decompose different rank insert_strided_slice and n-D + // extract_slided_slice. + vector::populateVectorInsertExtractStridedSliceDecompositionPatterns( + patterns); + vector::InsertOp::getCanonicalizationPatterns(patterns, context); + vector::ExtractOp::getCanonicalizationPatterns(patterns, context); + + // Trimming leading unit dims may generate broadcast/shape_cast ops. Clean + // them up. + vector::BroadcastOp::getCanonicalizationPatterns(patterns, context); + vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context); + + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + return failure(); + } + return success(); +} + //===----------------------------------------------------------------------===// // SPIR-V TypeConverter //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/ConvertToSPIRV/arith.mlir b/mlir/test/Conversion/ConvertToSPIRV/arith.mlir index 1a844a7cd018b..6418e931f7460 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/arith.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/arith.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" -split-input-file %s | FileCheck %s +// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s //===----------------------------------------------------------------------===// // arithmetic ops diff --git a/mlir/test/Conversion/ConvertToSPIRV/combined.mlir b/mlir/test/Conversion/ConvertToSPIRV/combined.mlir index 02b938be775a3..311174bef15ed 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/combined.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/combined.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s +// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s // CHECK-LABEL: @combined // CHECK: %[[C0_F32:.*]] = spirv.Constant 0.000000e+00 : f32 diff --git a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir index 347d282f9ee0c..c018ccb924983 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir @@ -66,6 +66,28 @@ func.func @simple_vector_8(%arg0 : vector<8xi32>) -> vector<8xi32> { // ----- +// CHECK-LABEL: @simple_vector_2d +// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>) +func.func @simple_vector_2d(%arg0 : vector<4x4xi32>) -> vector<4x4xi32> { + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4x4xi32> + // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32> + // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [1, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32> + // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[INSERT1]] {offsets = [2, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32> + // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [3, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32> + // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [0, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32> + // CHECK: %[[EXTRACT0_1:.*]] = vector.extract %[[EXTRACT0]][0] : vector<4xi32> from vector<1x4xi32> + // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [1, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32> + // CHECK: %[[EXTRACT1_1:.*]] = vector.extract %[[EXTRACT1]][0] : vector<4xi32> from vector<1x4xi32> + // CHECK: %[[EXTRACT2:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [2, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32> + // CHECK: %[[EXTRACT2_1:.*]] = vector.extract %[[EXTRACT2]][0] : vector<4xi32> from vector<1x4xi32> + // CHECK: %[[EXTRACT3:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [3, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32> + // CHECK: %[[EXTRACT3_1:.*]] = vector.extract %[[EXTRACT3]][0] : vector<4xi32> from vector<1x4xi32> + // CHECK: return %[[EXTRACT0_1]], %[[EXTRACT1_1]], %[[EXTRACT2_1]], %[[EXTRACT3_1]] : vector<4xi32>, vector<4xi32>, vector<4xi32>, vector<4xi32> + return %arg0 : vector<4x4xi32> +} + +// ----- + // CHECK-LABEL: @vector_6and8 // CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>) func.func @vector_6and8(%arg0 : vector<6xi32>, %arg1 : vector<8xi32>) -> (vector<6xi32>, vector<8xi32>) { @@ -113,6 +135,28 @@ func.func @scalar_vector(%arg0 : vector<8xi32>, %arg1 : vector<3xi32>, %arg2 : i // ----- +// CHECK-LABEL: @vector_2dand1d +// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<3xi32>, %[[ARG3:.+]]: vector<3xi32>, %[[ARG4:.+]]: vector<4xi32>) +func.func @vector_2dand1d(%arg0 : vector<2x6xi32>, %arg1 : vector<4xi32>) -> (vector<2x6xi32>, vector<4xi32>) { + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<2x6xi32> + // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<3xi32> into vector<2x6xi32> + // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [0, 3], strides = [1]} : vector<3xi32> into vector<2x6xi32> + // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[INSERT1]] {offsets = [1, 0], strides = [1]} : vector<3xi32> into vector<2x6xi32> + // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [1, 3], strides = [1]} : vector<3xi32> into vector<2x6xi32> + // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [0, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32> + // CHECK: %[[EXTRACT0_1:.*]] = vector.extract %[[EXTRACT0]][0] : vector<3xi32> from vector<1x3xi32> + // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [0, 3], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32> + // CHECK: %[[EXTRACT1_1:.*]] = vector.extract %[[EXTRACT1]][0] : vector<3xi32> from vector<1x3xi32> + // CHECK: %[[EXTRACT2:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [1, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32> + // CHECK: %[[EXTRACT2_1:.*]] = vector.extract %[[EXTRACT2]][0] : vector<3xi32> from vector<1x3xi32> + // CHECK: %[[EXTRACT3:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [1, 3], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32> + // CHECK: %[[EXTRACT3_1:.*]] = vector.extract %[[EXTRACT3]][0] : vector<3xi32> from vector<1x3xi32> + // CHECK: return %[[EXTRACT0_1]], %[[EXTRACT1_1]], %[[EXTRACT2_1]], %[[EXTRACT3_1]], %[[ARG4]] : vector<3xi32>, vector<3xi32>, vector<3xi32>, vector<3xi32>, vector<4xi32> + return %arg0, %arg1 : vector<2x6xi32>, vector<4xi32> +} + +// ----- + // CHECK-LABEL: @reduction // CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>, %[[ARG4:.+]]: i32) func.func @reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>, %arg2 : i32) -> (i32) { diff --git a/mlir/test/Conversion/ConvertToSPIRV/index.mlir b/mlir/test/Conversion/ConvertToSPIRV/index.mlir index e1cb18aac5d01..f4b116849fa93 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/index.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/index.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-to-spirv="run-signature-conversion=false" | FileCheck %s +// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s // CHECK-LABEL: @basic func.func @basic(%a: index, %b: index) { diff --git a/mlir/test/Conversion/ConvertToSPIRV/scf.mlir b/mlir/test/Conversion/ConvertToSPIRV/scf.mlir index 58ec6ac61f6ac..246464928b81c 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/scf.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/scf.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s +// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s // CHECK-LABEL: @if_yield // CHECK: %[[VAR:.*]] = spirv.Variable : !spirv.ptr diff --git a/mlir/test/Conversion/ConvertToSPIRV/simple.mlir b/mlir/test/Conversion/ConvertToSPIRV/simple.mlir index c5e0e6603d94a..00556140c3018 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/simple.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/simple.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s +// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s // CHECK-LABEL: @return_scalar // CHECK-SAME: %[[ARG0:.*]]: i32 diff --git a/mlir/test/Conversion/ConvertToSPIRV/ub.mlir b/mlir/test/Conversion/ConvertToSPIRV/ub.mlir index a83bfb6f405a0..f34ca01c94f00 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/ub.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/ub.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s +// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s // CHECK-LABEL: @ub // CHECK: %[[UNDEF:.*]] = spirv.Undef : i32 diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir new file mode 100644 index 0000000000000..043f9422d8790 --- /dev/null +++ b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir @@ -0,0 +1,102 @@ +// RUN: mlir-opt -test-spirv-vector-unrolling -split-input-file %s | FileCheck %s + +// CHECK-LABEL: @vaddi +// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<3xi32>, %[[ARG3:.+]]: vector<3xi32>) +func.func @vaddi(%arg0 : vector<6xi32>, %arg1 : vector<6xi32>) -> (vector<6xi32>) { + // CHECK: %[[ADD0:.*]] = arith.addi %[[ARG0]], %[[ARG2]] : vector<3xi32> + // CHECK: %[[ADD1:.*]] = arith.addi %[[ARG1]], %[[ARG3]] : vector<3xi32> + // CHECK: return %[[ADD0]], %[[ADD1]] : vector<3xi32>, vector<3xi32> + %0 = arith.addi %arg0, %arg1 : vector<6xi32> + return %0 : vector<6xi32> +} + +// CHECK-LABEL: @vaddi_2d +// CHECK-SAME: (%[[ARG0:.+]]: vector<2xi32>, %[[ARG1:.+]]: vector<2xi32>, %[[ARG2:.+]]: vector<2xi32>, %[[ARG3:.+]]: vector<2xi32>) +func.func @vaddi_2d(%arg0 : vector<2x2xi32>, %arg1 : vector<2x2xi32>) -> (vector<2x2xi32>) { + // CHECK: %[[ADD0:.*]] = arith.addi %[[ARG0]], %[[ARG2]] : vector<2xi32> + // CHECK: %[[ADD1:.*]] = arith.addi %[[ARG1]], %[[ARG3]] : vector<2xi32> + // CHECK: return %[[ADD0]], %[[ADD1]] : vector<2xi32>, vector<2xi32> + %0 = arith.addi %arg0, %arg1 : vector<2x2xi32> + return %0 : vector<2x2xi32> +} + +// CHECK-LABEL: @vaddi_2d_8 +// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>, %[[ARG4:.+]]: vector<4xi32>, %[[ARG5:.+]]: vector<4xi32>, %[[ARG6:.+]]: vector<4xi32>, %[[ARG7:.+]]: vector<4xi32>) +func.func @vaddi_2d_8(%arg0 : vector<2x8xi32>, %arg1 : vector<2x8xi32>) -> (vector<2x8xi32>) { + // CHECK: %[[ADD0:.*]] = arith.addi %[[ARG0]], %[[ARG4]] : vector<4xi32> + // CHECK: %[[ADD1:.*]] = arith.addi %[[ARG1]], %[[ARG5]] : vector<4xi32> + // CHECK: %[[ADD2:.*]] = arith.addi %[[ARG2]], %[[ARG6]] : vector<4xi32> + // CHECK: %[[ADD3:.*]] = arith.addi %[[ARG3]], %[[ARG7]] : vector<4xi32> + // CHECK: return %[[ADD0]], %[[ADD1]], %[[ADD2]], %[[ADD3]] : vector<4xi32>, vector<4xi32>, vector<4xi32>, vector<4xi32> + %0 = arith.addi %arg0, %arg1 : vector<2x8xi32> + return %0 : vector<2x8xi32> +} + +// ----- + +// CHECK-LABEL: @reduction_5 +// CHECK-SAME: (%[[ARG0:.+]]: vector<1xi32>, %[[ARG1:.+]]: vector<1xi32>, %[[ARG2:.+]]: vector<1xi32>, %[[ARG3:.+]]: vector<1xi32>, %[[ARG4:.+]]: vector<1xi32>) +func.func @reduction_5(%arg0 : vector<5xi32>) -> (i32) { + // CHECK: %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][0] : i32 from vector<1xi32> + // CHECK: %[[EXTRACT1:.*]] = vector.extract %[[ARG1]][0] : i32 from vector<1xi32> + // CHECK: %[[ADD0:.*]] = arith.addi %[[EXTRACT0]], %[[EXTRACT1]] : i32 + // CHECK: %[[EXTRACT2:.*]] = vector.extract %[[ARG2]][0] : i32 from vector<1xi32> + // CHECK: %[[ADD1:.*]] = arith.addi %[[ADD0]], %[[EXTRACT2]] : i32 + // CHECK: %[[EXTRACT3:.*]] = vector.extract %[[ARG3]][0] : i32 from vector<1xi32> + // CHECK: %[[ADD2:.*]] = arith.addi %[[ADD1]], %[[EXTRACT3]] : i32 + // CHECK: %[[EXTRACT4:.*]] = vector.extract %[[ARG4]][0] : i32 from vector<1xi32> + // CHECK: %[[ADD3:.*]] = arith.addi %[[ADD2]], %[[EXTRACT4]] : i32 + // CHECK: return %[[ADD3]] : i32 + %0 = vector.reduction , %arg0 : vector<5xi32> into i32 + return %0 : i32 +} + +// CHECK-LABEL: @reduction_8 +// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>) +func.func @reduction_8(%arg0 : vector<8xi32>) -> (i32) { + // CHECK: %[[REDUCTION0:.*]] = vector.reduction , %[[ARG0]] : vector<4xi32> into i32 + // CHECK: %[[REDUCTION1:.*]] = vector.reduction , %[[ARG1]] : vector<4xi32> into i32 + // CHECK: %[[ADD:.*]] = arith.addi %[[REDUCTION0]], %[[REDUCTION1]] : i32 + // CHECK: return %[[ADD]] : i32 + %0 = vector.reduction , %arg0 : vector<8xi32> into i32 + return %0 : i32 +} + +// ----- + +// CHECK-LABEL: @vaddi_reduction +// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>) +func.func @vaddi_reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>) -> (i32) { + // CHECK: %[[ADD0:.*]] = arith.addi %[[ARG0]], %[[ARG2]] : vector<4xi32> + // CHECK: %[[ADD1:.*]] = arith.addi %[[ARG1]], %[[ARG3]] : vector<4xi32> + // CHECK: %[[REDUCTION0:.*]] = vector.reduction , %[[ADD0]] : vector<4xi32> into i32 + // CHECK: %[[REDUCTION1:.*]] = vector.reduction , %[[ADD1]] : vector<4xi32> into i32 + // CHECK: %[[ADD2:.*]] = arith.addi %[[REDUCTION0]], %[[REDUCTION1]] : i32 + // CHECK: return %[[ADD2]] : i32 + %0 = arith.addi %arg0, %arg1 : vector<8xi32> + %1 = vector.reduction , %0 : vector<8xi32> into i32 + return %1 : i32 +} + +// ----- + +// CHECK-LABEL: @transpose +// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>) +func.func @transpose(%arg0 : vector<2x3xi32>) -> (vector<3x2xi32>) { + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<2xi32> + // CHECK: %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][0] : i32 from vector<3xi32> + // CHECK: %[[INSERT0:.*]]= vector.insert %[[EXTRACT0]], %[[CST]] [0] : i32 into vector<2xi32> + // CHECK: %[[EXTRACT1:.*]] = vector.extract %[[ARG1]][0] : i32 from vector<3xi32> + // CHECK: %[[INSERT1:.*]] = vector.insert %[[EXTRACT1]], %[[INSERT0]][1] : i32 into vector<2xi32> + // CHECK: %[[EXTRACT2:.*]] = vector.extract %[[ARG0]][1] : i32 from vector<3xi32> + // CHECK: %[[INSERT2:.*]] = vector.insert %[[EXTRACT2]], %[[CST]] [0] : i32 into vector<2xi32> + // CHECK: %[[EXTRACT3:.*]] = vector.extract %[[ARG1]][1] : i32 from vector<3xi32> + // CHECK: %[[INSERT3:.*]] = vector.insert %[[EXTRACT3]], %[[INSERT2]] [1] : i32 into vector<2xi32> + // CHECK: %[[EXTRACT4:.*]] = vector.extract %[[ARG0]][2] : i32 from vector<3xi32> + // CHECK: %[[INSERT4:.*]] = vector.insert %[[EXTRACT4]], %[[CST]] [0] : i32 into vector<2xi32> + // CHECK: %[[EXTRACT5:.*]] = vector.extract %[[ARG1]][2] : i32 from vector<3xi32> + // CHECK: %[[INSERT5:.*]] = vector.insert %[[EXTRACT5]], %[[INSERT4]] [1] : i32 into vector<2xi32> + // CHECK: return %[[INSERT1]], %[[INSERT3]], %[[INSERT5]] : vector<2xi32>, vector<2xi32>, vector<2xi32> + %0 = vector.transpose %arg0, [1, 0] : vector<2x3xi32> to vector<3x2xi32> + return %0 : vector<3x2xi32> +} diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir index c63dd030f4747..e369eadca5730 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -split-input-file -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s +// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s // CHECK-LABEL: @extract // CHECK-SAME: %[[ARG:.+]]: vector<2xf32> diff --git a/mlir/test/lib/Conversion/ConvertToSPIRV/CMakeLists.txt b/mlir/test/lib/Conversion/ConvertToSPIRV/CMakeLists.txt index 69b5787f7e851..aeade52c7ade5 100644 --- a/mlir/test/lib/Conversion/ConvertToSPIRV/CMakeLists.txt +++ b/mlir/test/lib/Conversion/ConvertToSPIRV/CMakeLists.txt @@ -1,6 +1,7 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRTestConvertToSPIRV TestSPIRVFuncSignatureConversion.cpp + TestSPIRVVectorUnrolling.cpp EXCLUDE_FROM_LIBMLIR @@ -13,4 +14,5 @@ add_mlir_library(MLIRTestConvertToSPIRV MLIRTransformUtils MLIRTransforms MLIRVectorDialect + MLIRVectorTransforms ) diff --git a/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVFuncSignatureConversion.cpp b/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVFuncSignatureConversion.cpp index ec67f85f6f27b..4a792336caba4 100644 --- a/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVFuncSignatureConversion.cpp +++ b/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVFuncSignatureConversion.cpp @@ -37,13 +37,8 @@ struct TestSPIRVFuncSignatureConversion final } void runOnOperation() override { - RewritePatternSet patterns(&getContext()); - populateFuncOpVectorRewritePatterns(patterns); - populateReturnOpVectorRewritePatterns(patterns); - GreedyRewriteConfig config; - config.strictMode = GreedyRewriteStrictness::ExistingOps; - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config); + Operation *op = getOperation(); + (void)spirv::unrollVectorsInSignatures(op); } }; diff --git a/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVVectorUnrolling.cpp b/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVVectorUnrolling.cpp new file mode 100644 index 0000000000000..0bad43d5214b1 --- /dev/null +++ b/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVVectorUnrolling.cpp @@ -0,0 +1,50 @@ +//===- TestSPIRVVectorUnrolling.cpp - Test signature conversion -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===-------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace { + +struct TestSPIRVVectorUnrolling final + : PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSPIRVVectorUnrolling) + + StringRef getArgument() const final { return "test-spirv-vector-unrolling"; } + + StringRef getDescription() const final { + return "Test patterns that unroll vectors to types supported by SPIR-V"; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + Operation *op = getOperation(); + (void)spirv::unrollVectorsInSignatures(op); + (void)spirv::unrollVectorsInFuncBodies(op); + } +}; + +} // namespace + +namespace test { +void registerTestSPIRVVectorUnrolling() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 149f9d59961b8..0f29963da39bb 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -142,6 +142,7 @@ void registerTestSCFWrapInZeroTripCheckPasses(); void registerTestShapeMappingPass(); void registerTestSliceAnalysisPass(); void registerTestSPIRVFuncSignatureConversion(); +void registerTestSPIRVVectorUnrolling(); void registerTestTensorCopyInsertionPass(); void registerTestTensorTransforms(); void registerTestTopologicalSortAnalysisPass(); @@ -275,6 +276,7 @@ void registerTestPasses() { mlir::test::registerTestShapeMappingPass(); mlir::test::registerTestSliceAnalysisPass(); mlir::test::registerTestSPIRVFuncSignatureConversion(); + mlir::test::registerTestSPIRVVectorUnrolling(); mlir::test::registerTestTensorCopyInsertionPass(); mlir::test::registerTestTensorTransforms(); mlir::test::registerTestTopologicalSortAnalysisPass();