Skip to content

[mlir][ArmSME] Add option to only enable streaming mode for scalable code #94759

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ namespace arm_sme {
/// Pass to enable Armv9 Streaming SVE mode.
std::unique_ptr<Pass> createEnableArmStreamingPass(
const ArmStreamingMode = ArmStreamingMode::Streaming,
const ArmZaMode = ArmZaMode::Disabled, bool onlyIfRequiredByOps = false);
const ArmZaMode = ArmZaMode::Disabled, bool ifRequiredByOps = false,
bool ifContainsScalableVectors = false);

/// Pass that fuses 'arm_sme.outerproduct' ops into 2-way or 4-way widening
/// variants.
Expand Down
10 changes: 7 additions & 3 deletions mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,14 @@ def EnableArmStreaming
"not be used for input and/or output and the "
"function must return with ZA unchanged")
)}]>,
Option<"onlyIfRequiredByOps", "only-if-required-by-ops", "bool",
Option<"ifRequiredByOps", "if-required-by-ops", "bool",
/*default=*/"false",
"Only apply the selected streaming/ZA modes if the function "
" contains ops that require them.">
"Only apply the selected streaming/ZA modes if the function contains"
" ops that implement the ArmSMETileOpInterface.">,
Option<"ifContainsScalableVectors", "if-contains-scalable-vectors",
"bool", /*default=*/"false",
"Only apply the selected streaming/ZA modes if the function contains"
" operations that use scalable vector types.">
];
let dependentDialects = ["func::FuncDialect"];
}
Expand Down
49 changes: 38 additions & 11 deletions mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,25 @@ constexpr StringLiteral
struct EnableArmStreamingPass
: public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
bool onlyIfRequiredByOps) {
bool ifRequiredByOps, bool ifContainsScalableVectors) {
this->streamingMode = streamingMode;
this->zaMode = zaMode;
this->onlyIfRequiredByOps = onlyIfRequiredByOps;
this->ifRequiredByOps = ifRequiredByOps;
this->ifContainsScalableVectors = ifContainsScalableVectors;
}
void runOnOperation() override {
auto op = getOperation();
auto function = getOperation();

if (onlyIfRequiredByOps) {
if (ifRequiredByOps && ifContainsScalableVectors) {
function->emitOpError(
"enable-arm-streaming: `if-required-by-ops` and "
"`if-contains-scalable-vectors` are mutually exclusive");
return signalPassFailure();
}

if (ifRequiredByOps) {
bool foundTileOp = false;
op.walk([&](Operation *op) {
function.walk([&](Operation *op) {
if (llvm::isa<ArmSMETileOpInterface>(op)) {
foundTileOp = true;
return WalkResult::interrupt();
Expand All @@ -79,27 +87,46 @@ struct EnableArmStreamingPass
return;
}

if (op->getAttr(kEnableArmStreamingIgnoreAttr) ||
if (ifContainsScalableVectors) {
bool foundScalableVector = false;
auto isScalableVector = [&](Type type) {
if (auto vectorType = dyn_cast<VectorType>(type))
return vectorType.isScalable();
return false;
};
function.walk([&](Operation *op) {
if (llvm::any_of(op->getOperandTypes(), isScalableVector) ||
llvm::any_of(op->getResultTypes(), isScalableVector)) {
foundScalableVector = true;
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (!foundScalableVector)
return;
}

if (function->getAttr(kEnableArmStreamingIgnoreAttr) ||
streamingMode == ArmStreamingMode::Disabled)
return;

auto unitAttr = UnitAttr::get(&getContext());

op->setAttr(stringifyArmStreamingMode(streamingMode), unitAttr);
function->setAttr(stringifyArmStreamingMode(streamingMode), unitAttr);

// The pass currently only supports enabling ZA when in streaming-mode, but
// ZA can be accessed by the SME LDR, STR and ZERO instructions when not in
// streaming-mode (see section B1.1.1, IDGNQM of spec [1]). It may be worth
// supporting this later.
if (zaMode != ArmZaMode::Disabled)
op->setAttr(stringifyArmZaMode(zaMode), unitAttr);
function->setAttr(stringifyArmZaMode(zaMode), unitAttr);
}
};
} // namespace

std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass(
const ArmStreamingMode streamingMode, const ArmZaMode zaMode,
bool onlyIfRequiredByOps) {
return std::make_unique<EnableArmStreamingPass>(streamingMode, zaMode,
onlyIfRequiredByOps);
bool ifRequiredByOps, bool ifContainsScalableVectors) {
return std::make_unique<EnableArmStreamingPass>(
streamingMode, zaMode, ifRequiredByOps, ifContainsScalableVectors);
}
4 changes: 4 additions & 0 deletions mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
// RUN: mlir-opt %s -enable-arm-streaming="if-contains-scalable-vectors if-required-by-ops" -verify-diagnostics

// expected-error@below {{enable-arm-streaming: `if-required-by-ops` and `if-contains-scalable-vectors` are mutually exclusive}}
func.func @test() { return }
17 changes: 16 additions & 1 deletion mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
// RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-locally -verify-diagnostics | FileCheck %s -check-prefix=CHECK-LOCALLY
// RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-compatible -verify-diagnostics | FileCheck %s -check-prefix=CHECK-COMPATIBLE
// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA
// RUN: mlir-opt %s -enable-arm-streaming=only-if-required-by-ops -verify-diagnostics | FileCheck %s -check-prefix=IF-REQUIRED
// RUN: mlir-opt %s -enable-arm-streaming=if-required-by-ops -verify-diagnostics | FileCheck %s -check-prefix=IF-REQUIRED
// RUN: mlir-opt %s -enable-arm-streaming=if-contains-scalable-vectors -verify-diagnostics | FileCheck %s -check-prefix=IF-SCALABLE

// CHECK-LABEL: @arm_streaming
// CHECK-SAME: attributes {arm_streaming}
Expand Down Expand Up @@ -38,3 +39,17 @@ func.func @requires_arm_streaming() {
// IF-REQUIRED: @does_not_require_arm_streaming
// IF-REQUIRED-NOT: arm_streaming
func.func @does_not_require_arm_streaming() { return }

// IF-SCALABLE-LABEL: @contains_scalable_vectors
// IF-SCALABLE-SAME: attributes {arm_streaming}
func.func @contains_scalable_vectors(%vec: vector<[4]xf32>) -> vector<[4]xf32> {
%0 = arith.addf %vec, %vec : vector<[4]xf32>
return %0 : vector<[4]xf32>
}

// IF-SCALABLE-LABEL: @no_scalable_vectors
// IF-SCALABLE-NOT: arm_streaming
func.func @no_scalable_vectors(%vec: vector<4xf32>) -> vector<4xf32> {
%0 = arith.addf %vec, %vec : vector<4xf32>
return %0 : vector<4xf32>
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// RUN: -arm-sme-vector-legalization -canonicalize -cse \
// RUN: -convert-vector-to-arm-sme -arm-sme-outer-product-fusion \
// RUN: -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za if-required-by-ops" \
// RUN: -convert-vector-to-scf=full-unroll -convert-arm-sme-to-llvm \
// RUN: -test-lower-to-llvm | \
// RUN: %mcr_aarch64_cmd \
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ void buildTestLowerToArmSME(OpPassManager &pm,
// Enable streaming-mode and ZA.
pm.addPass(arm_sme::createEnableArmStreamingPass(
arm_sme::ArmStreamingMode::StreamingLocally, arm_sme::ArmZaMode::NewZA,
/*onlyIfRequiredByOps=*/true));
/*ifRequiredByOps=*/true));

// Convert SCF to CF (required for ArmSME tile allocation).
pm.addPass(createConvertSCFToCFPass());
Expand Down
Loading