diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index b9fdede535112..b9cef003fa365 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -293,6 +293,10 @@ struct LinearizeVectorExtract final LogicalResult matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + // Skip if result is not a vector type + if (!isa(extractOp.getType())) + return rewriter.notifyMatchFailure(extractOp, + "scalar extract is not supported."); Type dstTy = getTypeConverter()->convertType(extractOp.getType()); assert(dstTy && "expected 1-D vector type"); @@ -415,6 +419,32 @@ struct LinearizeVectorBitCast final } }; +/// This pattern converts the SplatOp to work on a linearized vector. +/// Following, +/// vector.splat %value : vector<4x4xf32> +/// is converted to: +/// %out_1d = vector.splat %value : vector<16xf32> +/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32> +struct LinearizeVectorSplat final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context, + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit) {} + + LogicalResult + matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto dstTy = getTypeConverter()->convertType(splatOp.getType()); + if (!dstTy) + return rewriter.notifyMatchFailure(splatOp, "cannot convert type."); + rewriter.replaceOpWithNewOp(splatOp, adaptor.getInput(), + dstTy); + return success(); + } +}; + } // namespace /// Return true if the operation `op` does not support scalable vectors and @@ -501,7 +531,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns( const TypeConverter &typeConverter, const ConversionTarget &target, RewritePatternSet &patterns) { patterns.add(typeConverter, patterns.getContext()); + LinearizeVectorBitCast, LinearizeVectorSplat>( + typeConverter, patterns.getContext()); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 06eaf58b225ae..20169c15eb2c1 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -413,3 +413,37 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> { %1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16> return %1 : vector<[4]x4xf16> } + +// ----- +// ALL-LABEL: linearize_vector_splat +// ALL-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32> +func.func @linearize_vector_splat(%arg0: i32) -> vector<4x2xi32> { + // DEFAULT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32> + // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<8xi32> to vector<4x2xi32> + // DEFAULT: return %[[CAST]] : vector<4x2xi32> + // BW-128: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32> + // BW-128: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<8xi32> to vector<4x2xi32> + // BW-128: return %[[CAST]] : vector<4x2xi32> + + // BW-0: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4x2xi32> + // BW-0: return %[[SPLAT]] : vector<4x2xi32> + %0 = vector.splat %arg0 : vector<4x2xi32> + return %0 : vector<4x2xi32> +} + +// ----- +// ALL-LABEL: linearize_scalable_vector_splat +// ALL-SAME: (%[[ARG:.*]]: i32) -> vector<4x[2]xi32> +func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> { + // DEFAULT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<[8]xi32> + // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<[8]xi32> to vector<4x[2]xi32> + // DEFAULT: return %[[CAST]] : vector<4x[2]xi32> + // BW-128: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<[8]xi32> + // BW-128: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<[8]xi32> to vector<4x[2]xi32> + // BW-128: return %[[CAST]] : vector<4x[2]xi32> + + // BW-0: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4x[2]xi32> + // BW-0: return %[[SPLAT]] : vector<4x[2]xi32> + %0 = vector.splat %arg0 : vector<4x[2]xi32> + return %0 : vector<4x[2]xi32> +}