Skip to content

Commit 3adac49

Browse files
authored
[CIR] Backport VecSplatOp simplifier (#1704)
Backporting the VecSplatOp simplifier from the upstream
1 parent 19c36a6 commit 3adac49

File tree

2 files changed

+44
-2
lines changed

2 files changed

+44
-2
lines changed

clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,31 @@ struct SimplifySelect : public OpRewritePattern<SelectOp> {
141141
}
142142
};
143143

144+
struct SimplifyVecSplat : public OpRewritePattern<VecSplatOp> {
145+
using OpRewritePattern<VecSplatOp>::OpRewritePattern;
146+
LogicalResult matchAndRewrite(VecSplatOp op,
147+
PatternRewriter &rewriter) const override {
148+
mlir::Value splatValue = op.getValue();
149+
auto constant =
150+
mlir::dyn_cast_if_present<cir::ConstantOp>(splatValue.getDefiningOp());
151+
if (!constant)
152+
return mlir::failure();
153+
154+
auto value = constant.getValue();
155+
if (!mlir::isa_and_nonnull<cir::IntAttr>(value) &&
156+
!mlir::isa_and_nonnull<cir::FPAttr>(value))
157+
return mlir::failure();
158+
159+
cir::VectorType resultType = op.getResult().getType();
160+
SmallVector<mlir::Attribute, 16> elements(resultType.getSize(), value);
161+
auto constVecAttr = cir::ConstVectorAttr::get(
162+
resultType, mlir::ArrayAttr::get(getContext(), elements));
163+
164+
rewriter.replaceOpWithNewOp<cir::ConstantOp>(op, constVecAttr);
165+
return mlir::success();
166+
}
167+
};
168+
144169
//===----------------------------------------------------------------------===//
145170
// CIRSimplifyPass
146171
//===----------------------------------------------------------------------===//
@@ -155,7 +180,8 @@ void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
155180
// clang-format off
156181
patterns.add<
157182
SimplifyTernary,
158-
SimplifySelect
183+
SimplifySelect,
184+
SimplifyVecSplat
159185
>(patterns.getContext());
160186
// clang-format on
161187
}
@@ -168,7 +194,7 @@ void CIRSimplifyPass::runOnOperation() {
168194
// Collect operations to apply patterns.
169195
llvm::SmallVector<Operation *, 16> ops;
170196
getOperation()->walk([&](Operation *op) {
171-
if (isa<TernaryOp, SelectOp>(op))
197+
if (isa<TernaryOp, SelectOp, VecSplatOp>(op))
172198
ops.push_back(op);
173199
});
174200

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// RUN: cir-opt %s -cir-simplify -o - | FileCheck %s
2+
3+
!s32i = !cir.int<s, 32>
4+
5+
module {
6+
cir.func @fold_splat_vector_op_test() -> !cir.vector<!s32i x 4> {
7+
%v = cir.const #cir.int<3> : !s32i
8+
%vec = cir.vec.splat %v : !s32i, !cir.vector<!s32i x 4>
9+
cir.return %vec : !cir.vector<!s32i x 4>
10+
}
11+
12+
// CHECK: cir.func @fold_splat_vector_op_test() -> !cir.vector<!s32i x 4> {
13+
// CHECK-NEXT: %0 = cir.const #cir.const_vector<[#cir.int<3> : !s32i, #cir.int<3> : !s32i,
14+
// CHECK-SAME: #cir.int<3> : !s32i, #cir.int<3> : !s32i]> : !cir.vector<!s32i x 4>
15+
// CHECK-NEXT: cir.return %0 : !cir.vector<!s32i x 4>
16+
}

0 commit comments

Comments
 (0)