@@ -58,17 +58,25 @@ constexpr StringLiteral
58
58
struct EnableArmStreamingPass
59
59
: public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
60
60
EnableArmStreamingPass (ArmStreamingMode streamingMode, ArmZaMode zaMode,
61
- bool onlyIfRequiredByOps ) {
61
+ bool ifRequiredByOps, bool ifContainsScalableVectors ) {
62
62
this ->streamingMode = streamingMode;
63
63
this ->zaMode = zaMode;
64
- this ->onlyIfRequiredByOps = onlyIfRequiredByOps;
64
+ this ->ifRequiredByOps = ifRequiredByOps;
65
+ this ->ifContainsScalableVectors = ifContainsScalableVectors;
65
66
}
66
67
void runOnOperation () override {
67
- auto op = getOperation ();
68
+ auto function = getOperation ();
68
69
69
- if (onlyIfRequiredByOps) {
70
+ if (ifRequiredByOps && ifContainsScalableVectors) {
71
+ function->emitOpError (
72
+ " enable-arm-streaming: `if-required-by-ops` and "
73
+ " `if-contains-scalable-vectors` are mutually exclusive" );
74
+ return signalPassFailure ();
75
+ }
76
+
77
+ if (ifRequiredByOps) {
70
78
bool foundTileOp = false ;
71
- op .walk ([&](Operation *op) {
79
+ function .walk ([&](Operation *op) {
72
80
if (llvm::isa<ArmSMETileOpInterface>(op)) {
73
81
foundTileOp = true ;
74
82
return WalkResult::interrupt ();
@@ -79,27 +87,46 @@ struct EnableArmStreamingPass
79
87
return ;
80
88
}
81
89
82
- if (op->getAttr (kEnableArmStreamingIgnoreAttr ) ||
90
+ if (ifContainsScalableVectors) {
91
+ bool foundScalableVector = false ;
92
+ auto isScalableVector = [&](Type type) {
93
+ if (auto vectorType = dyn_cast<VectorType>(type))
94
+ return vectorType.isScalable ();
95
+ return false ;
96
+ };
97
+ function.walk ([&](Operation *op) {
98
+ if (llvm::any_of (op->getOperandTypes (), isScalableVector) ||
99
+ llvm::any_of (op->getResultTypes (), isScalableVector)) {
100
+ foundScalableVector = true ;
101
+ return WalkResult::interrupt ();
102
+ }
103
+ return WalkResult::advance ();
104
+ });
105
+ if (!foundScalableVector)
106
+ return ;
107
+ }
108
+
109
+ if (function->getAttr (kEnableArmStreamingIgnoreAttr ) ||
83
110
streamingMode == ArmStreamingMode::Disabled)
84
111
return ;
85
112
86
113
auto unitAttr = UnitAttr::get (&getContext ());
87
114
88
- op ->setAttr (stringifyArmStreamingMode (streamingMode), unitAttr);
115
+ function ->setAttr (stringifyArmStreamingMode (streamingMode), unitAttr);
89
116
90
117
// The pass currently only supports enabling ZA when in streaming-mode, but
91
118
// ZA can be accessed by the SME LDR, STR and ZERO instructions when not in
92
119
// streaming-mode (see section B1.1.1, IDGNQM of spec [1]). It may be worth
93
120
// supporting this later.
94
121
if (zaMode != ArmZaMode::Disabled)
95
- op ->setAttr (stringifyArmZaMode (zaMode), unitAttr);
122
+ function ->setAttr (stringifyArmZaMode (zaMode), unitAttr);
96
123
}
97
124
};
98
125
} // namespace
99
126
100
127
std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass (
101
128
const ArmStreamingMode streamingMode, const ArmZaMode zaMode,
102
- bool onlyIfRequiredByOps ) {
103
- return std::make_unique<EnableArmStreamingPass>(streamingMode, zaMode,
104
- onlyIfRequiredByOps );
129
+ bool ifRequiredByOps, bool ifContainsScalableVectors ) {
130
+ return std::make_unique<EnableArmStreamingPass>(
131
+ streamingMode, zaMode, ifRequiredByOps, ifContainsScalableVectors );
105
132
}
0 commit comments