diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 560b088dbe5cd..c4b9ff005919b 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -40,7 +40,15 @@ def ConvertToSPIRVPass : Pass<"convert-to-spirv"> { let description = [{ This is a generic pass to convert to SPIR-V. }]; - let dependentDialects = ["spirv::SPIRVDialect"]; + let dependentDialects = [ + "spirv::SPIRVDialect", + "vector::VectorDialect", + ]; + let options = [ + Option<"runSignatureConversion", "run-signature-conversion", "bool", + /*default=*/"true", + "Run function signature conversion to convert vector types"> + ]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h index 09eecafc0c8a5..9ad3d5fc85dd3 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h @@ -17,7 +17,9 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/OneToNTypeConversion.h" #include "llvm/ADT/SmallSet.h" namespace mlir { @@ -134,6 +136,10 @@ class SPIRVConversionTarget : public ConversionTarget { void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns); +void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns); + +void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns); + namespace spirv { class AccessChainOp; diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp index b5be4654bcb25..003a5feea9e9b 100644 --- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp +++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp @@ -39,18 +39,31 @@ namespace { /// A pass to perform the SPIR-V conversion. struct ConvertToSPIRVPass final : impl::ConvertToSPIRVPassBase { + using ConvertToSPIRVPassBase::ConvertToSPIRVPassBase; void runOnOperation() override { MLIRContext *context = &getContext(); Operation *op = getOperation(); + if (runSignatureConversion) { + // Unroll vectors in function signatures to native vector size. + RewritePatternSet patterns(context); + populateFuncOpVectorRewritePatterns(patterns); + populateReturnOpVectorRewritePatterns(patterns); + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteStrictness::ExistingOps; + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config))) + return signalPassFailure(); + } + spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op); + std::unique_ptr target = + SPIRVConversionTarget::get(targetAttr); SPIRVTypeConverter typeConverter(targetAttr); - RewritePatternSet patterns(context); ScfToSPIRVContext scfToSPIRVContext; - // Populate patterns. + // Populate patterns for each dialect. arith::populateCeilFloorDivExpandOpsPatterns(patterns); arith::populateArithToSPIRVPatterns(typeConverter, patterns); populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns); @@ -60,9 +73,6 @@ struct ConvertToSPIRVPass final populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns); ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns); - std::unique_ptr target = - SPIRVConversionTarget::get(targetAttr); - if (failed(applyPartialConversion(op, *target, std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt index 821f82ebc0796..4de9b4729e720 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt @@ -16,9 +16,15 @@ add_mlir_dialect_library(MLIRSPIRVConversion ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV LINK_LIBS PUBLIC + MLIRArithDialect + MLIRDialectUtils MLIRFuncDialect + MLIRIR MLIRSPIRVDialect + MLIRSupport MLIRTransformUtils + MLIRVectorDialect + MLIRVectorTransforms ) add_mlir_dialect_library(MLIRSPIRVTransforms diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 4072608dc8f87..e3a09ef1ff684 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -11,14 +11,24 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/OneToNTypeConversion.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Debug.h" #include "llvm/Support/MathExtras.h" @@ -34,6 +44,43 @@ using namespace mlir; // Utility functions //===----------------------------------------------------------------------===// +static int getComputeVectorSize(int64_t size) { + for (int i : {4, 3, 2}) { + if (size % i == 0) + return i; + } + return 1; +} + +static std::optional> getTargetShape(VectorType vecType) { + LLVM_DEBUG(llvm::dbgs() << "Get target shape\n"); + if (vecType.isScalable()) { + LLVM_DEBUG(llvm::dbgs() + << "--scalable vectors are not supported -> BAIL\n"); + return std::nullopt; + } + SmallVector unrollShape = llvm::to_vector<4>(vecType.getShape()); + std::optional> targetShape = + SmallVector(1, getComputeVectorSize(vecType.getShape().back())); + if (!targetShape) { + LLVM_DEBUG(llvm::dbgs() << "--no unrolling target shape defined\n"); + return std::nullopt; + } + auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape); + if (!maybeShapeRatio) { + LLVM_DEBUG(llvm::dbgs() + << "--could not compute integral shape ratio -> BAIL\n"); + return std::nullopt; + } + if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) { + LLVM_DEBUG(llvm::dbgs() << "--no unrolling needed -> SKIP\n"); + return std::nullopt; + } + LLVM_DEBUG(llvm::dbgs() + << "--found an integral shape ratio to unroll to -> SUCCESS\n"); + return targetShape; +} + /// Checks that `candidates` extension requirements are possible to be satisfied /// with the given `targetEnv`. /// @@ -813,6 +860,249 @@ void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter, patterns.add(typeConverter, patterns.getContext()); } +//===----------------------------------------------------------------------===// +// func::FuncOp Conversion Patterns +//===----------------------------------------------------------------------===// + +namespace { +/// A pattern for rewriting function signature to convert vector arguments of +/// functions to be of valid types +struct FuncOpVectorUnroll final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(func::FuncOp funcOp, + PatternRewriter &rewriter) const override { + FunctionType fnType = funcOp.getFunctionType(); + + // TODO: Handle declarations. + if (funcOp.isDeclaration()) { + LLVM_DEBUG(llvm::dbgs() + << fnType << " illegal: declarations are unsupported\n"); + return failure(); + } + + // Create a new func op with the original type and copy the function body. + auto newFuncOp = rewriter.create(funcOp.getLoc(), + funcOp.getName(), fnType); + rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), + newFuncOp.end()); + + Location loc = newFuncOp.getBody().getLoc(); + + Block &entryBlock = newFuncOp.getBlocks().front(); + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&entryBlock); + + OneToNTypeMapping oneToNTypeMapping(fnType.getInputs()); + + // For arguments that are of illegal types and require unrolling. + // `unrolledInputNums` stores the indices of arguments that result from + // unrolling in the new function signature. `newInputNo` is a counter. + SmallVector unrolledInputNums; + size_t newInputNo = 0; + + // For arguments that are of legal types and do not require unrolling. + // `tmpOps` stores a mapping from temporary operations that serve as + // placeholders for new arguments that will be added later. These operations + // will be erased once the entry block's argument list is updated. + llvm::SmallDenseMap tmpOps; + + // This counts the number of new operations created. + size_t newOpCount = 0; + + // Enumerate through the arguments. + for (auto [origInputNo, origType] : enumerate(fnType.getInputs())) { + // Check whether the argument is of vector type. + auto origVecType = dyn_cast(origType); + if (!origVecType) { + // We need a placeholder for the old argument that will be erased later. + Value result = rewriter.create( + loc, origType, rewriter.getZeroAttr(origType)); + rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result); + tmpOps.insert({result.getDefiningOp(), newInputNo}); + oneToNTypeMapping.addInputs(origInputNo, origType); + ++newInputNo; + ++newOpCount; + continue; + } + // Check whether the vector needs unrolling. + auto targetShape = getTargetShape(origVecType); + if (!targetShape) { + // We need a placeholder for the old argument that will be erased later. + Value result = rewriter.create( + loc, origType, rewriter.getZeroAttr(origType)); + rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result); + tmpOps.insert({result.getDefiningOp(), newInputNo}); + oneToNTypeMapping.addInputs(origInputNo, origType); + ++newInputNo; + ++newOpCount; + continue; + } + VectorType unrolledType = + VectorType::get(*targetShape, origVecType.getElementType()); + auto originalShape = + llvm::to_vector_of(origVecType.getShape()); + + // Prepare the result vector. + Value result = rewriter.create( + loc, origVecType, rewriter.getZeroAttr(origVecType)); + ++newOpCount; + // Prepare the placeholder for the new arguments that will be added later. + Value dummy = rewriter.create( + loc, unrolledType, rewriter.getZeroAttr(unrolledType)); + ++newOpCount; + + // Create the `vector.insert_strided_slice` ops. + SmallVector strides(targetShape->size(), 1); + SmallVector newTypes; + for (SmallVector offsets : + StaticTileOffsetRange(originalShape, *targetShape)) { + result = rewriter.create( + loc, dummy, result, offsets, strides); + newTypes.push_back(unrolledType); + unrolledInputNums.push_back(newInputNo); + ++newInputNo; + ++newOpCount; + } + rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result); + oneToNTypeMapping.addInputs(origInputNo, newTypes); + } + + // Change the function signature. + auto convertedTypes = oneToNTypeMapping.getConvertedTypes(); + auto newFnType = fnType.clone(convertedTypes, fnType.getResults()); + rewriter.modifyOpInPlace(newFuncOp, + [&] { newFuncOp.setFunctionType(newFnType); }); + + // Update the arguments in the entry block. + entryBlock.eraseArguments(0, fnType.getNumInputs()); + SmallVector locs(convertedTypes.size(), newFuncOp.getLoc()); + entryBlock.addArguments(convertedTypes, locs); + + // Replace the placeholder values with the new arguments. We assume there is + // only one block for now. + size_t unrolledInputIdx = 0; + for (auto [count, op] : enumerate(entryBlock.getOperations())) { + // We first look for operands that are placeholders for initially legal + // arguments. + Operation &curOp = op; + for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) { + Operation *operandOp = operandVal.getDefiningOp(); + if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) { + size_t idx = operandIdx; + rewriter.modifyOpInPlace(&curOp, [&curOp, &newFuncOp, it, idx] { + curOp.setOperand(idx, newFuncOp.getArgument(it->second)); + }); + } + } + // Since all newly created operations are in the beginning, reaching the + // end of them means that any later `vector.insert_strided_slice` should + // not be touched. + if (count >= newOpCount) + continue; + if (auto vecOp = dyn_cast(op)) { + size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx]; + rewriter.modifyOpInPlace(&curOp, [&] { + curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo)); + }); + ++unrolledInputIdx; + } + } + + // Erase the original funcOp. The `tmpOps` do not need to be erased since + // they have no uses and will be handled by dead-code elimination. + rewriter.eraseOp(funcOp); + return success(); + } +}; +} // namespace + +void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +//===----------------------------------------------------------------------===// +// func::ReturnOp Conversion Patterns +//===----------------------------------------------------------------------===// + +namespace { +/// A pattern for rewriting function signature and the return op to convert +/// vectors to be of valid types. +struct ReturnOpVectorUnroll final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(func::ReturnOp returnOp, + PatternRewriter &rewriter) const override { + // Check whether the parent funcOp is valid. + auto funcOp = dyn_cast(returnOp->getParentOp()); + if (!funcOp) + return failure(); + + FunctionType fnType = funcOp.getFunctionType(); + OneToNTypeMapping oneToNTypeMapping(fnType.getResults()); + Location loc = returnOp.getLoc(); + + // For the new return op. + SmallVector newOperands; + + // Enumerate through the results. + for (auto [origResultNo, origType] : enumerate(fnType.getResults())) { + // Check whether the argument is of vector type. + auto origVecType = dyn_cast(origType); + if (!origVecType) { + oneToNTypeMapping.addInputs(origResultNo, origType); + newOperands.push_back(returnOp.getOperand(origResultNo)); + continue; + } + // Check whether the vector needs unrolling. + auto targetShape = getTargetShape(origVecType); + if (!targetShape) { + // The original argument can be used. + oneToNTypeMapping.addInputs(origResultNo, origType); + newOperands.push_back(returnOp.getOperand(origResultNo)); + continue; + } + VectorType unrolledType = + VectorType::get(*targetShape, origVecType.getElementType()); + + // Create `vector.extract_strided_slice` ops to form legal vectors from + // the original operand of illegal type. + auto originalShape = + llvm::to_vector_of(origVecType.getShape()); + SmallVector strides(targetShape->size(), 1); + SmallVector newTypes; + Value returnValue = returnOp.getOperand(origResultNo); + for (SmallVector offsets : + StaticTileOffsetRange(originalShape, *targetShape)) { + Value result = rewriter.create( + loc, returnValue, offsets, *targetShape, strides); + newOperands.push_back(result); + newTypes.push_back(unrolledType); + } + oneToNTypeMapping.addInputs(origResultNo, newTypes); + } + + // Change the function signature. + auto newFnType = + FunctionType::get(rewriter.getContext(), TypeRange(fnType.getInputs()), + TypeRange(oneToNTypeMapping.getConvertedTypes())); + rewriter.modifyOpInPlace(funcOp, + [&] { funcOp.setFunctionType(newFnType); }); + + // Replace the return op using the new operands. This will automatically + // update the entry block as well. + rewriter.replaceOp(returnOp, + rewriter.create(loc, newOperands)); + + return success(); + } +}; +} // namespace + +void mlir::populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + //===----------------------------------------------------------------------===// // Builtin Variables //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/ConvertToSPIRV/arith.mlir b/mlir/test/Conversion/ConvertToSPIRV/arith.mlir index a2adc0ad9c7a5..1a844a7cd018b 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/arith.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/arith.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -convert-to-spirv -split-input-file %s | FileCheck %s +// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" -split-input-file %s | FileCheck %s //===----------------------------------------------------------------------===// // arithmetic ops diff --git a/mlir/test/Conversion/ConvertToSPIRV/combined.mlir b/mlir/test/Conversion/ConvertToSPIRV/combined.mlir index 9e908465cb142..02b938be775a3 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/combined.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/combined.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -convert-to-spirv %s | FileCheck %s +// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s // CHECK-LABEL: @combined // CHECK: %[[C0_F32:.*]] = spirv.Constant 0.000000e+00 : f32 diff --git a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir new file mode 100644 index 0000000000000..347d282f9ee0c --- /dev/null +++ b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir @@ -0,0 +1,147 @@ +// RUN: mlir-opt -test-spirv-func-signature-conversion -split-input-file %s | FileCheck %s + +// CHECK-LABEL: @simple_scalar +// CHECK-SAME: (%[[ARG0:.+]]: i32) +func.func @simple_scalar(%arg0 : i32) -> i32 { + // CHECK: return %[[ARG0]] : i32 + return %arg0 : i32 +} + +// ----- + +// CHECK-LABEL: @simple_vector_4 +// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>) +func.func @simple_vector_4(%arg0 : vector<4xi32>) -> vector<4xi32> { + // CHECK: return %[[ARG0]] : vector<4xi32> + return %arg0 : vector<4xi32> +} + +// ----- + +// CHECK-LABEL: @simple_vector_5 +// CHECK-SAME: (%[[ARG0:.+]]: vector<1xi32>, %[[ARG1:.+]]: vector<1xi32>, %[[ARG2:.+]]: vector<1xi32>, %[[ARG3:.+]]: vector<1xi32>, %[[ARG4:.+]]: vector<1xi32>) +func.func @simple_vector_5(%arg0 : vector<5xi32>) -> vector<5xi32> { + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<5xi32> + // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0], strides = [1]} : vector<1xi32> into vector<5xi32> + // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [1], strides = [1]} : vector<1xi32> into vector<5xi32> + // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[INSERT1]] {offsets = [2], strides = [1]} : vector<1xi32> into vector<5xi32> + // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [3], strides = [1]} : vector<1xi32> into vector<5xi32> + // CHECK: %[[INSERT4:.*]] = vector.insert_strided_slice %[[ARG4]], %[[INSERT3]] {offsets = [4], strides = [1]} : vector<1xi32> into vector<5xi32> + // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT4]] {offsets = [0], sizes = [1], strides = [1]} : vector<5xi32> to vector<1xi32> + // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT4]] {offsets = [1], sizes = [1], strides = [1]} : vector<5xi32> to vector<1xi32> + // CHECK: %[[EXTRACT2:.*]] = vector.extract_strided_slice %[[INSERT4]] {offsets = [2], sizes = [1], strides = [1]} : vector<5xi32> to vector<1xi32> + // CHECK: %[[EXTRACT3:.*]] = vector.extract_strided_slice %[[INSERT4]] {offsets = [3], sizes = [1], strides = [1]} : vector<5xi32> to vector<1xi32> + // CHECK: %[[EXTRACT4:.*]] = vector.extract_strided_slice %[[INSERT4]] {offsets = [4], sizes = [1], strides = [1]} : vector<5xi32> to vector<1xi32> + // CHECK: return %[[EXTRACT0]], %[[EXTRACT1]], %[[EXTRACT2]], %[[EXTRACT3]], %[[EXTRACT4]] : vector<1xi32>, vector<1xi32>, vector<1xi32>, vector<1xi32>, vector<1xi32> + return %arg0 : vector<5xi32> +} + +// ----- + +// CHECK-LABEL: @simple_vector_6 +// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>) +func.func @simple_vector_6(%arg0 : vector<6xi32>) -> vector<6xi32> { + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<6xi32> + // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0], strides = [1]} : vector<3xi32> into vector<6xi32> + // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [3], strides = [1]} : vector<3xi32> into vector<6xi32> + // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [0], sizes = [3], strides = [1]} : vector<6xi32> to vector<3xi32> + // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [3], sizes = [3], strides = [1]} : vector<6xi32> to vector<3xi32> + // CHECK: return %[[EXTRACT0]], %[[EXTRACT1]] : vector<3xi32>, vector<3xi32> + return %arg0 : vector<6xi32> +} + +// ----- + +// CHECK-LABEL: @simple_vector_8 +// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>) +func.func @simple_vector_8(%arg0 : vector<8xi32>) -> vector<8xi32> { + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32> + // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32> + // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32> + // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32> + // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32> + // CHECK: return %[[EXTRACT0]], %[[EXTRACT1]] : vector<4xi32>, vector<4xi32> + return %arg0 : vector<8xi32> +} + +// ----- + +// CHECK-LABEL: @vector_6and8 +// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>) +func.func @vector_6and8(%arg0 : vector<6xi32>, %arg1 : vector<8xi32>) -> (vector<6xi32>, vector<8xi32>) { + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32> + // CHECK: %[[CST0:.*]] = arith.constant dense<0> : vector<6xi32> + // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]] {offsets = [0], strides = [1]} : vector<3xi32> into vector<6xi32> + // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [3], strides = [1]} : vector<3xi32> into vector<6xi32> + // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32> + // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32> + // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [0], sizes = [3], strides = [1]} : vector<6xi32> to vector<3xi32> + // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [3], sizes = [3], strides = [1]} : vector<6xi32> to vector<3xi32> + // CHECK: %[[EXTRACT2:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32> + // CHECK: %[[EXTRACT3:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32> + // CHECK: return %[[EXTRACT0]], %[[EXTRACT1]], %[[EXTRACT2]], %[[EXTRACT3]] : vector<3xi32>, vector<3xi32>, vector<4xi32>, vector<4xi32> + return %arg0, %arg1 : vector<6xi32>, vector<8xi32> +} + +// ----- + +// CHECK-LABEL: @vector_3and8 +// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>) +func.func @vector_3and8(%arg0 : vector<3xi32>, %arg1 : vector<8xi32>) -> (vector<3xi32>, vector<8xi32>) { + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32> + // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG1]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32> + // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG2]], %[[INSERT0]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32> + // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32> + // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32> + // CHECK: return %[[ARG0]], %[[EXTRACT0]], %[[EXTRACT1]] : vector<3xi32>, vector<4xi32>, vector<4xi32> + return %arg0, %arg1 : vector<3xi32>, vector<8xi32> +} + +// ----- + +// CHECK-LABEL: @scalar_vector +// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<3xi32>, %[[ARG3:.+]]: i32) +func.func @scalar_vector(%arg0 : vector<8xi32>, %arg1 : vector<3xi32>, %arg2 : i32) -> (vector<8xi32>, vector<3xi32>, i32) { + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32> + // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32> + // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32> + // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32> + // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT1]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xi32> to vector<4xi32> + // CHECK: return %[[EXTRACT0]], %[[EXTRACT1]], %[[ARG2]], %[[ARG3]] : vector<4xi32>, vector<4xi32>, vector<3xi32>, i32 + return %arg0, %arg1, %arg2 : vector<8xi32>, vector<3xi32>, i32 +} + +// ----- + +// CHECK-LABEL: @reduction +// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>, %[[ARG4:.+]]: i32) +func.func @reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>, %arg2 : i32) -> (i32) { + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32> + // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32> + // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32> + // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[CST]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<8xi32> + // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<8xi32> + // CHECK: %[[ADDI:.*]] = arith.addi %[[INSERT1]], %[[INSERT3]] : vector<8xi32> + // CHECK: %[[REDUCTION:.*]] = vector.reduction , %[[ADDI]] : vector<8xi32> into i32 + // CHECK: %[[RET:.*]] = arith.addi %[[REDUCTION]], %[[ARG4]] : i32 + // CHECK: return %[[RET]] : i32 + %0 = arith.addi %arg0, %arg1 : vector<8xi32> + %1 = vector.reduction , %0 : vector<8xi32> into i32 + %2 = arith.addi %1, %arg2 : i32 + return %2 : i32 +} + +// ----- + +// CHECK-LABEL: func.func private @unsupported_decl(vector<8xi32>) +func.func private @unsupported_decl(vector<8xi32>) + +// ----- + +// CHECK-LABEL: @unsupported_scalable +// CHECK-SAME: (%[[ARG0:.+]]: vector<[8]xi32>) +func.func @unsupported_scalable(%arg0 : vector<[8]xi32>) -> (vector<[8]xi32>) { + // CHECK: return %[[ARG0]] : vector<[8]xi32> + return %arg0 : vector<[8]xi32> +} + diff --git a/mlir/test/Conversion/ConvertToSPIRV/index.mlir b/mlir/test/Conversion/ConvertToSPIRV/index.mlir index db747625bc7b3..e1cb18aac5d01 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/index.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/index.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-to-spirv | FileCheck %s +// RUN: mlir-opt %s -convert-to-spirv="run-signature-conversion=false" | FileCheck %s // CHECK-LABEL: @basic func.func @basic(%a: index, %b: index) { diff --git a/mlir/test/Conversion/ConvertToSPIRV/scf.mlir b/mlir/test/Conversion/ConvertToSPIRV/scf.mlir index f619ca5771824..58ec6ac61f6ac 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/scf.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/scf.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -convert-to-spirv %s | FileCheck %s +// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s // CHECK-LABEL: @if_yield // CHECK: %[[VAR:.*]] = spirv.Variable : !spirv.ptr diff --git a/mlir/test/Conversion/ConvertToSPIRV/simple.mlir b/mlir/test/Conversion/ConvertToSPIRV/simple.mlir index 20b2a42bc3975..c5e0e6603d94a 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/simple.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/simple.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -convert-to-spirv %s | FileCheck %s +// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s // CHECK-LABEL: @return_scalar // CHECK-SAME: %[[ARG0:.*]]: i32 diff --git a/mlir/test/Conversion/ConvertToSPIRV/ub.mlir b/mlir/test/Conversion/ConvertToSPIRV/ub.mlir index 66528b68f58cf..a83bfb6f405a0 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/ub.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/ub.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -convert-to-spirv %s | FileCheck %s +// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s // CHECK-LABEL: @ub // CHECK: %[[UNDEF:.*]] = spirv.Undef : i32 diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir index 336f0fe10c27e..c63dd030f4747 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -split-input-file -convert-to-spirv %s | FileCheck %s +// RUN: mlir-opt -split-input-file -convert-to-spirv="run-signature-conversion=false" %s | FileCheck %s // CHECK-LABEL: @extract // CHECK-SAME: %[[ARG:.+]]: vector<2xf32> diff --git a/mlir/test/lib/Conversion/CMakeLists.txt b/mlir/test/lib/Conversion/CMakeLists.txt index 754c9866d18e4..19975f671b081 100644 --- a/mlir/test/lib/Conversion/CMakeLists.txt +++ b/mlir/test/lib/Conversion/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(ConvertToSPIRV) add_subdirectory(FuncToLLVM) add_subdirectory(MathToVCIX) add_subdirectory(OneToNTypeConversion) diff --git a/mlir/test/lib/Conversion/ConvertToSPIRV/CMakeLists.txt b/mlir/test/lib/Conversion/ConvertToSPIRV/CMakeLists.txt new file mode 100644 index 0000000000000..69b5787f7e851 --- /dev/null +++ b/mlir/test/lib/Conversion/ConvertToSPIRV/CMakeLists.txt @@ -0,0 +1,16 @@ +# Exclude tests from libMLIR.so +add_mlir_library(MLIRTestConvertToSPIRV + TestSPIRVFuncSignatureConversion.cpp + + EXCLUDE_FROM_LIBMLIR + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRFuncDialect + MLIRPass + MLIRSPIRVConversion + MLIRSPIRVDialect + MLIRTransformUtils + MLIRTransforms + MLIRVectorDialect + ) diff --git a/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVFuncSignatureConversion.cpp b/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVFuncSignatureConversion.cpp new file mode 100644 index 0000000000000..ec67f85f6f27b --- /dev/null +++ b/mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVFuncSignatureConversion.cpp @@ -0,0 +1,57 @@ +//===- TestSPIRVFuncSignatureConversion.cpp - Test signature conversion -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===-------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace { + +struct TestSPIRVFuncSignatureConversion final + : PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSPIRVFuncSignatureConversion) + + StringRef getArgument() const final { + return "test-spirv-func-signature-conversion"; + } + + StringRef getDescription() const final { + return "Test patterns that convert vector inputs and results in function " + "signatures"; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateFuncOpVectorRewritePatterns(patterns); + populateReturnOpVectorRewritePatterns(patterns); + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteStrictness::ExistingOps; + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config); + } +}; + +} // namespace + +namespace test { +void registerTestSPIRVFuncSignatureConversion() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt index e8091bca3326c..8b79de58fa102 100644 --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -36,6 +36,7 @@ if(MLIR_INCLUDE_TESTS) MLIRSPIRVTestPasses MLIRTensorTestPasses MLIRTestAnalysis + MLIRTestConvertToSPIRV MLIRTestDialect MLIRTestDynDialect MLIRTestIR diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 8cafb0afac9ae..149f9d59961b8 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -141,6 +141,7 @@ void registerTestSCFWhileOpBuilderPass(); void registerTestSCFWrapInZeroTripCheckPasses(); void registerTestShapeMappingPass(); void registerTestSliceAnalysisPass(); +void registerTestSPIRVFuncSignatureConversion(); void registerTestTensorCopyInsertionPass(); void registerTestTensorTransforms(); void registerTestTopologicalSortAnalysisPass(); @@ -273,6 +274,7 @@ void registerTestPasses() { mlir::test::registerTestSCFWrapInZeroTripCheckPasses(); mlir::test::registerTestShapeMappingPass(); mlir::test::registerTestSliceAnalysisPass(); + mlir::test::registerTestSPIRVFuncSignatureConversion(); mlir::test::registerTestTensorCopyInsertionPass(); mlir::test::registerTestTensorTransforms(); mlir::test::registerTestTopologicalSortAnalysisPass(); diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index ab3757342c76f..73b3217824bd9 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -7205,10 +7205,15 @@ cc_library( hdrs = ["include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"], includes = ["include"], deps = [ + ":ArithDialect", + ":DialectUtils", ":FuncDialect", ":IR", ":SPIRVDialect", + ":Support", ":TransformUtils", + ":VectorDialect", + ":VectorTransforms", "//llvm:Support", ], ) @@ -9586,6 +9591,7 @@ cc_binary( "//mlir/test:TestArmSME", "//mlir/test:TestBufferization", "//mlir/test:TestControlFlow", + "//mlir/test:TestConvertToSPIRV", "//mlir/test:TestDLTI", "//mlir/test:TestDialect", "//mlir/test:TestFunc", diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel index 1d59370057d1c..a1d2b20a106e6 100644 --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -656,6 +656,21 @@ cc_library( ], ) +cc_library( + name = "TestConvertToSPIRV", + srcs = glob(["lib/Conversion/ConvertToSPIRV/*.cpp"]), + deps = [ + "//mlir:ArithDialect", + "//mlir:FuncDialect", + "//mlir:Pass", + "//mlir:SPIRVConversion", + "//mlir:SPIRVDialect", + "//mlir:TransformUtils", + "//mlir:Transforms", + "//mlir:VectorDialect", + ], +) + cc_library( name = "TestAffine", srcs = glob([