Skip to content

Commit 02fae68

Browse files
authored
[mlir][vector] VectorLinearize: ub.poison support (#128612)
Unify `arith.constant` and `up.poison` using `OpTraitConversionPattern<OpTrait::ConstantLike>`.
1 parent 95d28fe commit 02fae68

File tree

2 files changed

+72
-23
lines changed

2 files changed

+72
-23
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 56 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "mlir/Dialect/Arith/IR/Arith.h"
14+
#include "mlir/Dialect/UB/IR/UBOps.h"
1415
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1516
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
1617
#include "mlir/IR/Attributes.h"
@@ -56,40 +57,71 @@ static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
5657
return trailingVecDimBitWidth <= targetBitWidth;
5758
}
5859

60+
static FailureOr<Attribute>
61+
linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
62+
VectorType resType, Attribute value) {
63+
if (auto dstElementsAttr = dyn_cast<DenseElementsAttr>(value)) {
64+
if (resType.isScalable() && !isa<SplatElementsAttr>(value))
65+
return rewriter.notifyMatchFailure(
66+
loc,
67+
"Cannot linearize a constant scalable vector that's not a splat");
68+
69+
return dstElementsAttr.reshape(resType);
70+
}
71+
72+
if (auto poisonAttr = dyn_cast<ub::PoisonAttr>(value))
73+
return poisonAttr;
74+
75+
return rewriter.notifyMatchFailure(loc, "unsupported attr type");
76+
}
77+
5978
namespace {
60-
struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
61-
using OpConversionPattern::OpConversionPattern;
62-
LinearizeConstant(
79+
struct LinearizeConstantLike final
80+
: OpTraitConversionPattern<OpTrait::ConstantLike> {
81+
using OpTraitConversionPattern::OpTraitConversionPattern;
82+
83+
LinearizeConstantLike(
6384
const TypeConverter &typeConverter, MLIRContext *context,
6485
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
6586
PatternBenefit benefit = 1)
66-
: OpConversionPattern(typeConverter, context, benefit),
87+
: OpTraitConversionPattern(typeConverter, context, benefit),
6788
targetVectorBitWidth(targetVectBitWidth) {}
6889
LogicalResult
69-
matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
90+
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
7091
ConversionPatternRewriter &rewriter) const override {
71-
Location loc = constOp.getLoc();
92+
Location loc = op->getLoc();
93+
if (op->getNumResults() != 1)
94+
return rewriter.notifyMatchFailure(loc, "expected 1 result");
95+
96+
const TypeConverter &converter = *getTypeConverter();
7297
auto resType =
73-
getTypeConverter()->convertType<VectorType>(constOp.getType());
98+
converter.convertType<VectorType>(op->getResult(0).getType());
7499

75100
if (!resType)
76101
return rewriter.notifyMatchFailure(loc, "can't convert return type");
77102

78-
if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
79-
return rewriter.notifyMatchFailure(
80-
loc,
81-
"Cannot linearize a constant scalable vector that's not a splat");
82-
83-
if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
103+
if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
84104
return rewriter.notifyMatchFailure(
85105
loc, "Can't flatten since targetBitWidth <= OpSize");
86-
auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
87-
if (!dstElementsAttr)
88-
return rewriter.notifyMatchFailure(loc, "unsupported attr type");
89106

90-
dstElementsAttr = dstElementsAttr.reshape(resType);
91-
rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType,
92-
dstElementsAttr);
107+
StringAttr attrName = rewriter.getStringAttr("value");
108+
Attribute value = op->getAttr(attrName);
109+
if (!value)
110+
return rewriter.notifyMatchFailure(loc, "no 'value' attr");
111+
112+
FailureOr<Attribute> newValue =
113+
linearizeConstAttr(loc, rewriter, resType, value);
114+
if (failed(newValue))
115+
return failure();
116+
117+
FailureOr<Operation *> convertResult =
118+
convertOpResultTypes(op, /*operands=*/{}, converter, rewriter);
119+
if (failed(convertResult))
120+
return failure();
121+
122+
Operation *newOp = *convertResult;
123+
newOp->setAttr(attrName, *newValue);
124+
rewriter.replaceOp(op, newOp);
93125
return success();
94126
}
95127

@@ -525,7 +557,8 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
525557
typeConverter.addTargetMaterialization(materializeCast);
526558
target.markUnknownOpDynamicallyLegal(
527559
[=](Operation *op) -> std::optional<bool> {
528-
if ((isa<arith::ConstantOp>(op) || isa<vector::BitCastOp>(op) ||
560+
if ((isa<vector::BitCastOp>(op) ||
561+
op->hasTrait<OpTrait::ConstantLike>() ||
529562
op->hasTrait<OpTrait::Vectorizable>())) {
530563
return (isLessThanTargetBitWidth(op, targetBitWidth)
531564
? typeConverter.isLegal(op)
@@ -534,9 +567,9 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
534567
return std::nullopt;
535568
});
536569

537-
patterns
538-
.add<LinearizeConstant, LinearizeVectorizable, LinearizeVectorBitCast>(
539-
typeConverter, patterns.getContext(), targetBitWidth);
570+
patterns.add<LinearizeConstantLike, LinearizeVectorizable,
571+
LinearizeVectorBitCast>(typeConverter, patterns.getContext(),
572+
targetBitWidth);
540573
}
541574

542575
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,22 @@ func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
3232

3333
// -----
3434

35+
// ALL-LABEL: test_linearize_poison
36+
func.func @test_linearize_poison() -> vector<2x2xf32> {
37+
// DEFAULT: %[[POISON:.*]] = ub.poison : vector<4xf32>
38+
// DEFAULT: %[[RES:.*]] = vector.shape_cast %[[POISON]] : vector<4xf32> to vector<2x2xf32>
39+
40+
// BW-128: %[[POISON:.*]] = ub.poison : vector<4xf32>
41+
// BW-128: %[[RES:.*]] = vector.shape_cast %[[POISON]] : vector<4xf32> to vector<2x2xf32>
42+
43+
// BW-0: %[[RES:.*]] = ub.poison : vector<2x2xf32>
44+
%0 = ub.poison : vector<2x2xf32>
45+
// ALL: return %[[RES]] : vector<2x2xf32>
46+
return %0 : vector<2x2xf32>
47+
}
48+
49+
// -----
50+
3551
// ALL-LABEL: test_partial_linearize
3652
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>, %[[ORIG_ARG2:.*]]: vector<4x4xf32>)
3753
func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>) -> vector<2x2xf32> {

0 commit comments

Comments
 (0)