Skip to content

Commit d319fc4

Browse files
authored
[mlir][ArmSME] Add option to only enable streaming mode for scalable code (#94759)
This adds a new option `-enable-arm-streaming=if-contains-scalable-vectors`, which only applies the selected streaming/ZA modes if the function contains scalable vector types. As a NFC this patch also removes the `only-` prefix from the `if-required-by-ops` mode.
1 parent 5c268cf commit d319fc4

File tree

7 files changed

+69
-18
lines changed

7 files changed

+69
-18
lines changed

mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ namespace arm_sme {
2727
/// Pass to enable Armv9 Streaming SVE mode.
2828
std::unique_ptr<Pass> createEnableArmStreamingPass(
2929
const ArmStreamingMode = ArmStreamingMode::Streaming,
30-
const ArmZaMode = ArmZaMode::Disabled, bool onlyIfRequiredByOps = false);
30+
const ArmZaMode = ArmZaMode::Disabled, bool ifRequiredByOps = false,
31+
bool ifContainsScalableVectors = false);
3132

3233
/// Pass that fuses 'arm_sme.outerproduct' ops into 2-way or 4-way widening
3334
/// variants.

mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,14 @@ def EnableArmStreaming
116116
"not be used for input and/or output and the "
117117
"function must return with ZA unchanged")
118118
)}]>,
119-
Option<"onlyIfRequiredByOps", "only-if-required-by-ops", "bool",
119+
Option<"ifRequiredByOps", "if-required-by-ops", "bool",
120120
/*default=*/"false",
121-
"Only apply the selected streaming/ZA modes if the function "
122-
" contains ops that require them.">
121+
"Only apply the selected streaming/ZA modes if the function contains"
122+
" ops that implement the ArmSMETileOpInterface.">,
123+
Option<"ifContainsScalableVectors", "if-contains-scalable-vectors",
124+
"bool", /*default=*/"false",
125+
"Only apply the selected streaming/ZA modes if the function contains"
126+
" operations that use scalable vector types.">
123127
];
124128
let dependentDialects = ["func::FuncDialect"];
125129
}

mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,25 @@ constexpr StringLiteral
5858
struct EnableArmStreamingPass
5959
: public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
6060
EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
61-
bool onlyIfRequiredByOps) {
61+
bool ifRequiredByOps, bool ifContainsScalableVectors) {
6262
this->streamingMode = streamingMode;
6363
this->zaMode = zaMode;
64-
this->onlyIfRequiredByOps = onlyIfRequiredByOps;
64+
this->ifRequiredByOps = ifRequiredByOps;
65+
this->ifContainsScalableVectors = ifContainsScalableVectors;
6566
}
6667
void runOnOperation() override {
67-
auto op = getOperation();
68+
auto function = getOperation();
6869

69-
if (onlyIfRequiredByOps) {
70+
if (ifRequiredByOps && ifContainsScalableVectors) {
71+
function->emitOpError(
72+
"enable-arm-streaming: `if-required-by-ops` and "
73+
"`if-contains-scalable-vectors` are mutually exclusive");
74+
return signalPassFailure();
75+
}
76+
77+
if (ifRequiredByOps) {
7078
bool foundTileOp = false;
71-
op.walk([&](Operation *op) {
79+
function.walk([&](Operation *op) {
7280
if (llvm::isa<ArmSMETileOpInterface>(op)) {
7381
foundTileOp = true;
7482
return WalkResult::interrupt();
@@ -79,27 +87,46 @@ struct EnableArmStreamingPass
7987
return;
8088
}
8189

82-
if (op->getAttr(kEnableArmStreamingIgnoreAttr) ||
90+
if (ifContainsScalableVectors) {
91+
bool foundScalableVector = false;
92+
auto isScalableVector = [&](Type type) {
93+
if (auto vectorType = dyn_cast<VectorType>(type))
94+
return vectorType.isScalable();
95+
return false;
96+
};
97+
function.walk([&](Operation *op) {
98+
if (llvm::any_of(op->getOperandTypes(), isScalableVector) ||
99+
llvm::any_of(op->getResultTypes(), isScalableVector)) {
100+
foundScalableVector = true;
101+
return WalkResult::interrupt();
102+
}
103+
return WalkResult::advance();
104+
});
105+
if (!foundScalableVector)
106+
return;
107+
}
108+
109+
if (function->getAttr(kEnableArmStreamingIgnoreAttr) ||
83110
streamingMode == ArmStreamingMode::Disabled)
84111
return;
85112

86113
auto unitAttr = UnitAttr::get(&getContext());
87114

88-
op->setAttr(stringifyArmStreamingMode(streamingMode), unitAttr);
115+
function->setAttr(stringifyArmStreamingMode(streamingMode), unitAttr);
89116

90117
// The pass currently only supports enabling ZA when in streaming-mode, but
91118
// ZA can be accessed by the SME LDR, STR and ZERO instructions when not in
92119
// streaming-mode (see section B1.1.1, IDGNQM of spec [1]). It may be worth
93120
// supporting this later.
94121
if (zaMode != ArmZaMode::Disabled)
95-
op->setAttr(stringifyArmZaMode(zaMode), unitAttr);
122+
function->setAttr(stringifyArmZaMode(zaMode), unitAttr);
96123
}
97124
};
98125
} // namespace
99126

100127
std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass(
101128
const ArmStreamingMode streamingMode, const ArmZaMode zaMode,
102-
bool onlyIfRequiredByOps) {
103-
return std::make_unique<EnableArmStreamingPass>(streamingMode, zaMode,
104-
onlyIfRequiredByOps);
129+
bool ifRequiredByOps, bool ifContainsScalableVectors) {
130+
return std::make_unique<EnableArmStreamingPass>(
131+
streamingMode, zaMode, ifRequiredByOps, ifContainsScalableVectors);
105132
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
// RUN: mlir-opt %s -enable-arm-streaming="if-contains-scalable-vectors if-required-by-ops" -verify-diagnostics
2+
3+
// expected-error@below {{enable-arm-streaming: `if-required-by-ops` and `if-contains-scalable-vectors` are mutually exclusive}}
4+
func.func @test() { return }

mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
// RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-locally -verify-diagnostics | FileCheck %s -check-prefix=CHECK-LOCALLY
33
// RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-compatible -verify-diagnostics | FileCheck %s -check-prefix=CHECK-COMPATIBLE
44
// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA
5-
// RUN: mlir-opt %s -enable-arm-streaming=only-if-required-by-ops -verify-diagnostics | FileCheck %s -check-prefix=IF-REQUIRED
5+
// RUN: mlir-opt %s -enable-arm-streaming=if-required-by-ops -verify-diagnostics | FileCheck %s -check-prefix=IF-REQUIRED
6+
// RUN: mlir-opt %s -enable-arm-streaming=if-contains-scalable-vectors -verify-diagnostics | FileCheck %s -check-prefix=IF-SCALABLE
67

78
// CHECK-LABEL: @arm_streaming
89
// CHECK-SAME: attributes {arm_streaming}
@@ -38,3 +39,17 @@ func.func @requires_arm_streaming() {
3839
// IF-REQUIRED: @does_not_require_arm_streaming
3940
// IF-REQUIRED-NOT: arm_streaming
4041
func.func @does_not_require_arm_streaming() { return }
42+
43+
// IF-SCALABLE-LABEL: @contains_scalable_vectors
44+
// IF-SCALABLE-SAME: attributes {arm_streaming}
45+
func.func @contains_scalable_vectors(%vec: vector<[4]xf32>) -> vector<[4]xf32> {
46+
%0 = arith.addf %vec, %vec : vector<[4]xf32>
47+
return %0 : vector<[4]xf32>
48+
}
49+
50+
// IF-SCALABLE-LABEL: @no_scalable_vectors
51+
// IF-SCALABLE-NOT: arm_streaming
52+
func.func @no_scalable_vectors(%vec: vector<4xf32>) -> vector<4xf32> {
53+
%0 = arith.addf %vec, %vec : vector<4xf32>
54+
return %0 : vector<4xf32>
55+
}

mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
// RUN: -arm-sme-vector-legalization -canonicalize -cse \
55
// RUN: -convert-vector-to-arm-sme -arm-sme-outer-product-fusion \
66
// RUN: -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
7-
// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
7+
// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za if-required-by-ops" \
88
// RUN: -convert-vector-to-scf=full-unroll -convert-arm-sme-to-llvm \
99
// RUN: -test-lower-to-llvm | \
1010
// RUN: %mcr_aarch64_cmd \

mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ void buildTestLowerToArmSME(OpPassManager &pm,
7474
// Enable streaming-mode and ZA.
7575
pm.addPass(arm_sme::createEnableArmStreamingPass(
7676
arm_sme::ArmStreamingMode::StreamingLocally, arm_sme::ArmZaMode::NewZA,
77-
/*onlyIfRequiredByOps=*/true));
77+
/*ifRequiredByOps=*/true));
7878

7979
// Convert SCF to CF (required for ArmSME tile allocation).
8080
pm.addPass(createConvertSCFToCFPass());

0 commit comments

Comments
 (0)