11
11
// ===----------------------------------------------------------------------===//
12
12
13
13
#include " mlir/Dialect/Arith/IR/Arith.h"
14
+ #include " mlir/Dialect/UB/IR/UBOps.h"
14
15
#include " mlir/Dialect/Vector/IR/VectorOps.h"
15
16
#include " mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
16
17
#include " mlir/IR/Attributes.h"
@@ -56,40 +57,71 @@ static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
56
57
return trailingVecDimBitWidth <= targetBitWidth;
57
58
}
58
59
60
+ static FailureOr<Attribute>
61
+ linearizeConstAttr (Location loc, ConversionPatternRewriter &rewriter,
62
+ VectorType resType, Attribute value) {
63
+ if (auto dstElementsAttr = dyn_cast<DenseElementsAttr>(value)) {
64
+ if (resType.isScalable () && !isa<SplatElementsAttr>(value))
65
+ return rewriter.notifyMatchFailure (
66
+ loc,
67
+ " Cannot linearize a constant scalable vector that's not a splat" );
68
+
69
+ return dstElementsAttr.reshape (resType);
70
+ }
71
+
72
+ if (auto poisonAttr = dyn_cast<ub::PoisonAttr>(value))
73
+ return poisonAttr;
74
+
75
+ return rewriter.notifyMatchFailure (loc, " unsupported attr type" );
76
+ }
77
+
59
78
namespace {
60
- struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
61
- using OpConversionPattern::OpConversionPattern;
62
- LinearizeConstant (
79
+ struct LinearizeConstantLike final
80
+ : OpTraitConversionPattern<OpTrait::ConstantLike> {
81
+ using OpTraitConversionPattern::OpTraitConversionPattern;
82
+
83
+ LinearizeConstantLike (
63
84
const TypeConverter &typeConverter, MLIRContext *context,
64
85
unsigned targetVectBitWidth = std::numeric_limits<unsigned >::max(),
65
86
PatternBenefit benefit = 1 )
66
- : OpConversionPattern (typeConverter, context, benefit),
87
+ : OpTraitConversionPattern (typeConverter, context, benefit),
67
88
targetVectorBitWidth (targetVectBitWidth) {}
68
89
LogicalResult
69
- matchAndRewrite (arith::ConstantOp constOp, OpAdaptor adaptor ,
90
+ matchAndRewrite (Operation *op, ArrayRef<Value> operands ,
70
91
ConversionPatternRewriter &rewriter) const override {
71
- Location loc = constOp.getLoc ();
92
+ Location loc = op->getLoc ();
93
+ if (op->getNumResults () != 1 )
94
+ return rewriter.notifyMatchFailure (loc, " expected 1 result" );
95
+
96
+ const TypeConverter &converter = *getTypeConverter ();
72
97
auto resType =
73
- getTypeConverter ()-> convertType <VectorType>(constOp .getType ());
98
+ converter. convertType <VectorType>(op-> getResult ( 0 ) .getType ());
74
99
75
100
if (!resType)
76
101
return rewriter.notifyMatchFailure (loc, " can't convert return type" );
77
102
78
- if (resType.isScalable () && !isa<SplatElementsAttr>(constOp.getValue ()))
79
- return rewriter.notifyMatchFailure (
80
- loc,
81
- " Cannot linearize a constant scalable vector that's not a splat" );
82
-
83
- if (!isLessThanTargetBitWidth (constOp, targetVectorBitWidth))
103
+ if (!isLessThanTargetBitWidth (op, targetVectorBitWidth))
84
104
return rewriter.notifyMatchFailure (
85
105
loc, " Can't flatten since targetBitWidth <= OpSize" );
86
- auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue ());
87
- if (!dstElementsAttr)
88
- return rewriter.notifyMatchFailure (loc, " unsupported attr type" );
89
106
90
- dstElementsAttr = dstElementsAttr.reshape (resType);
91
- rewriter.replaceOpWithNewOp <arith::ConstantOp>(constOp, resType,
92
- dstElementsAttr);
107
+ StringAttr attrName = rewriter.getStringAttr (" value" );
108
+ Attribute value = op->getAttr (attrName);
109
+ if (!value)
110
+ return rewriter.notifyMatchFailure (loc, " no 'value' attr" );
111
+
112
+ FailureOr<Attribute> newValue =
113
+ linearizeConstAttr (loc, rewriter, resType, value);
114
+ if (failed (newValue))
115
+ return failure ();
116
+
117
+ FailureOr<Operation *> convertResult =
118
+ convertOpResultTypes (op, /* operands=*/ {}, converter, rewriter);
119
+ if (failed (convertResult))
120
+ return failure ();
121
+
122
+ Operation *newOp = *convertResult;
123
+ newOp->setAttr (attrName, *newValue);
124
+ rewriter.replaceOp (op, newOp);
93
125
return success ();
94
126
}
95
127
@@ -525,7 +557,8 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
525
557
typeConverter.addTargetMaterialization (materializeCast);
526
558
target.markUnknownOpDynamicallyLegal (
527
559
[=](Operation *op) -> std::optional<bool > {
528
- if ((isa<arith::ConstantOp>(op) || isa<vector::BitCastOp>(op) ||
560
+ if ((isa<vector::BitCastOp>(op) ||
561
+ op->hasTrait <OpTrait::ConstantLike>() ||
529
562
op->hasTrait <OpTrait::Vectorizable>())) {
530
563
return (isLessThanTargetBitWidth (op, targetBitWidth)
531
564
? typeConverter.isLegal (op)
@@ -534,9 +567,9 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
534
567
return std::nullopt;
535
568
});
536
569
537
- patterns
538
- . add <LinearizeConstant, LinearizeVectorizable, LinearizeVectorBitCast>(
539
- typeConverter, patterns. getContext (), targetBitWidth);
570
+ patterns. add <LinearizeConstantLike, LinearizeVectorizable,
571
+ LinearizeVectorBitCast>(typeConverter, patterns. getContext (),
572
+ targetBitWidth);
540
573
}
541
574
542
575
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns (
0 commit comments