-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][vector] VectorLinearize: ub.poison
support
#128612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Ivan Butygin (Hardcode84) ChangesFull diff: https://github.com/llvm/llvm-project/pull/128612.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 3ecd585c5a26d..65bd982319e45 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/Attributes.h"
@@ -97,6 +98,35 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
unsigned targetVectorBitWidth;
};
+struct LinearizePoison final : OpConversionPattern<ub::PoisonOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizePoison(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+ LogicalResult
+ matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto resType = getTypeConverter()->convertType<VectorType>(op.getType());
+
+ if (!resType)
+ return rewriter.notifyMatchFailure(loc, "can't convert return type");
+
+ if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
+ return rewriter.notifyMatchFailure(
+ loc, "Can't flatten since targetBitWidth <= OpSize");
+
+ rewriter.replaceOpWithNewOp<ub::PoisonOp>(op, resType);
+ return success();
+ }
+
+private:
+ unsigned targetVectorBitWidth;
+};
+
struct LinearizeVectorizable final
: OpTraitConversionPattern<OpTrait::Vectorizable> {
using OpTraitConversionPattern::OpTraitConversionPattern;
@@ -525,7 +555,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
typeConverter.addTargetMaterialization(materializeCast);
target.markUnknownOpDynamicallyLegal(
[=](Operation *op) -> std::optional<bool> {
- if ((isa<arith::ConstantOp>(op) || isa<vector::BitCastOp>(op) ||
+ if ((isa<arith::ConstantOp, ub::PoisonOp, vector::BitCastOp>(op) ||
op->hasTrait<OpTrait::Vectorizable>())) {
return (isLessThanTargetBitWidth(op, targetBitWidth)
? typeConverter.isLegal(op)
@@ -534,9 +564,9 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
return std::nullopt;
});
- patterns
- .add<LinearizeConstant, LinearizeVectorizable, LinearizeVectorBitCast>(
- typeConverter, patterns.getContext(), targetBitWidth);
+ patterns.add<LinearizeConstant, LinearizePoison, LinearizeVectorizable,
+ LinearizeVectorBitCast>(typeConverter, patterns.getContext(),
+ targetBitWidth);
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 99b1bbab1eede..f859ffd0e19d7 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -32,6 +32,22 @@ func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
// -----
+// ALL-LABEL: test_linearize_poison
+func.func @test_linearize_poison() -> vector<2x2xf32> {
+ // DEFAULT: %[[P:.*]] = ub.poison : vector<4xf32>
+ // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[P]] : vector<4xf32> to vector<2x2xf32>
+
+ // BW-128: %[[P:.*]] = ub.poison : vector<4xf32>
+ // BW-128: %[[RES:.*]] = vector.shape_cast %[[P]] : vector<4xf32> to vector<2x2xf32>
+
+ // BW-0: %[[RES:.*]] = ub.poison : vector<2x2xf32>
+ %0 = ub.poison : vector<2x2xf32>
+ // ALL: return %[[RES]] : vector<2x2xf32>
+ return %0 : vector<2x2xf32>
+}
+
+// -----
+
// ALL-LABEL: test_partial_linearize
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>, %[[ORIG_ARG2:.*]]: vector<4x4xf32>)
func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>) -> vector<2x2xf32> {
|
@llvm/pr-subscribers-mlir-vector Author: Ivan Butygin (Hardcode84) ChangesFull diff: https://github.com/llvm/llvm-project/pull/128612.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 3ecd585c5a26d..65bd982319e45 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/Attributes.h"
@@ -97,6 +98,35 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
unsigned targetVectorBitWidth;
};
+struct LinearizePoison final : OpConversionPattern<ub::PoisonOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizePoison(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+ LogicalResult
+ matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto resType = getTypeConverter()->convertType<VectorType>(op.getType());
+
+ if (!resType)
+ return rewriter.notifyMatchFailure(loc, "can't convert return type");
+
+ if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
+ return rewriter.notifyMatchFailure(
+ loc, "Can't flatten since targetBitWidth <= OpSize");
+
+ rewriter.replaceOpWithNewOp<ub::PoisonOp>(op, resType);
+ return success();
+ }
+
+private:
+ unsigned targetVectorBitWidth;
+};
+
struct LinearizeVectorizable final
: OpTraitConversionPattern<OpTrait::Vectorizable> {
using OpTraitConversionPattern::OpTraitConversionPattern;
@@ -525,7 +555,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
typeConverter.addTargetMaterialization(materializeCast);
target.markUnknownOpDynamicallyLegal(
[=](Operation *op) -> std::optional<bool> {
- if ((isa<arith::ConstantOp>(op) || isa<vector::BitCastOp>(op) ||
+ if ((isa<arith::ConstantOp, ub::PoisonOp, vector::BitCastOp>(op) ||
op->hasTrait<OpTrait::Vectorizable>())) {
return (isLessThanTargetBitWidth(op, targetBitWidth)
? typeConverter.isLegal(op)
@@ -534,9 +564,9 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
return std::nullopt;
});
- patterns
- .add<LinearizeConstant, LinearizeVectorizable, LinearizeVectorBitCast>(
- typeConverter, patterns.getContext(), targetBitWidth);
+ patterns.add<LinearizeConstant, LinearizePoison, LinearizeVectorizable,
+ LinearizeVectorBitCast>(typeConverter, patterns.getContext(),
+ targetBitWidth);
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 99b1bbab1eede..f859ffd0e19d7 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -32,6 +32,22 @@ func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
// -----
+// ALL-LABEL: test_linearize_poison
+func.func @test_linearize_poison() -> vector<2x2xf32> {
+ // DEFAULT: %[[P:.*]] = ub.poison : vector<4xf32>
+ // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[P]] : vector<4xf32> to vector<2x2xf32>
+
+ // BW-128: %[[P:.*]] = ub.poison : vector<4xf32>
+ // BW-128: %[[RES:.*]] = vector.shape_cast %[[P]] : vector<4xf32> to vector<2x2xf32>
+
+ // BW-0: %[[RES:.*]] = ub.poison : vector<2x2xf32>
+ %0 = ub.poison : vector<2x2xf32>
+ // ALL: return %[[RES]] : vector<2x2xf32>
+ return %0 : vector<2x2xf32>
+}
+
+// -----
+
// ALL-LABEL: test_partial_linearize
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>, %[[ORIG_ARG2:.*]]: vector<4x4xf32>)
func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>) -> vector<2x2xf32> {
|
private: | ||
unsigned targetVectorBitWidth; | ||
}; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm curious... Could we add OpTrait::Vectorizable
to the poison op and reuse the rewrite below? A scalar poison value should be trivially vectorizable...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, we can try to match ConstantLike
as both arith constant and poison have it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's even better :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
auto resType = | ||
getTypeConverter()->convertType<VectorType>(constOp.getType()); | ||
getTypeConverter()->convertType<VectorType>(op->getResult(0).getType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: take a reference to the type converter and reuse it across the function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
dstElementsAttr = dstElementsAttr.reshape(resType); | ||
FailureOr<Operation *> newOp = | ||
convertOpResultTypes(op, {}, *getTypeConverter(), rewriter); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: {}
-> /*paramName=*/{}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wonder if we could also refactor the convertOpResultTypes
code from DenseElementsAttr and PoisonAttr. Perhaps we could move the early exist to the top and the refactor some common code after that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, moved attr conversion to separate function
(*newOp)->setAttr(attrName, dstElementsAttr); | ||
rewriter.replaceOp(op, *newOp); | ||
return success(); | ||
} | ||
|
||
if (auto poisonAttr = dyn_cast<ub::PoisonAttr>(value)) { | ||
FailureOr<Operation *> newOp = | ||
convertOpResultTypes(op, {}, *getTypeConverter(), rewriter); | ||
if (failed(newOp)) | ||
return failure(); | ||
|
||
(*newOp)->setAttr(attrName, poisonAttr); | ||
rewriter.replaceOp(op, *newOp); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: deref newOp
only once?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -32,6 +32,22 @@ func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> { | |||
|
|||
// ----- | |||
|
|||
// ALL-LABEL: test_linearize_poison | |||
func.func @test_linearize_poison() -> vector<2x2xf32> { | |||
// DEFAULT: %[[P:.*]] = ub.poison : vector<4xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: P
-> POISON
per new WIP guidelines :)
llvm/mlir-www#216
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
ping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Unify `arith.constant` and `up.poison` using `OpTraitConversionPattern<OpTrait::ConstantLike>`.
Unify
arith.constant
andup.poison
usingOpTraitConversionPattern<OpTrait::ConstantLike>
.