Skip to content

[mlir][vector] VectorLinearize: ub.poison support #128612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 13, 2025

Conversation

Hardcode84
Copy link
Contributor

@Hardcode84 Hardcode84 commented Feb 25, 2025

Unify arith.constant and up.poison using OpTraitConversionPattern<OpTrait::ConstantLike>.

@llvmbot
Copy link
Member

llvmbot commented Feb 25, 2025

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/128612.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+34-4)
  • (modified) mlir/test/Dialect/Vector/linearize.mlir (+16)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 3ecd585c5a26d..65bd982319e45 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/IR/Attributes.h"
@@ -97,6 +98,35 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
   unsigned targetVectorBitWidth;
 };
 
+struct LinearizePoison final : OpConversionPattern<ub::PoisonOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizePoison(
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
+  LogicalResult
+  matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto resType = getTypeConverter()->convertType<VectorType>(op.getType());
+
+    if (!resType)
+      return rewriter.notifyMatchFailure(loc, "can't convert return type");
+
+    if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
+      return rewriter.notifyMatchFailure(
+          loc, "Can't flatten since targetBitWidth <= OpSize");
+
+    rewriter.replaceOpWithNewOp<ub::PoisonOp>(op, resType);
+    return success();
+  }
+
+private:
+  unsigned targetVectorBitWidth;
+};
+
 struct LinearizeVectorizable final
     : OpTraitConversionPattern<OpTrait::Vectorizable> {
   using OpTraitConversionPattern::OpTraitConversionPattern;
@@ -525,7 +555,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
   typeConverter.addTargetMaterialization(materializeCast);
   target.markUnknownOpDynamicallyLegal(
       [=](Operation *op) -> std::optional<bool> {
-        if ((isa<arith::ConstantOp>(op) || isa<vector::BitCastOp>(op) ||
+        if ((isa<arith::ConstantOp, ub::PoisonOp, vector::BitCastOp>(op) ||
              op->hasTrait<OpTrait::Vectorizable>())) {
           return (isLessThanTargetBitWidth(op, targetBitWidth)
                       ? typeConverter.isLegal(op)
@@ -534,9 +564,9 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
         return std::nullopt;
       });
 
-  patterns
-      .add<LinearizeConstant, LinearizeVectorizable, LinearizeVectorBitCast>(
-          typeConverter, patterns.getContext(), targetBitWidth);
+  patterns.add<LinearizeConstant, LinearizePoison, LinearizeVectorizable,
+               LinearizeVectorBitCast>(typeConverter, patterns.getContext(),
+                                       targetBitWidth);
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 99b1bbab1eede..f859ffd0e19d7 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -32,6 +32,22 @@ func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
 
 // -----
 
+// ALL-LABEL: test_linearize_poison
+func.func @test_linearize_poison() -> vector<2x2xf32> {
+  // DEFAULT: %[[P:.*]] = ub.poison : vector<4xf32>
+  // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[P]] : vector<4xf32> to vector<2x2xf32>
+
+  // BW-128: %[[P:.*]] = ub.poison : vector<4xf32>
+  // BW-128: %[[RES:.*]] = vector.shape_cast %[[P]] : vector<4xf32> to vector<2x2xf32>
+
+  // BW-0: %[[RES:.*]] = ub.poison : vector<2x2xf32>
+  %0 = ub.poison : vector<2x2xf32>
+  // ALL: return %[[RES]] : vector<2x2xf32>
+  return %0 : vector<2x2xf32>
+}
+
+// -----
+
 // ALL-LABEL: test_partial_linearize
 // ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>, %[[ORIG_ARG2:.*]]: vector<4x4xf32>)
 func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>) -> vector<2x2xf32> {

@llvmbot
Copy link
Member

llvmbot commented Feb 25, 2025

@llvm/pr-subscribers-mlir-vector

Author: Ivan Butygin (Hardcode84)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/128612.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+34-4)
  • (modified) mlir/test/Dialect/Vector/linearize.mlir (+16)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 3ecd585c5a26d..65bd982319e45 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/IR/Attributes.h"
@@ -97,6 +98,35 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
   unsigned targetVectorBitWidth;
 };
 
+struct LinearizePoison final : OpConversionPattern<ub::PoisonOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizePoison(
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
+  LogicalResult
+  matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto resType = getTypeConverter()->convertType<VectorType>(op.getType());
+
+    if (!resType)
+      return rewriter.notifyMatchFailure(loc, "can't convert return type");
+
+    if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
+      return rewriter.notifyMatchFailure(
+          loc, "Can't flatten since targetBitWidth <= OpSize");
+
+    rewriter.replaceOpWithNewOp<ub::PoisonOp>(op, resType);
+    return success();
+  }
+
+private:
+  unsigned targetVectorBitWidth;
+};
+
 struct LinearizeVectorizable final
     : OpTraitConversionPattern<OpTrait::Vectorizable> {
   using OpTraitConversionPattern::OpTraitConversionPattern;
@@ -525,7 +555,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
   typeConverter.addTargetMaterialization(materializeCast);
   target.markUnknownOpDynamicallyLegal(
       [=](Operation *op) -> std::optional<bool> {
-        if ((isa<arith::ConstantOp>(op) || isa<vector::BitCastOp>(op) ||
+        if ((isa<arith::ConstantOp, ub::PoisonOp, vector::BitCastOp>(op) ||
              op->hasTrait<OpTrait::Vectorizable>())) {
           return (isLessThanTargetBitWidth(op, targetBitWidth)
                       ? typeConverter.isLegal(op)
@@ -534,9 +564,9 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
         return std::nullopt;
       });
 
-  patterns
-      .add<LinearizeConstant, LinearizeVectorizable, LinearizeVectorBitCast>(
-          typeConverter, patterns.getContext(), targetBitWidth);
+  patterns.add<LinearizeConstant, LinearizePoison, LinearizeVectorizable,
+               LinearizeVectorBitCast>(typeConverter, patterns.getContext(),
+                                       targetBitWidth);
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 99b1bbab1eede..f859ffd0e19d7 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -32,6 +32,22 @@ func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
 
 // -----
 
+// ALL-LABEL: test_linearize_poison
+func.func @test_linearize_poison() -> vector<2x2xf32> {
+  // DEFAULT: %[[P:.*]] = ub.poison : vector<4xf32>
+  // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[P]] : vector<4xf32> to vector<2x2xf32>
+
+  // BW-128: %[[P:.*]] = ub.poison : vector<4xf32>
+  // BW-128: %[[RES:.*]] = vector.shape_cast %[[P]] : vector<4xf32> to vector<2x2xf32>
+
+  // BW-0: %[[RES:.*]] = ub.poison : vector<2x2xf32>
+  %0 = ub.poison : vector<2x2xf32>
+  // ALL: return %[[RES]] : vector<2x2xf32>
+  return %0 : vector<2x2xf32>
+}
+
+// -----
+
 // ALL-LABEL: test_partial_linearize
 // ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>, %[[ORIG_ARG2:.*]]: vector<4x4xf32>)
 func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>) -> vector<2x2xf32> {

private:
unsigned targetVectorBitWidth;
};

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious... Could we add OpTrait::Vectorizable to the poison op and reuse the rewrite below? A scalar poison value should be trivially vectorizable...

Copy link
Contributor Author

@Hardcode84 Hardcode84 Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, we can try to match ConstantLike as both arith constant and poison have it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's even better :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

auto resType =
getTypeConverter()->convertType<VectorType>(constOp.getType());
getTypeConverter()->convertType<VectorType>(op->getResult(0).getType());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: take a reference to the type converter and reuse it across the function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


dstElementsAttr = dstElementsAttr.reshape(resType);
FailureOr<Operation *> newOp =
convertOpResultTypes(op, {}, *getTypeConverter(), rewriter);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: {} -> /*paramName=*/{}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonder if we could also refactor the convertOpResultTypes code from DenseElementsAttr and PoisonAttr. Perhaps we could move the early exist to the top and the refactor some common code after that?

Copy link
Contributor Author

@Hardcode84 Hardcode84 Mar 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, moved attr conversion to separate function

Comment on lines 105 to 117
(*newOp)->setAttr(attrName, dstElementsAttr);
rewriter.replaceOp(op, *newOp);
return success();
}

if (auto poisonAttr = dyn_cast<ub::PoisonAttr>(value)) {
FailureOr<Operation *> newOp =
convertOpResultTypes(op, {}, *getTypeConverter(), rewriter);
if (failed(newOp))
return failure();

(*newOp)->setAttr(attrName, poisonAttr);
rewriter.replaceOp(op, *newOp);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: deref newOp only once?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -32,6 +32,22 @@ func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {

// -----

// ALL-LABEL: test_linearize_poison
func.func @test_linearize_poison() -> vector<2x2xf32> {
// DEFAULT: %[[P:.*]] = ub.poison : vector<4xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: P -> POISON per new WIP guidelines :)
llvm/mlir-www#216

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@Hardcode84
Copy link
Contributor Author

ping

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@Hardcode84 Hardcode84 merged commit 02fae68 into llvm:main Mar 13, 2025
11 checks passed
@Hardcode84 Hardcode84 deleted the poison-linearize branch March 13, 2025 11:18
frederik-h pushed a commit to frederik-h/llvm-project that referenced this pull request Mar 18, 2025
Unify `arith.constant` and `up.poison` using
`OpTraitConversionPattern<OpTrait::ConstantLike>`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants