diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index a009aa03aaf64..ee83414a502a3 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -26,7 +26,14 @@ using namespace mlir; +constexpr unsigned defaultTargetVectorBitWidth = + std::numeric_limits::max(); + static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) { + // For BW-0, all operations are legal + if (targetBitWidth == 0) + return false; + auto resultTypes = op->getResultTypes(); for (auto resType : resultTypes) { VectorType vecType = dyn_cast(resType); @@ -82,7 +89,7 @@ struct LinearizeConstantLike final LinearizeConstantLike( const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, PatternBenefit benefit = 1) : OpTraitConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} @@ -136,7 +143,7 @@ struct LinearizeVectorizable final public: LinearizeVectorizable( const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, PatternBenefit benefit = 1) : OpTraitConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} @@ -175,7 +182,7 @@ struct LinearizeVectorExtractStridedSlice final using OpConversionPattern::OpConversionPattern; LinearizeVectorExtractStridedSlice( const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} @@ -273,6 +280,84 @@ struct LinearizeVectorExtractStridedSlice final unsigned targetVectorBitWidth; }; +// clang-format off +/// This pattern linearizes the InsertStridedSliceOp by extracting rows from the +/// source vector using ExtractStridedSliceOp and inserting them into the +/// destination vector using InsertStridedSliceOp. +/// Following, +/// vector.insert_strided_slice %s, %d {offsets=[0, 0]}: vector<2x4xf32> into +/// vector<4x4xf32> +/// is converted to : +/// %0 = vector.extract_strided_slice %s {offsets=[0], sizes=[4], strides=[1]} +/// : vector<4xf32> from vector<8xf32> +/// %1 = vector.insert_strided_slice %0, %d {offsets=[0], strides=[1]} +/// : vector<4xf32> into vector<16xf32> +/// %2 = vector.extract_strided_slice %s {offsets=[4], sizes=[4], strides=[1]} +/// : vector<4xf32> from vector<8xf32> +/// %3 = vector.insert_strided_slice %2, %1 {offsets=[4], strides=[1]} +/// : vector<4xf32> into vector<16xf32> +// clang-format on +struct LinearizeVectorInsertStridedSlice final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LinearizeVectorInsertStridedSlice( + const TypeConverter &typeConverter, MLIRContext *context, + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + targetVectorBitWidth(targetVectBitWidth) {} + + LogicalResult + matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = insertOp.getLoc(); + auto srcTy = insertOp.getSourceVectorType(); + auto dstTy = insertOp.getDestVectorType(); + + if (insertOp.hasNonUnitStrides()) + return rewriter.notifyMatchFailure( + insertOp, + "InsertStridedSliceOp linearization only supports unit strides."); + + if (srcTy.getRank() != 2) + return rewriter.notifyMatchFailure( + insertOp, + "InsertStridedSliceOp linearization only supports 2D source."); + + if (srcTy.isScalable() || dstTy.isScalable()) + return rewriter.notifyMatchFailure(insertOp, + "scalable vectors are not supported."); + + auto dstShape = dstTy.getShape(); + auto dstStrides = dstShape.drop_front().vec(); + dstStrides.push_back(1); + int64_t linearizedOffset = 0; + for (auto [off, stride] : + llvm::zip_equal(insertOp.getOffsets(), dstStrides)) { + linearizedOffset += getConstantIntValue(off).value() * stride; + } + + // extracts a row from source, and insert it into the destination + auto srcShape = srcTy.getShape(); + Value dstValue = adaptor.getDest(); + for (auto i = 0; i < srcShape[0]; i++) { + auto srcOffset = i * srcShape[1]; + auto value = rewriter.create( + loc, adaptor.getValueToStore(), srcOffset, srcShape[1], 1); + + auto dstOffset = linearizedOffset + i * dstShape.back(); + dstValue = rewriter.create( + loc, value, dstValue, dstOffset, 1); + } + + rewriter.replaceOp(insertOp, dstValue); + return success(); + } + +private: + unsigned targetVectorBitWidth; +}; + /// This pattern converts the ShuffleOp that works on nD (n > 1) /// vectors to a ShuffleOp that works on linearized vectors. /// Following, @@ -289,7 +374,7 @@ struct LinearizeVectorShuffle final using OpConversionPattern::OpConversionPattern; LinearizeVectorShuffle( const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} @@ -362,13 +447,18 @@ struct LinearizeVectorExtract final using OpConversionPattern::OpConversionPattern; LinearizeVectorExtract( const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} LogicalResult matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + // Skip if result is not a vector type + if (!isa(extractOp.getType())) + return rewriter.notifyMatchFailure(extractOp, + "scalar extract is not supported."); + Type dstTy = getTypeConverter()->convertType(extractOp.getType()); if (!dstTy) return rewriter.notifyMatchFailure(extractOp, @@ -425,7 +515,7 @@ struct LinearizeVectorInsert final using OpConversionPattern::OpConversionPattern; LinearizeVectorInsert( const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} @@ -451,12 +541,11 @@ struct LinearizeVectorInsert final auto srcTy = insertOp.getValueToStoreType(); auto srcAsVec = dyn_cast(srcTy); uint64_t srcSize = 0; - if (srcAsVec) { + if (srcAsVec) srcSize = srcAsVec.getNumElements(); - } else { + else return rewriter.notifyMatchFailure(insertOp, "scalars are not supported."); - } auto dstShape = insertOp.getDestVectorType().getShape(); const auto dstSize = insertOp.getDestVectorType().getNumElements(); @@ -506,7 +595,7 @@ struct LinearizeVectorBitCast final using OpConversionPattern::OpConversionPattern; LinearizeVectorBitCast( const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} @@ -531,12 +620,239 @@ struct LinearizeVectorBitCast final unsigned targetVectorBitWidth; }; +// clang-format off +/// This pattern converts the LoadOp to a series of LoadOp & InsertOp +/// that works on a linearized vector. +/// Following, +/// vector.load %base[%indices] : vector<4x4xf32> +/// is converted to : +/// %result = arith.constant dense<0.0> : vector<4x4xf32> +/// %slice_0 = vector.load %base[%indices] : vector<4xf32> +/// %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into vector<4x4xf32> +/// %slice_1 = vector.load %base[%indices + 1] : vector<4xf32> +/// %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into vector<4x4xf32> +/// ... +/// This unrolls the 2D vector load into multiple 1D vector loads and inserts +/// them into the result vector. The pattern currently supports only 2D vectors +// clang-format on +struct LinearizeVectorLoad final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context, + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + targetVectorBitWidth(targetVectBitWidth) {} + + LogicalResult + matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = loadOp->getLoc(); + auto vecType = loadOp.getVectorType(); + auto shape = vecType.getShape(); + + if (shape.size() != 2) + return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors."); + + auto unrollCount = shape[0]; + auto vecSize = shape[1]; + auto newVecType = VectorType::get({vecSize}, vecType.getElementType()); + + llvm::SmallVector indices = adaptor.getIndices(); + Value xBaseIndex = indices[0]; + + // Construct the 2D vector. + Value resultVec = + rewriter.create(loc, rewriter.getZeroAttr(vecType)); + // Emit unrolled loads for each 1D vector slice. + for (auto i = 0; i < unrollCount; i++) { + Value xIndex = xBaseIndex; + if (i) { + auto increment = rewriter.create(loc, i); + xIndex = rewriter.create(loc, xBaseIndex, increment); + } + indices[0] = xIndex; + auto vec = rewriter.create(loc, newVecType, + adaptor.getBase(), indices); + resultVec = rewriter.create(loc, vec, resultVec, i); + } + + rewriter.replaceOp(loadOp, resultVec); + return success(); + } + +private: + unsigned targetVectorBitWidth; +}; + +/// This pattern converts the StoreOp to a series of StoreOp & ExtractOp +/// that works on a linearized vector. +/// Following, +/// vector.store %source, %base[%indices] : vector<4x4xf32> +/// is converted to : +/// %slice_0 = vector.extract %source[0] : vector<4xf32> +/// vector.store %slice_0, %base[%indices] : vector<4xf32> +/// %slice_1 = vector.extract %source[1] : vector<4xf32> +/// vector.store %slice_1, %base[%indices + 1] : vector<4xf32> +/// ... +/// This unrolls the 2D vector store into multiple 1D vector stores by +/// extracting slices from the source vector and storing them into the +/// destination. The pattern currently supports only 2D vectors +struct LinearizeVectorStore final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LinearizeVectorStore( + const TypeConverter &typeConverter, MLIRContext *context, + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + targetVectorBitWidth(targetVectBitWidth) {} + + LogicalResult + matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = storeOp->getLoc(); + auto vecType = storeOp.getVectorType(); + auto shape = vecType.getShape(); + + if (shape.size() != 2) + return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors."); + + auto unrollCount = shape[0]; + llvm::SmallVector indices = adaptor.getIndices(); + Value xBaseIndex = indices[0]; + + auto vec = rewriter.create(loc, vecType, + adaptor.getValueToStore()); + + for (auto i = 0; i < unrollCount; i++) { + auto vecSlice = rewriter.create(loc, vec, i); + Value xIndex = xBaseIndex; + if (i) { + auto increment = rewriter.create(loc, i); + xIndex = rewriter.create(loc, xBaseIndex, increment); + } + indices[0] = xIndex; + rewriter.create(loc, vecSlice, adaptor.getBase(), + indices); + } + rewriter.eraseOp(storeOp); + return success(); + } + +private: + unsigned targetVectorBitWidth; +}; + +/// This pattern converts the SplatOp to work on a linearized vector. +/// Following, +/// vector.splat %value : vector<4x4xf32> +/// is converted to: +/// %out_1d = vector.splat %value : vector<16xf32> +/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32> +/// It ensures that the operation is compatible with the target vector +/// bit width and replaces the original operation with a new SplatOp +/// that operates on the converted type. +struct LinearizeVectorSplat final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LinearizeVectorSplat( + const TypeConverter &typeConverter, MLIRContext *context, + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + targetVectorBitWidth(targetVectBitWidth) {} + + LogicalResult + matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto dstTy = getTypeConverter()->convertType(splatOp.getType()); + if (!dstTy) + return rewriter.notifyMatchFailure(splatOp, "cannot convert type."); + rewriter.replaceOpWithNewOp(splatOp, adaptor.getInput(), + dstTy); + return success(); + } + +private: + unsigned targetVectorBitWidth; +}; + +/// This pattern converts the CreateMaskOp to work on a +/// linearized vector. It ensures that the operation is compatible with the +/// target vector bit width and replaces the original operation with a new +/// CreateMaskOp that operates on the converted type. The pattern currently +/// supports only 2D masks with a unit outer dimension. +/// Following, +/// vector.create_mask %dims : vector<1x4xi1> +/// is converted to: +/// %out_1d = vector.create_mask %dims : vector<4xi1> +/// %out_nd = vector.shape_cast %out_1d : vector<4xi1> to vector<1x4xi1> +struct LinearizeVectorCreateMask final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LinearizeVectorCreateMask( + const TypeConverter &typeConverter, MLIRContext *context, + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + targetVectorBitWidth(targetVectBitWidth) {} + + LogicalResult + matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcTy = createMaskOp.getType(); + auto srcShape = srcTy.getShape(); + if (srcShape.size() != 2) + return rewriter.notifyMatchFailure(createMaskOp, + "only 2D mask is supported."); + + if (srcShape[0] != 1) + return rewriter.notifyMatchFailure( + createMaskOp, "only unit outer dimension is supported."); + + auto dstTy = getTypeConverter()->convertType(srcTy); + if (!dstTy) + return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type."); + + // Compare the first operand with 0. If it's less than or equal to 0, + // create a zero mask, else strip the first operand and create a mask + // using the second operand. + auto firstOperand = adaptor.getOperands().front(); + auto zero = + rewriter.create(createMaskOp.getLoc(), 0); + auto isZeroOrNegative = rewriter.create( + createMaskOp.getLoc(), mlir::arith::CmpIPredicate::sle, firstOperand, + zero); + auto isZeroOrNegativeSplat = rewriter.create( + createMaskOp.getLoc(), dstTy, isZeroOrNegative); + + // Use a select operation to choose between the masks. + auto zeroMask = rewriter.create( + createMaskOp.getLoc(), dstTy, rewriter.getZeroAttr(dstTy)); + auto newMask = rewriter.create( + createMaskOp.getLoc(), dstTy, adaptor.getOperands().back()); + auto result = rewriter.create( + createMaskOp.getLoc(), isZeroOrNegativeSplat, zeroMask, newMask); + + rewriter.replaceOp(createMaskOp, result.getResult()); + return success(); + } + +private: + unsigned targetVectorBitWidth; +}; + } // namespace void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, unsigned targetBitWidth) { + typeConverter.addConversion([](Type type) -> Type { return type; }); typeConverter.addConversion([](VectorType type) -> std::optional { if (!isLinearizableVector(type)) return type; @@ -555,9 +871,11 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( }; typeConverter.addSourceMaterialization(materializeCast); typeConverter.addTargetMaterialization(materializeCast); + target.addLegalOp(); target.markUnknownOpDynamicallyLegal( [=](Operation *op) -> std::optional { - if ((isa(op) || + if ((isa(op) || op->hasTrait() || op->hasTrait())) { return (isLessThanTargetBitWidth(op, targetBitWidth) @@ -567,9 +885,11 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( return std::nullopt; }); - patterns.add(typeConverter, patterns.getContext(), - targetBitWidth); + patterns + .add(typeConverter, patterns.getContext(), + targetBitWidth); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( @@ -583,7 +903,19 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( .getRank() == 1) : true; }); + + target.addDynamicallyLegalOp( + [=](vector::InsertStridedSliceOp op) -> bool { + if (isLessThanTargetBitWidth(op, targetBitWidth)) { + auto srcTy = op.getSourceVectorType(); + if (!op.hasNonUnitStrides() && srcTy.getRank() == 2) + return false; + } + return true; + }); + patterns.add( + LinearizeVectorInsert, LinearizeVectorExtractStridedSlice, + LinearizeVectorInsertStridedSlice>( typeConverter, patterns.getContext(), targetBitWidth); } diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 9052c6440e6ac..f7a767dbdc272 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -399,3 +399,190 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> { %1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16> return %1 : vector<[4]x4xf16> } + +// ----- +// ALL-LABEL: test_vector_load +// ALL-SAME: (%[[ARG_0:.*]]: memref<4x4xf16>) +func.func @test_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> { + // DEFAULT: %[[C1:.*]] = arith.constant 1 : index + // BW-128: %[[C1:.*]] = arith.constant 1 : index + // DEFAULT: %[[C2:.*]] = arith.constant 2 : index + // BW-128: %[[C2:.*]] = arith.constant 2 : index + // DEFAULT: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf16> + // BW-128: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf16> + // DEFAULT: %[[LOAD0:.*]] = vector.load %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // BW-128: %[[LOAD0:.*]] = vector.load %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // DEFAULT: %[[SHUFFLE0:.*]] = vector.shuffle %[[CST]], %[[LOAD0]] [16, 17, 18, 19, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16> + // BW-128: %[[SHUFFLE0:.*]] = vector.shuffle %[[CST]], %[[LOAD0]] [16, 17, 18, 19, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16> + // DEFAULT: %[[C1_0:.*]] = arith.constant 1 : index + // BW-128: %[[C1_0:.*]] = arith.constant 1 : index + // DEFAULT: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index + // BW-128: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index + // DEFAULT: %[[LOAD1:.*]] = vector.load %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // BW-128: %[[LOAD1:.*]] = vector.load %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // DEFAULT: %[[SHUFFLE1:.*]] = vector.shuffle %[[SHUFFLE0]], %[[LOAD1]] [0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16> + // BW-128: %[[SHUFFLE1:.*]] = vector.shuffle %[[SHUFFLE0]], %[[LOAD1]] [0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16> + // DEFAULT: %[[C2_1:.*]] = arith.constant 2 : index + // BW-128: %[[C2_1:.*]] = arith.constant 2 : index + // DEFAULT: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index + // BW-128: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index + // DEFAULT: %[[LOAD2:.*]] = vector.load %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // BW-128: %[[LOAD2:.*]] = vector.load %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // DEFAULT: %[[SHUFFLE2:.*]] = vector.shuffle %[[SHUFFLE1]], %[[LOAD2]] [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16> + // BW-128: %[[SHUFFLE2:.*]] = vector.shuffle %[[SHUFFLE1]], %[[LOAD2]] [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16> + // DEFAULT: %[[C3:.*]] = arith.constant 3 : index + // BW-128: %[[C3:.*]] = arith.constant 3 : index + // DEFAULT: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index + // BW-128: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index + // DEFAULT: %[[LOAD3:.*]] = vector.load %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // BW-128: %[[LOAD3:.*]] = vector.load %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // DEFAULT: %[[SHUFFLE3:.*]] = vector.shuffle %[[SHUFFLE2]], %[[LOAD3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 17, 18, 19] : vector<16xf16>, vector<4xf16> + // BW-128: %[[SHUFFLE3:.*]] = vector.shuffle %[[SHUFFLE2]], %[[LOAD3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 17, 18, 19] : vector<16xf16>, vector<4xf16> + // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SHUFFLE3]] : vector<16xf16> to vector<4x4xf16> + // BW-128: %[[CAST:.*]] = vector.shape_cast %[[SHUFFLE3]] : vector<16xf16> to vector<4x4xf16> + // DEFAULT: return %[[CAST]] : vector<4x4xf16> + // BW-128: return %[[CAST]] : vector<4x4xf16> + + // BW-0: %[[C1:.*]] = arith.constant 1 : index + // BW-0: %[[C2:.*]] = arith.constant 2 : index + // BW-0: %[[LOAD:.*]] = vector.load %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4x4xf16> + // BW-0: return %[[LOAD]] : vector<4x4xf16> + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0 = vector.load %arg0[%c1, %c2] : memref<4x4xf16>, vector<4x4xf16> + return %0 : vector<4x4xf16> +} + +// ----- +// ALL-LABEL: test_vector_store +// ALL-SAME: (%[[ARG_0:.*]]: memref<4x4xf16>, %[[ARG_1:.*]]: vector<4x4xf16>) { +func.func @test_vector_store(%arg0: memref<4x4xf16>, %arg1: vector<4x4xf16>) { + // DEFAULT: %[[CAST0:.*]] = vector.shape_cast %[[ARG_1]] : vector<4x4xf16> to vector<16xf16> + // BW-128: %[[CAST0:.*]] = vector.shape_cast %[[ARG_1]] : vector<4x4xf16> to vector<16xf16> + // DEFAULT: %[[C1:.*]] = arith.constant 1 : index + // BW-128: %[[C1:.*]] = arith.constant 1 : index + // DEFAULT: %[[C2:.*]] = arith.constant 2 : index + // BW-128: %[[C2:.*]] = arith.constant 2 : index + // DEFAULT: %[[CAST1:.*]] = vector.shape_cast %[[CAST0]] : vector<16xf16> to vector<4x4xf16> + // BW-128: %[[CAST1:.*]] = vector.shape_cast %[[CAST0]] : vector<16xf16> to vector<4x4xf16> + // DEFAULT: %[[CAST2:.*]] = vector.shape_cast %[[CAST1]] : vector<4x4xf16> to vector<16xf16> + // BW-128: %[[CAST2:.*]] = vector.shape_cast %[[CAST1]] : vector<4x4xf16> to vector<16xf16> + // DEFAULT: %[[SHUFFLE0:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [0, 1, 2, 3] : vector<16xf16>, vector<16xf16> + // BW-128: %[[SHUFFLE0:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [0, 1, 2, 3] : vector<16xf16>, vector<16xf16> + // DEFAULT: vector.store %[[SHUFFLE0]], %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // BW-128: vector.store %[[SHUFFLE0]], %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // DEFAULT: %[[SHUFFLE1:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [4, 5, 6, 7] : vector<16xf16>, vector<16xf16> + // BW-128: %[[SHUFFLE1:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [4, 5, 6, 7] : vector<16xf16>, vector<16xf16> + // DEFAULT: %[[C1_0:.*]] = arith.constant 1 : index + // BW-128: %[[C1_0:.*]] = arith.constant 1 : index + // DEFAULT: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index + // BW-128: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index + // DEFAULT: vector.store %[[SHUFFLE1]], %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // BW-128: vector.store %[[SHUFFLE1]], %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // DEFAULT: %[[SHUFFLE2:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [8, 9, 10, 11] : vector<16xf16>, vector<16xf16> + // BW-128: %[[SHUFFLE2:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [8, 9, 10, 11] : vector<16xf16>, vector<16xf16> + // DEFAULT: %[[C2_1:.*]] = arith.constant 2 : index + // BW-128: %[[C2_1:.*]] = arith.constant 2 : index + // DEFAULT: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index + // BW-128: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index + // DEFAULT: vector.store %[[SHUFFLE2]], %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // BW-128: vector.store %[[SHUFFLE2]], %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // DEFAULT: %[[SHUFFLE3:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [12, 13, 14, 15] : vector<16xf16>, vector<16xf16> + // BW-128: %[[SHUFFLE3:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [12, 13, 14, 15] : vector<16xf16>, vector<16xf16> + // DEFAULT: %[[C3:.*]] = arith.constant 3 : index + // BW-128: %[[C3:.*]] = arith.constant 3 : index + // DEFAULT: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index + // BW-128: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index + // DEFAULT: vector.store %[[SHUFFLE3]], %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // BW-128: vector.store %[[SHUFFLE3]], %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // DEFAULT: return + // BW-128: return + + // BW-0: %[[C1:.*]] = arith.constant 1 : index + // BW-0: %[[C2:.*]] = arith.constant 2 : index + // BW-0: vector.store %[[ARG_1]], %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4x4xf16> + // BW-0: return + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + vector.store %arg1, %arg0[%c1, %c2] : memref<4x4xf16>, vector<4x4xf16> + return +} + +// ----- +// ALL-LABEL: test_create_mask +func.func @test_create_mask() -> vector<1x16xi1> { + // DEFAULT: %[[C0:.*]] = arith.constant 0 : index + // BW-128: %[[C0:.*]] = arith.constant 0 : index + // DEFAULT: %[[C20:.*]] = arith.constant 20 : index + // BW-128: %[[C20:.*]] = arith.constant 20 : index + // DEFAULT: %[[C0_0:.*]] = arith.constant 0 : index + // BW-128: %[[C0_0:.*]] = arith.constant 0 : index + // DEFAULT: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index + // BW-128: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index + // DEFAULT: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<16xi1> + // BW-128: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<16xi1> + // DEFAULT: %[[CST:.*]] = arith.constant dense : vector<16xi1> + // BW-128: %[[CST:.*]] = arith.constant dense : vector<16xi1> + // DEFAULT: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<16xi1> + // BW-128: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<16xi1> + // DEFAULT: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<16xi1>, vector<16xi1> + // BW-128: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<16xi1> + // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<16xi1> to vector<1x16xi1> + // BW-128: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<16xi1> to vector<1x16xi1> + // DEFAULT: return %[[CAST]] : vector<1x16xi1> + // BW-128: return %[[CAST]] : vector<1x16xi1> + + // BW-0: %[[C0:.*]] = arith.constant 0 : index + // BW-0: %[[C20:.*]] = arith.constant 20 : index + // BW-0: %[[MASK:.*]] = vector.create_mask %[[C0]], %[[C20]] : vector<1x16xi1> + %c0 = arith.constant 0 : index + %c20 = arith.constant 20 : index + %0 = vector.create_mask %c0, %c20 : vector<1x16xi1> + return %0 : vector<1x16xi1> +} + +// ----- +// ALL-LABEL: test_vector_insert_2d_idx +// ALL-SAME: (%[[ARG:.*]]: vector<4x8xf16>) -> vector<8x16xf16> +func.func @test_vector_insert_2d_idx(%arg0: vector<4x8xf16>) -> vector<8x16xf16> { + // DEFAULT: %[[V0:.*]] = vector.shape_cast %[[ARG]] : vector<4x8xf16> to vector<32xf16> + // DEFAULT: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<128xf16> + // DEFAULT: %[[V1:.*]] = vector.shuffle %[[V0]], %[[V0]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<32xf16>, vector<32xf16> + // DEFAULT: %[[V2:.*]] = vector.insert_strided_slice %[[V1]], %[[CST]] {offsets = [0], strides = [1]} : vector<8xf16> into vector<128xf16> + // DEFAULT: %[[V3:.*]] = vector.shuffle %[[V0]], %[[V0]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<32xf16>, vector<32xf16> + // DEFAULT: %[[V4:.*]] = vector.insert_strided_slice %[[V3]], %[[V2]] {offsets = [16], strides = [1]} : vector<8xf16> into vector<128xf16> + // DEFAULT: %[[V5:.*]] = vector.shuffle %[[V0]], %[[V0]] [16, 17, 18, 19, 20, 21, 22, 23] : vector<32xf16>, vector<32xf16> + // DEFAULT: %[[V6:.*]] = vector.insert_strided_slice %[[V5]], %[[V4]] {offsets = [32], strides = [1]} : vector<8xf16> into vector<128xf16> + // DEFAULT: %[[V7:.*]] = vector.shuffle %[[V0]], %[[V0]] [24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf16>, vector<32xf16> + // DEFAULT: %[[V8:.*]] = vector.insert_strided_slice %[[V7]], %[[V6]] {offsets = [48], strides = [1]} : vector<8xf16> into vector<128xf16> + // DEFAULT: %[[V9:.*]] = vector.shape_cast %[[V8]] : vector<128xf16> to vector<8x16xf16> + // DEFAULT: return %[[V9]] : vector<8x16xf16> + + // BW-128: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf16> + // BW-128: %[[V0:.*]] = vector.insert_strided_slice %[[ARG]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<4x8xf16> into vector<8x16xf16> + // BW-128: return %[[V0]] : vector<8x16xf16> + + // BW-0: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf16> + // BW-0: %[[V0:.*]] = vector.insert_strided_slice %[[ARG]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<4x8xf16> into vector<8x16xf16> + // BW-0: return %[[V0]] : vector<8x16xf16> + %cst = arith.constant dense <0.0> : vector<8x16xf16> + %0 = vector.insert_strided_slice %arg0, %cst {offsets = [0, 0], strides = [1, 1]} : vector<4x8xf16> into vector<8x16xf16> + return %0 : vector<8x16xf16> +} + +// ----- +// ALL-LABEL: test_vector_splat +// ALL-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32> +func.func @test_vector_splat(%arg0: i32) -> vector<4x2xi32> { + // DEFAULT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32> + // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<8xi32> to vector<4x2xi32> + // DEFAULT: return %[[CAST]] : vector<4x2xi32> + // BW-128: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32> + // BW-128: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<8xi32> to vector<4x2xi32> + // BW-128: return %[[CAST]] : vector<4x2xi32> + + // BW-0: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4x2xi32> + // BW-0: return %[[SPLAT]] : vector<4x2xi32> + %0 = vector.splat %arg0 : vector<4x2xi32> + return %0 : vector<4x2xi32> +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index a54ae816570a8..aea116cffc3a8 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -851,7 +851,8 @@ struct TestVectorLinearize final return "Linearizes ND vectors for N >= 2 into 1D vectors"; } void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } Option targetVectorBitwidth{