@@ -141,6 +141,31 @@ struct SimplifySelect : public OpRewritePattern<SelectOp> {
141141 }
142142};
143143
144+ struct SimplifyVecSplat : public OpRewritePattern <VecSplatOp> {
145+ using OpRewritePattern<VecSplatOp>::OpRewritePattern;
146+ LogicalResult matchAndRewrite (VecSplatOp op,
147+ PatternRewriter &rewriter) const override {
148+ mlir::Value splatValue = op.getValue ();
149+ auto constant =
150+ mlir::dyn_cast_if_present<cir::ConstantOp>(splatValue.getDefiningOp ());
151+ if (!constant)
152+ return mlir::failure ();
153+
154+ auto value = constant.getValue ();
155+ if (!mlir::isa_and_nonnull<cir::IntAttr>(value) &&
156+ !mlir::isa_and_nonnull<cir::FPAttr>(value))
157+ return mlir::failure ();
158+
159+ cir::VectorType resultType = op.getResult ().getType ();
160+ SmallVector<mlir::Attribute, 16 > elements (resultType.getSize (), value);
161+ auto constVecAttr = cir::ConstVectorAttr::get (
162+ resultType, mlir::ArrayAttr::get (getContext (), elements));
163+
164+ rewriter.replaceOpWithNewOp <cir::ConstantOp>(op, constVecAttr);
165+ return mlir::success ();
166+ }
167+ };
168+
144169// ===----------------------------------------------------------------------===//
145170// CIRSimplifyPass
146171// ===----------------------------------------------------------------------===//
@@ -155,7 +180,8 @@ void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
155180 // clang-format off
156181 patterns.add <
157182 SimplifyTernary,
158- SimplifySelect
183+ SimplifySelect,
184+ SimplifyVecSplat
159185 >(patterns.getContext ());
160186 // clang-format on
161187}
@@ -168,7 +194,7 @@ void CIRSimplifyPass::runOnOperation() {
168194 // Collect operations to apply patterns.
169195 llvm::SmallVector<Operation *, 16 > ops;
170196 getOperation ()->walk ([&](Operation *op) {
171- if (isa<TernaryOp, SelectOp>(op))
197+ if (isa<TernaryOp, SelectOp, VecSplatOp >(op))
172198 ops.push_back (op);
173199 });
174200
0 commit comments