Skip to content

[mlir][ArmSME] Add option to only enable streaming mode for scalable code #94759

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 10, 2024

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Jun 7, 2024

This adds a new option -enable-arm-streaming=if-contains-scalable-vectors, which only applies the selected streaming/ZA modes if the function contains scalable vector types.

As a NFC this patch also removes the only- prefix from the if-required-by-ops mode.

…code

This adds a new option `-enable-arm-streaming=if-contains-scalable-vectors`,
which only applies the selected streaming/ZA modes if the function
contains scalable vector types.

As a NFC this patch also removes the `only-` prefix from the
`if-required-by-ops` mode.
@llvmbot
Copy link
Member

llvmbot commented Jun 7, 2024

@llvm/pr-subscribers-mlir-sme

@llvm/pr-subscribers-mlir-linalg

Author: Benjamin Maxwell (MacDue)

Changes

This adds a new option -enable-arm-streaming=if-contains-scalable-vectors, which only applies the selected streaming/ZA modes if the function contains scalable vector types.

As a NFC this patch also removes the only- prefix from the if-required-by-ops mode.


Full diff: https://github.com/llvm/llvm-project/pull/94759.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h (+2-1)
  • (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td (+7-3)
  • (modified) mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp (+38-11)
  • (added) mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir (+4)
  • (modified) mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir (+16-1)
  • (modified) mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir (+1-1)
  • (modified) mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp (+1-1)
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index 156744ba57e7b..167e5b787d1af 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -27,7 +27,8 @@ namespace arm_sme {
 /// Pass to enable Armv9 Streaming SVE mode.
 std::unique_ptr<Pass> createEnableArmStreamingPass(
     const ArmStreamingMode = ArmStreamingMode::Streaming,
-    const ArmZaMode = ArmZaMode::Disabled, bool onlyIfRequiredByOps = false);
+    const ArmZaMode = ArmZaMode::Disabled, bool ifRequiredByOps = false,
+    bool ifContainsScalableVectors = false);
 
 /// Pass that fuses 'arm_sme.outerproduct' ops into 2-way or 4-way widening
 /// variants.
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 869a031d6cae8..8aba121432bba 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -116,10 +116,14 @@ def EnableArmStreaming
                             "not be used for input and/or output and the "
                             "function must return with ZA unchanged")
            )}]>,
-    Option<"onlyIfRequiredByOps", "only-if-required-by-ops", "bool",
+    Option<"ifRequiredByOps", "if-required-by-ops", "bool",
            /*default=*/"false",
-           "Only apply the selected streaming/ZA modes if the function "
-           " contains ops that require them.">
+           "Apply the selected streaming/ZA modes if the function contains ops "
+           "that require them.">,
+    Option<"ifContainsScalableVectors", "if-contains-scalable-vectors",
+           "bool", /*default=*/"false",
+           "Apply the selected streaming/ZA modes if the function contains "
+           "operations that use scalable vector types.">
   ];
   let dependentDialects = ["func::FuncDialect"];
 }
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
index 79a6caffb6ee0..fb4bb41d87488 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
@@ -58,17 +58,25 @@ constexpr StringLiteral
 struct EnableArmStreamingPass
     : public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
   EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
-                         bool onlyIfRequiredByOps) {
+                         bool ifRequiredByOps, bool ifContainsScalableVectors) {
     this->streamingMode = streamingMode;
     this->zaMode = zaMode;
-    this->onlyIfRequiredByOps = onlyIfRequiredByOps;
+    this->ifRequiredByOps = ifRequiredByOps;
+    this->ifContainsScalableVectors = ifContainsScalableVectors;
   }
   void runOnOperation() override {
-    auto op = getOperation();
+    auto function = getOperation();
 
-    if (onlyIfRequiredByOps) {
+    if (ifRequiredByOps && ifContainsScalableVectors) {
+      function->emitOpError(
+          "enable-arm-streaming: `if-required-by-ops` and "
+          "`if-contains-scalable-vectors` are mutually exclusive");
+      return signalPassFailure();
+    }
+
+    if (ifRequiredByOps) {
       bool foundTileOp = false;
-      op.walk([&](Operation *op) {
+      function.walk([&](Operation *op) {
         if (llvm::isa<ArmSMETileOpInterface>(op)) {
           foundTileOp = true;
           return WalkResult::interrupt();
@@ -79,27 +87,46 @@ struct EnableArmStreamingPass
         return;
     }
 
-    if (op->getAttr(kEnableArmStreamingIgnoreAttr) ||
+    if (ifContainsScalableVectors) {
+      bool foundScalableVector = false;
+      auto isScalableVector = [&](Type type) {
+        if (auto vectorType = dyn_cast<VectorType>(type))
+          return vectorType.isScalable();
+        return false;
+      };
+      function.walk([&](Operation *op) {
+        if (llvm::any_of(op->getOperandTypes(), isScalableVector) ||
+            llvm::any_of(op->getResultTypes(), isScalableVector)) {
+          foundScalableVector = true;
+          return WalkResult::interrupt();
+        }
+        return WalkResult::advance();
+      });
+      if (!foundScalableVector)
+        return;
+    }
+
+    if (function->getAttr(kEnableArmStreamingIgnoreAttr) ||
         streamingMode == ArmStreamingMode::Disabled)
       return;
 
     auto unitAttr = UnitAttr::get(&getContext());
 
-    op->setAttr(stringifyArmStreamingMode(streamingMode), unitAttr);
+    function->setAttr(stringifyArmStreamingMode(streamingMode), unitAttr);
 
     // The pass currently only supports enabling ZA when in streaming-mode, but
     // ZA can be accessed by the SME LDR, STR and ZERO instructions when not in
     // streaming-mode (see section B1.1.1, IDGNQM of spec [1]). It may be worth
     // supporting this later.
     if (zaMode != ArmZaMode::Disabled)
-      op->setAttr(stringifyArmZaMode(zaMode), unitAttr);
+      function->setAttr(stringifyArmZaMode(zaMode), unitAttr);
   }
 };
 } // namespace
 
 std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass(
     const ArmStreamingMode streamingMode, const ArmZaMode zaMode,
-    bool onlyIfRequiredByOps) {
-  return std::make_unique<EnableArmStreamingPass>(streamingMode, zaMode,
-                                                  onlyIfRequiredByOps);
+    bool ifRequiredByOps, bool ifContainsScalableVectors) {
+  return std::make_unique<EnableArmStreamingPass>(
+      streamingMode, zaMode, ifRequiredByOps, ifContainsScalableVectors);
 }
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir
new file mode 100644
index 0000000000000..da70b632d70c4
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir
@@ -0,0 +1,4 @@
+// RUN: mlir-opt %s -enable-arm-streaming="if-contains-scalable-vectors if-required-by-ops" -verify-diagnostics
+
+// expected-error@below {{enable-arm-streaming: `if-required-by-ops` and `if-contains-scalable-vectors` are mutually exclusive}}
+func.func @test() { return }
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
index 6b58d8fdc41b0..2011802c5c8b2 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
@@ -2,7 +2,8 @@
 // RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-locally -verify-diagnostics | FileCheck %s -check-prefix=CHECK-LOCALLY
 // RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-compatible -verify-diagnostics | FileCheck %s -check-prefix=CHECK-COMPATIBLE
 // RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA
-// RUN: mlir-opt %s -enable-arm-streaming=only-if-required-by-ops -verify-diagnostics | FileCheck %s -check-prefix=IF-REQUIRED
+// RUN: mlir-opt %s -enable-arm-streaming=if-required-by-ops -verify-diagnostics | FileCheck %s -check-prefix=IF-REQUIRED
+// RUN: mlir-opt %s -enable-arm-streaming=if-contains-scalable-vectors -verify-diagnostics | FileCheck %s -check-prefix=IF-SCALABLE
 
 // CHECK-LABEL: @arm_streaming
 // CHECK-SAME: attributes {arm_streaming}
@@ -38,3 +39,17 @@ func.func @requires_arm_streaming() {
 // IF-REQUIRED: @does_not_require_arm_streaming
 // IF-REQUIRED-NOT: arm_streaming
 func.func @does_not_require_arm_streaming() { return }
+
+// IF-SCALABLE-LABEL: @contains_scalable_vectors
+// IF-SCALABLE-SAME: attributes {arm_streaming}
+func.func @contains_scalable_vectors(%vec: vector<[4]xf32>) -> vector<[4]xf32> {
+  %0 = arith.addf %vec, %vec : vector<[4]xf32>
+  return %0 : vector<[4]xf32>
+}
+
+// IF-SCALABLE-LABEL: @no_scalable_vectors
+// IF-SCALABLE-NOT: arm_streaming
+func.func @no_scalable_vectors(%vec: vector<4xf32>) -> vector<4xf32> {
+  %0 = arith.addf %vec, %vec : vector<4xf32>
+  return %0 : vector<4xf32>
+}
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir
index 10ffed2688178..aabd9d2ce788e 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir
@@ -4,7 +4,7 @@
 // RUN:   -arm-sme-vector-legalization -canonicalize -cse \
 // RUN:   -convert-vector-to-arm-sme -arm-sme-outer-product-fusion \
 // RUN:   -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
-// RUN:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
+// RUN:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za if-required-by-ops" \
 // RUN:   -convert-vector-to-scf=full-unroll -convert-arm-sme-to-llvm \
 // RUN:   -test-lower-to-llvm | \
 // RUN: %mcr_aarch64_cmd \
diff --git a/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp b/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp
index d3dabaf200fdc..a220791969d53 100644
--- a/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp
+++ b/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp
@@ -74,7 +74,7 @@ void buildTestLowerToArmSME(OpPassManager &pm,
   // Enable streaming-mode and ZA.
   pm.addPass(arm_sme::createEnableArmStreamingPass(
       arm_sme::ArmStreamingMode::StreamingLocally, arm_sme::ArmZaMode::NewZA,
-      /*onlyIfRequiredByOps=*/true));
+      /*ifRequiredByOps=*/true));
 
   // Convert SCF to CF (required for ArmSME tile allocation).
   pm.addPass(createConvertSCFToCFPass());

@llvmbot
Copy link
Member

llvmbot commented Jun 7, 2024

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

Changes

This adds a new option -enable-arm-streaming=if-contains-scalable-vectors, which only applies the selected streaming/ZA modes if the function contains scalable vector types.

As a NFC this patch also removes the only- prefix from the if-required-by-ops mode.


Full diff: https://github.com/llvm/llvm-project/pull/94759.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h (+2-1)
  • (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td (+7-3)
  • (modified) mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp (+38-11)
  • (added) mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir (+4)
  • (modified) mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir (+16-1)
  • (modified) mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir (+1-1)
  • (modified) mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp (+1-1)
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index 156744ba57e7b..167e5b787d1af 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -27,7 +27,8 @@ namespace arm_sme {
 /// Pass to enable Armv9 Streaming SVE mode.
 std::unique_ptr<Pass> createEnableArmStreamingPass(
     const ArmStreamingMode = ArmStreamingMode::Streaming,
-    const ArmZaMode = ArmZaMode::Disabled, bool onlyIfRequiredByOps = false);
+    const ArmZaMode = ArmZaMode::Disabled, bool ifRequiredByOps = false,
+    bool ifContainsScalableVectors = false);
 
 /// Pass that fuses 'arm_sme.outerproduct' ops into 2-way or 4-way widening
 /// variants.
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 869a031d6cae8..8aba121432bba 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -116,10 +116,14 @@ def EnableArmStreaming
                             "not be used for input and/or output and the "
                             "function must return with ZA unchanged")
            )}]>,
-    Option<"onlyIfRequiredByOps", "only-if-required-by-ops", "bool",
+    Option<"ifRequiredByOps", "if-required-by-ops", "bool",
            /*default=*/"false",
-           "Only apply the selected streaming/ZA modes if the function "
-           " contains ops that require them.">
+           "Apply the selected streaming/ZA modes if the function contains ops "
+           "that require them.">,
+    Option<"ifContainsScalableVectors", "if-contains-scalable-vectors",
+           "bool", /*default=*/"false",
+           "Apply the selected streaming/ZA modes if the function contains "
+           "operations that use scalable vector types.">
   ];
   let dependentDialects = ["func::FuncDialect"];
 }
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
index 79a6caffb6ee0..fb4bb41d87488 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
@@ -58,17 +58,25 @@ constexpr StringLiteral
 struct EnableArmStreamingPass
     : public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
   EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
-                         bool onlyIfRequiredByOps) {
+                         bool ifRequiredByOps, bool ifContainsScalableVectors) {
     this->streamingMode = streamingMode;
     this->zaMode = zaMode;
-    this->onlyIfRequiredByOps = onlyIfRequiredByOps;
+    this->ifRequiredByOps = ifRequiredByOps;
+    this->ifContainsScalableVectors = ifContainsScalableVectors;
   }
   void runOnOperation() override {
-    auto op = getOperation();
+    auto function = getOperation();
 
-    if (onlyIfRequiredByOps) {
+    if (ifRequiredByOps && ifContainsScalableVectors) {
+      function->emitOpError(
+          "enable-arm-streaming: `if-required-by-ops` and "
+          "`if-contains-scalable-vectors` are mutually exclusive");
+      return signalPassFailure();
+    }
+
+    if (ifRequiredByOps) {
       bool foundTileOp = false;
-      op.walk([&](Operation *op) {
+      function.walk([&](Operation *op) {
         if (llvm::isa<ArmSMETileOpInterface>(op)) {
           foundTileOp = true;
           return WalkResult::interrupt();
@@ -79,27 +87,46 @@ struct EnableArmStreamingPass
         return;
     }
 
-    if (op->getAttr(kEnableArmStreamingIgnoreAttr) ||
+    if (ifContainsScalableVectors) {
+      bool foundScalableVector = false;
+      auto isScalableVector = [&](Type type) {
+        if (auto vectorType = dyn_cast<VectorType>(type))
+          return vectorType.isScalable();
+        return false;
+      };
+      function.walk([&](Operation *op) {
+        if (llvm::any_of(op->getOperandTypes(), isScalableVector) ||
+            llvm::any_of(op->getResultTypes(), isScalableVector)) {
+          foundScalableVector = true;
+          return WalkResult::interrupt();
+        }
+        return WalkResult::advance();
+      });
+      if (!foundScalableVector)
+        return;
+    }
+
+    if (function->getAttr(kEnableArmStreamingIgnoreAttr) ||
         streamingMode == ArmStreamingMode::Disabled)
       return;
 
     auto unitAttr = UnitAttr::get(&getContext());
 
-    op->setAttr(stringifyArmStreamingMode(streamingMode), unitAttr);
+    function->setAttr(stringifyArmStreamingMode(streamingMode), unitAttr);
 
     // The pass currently only supports enabling ZA when in streaming-mode, but
     // ZA can be accessed by the SME LDR, STR and ZERO instructions when not in
     // streaming-mode (see section B1.1.1, IDGNQM of spec [1]). It may be worth
     // supporting this later.
     if (zaMode != ArmZaMode::Disabled)
-      op->setAttr(stringifyArmZaMode(zaMode), unitAttr);
+      function->setAttr(stringifyArmZaMode(zaMode), unitAttr);
   }
 };
 } // namespace
 
 std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass(
     const ArmStreamingMode streamingMode, const ArmZaMode zaMode,
-    bool onlyIfRequiredByOps) {
-  return std::make_unique<EnableArmStreamingPass>(streamingMode, zaMode,
-                                                  onlyIfRequiredByOps);
+    bool ifRequiredByOps, bool ifContainsScalableVectors) {
+  return std::make_unique<EnableArmStreamingPass>(
+      streamingMode, zaMode, ifRequiredByOps, ifContainsScalableVectors);
 }
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir
new file mode 100644
index 0000000000000..da70b632d70c4
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir
@@ -0,0 +1,4 @@
+// RUN: mlir-opt %s -enable-arm-streaming="if-contains-scalable-vectors if-required-by-ops" -verify-diagnostics
+
+// expected-error@below {{enable-arm-streaming: `if-required-by-ops` and `if-contains-scalable-vectors` are mutually exclusive}}
+func.func @test() { return }
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
index 6b58d8fdc41b0..2011802c5c8b2 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
@@ -2,7 +2,8 @@
 // RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-locally -verify-diagnostics | FileCheck %s -check-prefix=CHECK-LOCALLY
 // RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-compatible -verify-diagnostics | FileCheck %s -check-prefix=CHECK-COMPATIBLE
 // RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA
-// RUN: mlir-opt %s -enable-arm-streaming=only-if-required-by-ops -verify-diagnostics | FileCheck %s -check-prefix=IF-REQUIRED
+// RUN: mlir-opt %s -enable-arm-streaming=if-required-by-ops -verify-diagnostics | FileCheck %s -check-prefix=IF-REQUIRED
+// RUN: mlir-opt %s -enable-arm-streaming=if-contains-scalable-vectors -verify-diagnostics | FileCheck %s -check-prefix=IF-SCALABLE
 
 // CHECK-LABEL: @arm_streaming
 // CHECK-SAME: attributes {arm_streaming}
@@ -38,3 +39,17 @@ func.func @requires_arm_streaming() {
 // IF-REQUIRED: @does_not_require_arm_streaming
 // IF-REQUIRED-NOT: arm_streaming
 func.func @does_not_require_arm_streaming() { return }
+
+// IF-SCALABLE-LABEL: @contains_scalable_vectors
+// IF-SCALABLE-SAME: attributes {arm_streaming}
+func.func @contains_scalable_vectors(%vec: vector<[4]xf32>) -> vector<[4]xf32> {
+  %0 = arith.addf %vec, %vec : vector<[4]xf32>
+  return %0 : vector<[4]xf32>
+}
+
+// IF-SCALABLE-LABEL: @no_scalable_vectors
+// IF-SCALABLE-NOT: arm_streaming
+func.func @no_scalable_vectors(%vec: vector<4xf32>) -> vector<4xf32> {
+  %0 = arith.addf %vec, %vec : vector<4xf32>
+  return %0 : vector<4xf32>
+}
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir
index 10ffed2688178..aabd9d2ce788e 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir
@@ -4,7 +4,7 @@
 // RUN:   -arm-sme-vector-legalization -canonicalize -cse \
 // RUN:   -convert-vector-to-arm-sme -arm-sme-outer-product-fusion \
 // RUN:   -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
-// RUN:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
+// RUN:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za if-required-by-ops" \
 // RUN:   -convert-vector-to-scf=full-unroll -convert-arm-sme-to-llvm \
 // RUN:   -test-lower-to-llvm | \
 // RUN: %mcr_aarch64_cmd \
diff --git a/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp b/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp
index d3dabaf200fdc..a220791969d53 100644
--- a/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp
+++ b/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp
@@ -74,7 +74,7 @@ void buildTestLowerToArmSME(OpPassManager &pm,
   // Enable streaming-mode and ZA.
   pm.addPass(arm_sme::createEnableArmStreamingPass(
       arm_sme::ArmStreamingMode::StreamingLocally, arm_sme::ArmZaMode::NewZA,
-      /*onlyIfRequiredByOps=*/true));
+      /*ifRequiredByOps=*/true));
 
   // Convert SCF to CF (required for ArmSME tile allocation).
   pm.addPass(createConvertSCFToCFPass());

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, module the docs (see my comments inline), thanks!

@MacDue MacDue merged commit d319fc4 into llvm:main Jun 10, 2024
7 checks passed
@MacDue MacDue deleted the enable_streaming branch June 10, 2024 11:02
@HerrCai0907 HerrCai0907 mentioned this pull request Jun 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants