From 8e39c56b6f39cc03002ba9c5e6662fa29d478016 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 15 Apr 2025 22:34:54 +0000 Subject: [PATCH 1/7] Add more patterns to Vector Linearize Pass --- .../Vector/Transforms/VectorLinearize.cpp | 407 +++++++++++++++++- mlir/test/Dialect/Vector/linearize.mlir | 335 ++++++++++++++ .../Dialect/Vector/TestVectorTransforms.cpp | 3 +- 3 files changed, 741 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index a009aa03aaf64..6de5d0c5a101e 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" @@ -27,6 +28,10 @@ using namespace mlir; static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) { + // For BW-0, all operations are legal + if (targetBitWidth == 0) { + return false; + } auto resultTypes = op->getResultTypes(); for (auto resType : resultTypes) { VectorType vecType = dyn_cast(resType); @@ -273,6 +278,77 @@ struct LinearizeVectorExtractStridedSlice final unsigned targetVectorBitWidth; }; +/// This pattern linearizes the InsertStridedSliceOp by extracting rows from the +/// source vector using ExtractStridedSliceOp and inserting them into the +/// destination vector using InsertStridedSliceOp. +/// Following, +/// vector.insert_strided_slice %s, %d {offsets=[0, 0]}: vector<2x4xf32> into vector<4x4xf32> +/// is converted to : +/// %0 = vector.extract_strided_slice %s {offsets=[0], sizes=[4], strides=[1]} : vector<4xf32> from vector<8xf32> +/// %1 = vector.insert_strided_slice %0, %d {offsets=[0], strides=[1]} : vector<4xf32> into vector<16xf32> +/// %2 = vector.extract_strided_slice %s {offsets=[4], sizes=[4], strides=[1]} : vector<4xf32> from vector<8xf32> +/// %3 = vector.insert_strided_slice %2, %1 {offsets=[4], strides=[1]} : vector<4xf32> into vector<16xf32> +struct LinearizeVectorInsertStridedSlice final + : public OpConversionPattern { + using OpConversionPattern< + vector::InsertStridedSliceOp>::OpConversionPattern; + LinearizeVectorInsertStridedSlice( + const TypeConverter &typeConverter, MLIRContext *context, + unsigned targetVectBitWidth = std::numeric_limits::max(), + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + targetVectorBitWidth(targetVectBitWidth) {} + + LogicalResult + matchAndRewrite(vector::InsertStridedSliceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto srcTy = op.getSourceVectorType(); + auto dstTy = op.getDestVectorType(); + + if (op.hasNonUnitStrides()) { + return rewriter.notifyMatchFailure( + op, "InsertStridedSliceOp linearization only supports unit strides."); + } + + if (srcTy.getRank() != 2) { + return rewriter.notifyMatchFailure( + op, "InsertStridedSliceOp linearization only supports 2D source."); + } + + if (!srcTy.hasStaticShape() || !dstTy.hasStaticShape()) { + return rewriter.notifyMatchFailure( + op, "InsertStridedSliceOp linerization only supports static shapes."); + } + + auto dstShape = dstTy.getShape(); + auto dstStrides = dstShape.drop_front().vec(); + dstStrides.push_back(1); + int64_t linearizedOffset = 0; + for (auto [off, stride] : llvm::zip_equal(op.getOffsets(), dstStrides)) { + linearizedOffset += getConstantIntValue(off).value() * stride; + } + + // extracts a row from source, and insert it into the destination + auto srcShape = srcTy.getShape(); + Value dstValue = adaptor.getDest(); + for (auto i = 0; i < srcShape[0]; i++) { + auto srcOffset = i * srcShape[1]; + auto value = rewriter.create( + loc, adaptor.getValueToStore(), srcOffset, srcShape[1], 1); + + auto dstOffset = linearizedOffset + i * dstShape.back(); + dstValue = rewriter.create( + loc, value, dstValue, dstOffset, 1); + } + + rewriter.replaceOp(op, dstValue); + return success(); + } + private: + unsigned targetVectorBitWidth; +}; + /// This pattern converts the ShuffleOp that works on nD (n > 1) /// vectors to a ShuffleOp that works on linearized vectors. /// Following, @@ -369,6 +445,11 @@ struct LinearizeVectorExtract final LogicalResult matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + // Skip if result is not a vector type + if (!isa(extractOp.getType())) + return rewriter.notifyMatchFailure(extractOp, + "scalar extract is not supported."); + Type dstTy = getTypeConverter()->convertType(extractOp.getType()); if (!dstTy) return rewriter.notifyMatchFailure(extractOp, @@ -531,12 +612,312 @@ struct LinearizeVectorBitCast final unsigned targetVectorBitWidth; }; +/// This pattern converts the LoadOp to a series of LoadOp & InsertOp +/// that works on a linearized vector. +/// Following, +/// vector.load %base[%indices] : vector<4x4xf32> +/// is converted to : +/// %result = arith.constant dense<0.0> : vector<4x4xf32> +/// %slice_0 = vector.load %base[%indices] : vector<4xf32> +/// %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into vector<4x4xf32> +/// %slice_1 = vector.load %base[%indices + 1] : vector<4xf32> +/// %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into vector<4x4xf32> +/// ... +/// This unrolls the 2D vector load into multiple 1D vector loads and inserts +/// them into the result vector. The pattern currently supports only 2D vectors +struct LinearizeVectorLoad final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LinearizeVectorLoad( + const TypeConverter &typeConverter, MLIRContext *context, + unsigned targetVectBitWidth = std::numeric_limits::max(), + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + targetVectorBitWidth(targetVectBitWidth) {} + + LogicalResult + matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = loadOp->getLoc(); + auto vecType = loadOp.getVectorType(); + auto shape = vecType.getShape(); + + if (shape.size() != 2) { + return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors."); + } + auto unrollCount = shape[0]; + auto vecSize = shape[1]; + auto newVecType = + VectorType::get({vecSize}, vecType.getElementType()); + + llvm::SmallVector indices = adaptor.getIndices(); + Value xBaseIndex = indices[0]; + + // Construct the 2D vector. + Value resultVec = rewriter.create( + loc, rewriter.getZeroAttr(vecType)); + // Emit unrolled loads for each 1D vector slice. + for (auto i = 0; i < unrollCount; i++) { + Value xIndex = xBaseIndex; + if (i) { + auto increment = rewriter.create(loc, i); + xIndex = + rewriter.create(loc, xBaseIndex, increment); + } + indices[0] = xIndex; + auto vec = rewriter.create( + loc, newVecType, adaptor.getBase(), indices); + resultVec = + rewriter.create(loc, vec, resultVec, i); + } + + rewriter.replaceOp(loadOp, resultVec); + return success(); + } + private: + unsigned targetVectorBitWidth; +}; + +/// This pattern converts the StoreOp to a series of StoreOp & ExtractOp +/// that works on a linearized vector. +/// Following, +/// vector.store %source, %base[%indices] : vector<4x4xf32> +/// is converted to : +/// %slice_0 = vector.extract %source[0] : vector<4xf32> +/// vector.store %slice_0, %base[%indices] : vector<4xf32> +/// %slice_1 = vector.extract %source[1] : vector<4xf32> +/// vector.store %slice_1, %base[%indices + 1] : vector<4xf32> +/// ... +/// This unrolls the 2D vector store into multiple 1D vector stores by extracting +/// slices from the source vector and storing them into the destination. +/// The pattern currently supports only 2D vectors +struct LinearizeVectorStore final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LinearizeVectorStore( + const TypeConverter &typeConverter, MLIRContext *context, + unsigned targetVectBitWidth = std::numeric_limits::max(), + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + targetVectorBitWidth(targetVectBitWidth) {} + + LogicalResult + matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = storeOp->getLoc(); + auto vecType = storeOp.getVectorType(); + auto shape = vecType.getShape(); + + if (shape.size() != 2) { + return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors."); + } + + auto unrollCount = shape[0]; + llvm::SmallVector indices = adaptor.getIndices(); + Value xBaseIndex = indices[0]; + + auto vec = rewriter.create( + loc, vecType, adaptor.getValueToStore()); + + for (auto i = 0; i < unrollCount; i++) { + auto vecSlice = rewriter.create(loc, vec, i); + Value xIndex = xBaseIndex; + if (i) { + auto increment = rewriter.create(loc, i); + xIndex = + rewriter.create(loc, xBaseIndex, increment); + } + indices[0] = xIndex; + rewriter.create(loc, vecSlice, adaptor.getBase(), + indices); + } + rewriter.eraseOp(storeOp); + return success(); + } + private: + unsigned targetVectorBitWidth; +}; + +/// This pattern converts the SplatOp to work on a linearized vector. +/// Following, +/// vector.splat %value : vector<4x4xf32> +/// is converted to: +/// %out_1d = vector.splat %value : vector<16xf32> +/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32> +/// It ensures that the operation is compatible with the target vector +/// bit width and replaces the original operation with a new SplatOp +/// that operates on the converted type. +struct LinearizeVectorSplat final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LinearizeVectorSplat( + const TypeConverter &typeConverter, MLIRContext *context, + unsigned targetVectBitWidth = std::numeric_limits::max(), + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + targetVectorBitWidth(targetVectBitWidth) {} + + LogicalResult + matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto dstTy = getTypeConverter()->convertType(splatOp.getType()); + if (!dstTy) + return rewriter.notifyMatchFailure(splatOp, "cannot convert type."); + rewriter.replaceOpWithNewOp( + splatOp, adaptor.getInput(), dstTy); + return success(); + } + private: + unsigned targetVectorBitWidth; +}; + +/// This pattern converts the CreateMaskOp to work on a +/// linearized vector. It ensures that the operation is compatible with the +/// target vector bit width and replaces the original operation with a new +/// CreateMaskOp that operates on the converted type. The pattern currently +/// supports only 2D masks with a unit outer dimension. +/// Following, +/// vector.create_mask %dims : vector<1x4xi1> +/// is converted to: +/// %out_1d = vector.create_mask %dims : vector<4xi1> +/// %out_nd = vector.shape_cast %out_1d : vector<4xi1> to vector<1x4xi1> +struct LinearizeVectorCreateMask final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LinearizeVectorCreateMask( + const TypeConverter &typeConverter, MLIRContext *context, + unsigned targetVectBitWidth = std::numeric_limits::max(), + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + targetVectorBitWidth(targetVectBitWidth) {} + + LogicalResult + matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcTy = createMaskOp.getType(); + auto srcShape = srcTy.getShape(); + if (srcShape.size() != 2) + return rewriter.notifyMatchFailure(createMaskOp, + "only 2D mask is supported."); + + if (srcShape[0] != 1) + return rewriter.notifyMatchFailure( + createMaskOp, "only unit outer dimension is supported."); + + auto dstTy = getTypeConverter()->convertType(srcTy); + if (!dstTy) + return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type."); + + rewriter.replaceOpWithNewOp( + createMaskOp, dstTy, adaptor.getOperands().back()); + return success(); + } + private: + unsigned targetVectorBitWidth; +}; + +/// This pattern converts operations implementing the RegionBranchOpInterface +/// to ensure compatibility with linearized vector types. It updates the +/// operands, result types, and region types (block arguments and yields) to +/// match the converted types. Additionally, it processes yields within each +/// region to ensure that the types of yielded values are compatible with the +/// target vector bit width. If the result types of the operation are updated, +/// shape cast operations are inserted to maintain compatibility with the +/// original types. This pattern ensures that operations with regions are +/// properly linearized and remain valid after type conversion. +struct LinearizeRegionBranchOp final + : public OpInterfaceConversionPattern { + using OpInterfaceConversionPattern< + RegionBranchOpInterface>::OpInterfaceConversionPattern; + + LinearizeRegionBranchOp( + const TypeConverter &typeConverter, MLIRContext *context, + unsigned targetVectBitWidth = std::numeric_limits::max(), + PatternBenefit benefit = 1) + : OpInterfaceConversionPattern(typeConverter, context, benefit), + targetVectorBitWidth(targetVectBitWidth) {} + + LogicalResult + matchAndRewrite(RegionBranchOpInterface op, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto converter = getTypeConverter(); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.startOpModification(op); + + llvm::SmallVector convertedTypes; + for (Type ty : op->getResultTypes()) { + convertedTypes.push_back(converter->convertType(ty)); + } + + if (convertedTypes == op->getResultTypes() && + op->getOperands() == operands) { + return failure(); + } + + op->setOperands(operands); + + // Convert region types (block arguments and yields) + for (Region ®ion : op->getRegions()) { + if (failed(rewriter.convertRegionTypes(®ion, *converter))) { + return failure(); + } + + // Process yields within each region + for (Block &block : region) { + if (auto *terminator = block.getTerminator()) { + for (OpOperand &yieldOperand : terminator->getOpOperands()) { + Value value = yieldOperand.get(); + Type type = value.getType(); + if (!converter->isLegal(type)) { + Type newTy = converter->convertType(type); + rewriter.setInsertionPoint(terminator); + Value newValue = + rewriter.create(loc, newTy, value); + yieldOperand.set(newValue); + } + } + } + } + } + + // Update result types + rewriter.setInsertionPointAfter(op); + llvm::SmallVector newResults; + for (Value result : op->getResults()) { + Type oldTy = result.getType(); + if (!converter->isLegal(oldTy)) { + Type newTy = converter->convertType(oldTy); + result.setType(newTy); + Operation *castOp = + rewriter.create(loc, oldTy, result); + result.replaceAllUsesExcept(castOp->getResult(0), castOp); + newResults.push_back(castOp->getResult(0)); + } else { + newResults.push_back(result); + } + } + + rewriter.finalizeOpModification(op); + return success(); + } + private: + unsigned targetVectorBitWidth; +}; + } // namespace void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, unsigned targetBitWidth) { + typeConverter.addConversion([](Type type) -> Type { return type; }); typeConverter.addConversion([](VectorType type) -> std::optional { if (!isLinearizableVector(type)) return type; @@ -555,9 +936,12 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( }; typeConverter.addSourceMaterialization(materializeCast); typeConverter.addTargetMaterialization(materializeCast); + target.addLegalOp(); target.markUnknownOpDynamicallyLegal( [=](Operation *op) -> std::optional { - if ((isa(op) || + if ((isa(op) || op->hasTrait() || op->hasTrait())) { return (isLessThanTargetBitWidth(op, targetBitWidth) @@ -568,7 +952,10 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( }); patterns.add(typeConverter, patterns.getContext(), + LinearizeVectorBitCast, LinearizeVectorLoad, + LinearizeVectorStore, LinearizeVectorSplat, + LinearizeVectorCreateMask, LinearizeRegionBranchOp + >(typeConverter, patterns.getContext(), targetBitWidth); } @@ -583,7 +970,21 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( .getRank() == 1) : true; }); + + target.addDynamicallyLegalOp( + [=](vector::InsertStridedSliceOp op) -> bool { + if(isLessThanTargetBitWidth(op, targetBitWidth)) { + auto srcTy = op.getSourceVectorType(); + auto dstTy = op.getDestVectorType(); + if (!op.hasNonUnitStrides() && srcTy.getRank() == 2 && + srcTy.hasStaticShape() && dstTy.hasStaticShape()) + return false; + } + return true; + }); + patterns.add( + LinearizeVectorInsert, LinearizeVectorExtractStridedSlice, + LinearizeVectorInsertStridedSlice>( typeConverter, patterns.getContext(), targetBitWidth); } diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 9052c6440e6ac..e47e7c4a84d68 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -399,3 +399,338 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> { %1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16> return %1 : vector<[4]x4xf16> } + +// ----- +// ALL-LABEL: test_vector_load +// ALL-SAME: (%[[ARG_0:.*]]: memref<4x4xf16>) +func.func @test_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> { + // DEFAULT: %[[C1:.*]] = arith.constant 1 : index + // BW-128: %[[C1:.*]] = arith.constant 1 : index + // DEFAULT: %[[C2:.*]] = arith.constant 2 : index + // BW-128: %[[C2:.*]] = arith.constant 2 : index + // DEFAULT: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf16> + // BW-128: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf16> + // DEFAULT: %[[LOAD0:.*]] = vector.load %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // BW-128: %[[LOAD0:.*]] = vector.load %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // DEFAULT: %[[SHUFFLE0:.*]] = vector.shuffle %[[CST]], %[[LOAD0]] [16, 17, 18, 19, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16> + // BW-128: %[[SHUFFLE0:.*]] = vector.shuffle %[[CST]], %[[LOAD0]] [16, 17, 18, 19, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16> + // DEFAULT: %[[C1_0:.*]] = arith.constant 1 : index + // BW-128: %[[C1_0:.*]] = arith.constant 1 : index + // DEFAULT: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index + // BW-128: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index + // DEFAULT: %[[LOAD1:.*]] = vector.load %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // BW-128: %[[LOAD1:.*]] = vector.load %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // DEFAULT: %[[SHUFFLE1:.*]] = vector.shuffle %[[SHUFFLE0]], %[[LOAD1]] [0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16> + // BW-128: %[[SHUFFLE1:.*]] = vector.shuffle %[[SHUFFLE0]], %[[LOAD1]] [0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16> + // DEFAULT: %[[C2_1:.*]] = arith.constant 2 : index + // BW-128: %[[C2_1:.*]] = arith.constant 2 : index + // DEFAULT: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index + // BW-128: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index + // DEFAULT: %[[LOAD2:.*]] = vector.load %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // BW-128: %[[LOAD2:.*]] = vector.load %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // DEFAULT: %[[SHUFFLE2:.*]] = vector.shuffle %[[SHUFFLE1]], %[[LOAD2]] [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16> + // BW-128: %[[SHUFFLE2:.*]] = vector.shuffle %[[SHUFFLE1]], %[[LOAD2]] [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16> + // DEFAULT: %[[C3:.*]] = arith.constant 3 : index + // BW-128: %[[C3:.*]] = arith.constant 3 : index + // DEFAULT: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index + // BW-128: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index + // DEFAULT: %[[LOAD3:.*]] = vector.load %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // BW-128: %[[LOAD3:.*]] = vector.load %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // DEFAULT: %[[SHUFFLE3:.*]] = vector.shuffle %[[SHUFFLE2]], %[[LOAD3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 17, 18, 19] : vector<16xf16>, vector<4xf16> + // BW-128: %[[SHUFFLE3:.*]] = vector.shuffle %[[SHUFFLE2]], %[[LOAD3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 17, 18, 19] : vector<16xf16>, vector<4xf16> + // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SHUFFLE3]] : vector<16xf16> to vector<4x4xf16> + // BW-128: %[[CAST:.*]] = vector.shape_cast %[[SHUFFLE3]] : vector<16xf16> to vector<4x4xf16> + // DEFAULT: return %[[CAST]] : vector<4x4xf16> + // BW-128: return %[[CAST]] : vector<4x4xf16> + + // BW-0: %[[C1:.*]] = arith.constant 1 : index + // BW-0: %[[C2:.*]] = arith.constant 2 : index + // BW-0: %[[LOAD:.*]] = vector.load %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4x4xf16> + // BW-0: return %[[LOAD]] : vector<4x4xf16> + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0 = vector.load %arg0[%c1, %c2] : memref<4x4xf16>, vector<4x4xf16> + return %0 : vector<4x4xf16> +} + +// ----- +// ALL-LABEL: test_vector_store +// ALL-SAME: (%[[ARG_0:.*]]: memref<4x4xf16>, %[[ARG_1:.*]]: vector<4x4xf16>) { +func.func @test_vector_store(%arg0: memref<4x4xf16>, %arg1: vector<4x4xf16>) { + // DEFAULT: %[[CAST0:.*]] = vector.shape_cast %[[ARG_1]] : vector<4x4xf16> to vector<16xf16> + // BW-128: %[[CAST0:.*]] = vector.shape_cast %[[ARG_1]] : vector<4x4xf16> to vector<16xf16> + // DEFAULT: %[[C1:.*]] = arith.constant 1 : index + // BW-128: %[[C1:.*]] = arith.constant 1 : index + // DEFAULT: %[[C2:.*]] = arith.constant 2 : index + // BW-128: %[[C2:.*]] = arith.constant 2 : index + // DEFAULT: %[[CAST1:.*]] = vector.shape_cast %[[CAST0]] : vector<16xf16> to vector<4x4xf16> + // BW-128: %[[CAST1:.*]] = vector.shape_cast %[[CAST0]] : vector<16xf16> to vector<4x4xf16> + // DEFAULT: %[[CAST2:.*]] = vector.shape_cast %[[CAST1]] : vector<4x4xf16> to vector<16xf16> + // BW-128: %[[CAST2:.*]] = vector.shape_cast %[[CAST1]] : vector<4x4xf16> to vector<16xf16> + // DEFAULT: %[[SHUFFLE0:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [0, 1, 2, 3] : vector<16xf16>, vector<16xf16> + // BW-128: %[[SHUFFLE0:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [0, 1, 2, 3] : vector<16xf16>, vector<16xf16> + // DEFAULT: vector.store %[[SHUFFLE0]], %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // BW-128: vector.store %[[SHUFFLE0]], %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // DEFAULT: %[[SHUFFLE1:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [4, 5, 6, 7] : vector<16xf16>, vector<16xf16> + // BW-128: %[[SHUFFLE1:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [4, 5, 6, 7] : vector<16xf16>, vector<16xf16> + // DEFAULT: %[[C1_0:.*]] = arith.constant 1 : index + // BW-128: %[[C1_0:.*]] = arith.constant 1 : index + // DEFAULT: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index + // BW-128: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index + // DEFAULT: vector.store %[[SHUFFLE1]], %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // BW-128: vector.store %[[SHUFFLE1]], %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // DEFAULT: %[[SHUFFLE2:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [8, 9, 10, 11] : vector<16xf16>, vector<16xf16> + // BW-128: %[[SHUFFLE2:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [8, 9, 10, 11] : vector<16xf16>, vector<16xf16> + // DEFAULT: %[[C2_1:.*]] = arith.constant 2 : index + // BW-128: %[[C2_1:.*]] = arith.constant 2 : index + // DEFAULT: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index + // BW-128: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index + // DEFAULT: vector.store %[[SHUFFLE2]], %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // BW-128: vector.store %[[SHUFFLE2]], %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // DEFAULT: %[[SHUFFLE3:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [12, 13, 14, 15] : vector<16xf16>, vector<16xf16> + // BW-128: %[[SHUFFLE3:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [12, 13, 14, 15] : vector<16xf16>, vector<16xf16> + // DEFAULT: %[[C3:.*]] = arith.constant 3 : index + // BW-128: %[[C3:.*]] = arith.constant 3 : index + // DEFAULT: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index + // BW-128: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index + // DEFAULT: vector.store %[[SHUFFLE3]], %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // BW-128: vector.store %[[SHUFFLE3]], %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // DEFAULT: return + // BW-128: return + + // BW-0: %[[C1:.*]] = arith.constant 1 : index + // BW-0: %[[C2:.*]] = arith.constant 2 : index + // BW-0: vector.store %[[ARG_1]], %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4x4xf16> + // BW-0: return + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + vector.store %arg1, %arg0[%c1, %c2] : memref<4x4xf16>, vector<4x4xf16> + return +} + +// ----- +// ALL-LABEL: test_create_mask +func.func @test_create_mask() -> vector<1x16xi1> { + // DEFAULT: %[[C0:.*]] = arith.constant 0 : index + // BW-128: %[[C0:.*]] = arith.constant 0 : index + // DEFAULT: %[[C20:.*]] = arith.constant 20 : index + // BW-128: %[[C20:.*]] = arith.constant 20 : index + // DEFAULT: %[[MASK:.*]] = vector.create_mask %[[C20]] : vector<16xi1> + // BW-128: %[[MASK:.*]] = vector.create_mask %[[C20]] : vector<16xi1> + // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[MASK]] : vector<16xi1> to vector<1x16xi1> + // BW-128: %[[CAST:.*]] = vector.shape_cast %[[MASK]] : vector<16xi1> to vector<1x16xi1> + + // BW-0: %[[C0:.*]] = arith.constant 0 : index + // BW-0: %[[C20:.*]] = arith.constant 20 : index + // BW-0: %[[MASK:.*]] = vector.create_mask %[[C0]], %[[C20]] : vector<1x16xi1> + %c0 = arith.constant 0 : index + %c20 = arith.constant 20 : index + %0 = vector.create_mask %c0, %c20 : vector<1x16xi1> + return %0 : vector<1x16xi1> +} + +// ----- +// ALL-LABEL: test_loop +func.func @test_loop() -> vector<2x4xf16> { + // DEFAULT: %[[C0:.*]] = arith.constant 0 : index + // BW-128: %[[C0:.*]] = arith.constant 0 : index + // DEFAULT: %[[C1:.*]] = arith.constant 1 : index + // BW-128: %[[C1:.*]] = arith.constant 1 : index + // DEFAULT: %[[C4:.*]] = arith.constant 4 : index + // BW-128: %[[C4:.*]] = arith.constant 4 : index + // DEFAULT: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<8xf16> + // BW-128: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<8xf16> + // DEFAULT: %[[FOR:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ARG1:.*]] = %[[CST]]) -> (vector<8xf16>) { + // BW-128: %[[FOR:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ARG1:.*]] = %[[CST]]) -> (vector<8xf16>) { + // DEFAULT: %[[ADD:.*]] = arith.addf %[[ARG1]], %[[CST]] : vector<8xf16> + // BW-128: %[[ADD:.*]] = arith.addf %[[ARG1]], %[[CST]] : vector<8xf16> + // DEFAULT: %[[CAST0:.*]] = vector.shape_cast %[[ADD]] : vector<8xf16> to vector<2x4xf16> + // BW-128: %[[CAST0:.*]] = vector.shape_cast %[[ADD]] : vector<8xf16> to vector<2x4xf16> + // DEFAULT: %[[CAST1:.*]] = vector.shape_cast %[[CAST0]] : vector<2x4xf16> to vector<8xf16> + // BW-128: %[[CAST1:.*]] = vector.shape_cast %[[CAST0]] : vector<2x4xf16> to vector<8xf16> + // DEFAULT: scf.yield %[[CAST1]] : vector<8xf16> + // BW-128: scf.yield %[[CAST1]] : vector<8xf16> + // DEFAULT: } + // BW-128: } + // DEFAULT: %[[CAST2:.*]] = vector.shape_cast %[[FOR]] : vector<8xf16> to vector<2x4xf16> + // BW-128: %[[CAST2:.*]] = vector.shape_cast %[[FOR]] : vector<8xf16> to vector<2x4xf16> + // DEFAULT: return %[[CAST2]] : vector<2x4xf16> + // BW-128: return %[[CAST2]] : vector<2x4xf16> + + // BW-0: %[[C0:.*]] = arith.constant 0 : index + // BW-0: %[[C1:.*]] = arith.constant 1 : index + // BW-0: %[[C4:.*]] = arith.constant 4 : index + // BW-0: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<2x4xf16> + // BW-0: %[[FOR:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ARG1:.*]] = %[[CST]]) -> (vector<2x4xf16>) { + // BW-0: %[[ADD:.*]] = arith.addf %[[CST]], %[[ARG1]] : vector<2x4xf16> + // BW-0: scf.yield %[[ADD]] : vector<2x4xf16> + // BW-0: } + // BW-0: return %[[FOR]] : vector<2x4xf16> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %1 = arith.constant dense<1.0> : vector<2x4xf16> + %r = scf.for %i = %c0 to %c4 step %c1 iter_args(%arg1 = %1) -> (vector<2x4xf16>) { + %2 = arith.addf %1, %arg1 : vector<2x4xf16> + scf.yield %2 : vector<2x4xf16> + } + return %r : vector<2x4xf16> +} + +// ----- +// ALL-LABEL: test_vector_insert_2d_idx +// ALL-SAME: (%[[ARG:.*]]: vector<4x8xf16>) -> vector<8x16xf16> +func.func @test_vector_insert_2d_idx(%arg0: vector<4x8xf16>) -> vector<8x16xf16> { + // DEFAULT: %[[V0:.*]] = vector.shape_cast %[[ARG]] : vector<4x8xf16> to vector<32xf16> + // DEFAULT: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<128xf16> + // DEFAULT: %[[V1:.*]] = vector.shuffle %[[V0]], %[[V0]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<32xf16>, vector<32xf16> + // DEFAULT: %[[V2:.*]] = vector.insert_strided_slice %[[V1]], %[[CST]] {offsets = [0], strides = [1]} : vector<8xf16> into vector<128xf16> + // DEFAULT: %[[V3:.*]] = vector.shuffle %[[V0]], %[[V0]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<32xf16>, vector<32xf16> + // DEFAULT: %[[V4:.*]] = vector.insert_strided_slice %[[V3]], %[[V2]] {offsets = [16], strides = [1]} : vector<8xf16> into vector<128xf16> + // DEFAULT: %[[V5:.*]] = vector.shuffle %[[V0]], %[[V0]] [16, 17, 18, 19, 20, 21, 22, 23] : vector<32xf16>, vector<32xf16> + // DEFAULT: %[[V6:.*]] = vector.insert_strided_slice %[[V5]], %[[V4]] {offsets = [32], strides = [1]} : vector<8xf16> into vector<128xf16> + // DEFAULT: %[[V7:.*]] = vector.shuffle %[[V0]], %[[V0]] [24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf16>, vector<32xf16> + // DEFAULT: %[[V8:.*]] = vector.insert_strided_slice %[[V7]], %[[V6]] {offsets = [48], strides = [1]} : vector<8xf16> into vector<128xf16> + // DEFAULT: %[[V9:.*]] = vector.shape_cast %[[V8]] : vector<128xf16> to vector<8x16xf16> + // DEFAULT: return %[[V9]] : vector<8x16xf16> + + // BW-128: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf16> + // BW-128: %[[V0:.*]] = vector.insert_strided_slice %[[ARG]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<4x8xf16> into vector<8x16xf16> + // BW-128: return %[[V0]] : vector<8x16xf16> + + // BW-0: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf16> + // BW-0: %[[V0:.*]] = vector.insert_strided_slice %[[ARG]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<4x8xf16> into vector<8x16xf16> + // BW-0: return %[[V0]] : vector<8x16xf16> + %cst = arith.constant dense <0.0> : vector<8x16xf16> + %0 = vector.insert_strided_slice %arg0, %cst {offsets = [0, 0], strides = [1, 1]} : vector<4x8xf16> into vector<8x16xf16> + return %0 : vector<8x16xf16> +} + +// ----- +// ALL-LABEL: test_if_single_vector +func.func @test_if_single_vector() -> vector<16x1xi32> { + // DEFAULT: %[[COND:.*]] = arith.constant false + // DEFAULT: %[[CST:.*]] = arith.constant dense<3> : vector<16xi32> + // DEFAULT: %[[V0:.*]] = scf.if %[[COND]] -> (vector<16xi32>) { + // DEFAULT: %[[CST_THEN:.*]] = arith.constant dense<6> : vector<16xi32> + // DEFAULT: %[[V2:.*]] = vector.shape_cast %[[CST_THEN]] : vector<16xi32> to vector<16x1xi32> + // DEFAULT: %[[V3:.*]] = vector.shape_cast %[[V2]] : vector<16x1xi32> to vector<16xi32> + // DEFAULT: scf.yield %[[V3]] : vector<16xi32> + // DEFAULT: } else { + // DEFAULT: %[[CST_ELSE:.*]] = arith.constant dense<0> : vector<16xi32> + // DEFAULT: %[[V4:.*]] = vector.shape_cast %[[CST_ELSE]] : vector<16xi32> to vector<16x1xi32> + // DEFAULT: %[[V5:.*]] = vector.shape_cast %[[V4]] : vector<16x1xi32> to vector<16xi32> + // DEFAULT: scf.yield %[[V5]] : vector<16xi32> + // DEFAULT: } + // DEFAULT: %[[V1:.*]] = vector.shape_cast %[[V0]] : vector<16xi32> to vector<16x1xi32> + // DEFAULT: return %[[V1]] : vector<16x1xi32> + + // BW-128: %[[COND:.*]] = arith.constant false + // BW-128: %[[CST:.*]] = arith.constant dense<3> : vector<16xi32> + // BW-128: %[[V0:.*]] = scf.if %[[COND]] -> (vector<16xi32>) { + // BW-128: %[[CST_THEN:.*]] = arith.constant dense<6> : vector<16xi32> + // BW-128: %[[V2:.*]] = vector.shape_cast %[[CST_THEN]] : vector<16xi32> to vector<16x1xi32> + // BW-128: %[[V3:.*]] = vector.shape_cast %[[V2]] : vector<16x1xi32> to vector<16xi32> + // BW-128: scf.yield %[[V3]] : vector<16xi32> + // BW-128: } else { + // BW-128: %[[CST_ELSE:.*]] = arith.constant dense<0> : vector<16xi32> + // BW-128: %[[V4:.*]] = vector.shape_cast %[[CST_ELSE]] : vector<16xi32> to vector<16x1xi32> + // BW-128: %[[V5:.*]] = vector.shape_cast %[[V4]] : vector<16x1xi32> to vector<16xi32> + // BW-128: scf.yield %[[V5]] : vector<16xi32> + // BW-128: } + // BW-128: %[[V1:.*]] = vector.shape_cast %[[V0]] : vector<16xi32> to vector<16x1xi32> + // BW-128: return %[[V1]] : vector<16x1xi32> + + // BW-0: %[[COND:.*]] = arith.constant false + // BW-0: %[[V:.*]] = arith.constant dense<3> : vector<16x1xi32> + // BW-0: %[[R:.*]] = scf.if %[[COND]] -> (vector<16x1xi32>) { + // BW-0: %[[ADD:.*]] = arith.addi %[[V]], %[[V]] : vector<16x1xi32> + // BW-0: scf.yield %[[ADD]] : vector<16x1xi32> + // BW-0: } else { + // BW-0: %[[SUB:.*]] = arith.subi %[[V]], %[[V]] : vector<16x1xi32> + // BW-0: scf.yield %[[SUB]] : vector<16x1xi32> + // BW-0: } + %cond = arith.constant 0 : i1 + %v = arith.constant dense<3> : vector<16x1xi32> + %r = scf.if %cond -> (vector<16x1xi32>) { + %add = arith.addi %v, %v : vector<16x1xi32> + scf.yield %add : vector<16x1xi32> + } else { + %sub = arith.subi %v, %v : vector<16x1xi32> + scf.yield %sub : vector<16x1xi32> + } + return %r : vector<16x1xi32> +} + +// ----- +// ALL-LABEL: test_while +func.func @test_while() -> vector<2x4xf32> { + // DEFAULT: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<8xf32> + // DEFAULT: %[[V0:.*]] = scf.while (%[[ARG0:.*]] = %[[CST]]) : (vector<8xf32>) -> vector<8xf32> { + // DEFAULT: %[[V2:.*]] = vector.shape_cast %[[ARG0]] : vector<8xf32> to vector<2x4xf32> + // DEFAULT: %[[C0:.*]] = arith.constant 0 : i32 + // DEFAULT: %[[COND:.*]] = arith.cmpi slt, %[[C0]], %[[C0]] : i32 + // DEFAULT: %[[V4:.*]] = vector.shape_cast %[[V2]] : vector<2x4xf32> to vector<8xf32> + // DEFAULT: scf.condition(%[[COND]]) %[[V4]] : vector<8xf32> + // DEFAULT: } do { + // DEFAULT: ^bb0(%[[ARG1:.*]]: vector<8xf32>): + // DEFAULT: %[[V2:.*]] = arith.addf %[[ARG1]], %[[ARG1]] : vector<8xf32> + // DEFAULT: %[[V3:.*]] = vector.shape_cast %[[V2]] : vector<8xf32> to vector<2x4xf32> + // DEFAULT: %[[V4:.*]] = vector.shape_cast %[[V3]] : vector<2x4xf32> to vector<8xf32> + // DEFAULT: scf.yield %[[V4]] : vector<8xf32> + // DEFAULT: } + // DEFAULT: %[[V1:.*]] = vector.shape_cast %[[V0]] : vector<8xf32> to vector<2x4xf32> + // DEFAULT: return %[[V1]] : vector<2x4xf32> + + // BW-128: %[[V:.*]] = arith.constant dense<1.000000e+00> : vector<2x4xf32> + // BW-128: %[[RESULT:.*]] = scf.while (%[[ARG0:.*]] = %[[V]]) : (vector<2x4xf32>) -> vector<2x4xf32> { + // BW-128: %[[C0:.*]] = arith.constant 0 : i32 + // BW-128: %[[COND:.*]] = arith.cmpi slt, %[[C0]], %[[C0]] : i32 + // BW-128: scf.condition(%[[COND]]) %[[ARG0]] : vector<2x4xf32> + // BW-128: } do { + // BW-128: ^bb0(%[[ARG1:.*]]: vector<2x4xf32>): + // BW-128: %[[ADD:.*]] = arith.addf %[[ARG1]], %[[ARG1]] : vector<2x4xf32> + // BW-128: scf.yield %[[ADD]] : vector<2x4xf32> + // BW-128: } + // BW-128: return %[[RESULT]] : vector<2x4xf32> + + // BW-0: %[[V:.*]] = arith.constant dense<1.000000e+00> : vector<2x4xf32> + // BW-0: %[[RESULT:.*]] = scf.while (%[[ARG0:.*]] = %[[V]]) : (vector<2x4xf32>) -> vector<2x4xf32> { + // BW-0: %[[C0:.*]] = arith.constant 0 : i32 + // BW-0: %[[COND:.*]] = arith.cmpi slt, %[[C0]], %[[C0]] : i32 + // BW-0: scf.condition(%[[COND]]) %[[ARG0]] : vector<2x4xf32> + // BW-0: } do { + // BW-0: ^bb0(%[[ARG1:.*]]: vector<2x4xf32>): + // BW-0: %[[ADD:.*]] = arith.addf %[[ARG1]], %[[ARG1]] : vector<2x4xf32> + // BW-0: scf.yield %[[ADD]] : vector<2x4xf32> + // BW-0: } + // BW-0: return %[[RESULT]] : vector<2x4xf32> + %v = arith.constant dense<1.0> : vector<2x4xf32> + %result = scf.while (%arg0 = %v) : (vector<2x4xf32>) -> vector<2x4xf32> { + %c0 = arith.constant 0 : i32 + %cond = arith.cmpi slt, %c0, %c0 : i32 + scf.condition(%cond) %arg0 : vector<2x4xf32> + } do { + ^bb0(%arg1: vector<2x4xf32>): + %add = arith.addf %arg1, %arg1 : vector<2x4xf32> + scf.yield %add : vector<2x4xf32> + } + return %result : vector<2x4xf32> +} + +// ----- +// ALL-LABEL: test_vector_splat +// ALL-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32> +func.func @test_vector_splat(%arg0: i32) -> vector<4x2xi32> { + // DEFAULT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32> + // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<8xi32> to vector<4x2xi32> + // DEFAULT: return %[[CAST]] : vector<4x2xi32> + // BW-128: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32> + // BW-128: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<8xi32> to vector<4x2xi32> + // BW-128: return %[[CAST]] : vector<4x2xi32> + + // BW-0: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4x2xi32> + // BW-0: return %[[SPLAT]] : vector<4x2xi32> + %0 = vector.splat %arg0 : vector<4x2xi32> + return %0 : vector<4x2xi32> +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index a54ae816570a8..40b0a2321a2b2 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -851,7 +851,8 @@ struct TestVectorLinearize final return "Linearizes ND vectors for N >= 2 into 1D vectors"; } void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } Option targetVectorBitwidth{ From a76f02d7beb790cb30df34d42f5c0f0047be7a10 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Thu, 17 Apr 2025 20:39:51 +0000 Subject: [PATCH 2/7] Run Clang-format --- .../Vector/Transforms/VectorLinearize.cpp | 214 +++++++++--------- 1 file changed, 108 insertions(+), 106 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 6de5d0c5a101e..d97eed7aea008 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -282,22 +282,24 @@ struct LinearizeVectorExtractStridedSlice final /// source vector using ExtractStridedSliceOp and inserting them into the /// destination vector using InsertStridedSliceOp. /// Following, -/// vector.insert_strided_slice %s, %d {offsets=[0, 0]}: vector<2x4xf32> into vector<4x4xf32> +/// vector.insert_strided_slice %s, %d {offsets=[0, 0]}: vector<2x4xf32> into +/// vector<4x4xf32> /// is converted to : -/// %0 = vector.extract_strided_slice %s {offsets=[0], sizes=[4], strides=[1]} : vector<4xf32> from vector<8xf32> -/// %1 = vector.insert_strided_slice %0, %d {offsets=[0], strides=[1]} : vector<4xf32> into vector<16xf32> -/// %2 = vector.extract_strided_slice %s {offsets=[4], sizes=[4], strides=[1]} : vector<4xf32> from vector<8xf32> -/// %3 = vector.insert_strided_slice %2, %1 {offsets=[4], strides=[1]} : vector<4xf32> into vector<16xf32> +/// %0 = vector.extract_strided_slice %s {offsets=[0], sizes=[4], strides=[1]} +/// : vector<4xf32> from vector<8xf32> %1 = vector.insert_strided_slice %0, %d +/// {offsets=[0], strides=[1]} : vector<4xf32> into vector<16xf32> %2 = +/// vector.extract_strided_slice %s {offsets=[4], sizes=[4], strides=[1]} : +/// vector<4xf32> from vector<8xf32> %3 = vector.insert_strided_slice %2, %1 +/// {offsets=[4], strides=[1]} : vector<4xf32> into vector<16xf32> struct LinearizeVectorInsertStridedSlice final : public OpConversionPattern { - using OpConversionPattern< - vector::InsertStridedSliceOp>::OpConversionPattern; - LinearizeVectorInsertStridedSlice( - const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), - PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} + using OpConversionPattern::OpConversionPattern; + LinearizeVectorInsertStridedSlice( + const TypeConverter &typeConverter, MLIRContext *context, + unsigned targetVectBitWidth = std::numeric_limits::max(), + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + targetVectorBitWidth(targetVectBitWidth) {} LogicalResult matchAndRewrite(vector::InsertStridedSliceOp op, OpAdaptor adaptor, @@ -345,8 +347,9 @@ struct LinearizeVectorInsertStridedSlice final rewriter.replaceOp(op, dstValue); return success(); } - private: - unsigned targetVectorBitWidth; + +private: + unsigned targetVectorBitWidth; }; /// This pattern converts the ShuffleOp that works on nD (n > 1) @@ -619,22 +622,22 @@ struct LinearizeVectorBitCast final /// is converted to : /// %result = arith.constant dense<0.0> : vector<4x4xf32> /// %slice_0 = vector.load %base[%indices] : vector<4xf32> -/// %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into vector<4x4xf32> -/// %slice_1 = vector.load %base[%indices + 1] : vector<4xf32> -/// %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into vector<4x4xf32> +/// %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into +/// vector<4x4xf32> %slice_1 = vector.load %base[%indices + 1] : vector<4xf32> +/// %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into +/// vector<4x4xf32> /// ... /// This unrolls the 2D vector load into multiple 1D vector loads and inserts /// them into the result vector. The pattern currently supports only 2D vectors -struct LinearizeVectorLoad final - : public OpConversionPattern { +struct LinearizeVectorLoad final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LinearizeVectorLoad( - const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), - PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} + const TypeConverter &typeConverter, MLIRContext *context, + unsigned targetVectBitWidth = std::numeric_limits::max(), + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + targetVectorBitWidth(targetVectBitWidth) {} LogicalResult matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor, @@ -648,35 +651,33 @@ struct LinearizeVectorLoad final } auto unrollCount = shape[0]; auto vecSize = shape[1]; - auto newVecType = - VectorType::get({vecSize}, vecType.getElementType()); + auto newVecType = VectorType::get({vecSize}, vecType.getElementType()); llvm::SmallVector indices = adaptor.getIndices(); Value xBaseIndex = indices[0]; // Construct the 2D vector. - Value resultVec = rewriter.create( - loc, rewriter.getZeroAttr(vecType)); + Value resultVec = + rewriter.create(loc, rewriter.getZeroAttr(vecType)); // Emit unrolled loads for each 1D vector slice. for (auto i = 0; i < unrollCount; i++) { Value xIndex = xBaseIndex; if (i) { auto increment = rewriter.create(loc, i); - xIndex = - rewriter.create(loc, xBaseIndex, increment); + xIndex = rewriter.create(loc, xBaseIndex, increment); } indices[0] = xIndex; - auto vec = rewriter.create( - loc, newVecType, adaptor.getBase(), indices); - resultVec = - rewriter.create(loc, vec, resultVec, i); + auto vec = rewriter.create(loc, newVecType, + adaptor.getBase(), indices); + resultVec = rewriter.create(loc, vec, resultVec, i); } rewriter.replaceOp(loadOp, resultVec); return success(); } - private: - unsigned targetVectorBitWidth; + +private: + unsigned targetVectorBitWidth; }; /// This pattern converts the StoreOp to a series of StoreOp & ExtractOp @@ -689,19 +690,19 @@ struct LinearizeVectorLoad final /// %slice_1 = vector.extract %source[1] : vector<4xf32> /// vector.store %slice_1, %base[%indices + 1] : vector<4xf32> /// ... -/// This unrolls the 2D vector store into multiple 1D vector stores by extracting -/// slices from the source vector and storing them into the destination. -/// The pattern currently supports only 2D vectors +/// This unrolls the 2D vector store into multiple 1D vector stores by +/// extracting slices from the source vector and storing them into the +/// destination. The pattern currently supports only 2D vectors struct LinearizeVectorStore final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LinearizeVectorStore( - const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), - PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} + const TypeConverter &typeConverter, MLIRContext *context, + unsigned targetVectBitWidth = std::numeric_limits::max(), + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + targetVectorBitWidth(targetVectBitWidth) {} LogicalResult matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor, @@ -718,26 +719,26 @@ struct LinearizeVectorStore final llvm::SmallVector indices = adaptor.getIndices(); Value xBaseIndex = indices[0]; - auto vec = rewriter.create( - loc, vecType, adaptor.getValueToStore()); + auto vec = rewriter.create(loc, vecType, + adaptor.getValueToStore()); for (auto i = 0; i < unrollCount; i++) { auto vecSlice = rewriter.create(loc, vec, i); Value xIndex = xBaseIndex; if (i) { auto increment = rewriter.create(loc, i); - xIndex = - rewriter.create(loc, xBaseIndex, increment); + xIndex = rewriter.create(loc, xBaseIndex, increment); } indices[0] = xIndex; rewriter.create(loc, vecSlice, adaptor.getBase(), - indices); + indices); } rewriter.eraseOp(storeOp); return success(); } - private: - unsigned targetVectorBitWidth; + +private: + unsigned targetVectorBitWidth; }; /// This pattern converts the SplatOp to work on a linearized vector. @@ -754,11 +755,11 @@ struct LinearizeVectorSplat final using OpConversionPattern::OpConversionPattern; LinearizeVectorSplat( - const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), - PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} + const TypeConverter &typeConverter, MLIRContext *context, + unsigned targetVectBitWidth = std::numeric_limits::max(), + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + targetVectorBitWidth(targetVectBitWidth) {} LogicalResult matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, @@ -766,12 +767,13 @@ struct LinearizeVectorSplat final auto dstTy = getTypeConverter()->convertType(splatOp.getType()); if (!dstTy) return rewriter.notifyMatchFailure(splatOp, "cannot convert type."); - rewriter.replaceOpWithNewOp( - splatOp, adaptor.getInput(), dstTy); + rewriter.replaceOpWithNewOp(splatOp, adaptor.getInput(), + dstTy); return success(); } - private: - unsigned targetVectorBitWidth; + +private: + unsigned targetVectorBitWidth; }; /// This pattern converts the CreateMaskOp to work on a @@ -789,11 +791,11 @@ struct LinearizeVectorCreateMask final using OpConversionPattern::OpConversionPattern; LinearizeVectorCreateMask( - const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), - PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} + const TypeConverter &typeConverter, MLIRContext *context, + unsigned targetVectBitWidth = std::numeric_limits::max(), + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + targetVectorBitWidth(targetVectBitWidth) {} LogicalResult matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor, @@ -816,8 +818,9 @@ struct LinearizeVectorCreateMask final createMaskOp, dstTy, adaptor.getOperands().back()); return success(); } - private: - unsigned targetVectorBitWidth; + +private: + unsigned targetVectorBitWidth; }; /// This pattern converts operations implementing the RegionBranchOpInterface @@ -835,15 +838,14 @@ struct LinearizeRegionBranchOp final RegionBranchOpInterface>::OpInterfaceConversionPattern; LinearizeRegionBranchOp( - const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), - PatternBenefit benefit = 1) - : OpInterfaceConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} + const TypeConverter &typeConverter, MLIRContext *context, + unsigned targetVectBitWidth = std::numeric_limits::max(), + PatternBenefit benefit = 1) + : OpInterfaceConversionPattern(typeConverter, context, benefit), + targetVectorBitWidth(targetVectBitWidth) {} LogicalResult - matchAndRewrite(RegionBranchOpInterface op, - ArrayRef operands, + matchAndRewrite(RegionBranchOpInterface op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto converter = getTypeConverter(); @@ -907,8 +909,9 @@ struct LinearizeRegionBranchOp final rewriter.finalizeOpModification(op); return success(); } - private: - unsigned targetVectorBitWidth; + +private: + unsigned targetVectorBitWidth; }; } // namespace @@ -937,26 +940,25 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( typeConverter.addSourceMaterialization(materializeCast); typeConverter.addTargetMaterialization(materializeCast); target.addLegalOp(); - target.markUnknownOpDynamicallyLegal( - [=](Operation *op) -> std::optional { - if ((isa(op) || - op->hasTrait() || - op->hasTrait())) { - return (isLessThanTargetBitWidth(op, targetBitWidth) - ? typeConverter.isLegal(op) - : true); - } - return std::nullopt; - }); + target.markUnknownOpDynamicallyLegal([=](Operation *op) + -> std::optional { + if ((isa( + op) || + op->hasTrait() || + op->hasTrait())) { + return (isLessThanTargetBitWidth(op, targetBitWidth) + ? typeConverter.isLegal(op) + : true); + } + return std::nullopt; + }); - patterns.add(typeConverter, patterns.getContext(), - targetBitWidth); + patterns + .add( + typeConverter, patterns.getContext(), targetBitWidth); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( @@ -972,16 +974,16 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( }); target.addDynamicallyLegalOp( - [=](vector::InsertStridedSliceOp op) -> bool { - if(isLessThanTargetBitWidth(op, targetBitWidth)) { - auto srcTy = op.getSourceVectorType(); - auto dstTy = op.getDestVectorType(); - if (!op.hasNonUnitStrides() && srcTy.getRank() == 2 && - srcTy.hasStaticShape() && dstTy.hasStaticShape()) - return false; - } - return true; - }); + [=](vector::InsertStridedSliceOp op) -> bool { + if (isLessThanTargetBitWidth(op, targetBitWidth)) { + auto srcTy = op.getSourceVectorType(); + auto dstTy = op.getDestVectorType(); + if (!op.hasNonUnitStrides() && srcTy.getRank() == 2 && + srcTy.hasStaticShape() && dstTy.hasStaticShape()) + return false; + } + return true; + }); patterns.add Date: Thu, 17 Apr 2025 21:28:49 +0000 Subject: [PATCH 3/7] Remove RegionBranchOp pattern and address comments --- .../Vector/Transforms/VectorLinearize.cpp | 172 +++++------------- mlir/test/Dialect/Vector/linearize.mlir | 160 ---------------- .../Dialect/Vector/TestVectorTransforms.cpp | 2 +- 3 files changed, 42 insertions(+), 292 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index d97eed7aea008..06ba40da3b0b0 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -11,7 +11,6 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" @@ -29,9 +28,9 @@ using namespace mlir; static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) { // For BW-0, all operations are legal - if (targetBitWidth == 0) { + if (targetBitWidth == 0) return false; - } + auto resultTypes = op->getResultTypes(); for (auto resType : resultTypes) { VectorType vecType = dyn_cast(resType); @@ -302,32 +301,37 @@ struct LinearizeVectorInsertStridedSlice final targetVectorBitWidth(targetVectBitWidth) {} LogicalResult - matchAndRewrite(vector::InsertStridedSliceOp op, OpAdaptor adaptor, + matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto srcTy = op.getSourceVectorType(); - auto dstTy = op.getDestVectorType(); + auto loc = insertOp.getLoc(); + auto srcTy = insertOp.getSourceVectorType(); + auto dstTy = insertOp.getDestVectorType(); - if (op.hasNonUnitStrides()) { + if (insertOp.hasNonUnitStrides()) return rewriter.notifyMatchFailure( - op, "InsertStridedSliceOp linearization only supports unit strides."); - } + insertOp, + "InsertStridedSliceOp linearization only supports unit strides."); - if (srcTy.getRank() != 2) { + if (srcTy.getRank() != 2) return rewriter.notifyMatchFailure( - op, "InsertStridedSliceOp linearization only supports 2D source."); - } + insertOp, + "InsertStridedSliceOp linearization only supports 2D source."); - if (!srcTy.hasStaticShape() || !dstTy.hasStaticShape()) { + if (!srcTy.hasStaticShape() || !dstTy.hasStaticShape()) return rewriter.notifyMatchFailure( - op, "InsertStridedSliceOp linerization only supports static shapes."); - } + insertOp, + "InsertStridedSliceOp linerization only supports static shapes."); + + if (srcTy.isScalable() || dstTy.isScalable()) + return rewriter.notifyMatchFailure(insertOp, + "scalable vectors are not supported."); auto dstShape = dstTy.getShape(); auto dstStrides = dstShape.drop_front().vec(); dstStrides.push_back(1); int64_t linearizedOffset = 0; - for (auto [off, stride] : llvm::zip_equal(op.getOffsets(), dstStrides)) { + for (auto [off, stride] : + llvm::zip_equal(insertOp.getOffsets(), dstStrides)) { linearizedOffset += getConstantIntValue(off).value() * stride; } @@ -344,7 +348,7 @@ struct LinearizeVectorInsertStridedSlice final loc, value, dstValue, dstOffset, 1); } - rewriter.replaceOp(op, dstValue); + rewriter.replaceOp(insertOp, dstValue); return success(); } @@ -535,12 +539,11 @@ struct LinearizeVectorInsert final auto srcTy = insertOp.getValueToStoreType(); auto srcAsVec = dyn_cast(srcTy); uint64_t srcSize = 0; - if (srcAsVec) { + if (srcAsVec) srcSize = srcAsVec.getNumElements(); - } else { + else return rewriter.notifyMatchFailure(insertOp, "scalars are not supported."); - } auto dstShape = insertOp.getDestVectorType().getShape(); const auto dstSize = insertOp.getDestVectorType().getNumElements(); @@ -646,9 +649,9 @@ struct LinearizeVectorLoad final : public OpConversionPattern { auto vecType = loadOp.getVectorType(); auto shape = vecType.getShape(); - if (shape.size() != 2) { + if (shape.size() != 2) return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors."); - } + auto unrollCount = shape[0]; auto vecSize = shape[1]; auto newVecType = VectorType::get({vecSize}, vecType.getElementType()); @@ -711,9 +714,8 @@ struct LinearizeVectorStore final auto vecType = storeOp.getVectorType(); auto shape = vecType.getShape(); - if (shape.size() != 2) { + if (shape.size() != 2) return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors."); - } auto unrollCount = shape[0]; llvm::SmallVector indices = adaptor.getIndices(); @@ -823,97 +825,6 @@ struct LinearizeVectorCreateMask final unsigned targetVectorBitWidth; }; -/// This pattern converts operations implementing the RegionBranchOpInterface -/// to ensure compatibility with linearized vector types. It updates the -/// operands, result types, and region types (block arguments and yields) to -/// match the converted types. Additionally, it processes yields within each -/// region to ensure that the types of yielded values are compatible with the -/// target vector bit width. If the result types of the operation are updated, -/// shape cast operations are inserted to maintain compatibility with the -/// original types. This pattern ensures that operations with regions are -/// properly linearized and remain valid after type conversion. -struct LinearizeRegionBranchOp final - : public OpInterfaceConversionPattern { - using OpInterfaceConversionPattern< - RegionBranchOpInterface>::OpInterfaceConversionPattern; - - LinearizeRegionBranchOp( - const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), - PatternBenefit benefit = 1) - : OpInterfaceConversionPattern(typeConverter, context, benefit), - targetVectorBitWidth(targetVectBitWidth) {} - - LogicalResult - matchAndRewrite(RegionBranchOpInterface op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto converter = getTypeConverter(); - - OpBuilder::InsertionGuard guard(rewriter); - rewriter.startOpModification(op); - - llvm::SmallVector convertedTypes; - for (Type ty : op->getResultTypes()) { - convertedTypes.push_back(converter->convertType(ty)); - } - - if (convertedTypes == op->getResultTypes() && - op->getOperands() == operands) { - return failure(); - } - - op->setOperands(operands); - - // Convert region types (block arguments and yields) - for (Region ®ion : op->getRegions()) { - if (failed(rewriter.convertRegionTypes(®ion, *converter))) { - return failure(); - } - - // Process yields within each region - for (Block &block : region) { - if (auto *terminator = block.getTerminator()) { - for (OpOperand &yieldOperand : terminator->getOpOperands()) { - Value value = yieldOperand.get(); - Type type = value.getType(); - if (!converter->isLegal(type)) { - Type newTy = converter->convertType(type); - rewriter.setInsertionPoint(terminator); - Value newValue = - rewriter.create(loc, newTy, value); - yieldOperand.set(newValue); - } - } - } - } - } - - // Update result types - rewriter.setInsertionPointAfter(op); - llvm::SmallVector newResults; - for (Value result : op->getResults()) { - Type oldTy = result.getType(); - if (!converter->isLegal(oldTy)) { - Type newTy = converter->convertType(oldTy); - result.setType(newTy); - Operation *castOp = - rewriter.create(loc, oldTy, result); - result.replaceAllUsesExcept(castOp->getResult(0), castOp); - newResults.push_back(castOp->getResult(0)); - } else { - newResults.push_back(result); - } - } - - rewriter.finalizeOpModification(op); - return success(); - } - -private: - unsigned targetVectorBitWidth; -}; - } // namespace void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( @@ -940,25 +851,24 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( typeConverter.addSourceMaterialization(materializeCast); typeConverter.addTargetMaterialization(materializeCast); target.addLegalOp(); - target.markUnknownOpDynamicallyLegal([=](Operation *op) - -> std::optional { - if ((isa( - op) || - op->hasTrait() || - op->hasTrait())) { - return (isLessThanTargetBitWidth(op, targetBitWidth) - ? typeConverter.isLegal(op) - : true); - } - return std::nullopt; - }); + target.markUnknownOpDynamicallyLegal( + [=](Operation *op) -> std::optional { + if ((isa(op) || + op->hasTrait() || + op->hasTrait())) { + return (isLessThanTargetBitWidth(op, targetBitWidth) + ? typeConverter.isLegal(op) + : true); + } + return std::nullopt; + }); patterns .add( - typeConverter, patterns.getContext(), targetBitWidth); + LinearizeVectorCreateMask>(typeConverter, patterns.getContext(), + targetBitWidth); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index e47e7c4a84d68..2ea4751393ebf 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -529,54 +529,6 @@ func.func @test_create_mask() -> vector<1x16xi1> { return %0 : vector<1x16xi1> } -// ----- -// ALL-LABEL: test_loop -func.func @test_loop() -> vector<2x4xf16> { - // DEFAULT: %[[C0:.*]] = arith.constant 0 : index - // BW-128: %[[C0:.*]] = arith.constant 0 : index - // DEFAULT: %[[C1:.*]] = arith.constant 1 : index - // BW-128: %[[C1:.*]] = arith.constant 1 : index - // DEFAULT: %[[C4:.*]] = arith.constant 4 : index - // BW-128: %[[C4:.*]] = arith.constant 4 : index - // DEFAULT: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<8xf16> - // BW-128: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<8xf16> - // DEFAULT: %[[FOR:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ARG1:.*]] = %[[CST]]) -> (vector<8xf16>) { - // BW-128: %[[FOR:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ARG1:.*]] = %[[CST]]) -> (vector<8xf16>) { - // DEFAULT: %[[ADD:.*]] = arith.addf %[[ARG1]], %[[CST]] : vector<8xf16> - // BW-128: %[[ADD:.*]] = arith.addf %[[ARG1]], %[[CST]] : vector<8xf16> - // DEFAULT: %[[CAST0:.*]] = vector.shape_cast %[[ADD]] : vector<8xf16> to vector<2x4xf16> - // BW-128: %[[CAST0:.*]] = vector.shape_cast %[[ADD]] : vector<8xf16> to vector<2x4xf16> - // DEFAULT: %[[CAST1:.*]] = vector.shape_cast %[[CAST0]] : vector<2x4xf16> to vector<8xf16> - // BW-128: %[[CAST1:.*]] = vector.shape_cast %[[CAST0]] : vector<2x4xf16> to vector<8xf16> - // DEFAULT: scf.yield %[[CAST1]] : vector<8xf16> - // BW-128: scf.yield %[[CAST1]] : vector<8xf16> - // DEFAULT: } - // BW-128: } - // DEFAULT: %[[CAST2:.*]] = vector.shape_cast %[[FOR]] : vector<8xf16> to vector<2x4xf16> - // BW-128: %[[CAST2:.*]] = vector.shape_cast %[[FOR]] : vector<8xf16> to vector<2x4xf16> - // DEFAULT: return %[[CAST2]] : vector<2x4xf16> - // BW-128: return %[[CAST2]] : vector<2x4xf16> - - // BW-0: %[[C0:.*]] = arith.constant 0 : index - // BW-0: %[[C1:.*]] = arith.constant 1 : index - // BW-0: %[[C4:.*]] = arith.constant 4 : index - // BW-0: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<2x4xf16> - // BW-0: %[[FOR:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ARG1:.*]] = %[[CST]]) -> (vector<2x4xf16>) { - // BW-0: %[[ADD:.*]] = arith.addf %[[CST]], %[[ARG1]] : vector<2x4xf16> - // BW-0: scf.yield %[[ADD]] : vector<2x4xf16> - // BW-0: } - // BW-0: return %[[FOR]] : vector<2x4xf16> - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - %1 = arith.constant dense<1.0> : vector<2x4xf16> - %r = scf.for %i = %c0 to %c4 step %c1 iter_args(%arg1 = %1) -> (vector<2x4xf16>) { - %2 = arith.addf %1, %arg1 : vector<2x4xf16> - scf.yield %2 : vector<2x4xf16> - } - return %r : vector<2x4xf16> -} - // ----- // ALL-LABEL: test_vector_insert_2d_idx // ALL-SAME: (%[[ARG:.*]]: vector<4x8xf16>) -> vector<8x16xf16> @@ -606,118 +558,6 @@ func.func @test_vector_insert_2d_idx(%arg0: vector<4x8xf16>) -> vector<8x16xf16> return %0 : vector<8x16xf16> } -// ----- -// ALL-LABEL: test_if_single_vector -func.func @test_if_single_vector() -> vector<16x1xi32> { - // DEFAULT: %[[COND:.*]] = arith.constant false - // DEFAULT: %[[CST:.*]] = arith.constant dense<3> : vector<16xi32> - // DEFAULT: %[[V0:.*]] = scf.if %[[COND]] -> (vector<16xi32>) { - // DEFAULT: %[[CST_THEN:.*]] = arith.constant dense<6> : vector<16xi32> - // DEFAULT: %[[V2:.*]] = vector.shape_cast %[[CST_THEN]] : vector<16xi32> to vector<16x1xi32> - // DEFAULT: %[[V3:.*]] = vector.shape_cast %[[V2]] : vector<16x1xi32> to vector<16xi32> - // DEFAULT: scf.yield %[[V3]] : vector<16xi32> - // DEFAULT: } else { - // DEFAULT: %[[CST_ELSE:.*]] = arith.constant dense<0> : vector<16xi32> - // DEFAULT: %[[V4:.*]] = vector.shape_cast %[[CST_ELSE]] : vector<16xi32> to vector<16x1xi32> - // DEFAULT: %[[V5:.*]] = vector.shape_cast %[[V4]] : vector<16x1xi32> to vector<16xi32> - // DEFAULT: scf.yield %[[V5]] : vector<16xi32> - // DEFAULT: } - // DEFAULT: %[[V1:.*]] = vector.shape_cast %[[V0]] : vector<16xi32> to vector<16x1xi32> - // DEFAULT: return %[[V1]] : vector<16x1xi32> - - // BW-128: %[[COND:.*]] = arith.constant false - // BW-128: %[[CST:.*]] = arith.constant dense<3> : vector<16xi32> - // BW-128: %[[V0:.*]] = scf.if %[[COND]] -> (vector<16xi32>) { - // BW-128: %[[CST_THEN:.*]] = arith.constant dense<6> : vector<16xi32> - // BW-128: %[[V2:.*]] = vector.shape_cast %[[CST_THEN]] : vector<16xi32> to vector<16x1xi32> - // BW-128: %[[V3:.*]] = vector.shape_cast %[[V2]] : vector<16x1xi32> to vector<16xi32> - // BW-128: scf.yield %[[V3]] : vector<16xi32> - // BW-128: } else { - // BW-128: %[[CST_ELSE:.*]] = arith.constant dense<0> : vector<16xi32> - // BW-128: %[[V4:.*]] = vector.shape_cast %[[CST_ELSE]] : vector<16xi32> to vector<16x1xi32> - // BW-128: %[[V5:.*]] = vector.shape_cast %[[V4]] : vector<16x1xi32> to vector<16xi32> - // BW-128: scf.yield %[[V5]] : vector<16xi32> - // BW-128: } - // BW-128: %[[V1:.*]] = vector.shape_cast %[[V0]] : vector<16xi32> to vector<16x1xi32> - // BW-128: return %[[V1]] : vector<16x1xi32> - - // BW-0: %[[COND:.*]] = arith.constant false - // BW-0: %[[V:.*]] = arith.constant dense<3> : vector<16x1xi32> - // BW-0: %[[R:.*]] = scf.if %[[COND]] -> (vector<16x1xi32>) { - // BW-0: %[[ADD:.*]] = arith.addi %[[V]], %[[V]] : vector<16x1xi32> - // BW-0: scf.yield %[[ADD]] : vector<16x1xi32> - // BW-0: } else { - // BW-0: %[[SUB:.*]] = arith.subi %[[V]], %[[V]] : vector<16x1xi32> - // BW-0: scf.yield %[[SUB]] : vector<16x1xi32> - // BW-0: } - %cond = arith.constant 0 : i1 - %v = arith.constant dense<3> : vector<16x1xi32> - %r = scf.if %cond -> (vector<16x1xi32>) { - %add = arith.addi %v, %v : vector<16x1xi32> - scf.yield %add : vector<16x1xi32> - } else { - %sub = arith.subi %v, %v : vector<16x1xi32> - scf.yield %sub : vector<16x1xi32> - } - return %r : vector<16x1xi32> -} - -// ----- -// ALL-LABEL: test_while -func.func @test_while() -> vector<2x4xf32> { - // DEFAULT: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<8xf32> - // DEFAULT: %[[V0:.*]] = scf.while (%[[ARG0:.*]] = %[[CST]]) : (vector<8xf32>) -> vector<8xf32> { - // DEFAULT: %[[V2:.*]] = vector.shape_cast %[[ARG0]] : vector<8xf32> to vector<2x4xf32> - // DEFAULT: %[[C0:.*]] = arith.constant 0 : i32 - // DEFAULT: %[[COND:.*]] = arith.cmpi slt, %[[C0]], %[[C0]] : i32 - // DEFAULT: %[[V4:.*]] = vector.shape_cast %[[V2]] : vector<2x4xf32> to vector<8xf32> - // DEFAULT: scf.condition(%[[COND]]) %[[V4]] : vector<8xf32> - // DEFAULT: } do { - // DEFAULT: ^bb0(%[[ARG1:.*]]: vector<8xf32>): - // DEFAULT: %[[V2:.*]] = arith.addf %[[ARG1]], %[[ARG1]] : vector<8xf32> - // DEFAULT: %[[V3:.*]] = vector.shape_cast %[[V2]] : vector<8xf32> to vector<2x4xf32> - // DEFAULT: %[[V4:.*]] = vector.shape_cast %[[V3]] : vector<2x4xf32> to vector<8xf32> - // DEFAULT: scf.yield %[[V4]] : vector<8xf32> - // DEFAULT: } - // DEFAULT: %[[V1:.*]] = vector.shape_cast %[[V0]] : vector<8xf32> to vector<2x4xf32> - // DEFAULT: return %[[V1]] : vector<2x4xf32> - - // BW-128: %[[V:.*]] = arith.constant dense<1.000000e+00> : vector<2x4xf32> - // BW-128: %[[RESULT:.*]] = scf.while (%[[ARG0:.*]] = %[[V]]) : (vector<2x4xf32>) -> vector<2x4xf32> { - // BW-128: %[[C0:.*]] = arith.constant 0 : i32 - // BW-128: %[[COND:.*]] = arith.cmpi slt, %[[C0]], %[[C0]] : i32 - // BW-128: scf.condition(%[[COND]]) %[[ARG0]] : vector<2x4xf32> - // BW-128: } do { - // BW-128: ^bb0(%[[ARG1:.*]]: vector<2x4xf32>): - // BW-128: %[[ADD:.*]] = arith.addf %[[ARG1]], %[[ARG1]] : vector<2x4xf32> - // BW-128: scf.yield %[[ADD]] : vector<2x4xf32> - // BW-128: } - // BW-128: return %[[RESULT]] : vector<2x4xf32> - - // BW-0: %[[V:.*]] = arith.constant dense<1.000000e+00> : vector<2x4xf32> - // BW-0: %[[RESULT:.*]] = scf.while (%[[ARG0:.*]] = %[[V]]) : (vector<2x4xf32>) -> vector<2x4xf32> { - // BW-0: %[[C0:.*]] = arith.constant 0 : i32 - // BW-0: %[[COND:.*]] = arith.cmpi slt, %[[C0]], %[[C0]] : i32 - // BW-0: scf.condition(%[[COND]]) %[[ARG0]] : vector<2x4xf32> - // BW-0: } do { - // BW-0: ^bb0(%[[ARG1:.*]]: vector<2x4xf32>): - // BW-0: %[[ADD:.*]] = arith.addf %[[ARG1]], %[[ARG1]] : vector<2x4xf32> - // BW-0: scf.yield %[[ADD]] : vector<2x4xf32> - // BW-0: } - // BW-0: return %[[RESULT]] : vector<2x4xf32> - %v = arith.constant dense<1.0> : vector<2x4xf32> - %result = scf.while (%arg0 = %v) : (vector<2x4xf32>) -> vector<2x4xf32> { - %c0 = arith.constant 0 : i32 - %cond = arith.cmpi slt, %c0, %c0 : i32 - scf.condition(%cond) %arg0 : vector<2x4xf32> - } do { - ^bb0(%arg1: vector<2x4xf32>): - %add = arith.addf %arg1, %arg1 : vector<2x4xf32> - scf.yield %add : vector<2x4xf32> - } - return %result : vector<2x4xf32> -} - // ----- // ALL-LABEL: test_vector_splat // ALL-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32> diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 40b0a2321a2b2..aea116cffc3a8 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -852,7 +852,7 @@ struct TestVectorLinearize final } void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); + arith::ArithDialect>(); } Option targetVectorBitwidth{ From 03789ec19d9c802f99a23a81e093c5571c9f71fd Mon Sep 17 00:00:00 2001 From: nbpatel Date: Fri, 18 Apr 2025 17:15:49 +0000 Subject: [PATCH 4/7] Modify create_mask pattern --- .../Vector/Transforms/VectorLinearize.cpp | 23 +++++++++++++++++-- mlir/test/Dialect/Vector/linearize.mlir | 20 ++++++++++++---- 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 06ba40da3b0b0..7028285c0a91d 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -816,8 +816,27 @@ struct LinearizeVectorCreateMask final if (!dstTy) return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type."); - rewriter.replaceOpWithNewOp( - createMaskOp, dstTy, adaptor.getOperands().back()); + // Compare the first operand with 0. If it's less than or equal to 0, + // create a zero mask, else strip the first operand and create a mask + // using the second operand. + auto firstOperand = adaptor.getOperands().front(); + auto zero = + rewriter.create(createMaskOp.getLoc(), 0); + auto isZeroOrNegative = rewriter.create( + createMaskOp.getLoc(), mlir::arith::CmpIPredicate::sle, firstOperand, + zero); + auto isZeroOrNegativeSplat = rewriter.create( + createMaskOp.getLoc(), dstTy, isZeroOrNegative); + + // Use a select operation to choose between the masks. + auto zeroMask = rewriter.create( + createMaskOp.getLoc(), dstTy, rewriter.getZeroAttr(dstTy)); + auto newMask = rewriter.create( + createMaskOp.getLoc(), dstTy, adaptor.getOperands().back()); + auto result = rewriter.create( + createMaskOp.getLoc(), isZeroOrNegativeSplat, zeroMask, newMask); + + rewriter.replaceOp(createMaskOp, result.getResult()); return success(); } diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 2ea4751393ebf..f7a767dbdc272 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -515,10 +515,22 @@ func.func @test_create_mask() -> vector<1x16xi1> { // BW-128: %[[C0:.*]] = arith.constant 0 : index // DEFAULT: %[[C20:.*]] = arith.constant 20 : index // BW-128: %[[C20:.*]] = arith.constant 20 : index - // DEFAULT: %[[MASK:.*]] = vector.create_mask %[[C20]] : vector<16xi1> - // BW-128: %[[MASK:.*]] = vector.create_mask %[[C20]] : vector<16xi1> - // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[MASK]] : vector<16xi1> to vector<1x16xi1> - // BW-128: %[[CAST:.*]] = vector.shape_cast %[[MASK]] : vector<16xi1> to vector<1x16xi1> + // DEFAULT: %[[C0_0:.*]] = arith.constant 0 : index + // BW-128: %[[C0_0:.*]] = arith.constant 0 : index + // DEFAULT: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index + // BW-128: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index + // DEFAULT: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<16xi1> + // BW-128: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<16xi1> + // DEFAULT: %[[CST:.*]] = arith.constant dense : vector<16xi1> + // BW-128: %[[CST:.*]] = arith.constant dense : vector<16xi1> + // DEFAULT: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<16xi1> + // BW-128: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<16xi1> + // DEFAULT: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<16xi1>, vector<16xi1> + // BW-128: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<16xi1> + // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<16xi1> to vector<1x16xi1> + // BW-128: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<16xi1> to vector<1x16xi1> + // DEFAULT: return %[[CAST]] : vector<1x16xi1> + // BW-128: return %[[CAST]] : vector<1x16xi1> // BW-0: %[[C0:.*]] = arith.constant 0 : index // BW-0: %[[C20:.*]] = arith.constant 20 : index From 231371c66b5f5a0ab109003ba85bffdb9d962aae Mon Sep 17 00:00:00 2001 From: nbpatel Date: Mon, 21 Apr 2025 21:50:40 +0000 Subject: [PATCH 5/7] Address comments --- .../Vector/Transforms/VectorLinearize.cpp | 49 +++++++++---------- 1 file changed, 23 insertions(+), 26 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 7028285c0a91d..3b3153b787bb9 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -26,6 +26,9 @@ using namespace mlir; +constexpr unsigned defaultTargetVectorBitWidth = + std::numeric_limits::max(); + static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) { // For BW-0, all operations are legal if (targetBitWidth == 0) @@ -86,7 +89,7 @@ struct LinearizeConstantLike final LinearizeConstantLike( const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, PatternBenefit benefit = 1) : OpTraitConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} @@ -140,7 +143,7 @@ struct LinearizeVectorizable final public: LinearizeVectorizable( const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, PatternBenefit benefit = 1) : OpTraitConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} @@ -179,7 +182,7 @@ struct LinearizeVectorExtractStridedSlice final using OpConversionPattern::OpConversionPattern; LinearizeVectorExtractStridedSlice( const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} @@ -295,7 +298,7 @@ struct LinearizeVectorInsertStridedSlice final using OpConversionPattern::OpConversionPattern; LinearizeVectorInsertStridedSlice( const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} @@ -317,11 +320,6 @@ struct LinearizeVectorInsertStridedSlice final insertOp, "InsertStridedSliceOp linearization only supports 2D source."); - if (!srcTy.hasStaticShape() || !dstTy.hasStaticShape()) - return rewriter.notifyMatchFailure( - insertOp, - "InsertStridedSliceOp linerization only supports static shapes."); - if (srcTy.isScalable() || dstTy.isScalable()) return rewriter.notifyMatchFailure(insertOp, "scalable vectors are not supported."); @@ -372,7 +370,7 @@ struct LinearizeVectorShuffle final using OpConversionPattern::OpConversionPattern; LinearizeVectorShuffle( const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} @@ -445,7 +443,7 @@ struct LinearizeVectorExtract final using OpConversionPattern::OpConversionPattern; LinearizeVectorExtract( const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} @@ -513,7 +511,7 @@ struct LinearizeVectorInsert final using OpConversionPattern::OpConversionPattern; LinearizeVectorInsert( const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} @@ -593,7 +591,7 @@ struct LinearizeVectorBitCast final using OpConversionPattern::OpConversionPattern; LinearizeVectorBitCast( const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} @@ -618,6 +616,7 @@ struct LinearizeVectorBitCast final unsigned targetVectorBitWidth; }; +// clang-format off /// This pattern converts the LoadOp to a series of LoadOp & InsertOp /// that works on a linearized vector. /// Following, @@ -625,20 +624,19 @@ struct LinearizeVectorBitCast final /// is converted to : /// %result = arith.constant dense<0.0> : vector<4x4xf32> /// %slice_0 = vector.load %base[%indices] : vector<4xf32> -/// %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into -/// vector<4x4xf32> %slice_1 = vector.load %base[%indices + 1] : vector<4xf32> -/// %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into -/// vector<4x4xf32> +/// %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into vector<4x4xf32> +/// %slice_1 = vector.load %base[%indices + 1] : vector<4xf32> +/// %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into vector<4x4xf32> /// ... /// This unrolls the 2D vector load into multiple 1D vector loads and inserts /// them into the result vector. The pattern currently supports only 2D vectors +// clang-format on struct LinearizeVectorLoad final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LinearizeVectorLoad( - const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), - PatternBenefit benefit = 1) + LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context, + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, + PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} @@ -702,7 +700,7 @@ struct LinearizeVectorStore final LinearizeVectorStore( const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} @@ -758,7 +756,7 @@ struct LinearizeVectorSplat final LinearizeVectorSplat( const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} @@ -794,7 +792,7 @@ struct LinearizeVectorCreateMask final LinearizeVectorCreateMask( const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} @@ -907,8 +905,7 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( if (isLessThanTargetBitWidth(op, targetBitWidth)) { auto srcTy = op.getSourceVectorType(); auto dstTy = op.getDestVectorType(); - if (!op.hasNonUnitStrides() && srcTy.getRank() == 2 && - srcTy.hasStaticShape() && dstTy.hasStaticShape()) + if (!op.hasNonUnitStrides() && srcTy.getRank() == 2) return false; } return true; From e3788909066b3690104b0570feb0f05bb0140526 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Wed, 23 Apr 2025 16:46:47 +0000 Subject: [PATCH 6/7] Fix formatting --- .../Dialect/Vector/Transforms/VectorLinearize.cpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 3b3153b787bb9..ede77c6d0fa12 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -280,6 +280,7 @@ struct LinearizeVectorExtractStridedSlice final unsigned targetVectorBitWidth; }; +// clang-format off /// This pattern linearizes the InsertStridedSliceOp by extracting rows from the /// source vector using ExtractStridedSliceOp and inserting them into the /// destination vector using InsertStridedSliceOp. @@ -288,11 +289,14 @@ struct LinearizeVectorExtractStridedSlice final /// vector<4x4xf32> /// is converted to : /// %0 = vector.extract_strided_slice %s {offsets=[0], sizes=[4], strides=[1]} -/// : vector<4xf32> from vector<8xf32> %1 = vector.insert_strided_slice %0, %d -/// {offsets=[0], strides=[1]} : vector<4xf32> into vector<16xf32> %2 = -/// vector.extract_strided_slice %s {offsets=[4], sizes=[4], strides=[1]} : -/// vector<4xf32> from vector<8xf32> %3 = vector.insert_strided_slice %2, %1 -/// {offsets=[4], strides=[1]} : vector<4xf32> into vector<16xf32> +/// : vector<4xf32> from vector<8xf32> +/// %1 = vector.insert_strided_slice %0, %d {offsets=[0], strides=[1]} +/// : vector<4xf32> into vector<16xf32> +/// %2 = vector.extract_strided_slice %s {offsets=[4], sizes=[4], strides=[1]} +/// : vector<4xf32> from vector<8xf32> +/// %3 = vector.insert_strided_slice %2, %1 {offsets=[4], strides=[1]} +/// : vector<4xf32> into vector<16xf32> +// clang-format on struct LinearizeVectorInsertStridedSlice final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; From b01686a98b346cf7daa810407bb94220aaa2074b Mon Sep 17 00:00:00 2001 From: nbpatel Date: Wed, 23 Apr 2025 16:54:49 +0000 Subject: [PATCH 7/7] Remove unused variable --- mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index ede77c6d0fa12..ee83414a502a3 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -908,7 +908,6 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( [=](vector::InsertStridedSliceOp op) -> bool { if (isLessThanTargetBitWidth(op, targetBitWidth)) { auto srcTy = op.getSourceVectorType(); - auto dstTy = op.getDestVectorType(); if (!op.hasNonUnitStrides() && srcTy.getRank() == 2) return false; }