Skip to content

Conversation

newling
Copy link
Contributor

@newling newling commented Jul 23, 2025

This is part of vector.splat deprecation
Reference: https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/5

Instead of creating vector::SplatOp, create vector::BroadcastOp

@newling newling changed the title [mlir][vector] Towards deprecating vector.splat [mlir][vector] Deprecate use of vector.splat in transforms Jul 23, 2025
@newling newling changed the title [mlir][vector] Deprecate use of vector.splat in transforms [mlir][vector] Avoid use of vector.splat in transforms Jul 23, 2025
@newling newling marked this pull request as ready for review July 23, 2025 18:54
@llvmbot
Copy link
Member

llvmbot commented Jul 23, 2025

@llvm/pr-subscribers-mlir

Author: James Newling (newling)

Changes

This is part of vector.splat deprecation
Reference: https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/5

Instead of creating vector::SplatOp, create vector::BroadcastOp


Full diff: https://github.com/llvm/llvm-project/pull/150279.diff

6 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp (+12-9)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+36-20)
  • (modified) mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir (+8-8)
  • (modified) mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir (+12-12)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index cb8e566869cfd..0a48eab3e6666 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -28,7 +28,10 @@ using namespace mlir;
 using namespace mlir::vector;
 
 namespace {
-/// Progressive lowering of BroadcastOp.
+
+/// Convert a vector.broadcast without a scalar operand to a lower rank
+/// vector.broadcast. vector.broadcast with a scalar operand is expected to be
+/// convertible to the lower level target dialect (LLVM, SPIR-V, etc.) directly.
 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
@@ -40,20 +43,20 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
     VectorType srcType = dyn_cast<VectorType>(op.getSourceType());
     Type eltType = dstType.getElementType();
 
-    // Scalar to any vector can use splat.
-    if (!srcType) {
-      rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource());
-      return success();
-    }
+    // A broadcast from a scalar is considered to be in the lowered form.
+    if (!srcType)
+      return failure();
 
     // Determine rank of source and destination.
     int64_t srcRank = srcType.getRank();
     int64_t dstRank = dstType.getRank();
 
-    // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
     if (srcRank <= 1 && dstRank == 1) {
-      Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource());
-      rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
+      SmallVector<int64_t> fullRankPosition(srcRank, 0);
+      Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(),
+                                            fullRankPosition);
+      assert(!isa<VectorType>(ext.getType()) && "expected scalar");
+      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, dstType, ext);
       return success();
     }
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index e9109322ed3d8..7122c53e49780 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -468,7 +468,7 @@ struct TransferReadToVectorLoadLowering
             read, "vector type is not rank 1, can't create masked load, needs "
                   "VectorToSCF");
 
-      Value fill = vector::SplatOp::create(
+      Value fill = vector::BroadcastOp::create(
           rewriter, read.getLoc(), unbroadcastedVectorType, read.getPadding());
       res = vector::MaskedLoadOp::create(
           rewriter, read.getLoc(), unbroadcastedVectorType, read.getBase(),
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 72352d72bfe77..cbb9d4bbf0b1f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -303,7 +303,7 @@ class DecomposeNDExtractStridedSlice
     // Extract/insert on a lower ranked extract strided slice op.
     Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
                                            rewriter.getZeroAttr(elemType));
-    Value res = SplatOp::create(rewriter, loc, dstType, zero);
+    Value res = BroadcastOp::create(rewriter, loc, dstType, zero);
     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
          off += stride, ++idx) {
       Value one = ExtractOp::create(rewriter, loc, op.getVector(), off);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 73ca327bb49c5..abcc49144187e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -940,7 +940,7 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
 
     Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
                                            rewriter.getZeroAttr(elemType));
-    Value res = SplatOp::create(rewriter, loc, castDstType, zero);
+    Value res = BroadcastOp::create(rewriter, loc, castDstType, zero);
 
     SmallVector<int64_t> sliceShape = {castDstLastDim};
     SmallVector<int64_t> strides = {1};
@@ -966,6 +966,23 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
   std::function<bool(BitCastOp)> controlFn;
 };
 
+/// If `value` is the result of a splat or broadcast operation, return the input
+/// of the splat/broadcast operation.
+static Value getBroadcastLikeSource(Value value) {
+
+  Operation *op = value.getDefiningOp();
+  if (!op)
+    return {};
+
+  if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
+    return broadcast.getSource();
+
+  if (auto splat = dyn_cast<vector::SplatOp>(op))
+    return splat.getInput();
+
+  return {};
+}
+
 /// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
 ///
 /// Example:
@@ -1007,26 +1024,23 @@ struct ReorderElementwiseOpsOnBroadcast final
     }
 
     // Get the type of the lhs operand
-    auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
-    if (!lhsBcastOrSplat ||
-        !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
-      return failure();
-    auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
+    Value lhsSource = getBroadcastLikeSource(op->getOperand(0));
+    if (!lhsSource)
+      return rewriter.notifyMatchFailure(
+          op, "operand #0 not the result of a broadcast");
+    Type lhsBcastOrSplatType = lhsSource.getType();
 
     // Make sure that all operands are broadcast from identical types:
     //  * scalar (`vector.broadcast` + `vector.splat`), or
     //  * vector (`vector.broadcast`).
     // Otherwise the re-ordering wouldn't be safe.
-    if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
-          auto bcast = val.getDefiningOp<vector::BroadcastOp>();
-          if (bcast)
-            return (bcast.getOperand().getType() == lhsBcastOrSplatType);
-          auto splat = val.getDefiningOp<vector::SplatOp>();
-          if (splat)
-            return (splat.getOperand().getType() == lhsBcastOrSplatType);
+    if (!llvm::all_of(op->getOperands(), [lhsBcastOrSplatType](Value val) {
+          if (auto source = getBroadcastLikeSource(val))
+            return source.getType() == lhsBcastOrSplatType;
           return false;
         })) {
-      return failure();
+      return rewriter.notifyMatchFailure(
+          op, "not all operands are broadcasts from the sametype");
     }
 
     // Collect the source values before broadcasting
@@ -1240,15 +1254,17 @@ class StoreOpFromSplatOrBroadcast final
       return rewriter.notifyMatchFailure(
           op, "only 1-element vectors are supported");
 
-    Operation *splat = op.getValueToStore().getDefiningOp();
-    if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
-      return rewriter.notifyMatchFailure(op, "neither a splat nor a broadcast");
+    Value toStore = op.getValueToStore();
+    Value source = getBroadcastLikeSource(toStore);
+    if (!source)
+      return rewriter.notifyMatchFailure(
+          op, "value to store is not from a broadcast");
 
     // Checking for single use so we can remove splat.
+    Operation *splat = toStore.getDefiningOp();
     if (!splat->hasOneUse())
       return rewriter.notifyMatchFailure(op, "expected single op use");
 
-    Value source = splat->getOperand(0);
     Value base = op.getBase();
     ValueRange indices = op.getIndices();
 
@@ -1298,13 +1314,13 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
   // Add in an offset if requested.
   if (off) {
     Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
-    Value ov = vector::SplatOp::create(rewriter, loc, indices.getType(), o);
+    Value ov = vector::BroadcastOp::create(rewriter, loc, indices.getType(), o);
     indices = arith::AddIOp::create(rewriter, loc, ov, indices);
   }
   // Construct the vector comparison.
   Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
   Value bounds =
-      vector::SplatOp::create(rewriter, loc, indices.getType(), bound);
+      vector::BroadcastOp::create(rewriter, loc, indices.getType(), bound);
   return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
                                indices, bounds);
 }
diff --git a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
index 8e167a520260f..d5e344393b217 100644
--- a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
@@ -2,7 +2,7 @@
 
 // CHECK-LABEL: func @broadcast_vec1d_from_scalar
 // CHECK-SAME: %[[A:.*0]]: f32
-// CHECK:      %[[T0:.*]] = vector.splat %[[A]] : vector<2xf32>
+// CHECK:      %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2xf32>
 // CHECK:      return %[[T0]] : vector<2xf32>
 
 func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
@@ -12,7 +12,7 @@ func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
 
 // CHECK-LABEL: func @broadcast_vec2d_from_scalar
 // CHECK-SAME: %[[A:.*0]]: f32
-// CHECK:      %[[T0:.*]] = vector.splat %[[A]] : vector<2x3xf32>
+// CHECK:      %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3xf32>
 // CHECK:      return %[[T0]] : vector<2x3xf32>
 
 func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
@@ -22,7 +22,7 @@ func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
 
 // CHECK-LABEL: func @broadcast_vec3d_from_scalar
 // CHECK-SAME: %[[A:.*0]]: f32
-// CHECK:      %[[T0:.*]] = vector.splat %[[A]] : vector<2x3x4xf32>
+// CHECK:      %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3x4xf32>
 // CHECK:      return %[[T0]] : vector<2x3x4xf32>
 
 func.func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> {
@@ -87,7 +87,7 @@ func.func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf3
 // CHECK-LABEL: func @broadcast_stretch
 // CHECK-SAME: %[[A:.*0]]: vector<1xf32>
 // CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<1xf32>
-// CHECK:      %[[T1:.*]] = vector.splat %[[T0]] : vector<4xf32>
+// CHECK:      %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<4xf32>
 // CHECK:      return %[[T1]] : vector<4xf32>
 
 func.func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> {
@@ -113,16 +113,16 @@ func.func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32>
 // CHECK-SAME: %[[A:.*0]]: vector<4x1xf32>
 // CHECK:      %[[U0:.*]] = ub.poison : vector<4x3xf32>
 // CHECK:      %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<4x1xf32>
-// CHECK:      %[[T2:.*]] = vector.splat %[[T0]] : vector<3xf32>
+// CHECK:      %[[T2:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
 // CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[U0]] [0] : vector<3xf32> into vector<4x3xf32>
 // CHECK:      %[[T4:.*]] = vector.extract %[[A]][1, 0] : f32 from vector<4x1xf32>
-// CHECK:      %[[T6:.*]] = vector.splat %[[T4]] : vector<3xf32>
+// CHECK:      %[[T6:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32>
 // CHECK:      %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<4x3xf32>
 // CHECK:      %[[T8:.*]] = vector.extract %[[A]][2, 0] : f32 from vector<4x1xf32>
-// CHECK:      %[[T10:.*]] = vector.splat %[[T8]] : vector<3xf32>
+// CHECK:      %[[T10:.*]] = vector.broadcast %[[T8]] : f32 to vector<3xf32>
 // CHECK:      %[[T11:.*]] = vector.insert %[[T10]], %[[T7]] [2] : vector<3xf32> into vector<4x3xf32>
 // CHECK:      %[[T12:.*]] = vector.extract %[[A]][3, 0] : f32 from vector<4x1xf32>
-// CHECK:      %[[T14:.*]] = vector.splat %[[T12]] : vector<3xf32>
+// CHECK:      %[[T14:.*]] = vector.broadcast %[[T12]] : f32 to vector<3xf32>
 // CHECK:      %[[T15:.*]] = vector.insert %[[T14]], %[[T11]] [3] : vector<3xf32> into vector<4x3xf32>
 // CHECK:      return %[[T15]] : vector<4x3xf32>
 
diff --git a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir
index 059d955f77313..5a8125ed67173 100644
--- a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir
@@ -5,11 +5,11 @@
 // CHECK-SAME: %[[B:.*1]]: vector<3xf32>
 // CHECK:      %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
 // CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32>
-// CHECK:      %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32>
+// CHECK:      %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
 // CHECK:      %[[T2:.*]] = arith.mulf %[[T1]], %[[B]] : vector<3xf32>
 // CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
 // CHECK:      %[[T4:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32>
-// CHECK:      %[[T5:.*]] = vector.splat %[[T4]] : vector<3xf32>
+// CHECK:      %[[T5:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32>
 // CHECK:      %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32>
 // CHECK:      %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32>
 // CHECK:      return %[[T7]] : vector<2x3xf32>
@@ -26,12 +26,12 @@ func.func @outerproduct_noacc(%arg0: vector<2xf32>,
 // CHECK-SAME: %[[C:.*2]]: vector<2x3xf32>
 // CHECK:      %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
 // CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32>
-// CHECK:      %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32>
+// CHECK:      %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
 // CHECK:      %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xf32> from vector<2x3xf32>
 // CHECK:      %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32>
 // CHECK:      %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
 // CHECK:      %[[T5:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32>
-// CHECK:      %[[T6:.*]] = vector.splat %[[T5]] : vector<3xf32>
+// CHECK:      %[[T6:.*]] = vector.broadcast %[[T5]] : f32 to vector<3xf32>
 // CHECK:      %[[T7:.*]] = vector.extract %[[C]][1] : vector<3xf32> from vector<2x3xf32>
 // CHECK:      %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32>
 // CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32>
@@ -49,11 +49,11 @@ func.func @outerproduct_acc(%arg0: vector<2xf32>,
 // CHECK-SAME: %[[B:.*1]]: vector<3xi32>
 // CHECK:      %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
 // CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32>
-// CHECK:      %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32>
+// CHECK:      %[[T1:.*]] = vector.broadcast %[[T0]] : i32 to vector<3xi32>
 // CHECK:      %[[T2:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
 // CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
 // CHECK:      %[[T4:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32>
-// CHECK:      %[[T5:.*]] = vector.splat %[[T4]] : vector<3xi32>
+// CHECK:      %[[T5:.*]] = vector.broadcast %[[T4]] : i32 to vector<3xi32>
 // CHECK:      %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32>
 // CHECK:      %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xi32> into vector<2x3xi32>
 // CHECK:      return %[[T7]] : vector<2x3xi32>
@@ -69,13 +69,13 @@ func.func @outerproduct_noacc_int(%arg0: vector<2xi32>,
 // CHECK-SAME: %[[C:.*2]]: vector<2x3xi32>
 // CHECK:      %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
 // CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32>
-// CHECK:      %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32>
+// CHECK:      %[[T1:.*]] = vector.broadcast %[[T0]] : i32 to vector<3xi32>
 // CHECK:      %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xi32> from vector<2x3xi32>
 // CHECK:      %[[T3:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
 // CHECK:      %[[T4:.*]] = arith.addi %[[T3]], %[[T2]] : vector<3xi32>
 // CHECK:      %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
 // CHECK:      %[[T6:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32>
-// CHECK:      %[[T7:.*]] = vector.splat %[[T6]] : vector<3xi32>
+// CHECK:      %[[T7:.*]] = vector.broadcast %[[T6]] : i32 to vector<3xi32>
 // CHECK:      %[[T8:.*]] = vector.extract %[[C]][1] : vector<3xi32> from vector<2x3xi32>
 // CHECK:      %[[T9:.*]] = arith.muli %[[T7]], %[[B]] : vector<3xi32>
 // CHECK:      %[[T10:.*]] = arith.addi %[[T9]], %[[T8]] : vector<3xi32>
@@ -91,7 +91,7 @@ func.func @outerproduct_acc_int(%arg0: vector<2xi32>,
 // CHECK-LABEL: func @axpy_fp(
 // CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
 // CHECK-SAME: %[[B:.*1]]: f32)
-// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : f32 to vector<16xf32>
 // CHECK: %[[T1:.*]] = arith.mulf %[[A]], %[[T0]] : vector<16xf32>
 // CHECK: return %[[T1]] : vector<16xf32>
 func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> {
@@ -103,7 +103,7 @@ func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> {
 // CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
 // CHECK-SAME: %[[B:.*1]]: f32,
 // CHECK-SAME: %[[C:.*2]]: vector<16xf32>)
-// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : f32 to vector<16xf32>
 // CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32>
 // CHECK: return %[[T1]] : vector<16xf32>
 func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> {
@@ -114,7 +114,7 @@ func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>
 // CHECK-LABEL: func @axpy_int(
 // CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
 // CHECK-SAME: %[[B:.*1]]: i32)
-// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : i32 to vector<16xi32>
 // CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
 // CHECK: return %[[T1]] : vector<16xi32>
 func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> {
@@ -126,7 +126,7 @@ func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> {
 // CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
 // CHECK-SAME: %[[B:.*1]]: i32,
 // CHECK-SAME: %[[C:.*2]]: vector<16xi32>)
-// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : i32 to vector<16xi32>
 // CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
 // CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[C]] : vector<16xi32>
 // CHECK: return %[[T2]] : vector<16xi32>

@llvmbot
Copy link
Member

llvmbot commented Jul 23, 2025

@llvm/pr-subscribers-mlir-vector

Author: James Newling (newling)

Changes

This is part of vector.splat deprecation
Reference: https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/5

Instead of creating vector::SplatOp, create vector::BroadcastOp


Full diff: https://github.com/llvm/llvm-project/pull/150279.diff

6 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp (+12-9)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+36-20)
  • (modified) mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir (+8-8)
  • (modified) mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir (+12-12)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index cb8e566869cfd..0a48eab3e6666 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -28,7 +28,10 @@ using namespace mlir;
 using namespace mlir::vector;
 
 namespace {
-/// Progressive lowering of BroadcastOp.
+
+/// Convert a vector.broadcast without a scalar operand to a lower rank
+/// vector.broadcast. vector.broadcast with a scalar operand is expected to be
+/// convertible to the lower level target dialect (LLVM, SPIR-V, etc.) directly.
 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
@@ -40,20 +43,20 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
     VectorType srcType = dyn_cast<VectorType>(op.getSourceType());
     Type eltType = dstType.getElementType();
 
-    // Scalar to any vector can use splat.
-    if (!srcType) {
-      rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource());
-      return success();
-    }
+    // A broadcast from a scalar is considered to be in the lowered form.
+    if (!srcType)
+      return failure();
 
     // Determine rank of source and destination.
     int64_t srcRank = srcType.getRank();
     int64_t dstRank = dstType.getRank();
 
-    // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
     if (srcRank <= 1 && dstRank == 1) {
-      Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource());
-      rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
+      SmallVector<int64_t> fullRankPosition(srcRank, 0);
+      Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(),
+                                            fullRankPosition);
+      assert(!isa<VectorType>(ext.getType()) && "expected scalar");
+      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, dstType, ext);
       return success();
     }
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index e9109322ed3d8..7122c53e49780 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -468,7 +468,7 @@ struct TransferReadToVectorLoadLowering
             read, "vector type is not rank 1, can't create masked load, needs "
                   "VectorToSCF");
 
-      Value fill = vector::SplatOp::create(
+      Value fill = vector::BroadcastOp::create(
           rewriter, read.getLoc(), unbroadcastedVectorType, read.getPadding());
       res = vector::MaskedLoadOp::create(
           rewriter, read.getLoc(), unbroadcastedVectorType, read.getBase(),
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 72352d72bfe77..cbb9d4bbf0b1f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -303,7 +303,7 @@ class DecomposeNDExtractStridedSlice
     // Extract/insert on a lower ranked extract strided slice op.
     Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
                                            rewriter.getZeroAttr(elemType));
-    Value res = SplatOp::create(rewriter, loc, dstType, zero);
+    Value res = BroadcastOp::create(rewriter, loc, dstType, zero);
     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
          off += stride, ++idx) {
       Value one = ExtractOp::create(rewriter, loc, op.getVector(), off);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 73ca327bb49c5..abcc49144187e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -940,7 +940,7 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
 
     Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
                                            rewriter.getZeroAttr(elemType));
-    Value res = SplatOp::create(rewriter, loc, castDstType, zero);
+    Value res = BroadcastOp::create(rewriter, loc, castDstType, zero);
 
     SmallVector<int64_t> sliceShape = {castDstLastDim};
     SmallVector<int64_t> strides = {1};
@@ -966,6 +966,23 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
   std::function<bool(BitCastOp)> controlFn;
 };
 
+/// If `value` is the result of a splat or broadcast operation, return the input
+/// of the splat/broadcast operation.
+static Value getBroadcastLikeSource(Value value) {
+
+  Operation *op = value.getDefiningOp();
+  if (!op)
+    return {};
+
+  if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
+    return broadcast.getSource();
+
+  if (auto splat = dyn_cast<vector::SplatOp>(op))
+    return splat.getInput();
+
+  return {};
+}
+
 /// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
 ///
 /// Example:
@@ -1007,26 +1024,23 @@ struct ReorderElementwiseOpsOnBroadcast final
     }
 
     // Get the type of the lhs operand
-    auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
-    if (!lhsBcastOrSplat ||
-        !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
-      return failure();
-    auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
+    Value lhsSource = getBroadcastLikeSource(op->getOperand(0));
+    if (!lhsSource)
+      return rewriter.notifyMatchFailure(
+          op, "operand #0 not the result of a broadcast");
+    Type lhsBcastOrSplatType = lhsSource.getType();
 
     // Make sure that all operands are broadcast from identical types:
     //  * scalar (`vector.broadcast` + `vector.splat`), or
     //  * vector (`vector.broadcast`).
     // Otherwise the re-ordering wouldn't be safe.
-    if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
-          auto bcast = val.getDefiningOp<vector::BroadcastOp>();
-          if (bcast)
-            return (bcast.getOperand().getType() == lhsBcastOrSplatType);
-          auto splat = val.getDefiningOp<vector::SplatOp>();
-          if (splat)
-            return (splat.getOperand().getType() == lhsBcastOrSplatType);
+    if (!llvm::all_of(op->getOperands(), [lhsBcastOrSplatType](Value val) {
+          if (auto source = getBroadcastLikeSource(val))
+            return source.getType() == lhsBcastOrSplatType;
           return false;
         })) {
-      return failure();
+      return rewriter.notifyMatchFailure(
+          op, "not all operands are broadcasts from the sametype");
     }
 
     // Collect the source values before broadcasting
@@ -1240,15 +1254,17 @@ class StoreOpFromSplatOrBroadcast final
       return rewriter.notifyMatchFailure(
           op, "only 1-element vectors are supported");
 
-    Operation *splat = op.getValueToStore().getDefiningOp();
-    if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
-      return rewriter.notifyMatchFailure(op, "neither a splat nor a broadcast");
+    Value toStore = op.getValueToStore();
+    Value source = getBroadcastLikeSource(toStore);
+    if (!source)
+      return rewriter.notifyMatchFailure(
+          op, "value to store is not from a broadcast");
 
     // Checking for single use so we can remove splat.
+    Operation *splat = toStore.getDefiningOp();
     if (!splat->hasOneUse())
       return rewriter.notifyMatchFailure(op, "expected single op use");
 
-    Value source = splat->getOperand(0);
     Value base = op.getBase();
     ValueRange indices = op.getIndices();
 
@@ -1298,13 +1314,13 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
   // Add in an offset if requested.
   if (off) {
     Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
-    Value ov = vector::SplatOp::create(rewriter, loc, indices.getType(), o);
+    Value ov = vector::BroadcastOp::create(rewriter, loc, indices.getType(), o);
     indices = arith::AddIOp::create(rewriter, loc, ov, indices);
   }
   // Construct the vector comparison.
   Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
   Value bounds =
-      vector::SplatOp::create(rewriter, loc, indices.getType(), bound);
+      vector::BroadcastOp::create(rewriter, loc, indices.getType(), bound);
   return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
                                indices, bounds);
 }
diff --git a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
index 8e167a520260f..d5e344393b217 100644
--- a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
@@ -2,7 +2,7 @@
 
 // CHECK-LABEL: func @broadcast_vec1d_from_scalar
 // CHECK-SAME: %[[A:.*0]]: f32
-// CHECK:      %[[T0:.*]] = vector.splat %[[A]] : vector<2xf32>
+// CHECK:      %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2xf32>
 // CHECK:      return %[[T0]] : vector<2xf32>
 
 func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
@@ -12,7 +12,7 @@ func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
 
 // CHECK-LABEL: func @broadcast_vec2d_from_scalar
 // CHECK-SAME: %[[A:.*0]]: f32
-// CHECK:      %[[T0:.*]] = vector.splat %[[A]] : vector<2x3xf32>
+// CHECK:      %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3xf32>
 // CHECK:      return %[[T0]] : vector<2x3xf32>
 
 func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
@@ -22,7 +22,7 @@ func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
 
 // CHECK-LABEL: func @broadcast_vec3d_from_scalar
 // CHECK-SAME: %[[A:.*0]]: f32
-// CHECK:      %[[T0:.*]] = vector.splat %[[A]] : vector<2x3x4xf32>
+// CHECK:      %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3x4xf32>
 // CHECK:      return %[[T0]] : vector<2x3x4xf32>
 
 func.func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> {
@@ -87,7 +87,7 @@ func.func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf3
 // CHECK-LABEL: func @broadcast_stretch
 // CHECK-SAME: %[[A:.*0]]: vector<1xf32>
 // CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<1xf32>
-// CHECK:      %[[T1:.*]] = vector.splat %[[T0]] : vector<4xf32>
+// CHECK:      %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<4xf32>
 // CHECK:      return %[[T1]] : vector<4xf32>
 
 func.func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> {
@@ -113,16 +113,16 @@ func.func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32>
 // CHECK-SAME: %[[A:.*0]]: vector<4x1xf32>
 // CHECK:      %[[U0:.*]] = ub.poison : vector<4x3xf32>
 // CHECK:      %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<4x1xf32>
-// CHECK:      %[[T2:.*]] = vector.splat %[[T0]] : vector<3xf32>
+// CHECK:      %[[T2:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
 // CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[U0]] [0] : vector<3xf32> into vector<4x3xf32>
 // CHECK:      %[[T4:.*]] = vector.extract %[[A]][1, 0] : f32 from vector<4x1xf32>
-// CHECK:      %[[T6:.*]] = vector.splat %[[T4]] : vector<3xf32>
+// CHECK:      %[[T6:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32>
 // CHECK:      %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<4x3xf32>
 // CHECK:      %[[T8:.*]] = vector.extract %[[A]][2, 0] : f32 from vector<4x1xf32>
-// CHECK:      %[[T10:.*]] = vector.splat %[[T8]] : vector<3xf32>
+// CHECK:      %[[T10:.*]] = vector.broadcast %[[T8]] : f32 to vector<3xf32>
 // CHECK:      %[[T11:.*]] = vector.insert %[[T10]], %[[T7]] [2] : vector<3xf32> into vector<4x3xf32>
 // CHECK:      %[[T12:.*]] = vector.extract %[[A]][3, 0] : f32 from vector<4x1xf32>
-// CHECK:      %[[T14:.*]] = vector.splat %[[T12]] : vector<3xf32>
+// CHECK:      %[[T14:.*]] = vector.broadcast %[[T12]] : f32 to vector<3xf32>
 // CHECK:      %[[T15:.*]] = vector.insert %[[T14]], %[[T11]] [3] : vector<3xf32> into vector<4x3xf32>
 // CHECK:      return %[[T15]] : vector<4x3xf32>
 
diff --git a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir
index 059d955f77313..5a8125ed67173 100644
--- a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir
@@ -5,11 +5,11 @@
 // CHECK-SAME: %[[B:.*1]]: vector<3xf32>
 // CHECK:      %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
 // CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32>
-// CHECK:      %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32>
+// CHECK:      %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
 // CHECK:      %[[T2:.*]] = arith.mulf %[[T1]], %[[B]] : vector<3xf32>
 // CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
 // CHECK:      %[[T4:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32>
-// CHECK:      %[[T5:.*]] = vector.splat %[[T4]] : vector<3xf32>
+// CHECK:      %[[T5:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32>
 // CHECK:      %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32>
 // CHECK:      %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32>
 // CHECK:      return %[[T7]] : vector<2x3xf32>
@@ -26,12 +26,12 @@ func.func @outerproduct_noacc(%arg0: vector<2xf32>,
 // CHECK-SAME: %[[C:.*2]]: vector<2x3xf32>
 // CHECK:      %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
 // CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32>
-// CHECK:      %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32>
+// CHECK:      %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
 // CHECK:      %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xf32> from vector<2x3xf32>
 // CHECK:      %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32>
 // CHECK:      %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
 // CHECK:      %[[T5:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32>
-// CHECK:      %[[T6:.*]] = vector.splat %[[T5]] : vector<3xf32>
+// CHECK:      %[[T6:.*]] = vector.broadcast %[[T5]] : f32 to vector<3xf32>
 // CHECK:      %[[T7:.*]] = vector.extract %[[C]][1] : vector<3xf32> from vector<2x3xf32>
 // CHECK:      %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32>
 // CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32>
@@ -49,11 +49,11 @@ func.func @outerproduct_acc(%arg0: vector<2xf32>,
 // CHECK-SAME: %[[B:.*1]]: vector<3xi32>
 // CHECK:      %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
 // CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32>
-// CHECK:      %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32>
+// CHECK:      %[[T1:.*]] = vector.broadcast %[[T0]] : i32 to vector<3xi32>
 // CHECK:      %[[T2:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
 // CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
 // CHECK:      %[[T4:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32>
-// CHECK:      %[[T5:.*]] = vector.splat %[[T4]] : vector<3xi32>
+// CHECK:      %[[T5:.*]] = vector.broadcast %[[T4]] : i32 to vector<3xi32>
 // CHECK:      %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32>
 // CHECK:      %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xi32> into vector<2x3xi32>
 // CHECK:      return %[[T7]] : vector<2x3xi32>
@@ -69,13 +69,13 @@ func.func @outerproduct_noacc_int(%arg0: vector<2xi32>,
 // CHECK-SAME: %[[C:.*2]]: vector<2x3xi32>
 // CHECK:      %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
 // CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32>
-// CHECK:      %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32>
+// CHECK:      %[[T1:.*]] = vector.broadcast %[[T0]] : i32 to vector<3xi32>
 // CHECK:      %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xi32> from vector<2x3xi32>
 // CHECK:      %[[T3:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
 // CHECK:      %[[T4:.*]] = arith.addi %[[T3]], %[[T2]] : vector<3xi32>
 // CHECK:      %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
 // CHECK:      %[[T6:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32>
-// CHECK:      %[[T7:.*]] = vector.splat %[[T6]] : vector<3xi32>
+// CHECK:      %[[T7:.*]] = vector.broadcast %[[T6]] : i32 to vector<3xi32>
 // CHECK:      %[[T8:.*]] = vector.extract %[[C]][1] : vector<3xi32> from vector<2x3xi32>
 // CHECK:      %[[T9:.*]] = arith.muli %[[T7]], %[[B]] : vector<3xi32>
 // CHECK:      %[[T10:.*]] = arith.addi %[[T9]], %[[T8]] : vector<3xi32>
@@ -91,7 +91,7 @@ func.func @outerproduct_acc_int(%arg0: vector<2xi32>,
 // CHECK-LABEL: func @axpy_fp(
 // CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
 // CHECK-SAME: %[[B:.*1]]: f32)
-// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : f32 to vector<16xf32>
 // CHECK: %[[T1:.*]] = arith.mulf %[[A]], %[[T0]] : vector<16xf32>
 // CHECK: return %[[T1]] : vector<16xf32>
 func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> {
@@ -103,7 +103,7 @@ func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> {
 // CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
 // CHECK-SAME: %[[B:.*1]]: f32,
 // CHECK-SAME: %[[C:.*2]]: vector<16xf32>)
-// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : f32 to vector<16xf32>
 // CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32>
 // CHECK: return %[[T1]] : vector<16xf32>
 func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> {
@@ -114,7 +114,7 @@ func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>
 // CHECK-LABEL: func @axpy_int(
 // CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
 // CHECK-SAME: %[[B:.*1]]: i32)
-// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : i32 to vector<16xi32>
 // CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
 // CHECK: return %[[T1]] : vector<16xi32>
 func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> {
@@ -126,7 +126,7 @@ func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> {
 // CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
 // CHECK-SAME: %[[B:.*1]]: i32,
 // CHECK-SAME: %[[C:.*2]]: vector<16xi32>)
-// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : i32 to vector<16xi32>
 // CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
 // CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[C]] : vector<16xi32>
 // CHECK: return %[[T2]] : vector<16xi32>

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Minor comments

Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource());
rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
SmallVector<int64_t> fullRankPosition(srcRank, 0);
Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we use rewriter.create<...> here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't we moving to dialect::Op::create instead? #147168

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. It'll take some getting used to..

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG % nits

It's really great to see this progressing so quickly, thanks!

@newling newling force-pushed the deprecate-vector-splat-vector-transforms branch from 10979e2 to 3d1e82b Compare July 31, 2025 00:16
@newling
Copy link
Contributor Author

newling commented Jul 31, 2025

LG % nits

It's really great to see this progressing so quickly, thanks!

Thanks for your thorough and useful reviewing, as usual :-D

@newling newling merged commit 671eaf8 into llvm:main Jul 31, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants