Skip to content

Commit f7d91fa

Browse files
authored
[mlir][ArmSME] Add option to only enable streaming mode/ZA if required (#73931)
This adds a `only-if-required-by-ops` flag to the `enable-arm-streaming` pass. This flag defaults to `false` (which preserves the original behaviour), however, if set to `true` the pass will only add the selected ZA/streaming mode to functions that contain ops that implement `ArmSMETileOpInterface`. This simplifies enabling these modes, as we can now first try lowering ops to ArmSME, then only if we succeed, add the relevant function attributes.
1 parent 1ee41b4 commit f7d91fa

File tree

4 files changed

+44
-5
lines changed

4 files changed

+44
-5
lines changed

mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ namespace arm_sme {
2727
/// Pass to enable Armv9 Streaming SVE mode.
2828
std::unique_ptr<Pass> createEnableArmStreamingPass(
2929
const ArmStreamingMode = ArmStreamingMode::Streaming,
30-
const ArmZaMode = ArmZaMode::Disabled);
30+
const ArmZaMode = ArmZaMode::Disabled, bool onlyIfRequiredByOps = false);
3131

3232
/// Pass that allocates tile IDs to ArmSME operations.
3333
std::unique_ptr<Pass> createTileAllocationPass();

mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,11 @@ def EnableArmStreaming
7373
"new-za",
7474
"The function has ZA state. The ZA state is "
7575
"created on entry and destroyed on exit.")
76-
)}]>
76+
)}]>,
77+
Option<"onlyIfRequiredByOps", "only-if-required-by-ops", "bool",
78+
/*default=*/"false",
79+
"Only apply the selected streaming/ZA modes if the function "
80+
" contains ops that require them.">
7781
];
7882
let dependentDialects = ["func::FuncDialect"];
7983
}

mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
//
3434
//===----------------------------------------------------------------------===//
3535

36+
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
3637
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
3738
#include "mlir/Dialect/ArmSME/Transforms/PassesEnums.cpp.inc"
3839

@@ -56,12 +57,28 @@ constexpr StringLiteral
5657

5758
struct EnableArmStreamingPass
5859
: public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
59-
EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode) {
60+
EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
61+
bool onlyIfRequiredByOps) {
6062
this->streamingMode = streamingMode;
6163
this->zaMode = zaMode;
64+
this->onlyIfRequiredByOps = onlyIfRequiredByOps;
6265
}
6366
void runOnOperation() override {
6467
auto op = getOperation();
68+
69+
if (onlyIfRequiredByOps) {
70+
bool foundTileOp = false;
71+
op.walk([&](Operation *op) {
72+
if (llvm::isa<ArmSMETileOpInterface>(op)) {
73+
foundTileOp = true;
74+
return WalkResult::interrupt();
75+
}
76+
return WalkResult::advance();
77+
});
78+
if (!foundTileOp)
79+
return;
80+
}
81+
6582
if (op->getAttr(kEnableArmStreamingIgnoreAttr) ||
6683
streamingMode == ArmStreamingMode::Disabled)
6784
return;
@@ -81,6 +98,8 @@ struct EnableArmStreamingPass
8198
} // namespace
8299

83100
std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass(
84-
const ArmStreamingMode streamingMode, const ArmZaMode zaMode) {
85-
return std::make_unique<EnableArmStreamingPass>(streamingMode, zaMode);
101+
const ArmStreamingMode streamingMode, const ArmZaMode zaMode,
102+
bool onlyIfRequiredByOps) {
103+
return std::make_unique<EnableArmStreamingPass>(streamingMode, zaMode,
104+
onlyIfRequiredByOps);
86105
}

mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// RUN: mlir-opt %s -enable-arm-streaming -verify-diagnostics | FileCheck %s
22
// RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-locally -verify-diagnostics | FileCheck %s -check-prefix=CHECK-LOCALLY
33
// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA
4+
// RUN: mlir-opt %s -enable-arm-streaming=only-if-required-by-ops -verify-diagnostics | FileCheck %s -check-prefix=IF-REQUIRED
45

56
// CHECK-LABEL: @arm_streaming
67
// CHECK-SAME: attributes {arm_streaming}
@@ -17,3 +18,18 @@ func.func @arm_streaming() { return }
1718
// CHECK-ENABLE-ZA-LABEL: @not_arm_streaming
1819
// CHECK-ENABLE-ZA-SAME: attributes {enable_arm_streaming_ignore}
1920
func.func @not_arm_streaming() attributes {enable_arm_streaming_ignore} { return }
21+
22+
// CHECK-LABEL: @requires_arm_streaming
23+
// CHECK-SAME: attributes {arm_streaming}
24+
// IF-REQUIRED: @requires_arm_streaming
25+
// IF-REQUIRED-SAME: attributes {arm_streaming}
26+
func.func @requires_arm_streaming() {
27+
%tile = arm_sme.get_tile : vector<[4]x[4]xi32>
28+
return
29+
}
30+
31+
// CHECK-LABEL: @does_not_require_arm_streaming
32+
// CHECK-SAME: attributes {arm_streaming}
33+
// IF-REQUIRED: @does_not_require_arm_streaming
34+
// IF-REQUIRED-NOT: arm_streaming
35+
func.func @does_not_require_arm_streaming() { return }

0 commit comments

Comments
 (0)