-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[mlir][vector] Avoid use of vector.splat in transforms #150279
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Avoid use of vector.splat in transforms #150279
Conversation
@llvm/pr-subscribers-mlir Author: James Newling (newling) ChangesThis is part of vector.splat deprecation Instead of creating vector::SplatOp, create vector::BroadcastOp Full diff: https://github.com/llvm/llvm-project/pull/150279.diff 6 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index cb8e566869cfd..0a48eab3e6666 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -28,7 +28,10 @@ using namespace mlir;
using namespace mlir::vector;
namespace {
-/// Progressive lowering of BroadcastOp.
+
+/// Convert a vector.broadcast without a scalar operand to a lower rank
+/// vector.broadcast. vector.broadcast with a scalar operand is expected to be
+/// convertible to the lower level target dialect (LLVM, SPIR-V, etc.) directly.
class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
public:
using OpRewritePattern::OpRewritePattern;
@@ -40,20 +43,20 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
VectorType srcType = dyn_cast<VectorType>(op.getSourceType());
Type eltType = dstType.getElementType();
- // Scalar to any vector can use splat.
- if (!srcType) {
- rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource());
- return success();
- }
+ // A broadcast from a scalar is considered to be in the lowered form.
+ if (!srcType)
+ return failure();
// Determine rank of source and destination.
int64_t srcRank = srcType.getRank();
int64_t dstRank = dstType.getRank();
- // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
if (srcRank <= 1 && dstRank == 1) {
- Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource());
- rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
+ SmallVector<int64_t> fullRankPosition(srcRank, 0);
+ Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(),
+ fullRankPosition);
+ assert(!isa<VectorType>(ext.getType()) && "expected scalar");
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, dstType, ext);
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index e9109322ed3d8..7122c53e49780 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -468,7 +468,7 @@ struct TransferReadToVectorLoadLowering
read, "vector type is not rank 1, can't create masked load, needs "
"VectorToSCF");
- Value fill = vector::SplatOp::create(
+ Value fill = vector::BroadcastOp::create(
rewriter, read.getLoc(), unbroadcastedVectorType, read.getPadding());
res = vector::MaskedLoadOp::create(
rewriter, read.getLoc(), unbroadcastedVectorType, read.getBase(),
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 72352d72bfe77..cbb9d4bbf0b1f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -303,7 +303,7 @@ class DecomposeNDExtractStridedSlice
// Extract/insert on a lower ranked extract strided slice op.
Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
rewriter.getZeroAttr(elemType));
- Value res = SplatOp::create(rewriter, loc, dstType, zero);
+ Value res = BroadcastOp::create(rewriter, loc, dstType, zero);
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
off += stride, ++idx) {
Value one = ExtractOp::create(rewriter, loc, op.getVector(), off);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 73ca327bb49c5..abcc49144187e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -940,7 +940,7 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
rewriter.getZeroAttr(elemType));
- Value res = SplatOp::create(rewriter, loc, castDstType, zero);
+ Value res = BroadcastOp::create(rewriter, loc, castDstType, zero);
SmallVector<int64_t> sliceShape = {castDstLastDim};
SmallVector<int64_t> strides = {1};
@@ -966,6 +966,23 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
std::function<bool(BitCastOp)> controlFn;
};
+/// If `value` is the result of a splat or broadcast operation, return the input
+/// of the splat/broadcast operation.
+static Value getBroadcastLikeSource(Value value) {
+
+ Operation *op = value.getDefiningOp();
+ if (!op)
+ return {};
+
+ if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
+ return broadcast.getSource();
+
+ if (auto splat = dyn_cast<vector::SplatOp>(op))
+ return splat.getInput();
+
+ return {};
+}
+
/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
///
/// Example:
@@ -1007,26 +1024,23 @@ struct ReorderElementwiseOpsOnBroadcast final
}
// Get the type of the lhs operand
- auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
- if (!lhsBcastOrSplat ||
- !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
- return failure();
- auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
+ Value lhsSource = getBroadcastLikeSource(op->getOperand(0));
+ if (!lhsSource)
+ return rewriter.notifyMatchFailure(
+ op, "operand #0 not the result of a broadcast");
+ Type lhsBcastOrSplatType = lhsSource.getType();
// Make sure that all operands are broadcast from identical types:
// * scalar (`vector.broadcast` + `vector.splat`), or
// * vector (`vector.broadcast`).
// Otherwise the re-ordering wouldn't be safe.
- if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
- auto bcast = val.getDefiningOp<vector::BroadcastOp>();
- if (bcast)
- return (bcast.getOperand().getType() == lhsBcastOrSplatType);
- auto splat = val.getDefiningOp<vector::SplatOp>();
- if (splat)
- return (splat.getOperand().getType() == lhsBcastOrSplatType);
+ if (!llvm::all_of(op->getOperands(), [lhsBcastOrSplatType](Value val) {
+ if (auto source = getBroadcastLikeSource(val))
+ return source.getType() == lhsBcastOrSplatType;
return false;
})) {
- return failure();
+ return rewriter.notifyMatchFailure(
+ op, "not all operands are broadcasts from the sametype");
}
// Collect the source values before broadcasting
@@ -1240,15 +1254,17 @@ class StoreOpFromSplatOrBroadcast final
return rewriter.notifyMatchFailure(
op, "only 1-element vectors are supported");
- Operation *splat = op.getValueToStore().getDefiningOp();
- if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
- return rewriter.notifyMatchFailure(op, "neither a splat nor a broadcast");
+ Value toStore = op.getValueToStore();
+ Value source = getBroadcastLikeSource(toStore);
+ if (!source)
+ return rewriter.notifyMatchFailure(
+ op, "value to store is not from a broadcast");
// Checking for single use so we can remove splat.
+ Operation *splat = toStore.getDefiningOp();
if (!splat->hasOneUse())
return rewriter.notifyMatchFailure(op, "expected single op use");
- Value source = splat->getOperand(0);
Value base = op.getBase();
ValueRange indices = op.getIndices();
@@ -1298,13 +1314,13 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
// Add in an offset if requested.
if (off) {
Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
- Value ov = vector::SplatOp::create(rewriter, loc, indices.getType(), o);
+ Value ov = vector::BroadcastOp::create(rewriter, loc, indices.getType(), o);
indices = arith::AddIOp::create(rewriter, loc, ov, indices);
}
// Construct the vector comparison.
Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
Value bounds =
- vector::SplatOp::create(rewriter, loc, indices.getType(), bound);
+ vector::BroadcastOp::create(rewriter, loc, indices.getType(), bound);
return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
indices, bounds);
}
diff --git a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
index 8e167a520260f..d5e344393b217 100644
--- a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
@@ -2,7 +2,7 @@
// CHECK-LABEL: func @broadcast_vec1d_from_scalar
// CHECK-SAME: %[[A:.*0]]: f32
-// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2xf32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2xf32>
// CHECK: return %[[T0]] : vector<2xf32>
func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
@@ -12,7 +12,7 @@ func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
// CHECK-LABEL: func @broadcast_vec2d_from_scalar
// CHECK-SAME: %[[A:.*0]]: f32
-// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3xf32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3xf32>
// CHECK: return %[[T0]] : vector<2x3xf32>
func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
@@ -22,7 +22,7 @@ func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
// CHECK-LABEL: func @broadcast_vec3d_from_scalar
// CHECK-SAME: %[[A:.*0]]: f32
-// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3x4xf32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3x4xf32>
// CHECK: return %[[T0]] : vector<2x3x4xf32>
func.func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> {
@@ -87,7 +87,7 @@ func.func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf3
// CHECK-LABEL: func @broadcast_stretch
// CHECK-SAME: %[[A:.*0]]: vector<1xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<1xf32>
-// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<4xf32>
+// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<4xf32>
// CHECK: return %[[T1]] : vector<4xf32>
func.func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> {
@@ -113,16 +113,16 @@ func.func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32>
// CHECK-SAME: %[[A:.*0]]: vector<4x1xf32>
// CHECK: %[[U0:.*]] = ub.poison : vector<4x3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<4x1xf32>
-// CHECK: %[[T2:.*]] = vector.splat %[[T0]] : vector<3xf32>
+// CHECK: %[[T2:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[U0]] [0] : vector<3xf32> into vector<4x3xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : f32 from vector<4x1xf32>
-// CHECK: %[[T6:.*]] = vector.splat %[[T4]] : vector<3xf32>
+// CHECK: %[[T6:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<4x3xf32>
// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : f32 from vector<4x1xf32>
-// CHECK: %[[T10:.*]] = vector.splat %[[T8]] : vector<3xf32>
+// CHECK: %[[T10:.*]] = vector.broadcast %[[T8]] : f32 to vector<3xf32>
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T7]] [2] : vector<3xf32> into vector<4x3xf32>
// CHECK: %[[T12:.*]] = vector.extract %[[A]][3, 0] : f32 from vector<4x1xf32>
-// CHECK: %[[T14:.*]] = vector.splat %[[T12]] : vector<3xf32>
+// CHECK: %[[T14:.*]] = vector.broadcast %[[T12]] : f32 to vector<3xf32>
// CHECK: %[[T15:.*]] = vector.insert %[[T14]], %[[T11]] [3] : vector<3xf32> into vector<4x3xf32>
// CHECK: return %[[T15]] : vector<4x3xf32>
diff --git a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir
index 059d955f77313..5a8125ed67173 100644
--- a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir
@@ -5,11 +5,11 @@
// CHECK-SAME: %[[B:.*1]]: vector<3xf32>
// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32>
-// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32>
+// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
// CHECK: %[[T2:.*]] = arith.mulf %[[T1]], %[[B]] : vector<3xf32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xf32>
+// CHECK: %[[T5:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32>
// CHECK: %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32>
// CHECK: return %[[T7]] : vector<2x3xf32>
@@ -26,12 +26,12 @@ func.func @outerproduct_noacc(%arg0: vector<2xf32>,
// CHECK-SAME: %[[C:.*2]]: vector<2x3xf32>
// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32>
-// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32>
+// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xf32> from vector<2x3xf32>
// CHECK: %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32>
// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.splat %[[T5]] : vector<3xf32>
+// CHECK: %[[T6:.*]] = vector.broadcast %[[T5]] : f32 to vector<3xf32>
// CHECK: %[[T7:.*]] = vector.extract %[[C]][1] : vector<3xf32> from vector<2x3xf32>
// CHECK: %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32>
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32>
@@ -49,11 +49,11 @@ func.func @outerproduct_acc(%arg0: vector<2xf32>,
// CHECK-SAME: %[[B:.*1]]: vector<3xi32>
// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32>
-// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32>
+// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : i32 to vector<3xi32>
// CHECK: %[[T2:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32>
-// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xi32>
+// CHECK: %[[T5:.*]] = vector.broadcast %[[T4]] : i32 to vector<3xi32>
// CHECK: %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xi32> into vector<2x3xi32>
// CHECK: return %[[T7]] : vector<2x3xi32>
@@ -69,13 +69,13 @@ func.func @outerproduct_noacc_int(%arg0: vector<2xi32>,
// CHECK-SAME: %[[C:.*2]]: vector<2x3xi32>
// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32>
-// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32>
+// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : i32 to vector<3xi32>
// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xi32> from vector<2x3xi32>
// CHECK: %[[T3:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
// CHECK: %[[T4:.*]] = arith.addi %[[T3]], %[[T2]] : vector<3xi32>
// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32>
-// CHECK: %[[T7:.*]] = vector.splat %[[T6]] : vector<3xi32>
+// CHECK: %[[T7:.*]] = vector.broadcast %[[T6]] : i32 to vector<3xi32>
// CHECK: %[[T8:.*]] = vector.extract %[[C]][1] : vector<3xi32> from vector<2x3xi32>
// CHECK: %[[T9:.*]] = arith.muli %[[T7]], %[[B]] : vector<3xi32>
// CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[T8]] : vector<3xi32>
@@ -91,7 +91,7 @@ func.func @outerproduct_acc_int(%arg0: vector<2xi32>,
// CHECK-LABEL: func @axpy_fp(
// CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
// CHECK-SAME: %[[B:.*1]]: f32)
-// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : f32 to vector<16xf32>
// CHECK: %[[T1:.*]] = arith.mulf %[[A]], %[[T0]] : vector<16xf32>
// CHECK: return %[[T1]] : vector<16xf32>
func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> {
@@ -103,7 +103,7 @@ func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> {
// CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
// CHECK-SAME: %[[B:.*1]]: f32,
// CHECK-SAME: %[[C:.*2]]: vector<16xf32>)
-// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : f32 to vector<16xf32>
// CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32>
// CHECK: return %[[T1]] : vector<16xf32>
func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> {
@@ -114,7 +114,7 @@ func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>
// CHECK-LABEL: func @axpy_int(
// CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
// CHECK-SAME: %[[B:.*1]]: i32)
-// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : i32 to vector<16xi32>
// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
// CHECK: return %[[T1]] : vector<16xi32>
func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> {
@@ -126,7 +126,7 @@ func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> {
// CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
// CHECK-SAME: %[[B:.*1]]: i32,
// CHECK-SAME: %[[C:.*2]]: vector<16xi32>)
-// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : i32 to vector<16xi32>
// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
// CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[C]] : vector<16xi32>
// CHECK: return %[[T2]] : vector<16xi32>
|
@llvm/pr-subscribers-mlir-vector Author: James Newling (newling) ChangesThis is part of vector.splat deprecation Instead of creating vector::SplatOp, create vector::BroadcastOp Full diff: https://github.com/llvm/llvm-project/pull/150279.diff 6 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index cb8e566869cfd..0a48eab3e6666 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -28,7 +28,10 @@ using namespace mlir;
using namespace mlir::vector;
namespace {
-/// Progressive lowering of BroadcastOp.
+
+/// Convert a vector.broadcast without a scalar operand to a lower rank
+/// vector.broadcast. vector.broadcast with a scalar operand is expected to be
+/// convertible to the lower level target dialect (LLVM, SPIR-V, etc.) directly.
class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
public:
using OpRewritePattern::OpRewritePattern;
@@ -40,20 +43,20 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
VectorType srcType = dyn_cast<VectorType>(op.getSourceType());
Type eltType = dstType.getElementType();
- // Scalar to any vector can use splat.
- if (!srcType) {
- rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource());
- return success();
- }
+ // A broadcast from a scalar is considered to be in the lowered form.
+ if (!srcType)
+ return failure();
// Determine rank of source and destination.
int64_t srcRank = srcType.getRank();
int64_t dstRank = dstType.getRank();
- // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
if (srcRank <= 1 && dstRank == 1) {
- Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource());
- rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
+ SmallVector<int64_t> fullRankPosition(srcRank, 0);
+ Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(),
+ fullRankPosition);
+ assert(!isa<VectorType>(ext.getType()) && "expected scalar");
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, dstType, ext);
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index e9109322ed3d8..7122c53e49780 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -468,7 +468,7 @@ struct TransferReadToVectorLoadLowering
read, "vector type is not rank 1, can't create masked load, needs "
"VectorToSCF");
- Value fill = vector::SplatOp::create(
+ Value fill = vector::BroadcastOp::create(
rewriter, read.getLoc(), unbroadcastedVectorType, read.getPadding());
res = vector::MaskedLoadOp::create(
rewriter, read.getLoc(), unbroadcastedVectorType, read.getBase(),
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 72352d72bfe77..cbb9d4bbf0b1f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -303,7 +303,7 @@ class DecomposeNDExtractStridedSlice
// Extract/insert on a lower ranked extract strided slice op.
Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
rewriter.getZeroAttr(elemType));
- Value res = SplatOp::create(rewriter, loc, dstType, zero);
+ Value res = BroadcastOp::create(rewriter, loc, dstType, zero);
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
off += stride, ++idx) {
Value one = ExtractOp::create(rewriter, loc, op.getVector(), off);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 73ca327bb49c5..abcc49144187e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -940,7 +940,7 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
rewriter.getZeroAttr(elemType));
- Value res = SplatOp::create(rewriter, loc, castDstType, zero);
+ Value res = BroadcastOp::create(rewriter, loc, castDstType, zero);
SmallVector<int64_t> sliceShape = {castDstLastDim};
SmallVector<int64_t> strides = {1};
@@ -966,6 +966,23 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
std::function<bool(BitCastOp)> controlFn;
};
+/// If `value` is the result of a splat or broadcast operation, return the input
+/// of the splat/broadcast operation.
+static Value getBroadcastLikeSource(Value value) {
+
+ Operation *op = value.getDefiningOp();
+ if (!op)
+ return {};
+
+ if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
+ return broadcast.getSource();
+
+ if (auto splat = dyn_cast<vector::SplatOp>(op))
+ return splat.getInput();
+
+ return {};
+}
+
/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
///
/// Example:
@@ -1007,26 +1024,23 @@ struct ReorderElementwiseOpsOnBroadcast final
}
// Get the type of the lhs operand
- auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
- if (!lhsBcastOrSplat ||
- !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
- return failure();
- auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
+ Value lhsSource = getBroadcastLikeSource(op->getOperand(0));
+ if (!lhsSource)
+ return rewriter.notifyMatchFailure(
+ op, "operand #0 not the result of a broadcast");
+ Type lhsBcastOrSplatType = lhsSource.getType();
// Make sure that all operands are broadcast from identical types:
// * scalar (`vector.broadcast` + `vector.splat`), or
// * vector (`vector.broadcast`).
// Otherwise the re-ordering wouldn't be safe.
- if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
- auto bcast = val.getDefiningOp<vector::BroadcastOp>();
- if (bcast)
- return (bcast.getOperand().getType() == lhsBcastOrSplatType);
- auto splat = val.getDefiningOp<vector::SplatOp>();
- if (splat)
- return (splat.getOperand().getType() == lhsBcastOrSplatType);
+ if (!llvm::all_of(op->getOperands(), [lhsBcastOrSplatType](Value val) {
+ if (auto source = getBroadcastLikeSource(val))
+ return source.getType() == lhsBcastOrSplatType;
return false;
})) {
- return failure();
+ return rewriter.notifyMatchFailure(
+ op, "not all operands are broadcasts from the sametype");
}
// Collect the source values before broadcasting
@@ -1240,15 +1254,17 @@ class StoreOpFromSplatOrBroadcast final
return rewriter.notifyMatchFailure(
op, "only 1-element vectors are supported");
- Operation *splat = op.getValueToStore().getDefiningOp();
- if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
- return rewriter.notifyMatchFailure(op, "neither a splat nor a broadcast");
+ Value toStore = op.getValueToStore();
+ Value source = getBroadcastLikeSource(toStore);
+ if (!source)
+ return rewriter.notifyMatchFailure(
+ op, "value to store is not from a broadcast");
// Checking for single use so we can remove splat.
+ Operation *splat = toStore.getDefiningOp();
if (!splat->hasOneUse())
return rewriter.notifyMatchFailure(op, "expected single op use");
- Value source = splat->getOperand(0);
Value base = op.getBase();
ValueRange indices = op.getIndices();
@@ -1298,13 +1314,13 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
// Add in an offset if requested.
if (off) {
Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
- Value ov = vector::SplatOp::create(rewriter, loc, indices.getType(), o);
+ Value ov = vector::BroadcastOp::create(rewriter, loc, indices.getType(), o);
indices = arith::AddIOp::create(rewriter, loc, ov, indices);
}
// Construct the vector comparison.
Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
Value bounds =
- vector::SplatOp::create(rewriter, loc, indices.getType(), bound);
+ vector::BroadcastOp::create(rewriter, loc, indices.getType(), bound);
return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
indices, bounds);
}
diff --git a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
index 8e167a520260f..d5e344393b217 100644
--- a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
@@ -2,7 +2,7 @@
// CHECK-LABEL: func @broadcast_vec1d_from_scalar
// CHECK-SAME: %[[A:.*0]]: f32
-// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2xf32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2xf32>
// CHECK: return %[[T0]] : vector<2xf32>
func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
@@ -12,7 +12,7 @@ func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
// CHECK-LABEL: func @broadcast_vec2d_from_scalar
// CHECK-SAME: %[[A:.*0]]: f32
-// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3xf32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3xf32>
// CHECK: return %[[T0]] : vector<2x3xf32>
func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
@@ -22,7 +22,7 @@ func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
// CHECK-LABEL: func @broadcast_vec3d_from_scalar
// CHECK-SAME: %[[A:.*0]]: f32
-// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3x4xf32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3x4xf32>
// CHECK: return %[[T0]] : vector<2x3x4xf32>
func.func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> {
@@ -87,7 +87,7 @@ func.func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf3
// CHECK-LABEL: func @broadcast_stretch
// CHECK-SAME: %[[A:.*0]]: vector<1xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<1xf32>
-// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<4xf32>
+// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<4xf32>
// CHECK: return %[[T1]] : vector<4xf32>
func.func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> {
@@ -113,16 +113,16 @@ func.func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32>
// CHECK-SAME: %[[A:.*0]]: vector<4x1xf32>
// CHECK: %[[U0:.*]] = ub.poison : vector<4x3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<4x1xf32>
-// CHECK: %[[T2:.*]] = vector.splat %[[T0]] : vector<3xf32>
+// CHECK: %[[T2:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[U0]] [0] : vector<3xf32> into vector<4x3xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : f32 from vector<4x1xf32>
-// CHECK: %[[T6:.*]] = vector.splat %[[T4]] : vector<3xf32>
+// CHECK: %[[T6:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<4x3xf32>
// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : f32 from vector<4x1xf32>
-// CHECK: %[[T10:.*]] = vector.splat %[[T8]] : vector<3xf32>
+// CHECK: %[[T10:.*]] = vector.broadcast %[[T8]] : f32 to vector<3xf32>
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T7]] [2] : vector<3xf32> into vector<4x3xf32>
// CHECK: %[[T12:.*]] = vector.extract %[[A]][3, 0] : f32 from vector<4x1xf32>
-// CHECK: %[[T14:.*]] = vector.splat %[[T12]] : vector<3xf32>
+// CHECK: %[[T14:.*]] = vector.broadcast %[[T12]] : f32 to vector<3xf32>
// CHECK: %[[T15:.*]] = vector.insert %[[T14]], %[[T11]] [3] : vector<3xf32> into vector<4x3xf32>
// CHECK: return %[[T15]] : vector<4x3xf32>
diff --git a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir
index 059d955f77313..5a8125ed67173 100644
--- a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir
@@ -5,11 +5,11 @@
// CHECK-SAME: %[[B:.*1]]: vector<3xf32>
// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32>
-// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32>
+// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
// CHECK: %[[T2:.*]] = arith.mulf %[[T1]], %[[B]] : vector<3xf32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xf32>
+// CHECK: %[[T5:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32>
// CHECK: %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32>
// CHECK: return %[[T7]] : vector<2x3xf32>
@@ -26,12 +26,12 @@ func.func @outerproduct_noacc(%arg0: vector<2xf32>,
// CHECK-SAME: %[[C:.*2]]: vector<2x3xf32>
// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32>
-// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32>
+// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xf32> from vector<2x3xf32>
// CHECK: %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32>
// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.splat %[[T5]] : vector<3xf32>
+// CHECK: %[[T6:.*]] = vector.broadcast %[[T5]] : f32 to vector<3xf32>
// CHECK: %[[T7:.*]] = vector.extract %[[C]][1] : vector<3xf32> from vector<2x3xf32>
// CHECK: %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32>
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32>
@@ -49,11 +49,11 @@ func.func @outerproduct_acc(%arg0: vector<2xf32>,
// CHECK-SAME: %[[B:.*1]]: vector<3xi32>
// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32>
-// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32>
+// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : i32 to vector<3xi32>
// CHECK: %[[T2:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32>
-// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xi32>
+// CHECK: %[[T5:.*]] = vector.broadcast %[[T4]] : i32 to vector<3xi32>
// CHECK: %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xi32> into vector<2x3xi32>
// CHECK: return %[[T7]] : vector<2x3xi32>
@@ -69,13 +69,13 @@ func.func @outerproduct_noacc_int(%arg0: vector<2xi32>,
// CHECK-SAME: %[[C:.*2]]: vector<2x3xi32>
// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32>
-// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32>
+// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : i32 to vector<3xi32>
// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xi32> from vector<2x3xi32>
// CHECK: %[[T3:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
// CHECK: %[[T4:.*]] = arith.addi %[[T3]], %[[T2]] : vector<3xi32>
// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32>
-// CHECK: %[[T7:.*]] = vector.splat %[[T6]] : vector<3xi32>
+// CHECK: %[[T7:.*]] = vector.broadcast %[[T6]] : i32 to vector<3xi32>
// CHECK: %[[T8:.*]] = vector.extract %[[C]][1] : vector<3xi32> from vector<2x3xi32>
// CHECK: %[[T9:.*]] = arith.muli %[[T7]], %[[B]] : vector<3xi32>
// CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[T8]] : vector<3xi32>
@@ -91,7 +91,7 @@ func.func @outerproduct_acc_int(%arg0: vector<2xi32>,
// CHECK-LABEL: func @axpy_fp(
// CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
// CHECK-SAME: %[[B:.*1]]: f32)
-// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : f32 to vector<16xf32>
// CHECK: %[[T1:.*]] = arith.mulf %[[A]], %[[T0]] : vector<16xf32>
// CHECK: return %[[T1]] : vector<16xf32>
func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> {
@@ -103,7 +103,7 @@ func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> {
// CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
// CHECK-SAME: %[[B:.*1]]: f32,
// CHECK-SAME: %[[C:.*2]]: vector<16xf32>)
-// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : f32 to vector<16xf32>
// CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32>
// CHECK: return %[[T1]] : vector<16xf32>
func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> {
@@ -114,7 +114,7 @@ func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>
// CHECK-LABEL: func @axpy_int(
// CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
// CHECK-SAME: %[[B:.*1]]: i32)
-// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : i32 to vector<16xi32>
// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
// CHECK: return %[[T1]] : vector<16xi32>
func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> {
@@ -126,7 +126,7 @@ func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> {
// CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
// CHECK-SAME: %[[B:.*1]]: i32,
// CHECK-SAME: %[[C:.*2]]: vector<16xi32>)
-// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : i32 to vector<16xi32>
// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
// CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[C]] : vector<16xi32>
// CHECK: return %[[T2]] : vector<16xi32>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Minor comments
Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource()); | ||
rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext); | ||
SmallVector<int64_t> fullRankPosition(srcRank, 0); | ||
Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we use rewriter.create<...>
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aren't we moving to dialect::Op::create
instead? #147168
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so. It'll take some getting used to..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG % nits
It's really great to see this progressing so quickly, thanks!
10979e2
to
3d1e82b
Compare
Thanks for your thorough and useful reviewing, as usual :-D |
This is part of vector.splat deprecation
Reference: https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/5
Instead of creating vector::SplatOp, create vector::BroadcastOp