@@ -900,6 +900,11 @@ void MFMASmallGemmOpt::applyIGLPStrategy(
900
900
}
901
901
}
902
902
903
+ static unsigned DSWCount = 0 ;
904
+ static unsigned DSWWithPermCount = 0 ;
905
+ static unsigned DSWWithSharedVMEMCount = 0 ;
906
+ static bool HasDSWCounts = false ;
907
+
903
908
class MFMASmallGemmSingleWaveOpt final : public IGLPStrategy {
904
909
private:
905
910
// Whether the DS_READ is a predecessor of first four MFMA in region
@@ -1076,9 +1081,12 @@ class MFMASmallGemmSingleWaveOpt final : public IGLPStrategy {
1076
1081
Cache->push_back (Pred.getSUnit ());
1077
1082
}
1078
1083
}
1084
+
1085
+ // If the other group has no PERM preds, then this group won't share any
1086
+ if (!Cache->size ())
1087
+ return false ;
1079
1088
}
1080
1089
1081
- assert (Cache->size ());
1082
1090
auto DAG = SyncPipe[0 ].DAG ;
1083
1091
// Does the previous DS_WRITE share a V_PERM predecessor with this
1084
1092
// VMEM_READ
@@ -1109,9 +1117,6 @@ void MFMASmallGemmSingleWaveOpt::applyIGLPStrategy(
1109
1117
DenseMap<int , SUnitsToCandidateSGsMap> &SyncedInstrs,
1110
1118
DenseMap<int , SmallVector<SchedGroup, 4 >> &SyncedSchedGroups) {
1111
1119
unsigned MFMACount = 0 ;
1112
- unsigned DSWCount = 0 ;
1113
- unsigned DSWWithPermCount = 0 ;
1114
- unsigned DSWWithSharedVMEMCount = 0 ;
1115
1120
unsigned DSRCount = 0 ;
1116
1121
SmallVector<SUnit *, 6 > DSWithPerms;
1117
1122
for (auto &SU : DAG->SUnits ) {
@@ -1121,7 +1126,7 @@ void MFMASmallGemmSingleWaveOpt::applyIGLPStrategy(
1121
1126
else if (TII->isDS (*I)) {
1122
1127
if (I->mayLoad ())
1123
1128
++DSRCount;
1124
- else if (I->mayStore ()) {
1129
+ else if (I->mayStore () && !HasDSWCounts ) {
1125
1130
++DSWCount;
1126
1131
for (auto Pred : SU.Preds ) {
1127
1132
if (Pred.getSUnit ()->getInstr ()->getOpcode () ==
@@ -1133,58 +1138,62 @@ void MFMASmallGemmSingleWaveOpt::applyIGLPStrategy(
1133
1138
}
1134
1139
}
1135
1140
}
1136
- DSWWithPermCount = DSWithPerms.size ();
1137
- auto I = DSWithPerms.begin ();
1138
- auto E = DSWithPerms.end ();
1139
-
1140
- // Get the count of DS_WRITES with V_PERM predecessors which
1141
- // have loop carried dependencies (WAR) on the same VMEM_READs.
1142
- // We consider partial overlap as a miss -- in other words,
1143
- // for a given DS_W, we only consider another DS_W as matching
1144
- // if there is a corresponding (in terms of the VMEM_R it uses) V_PERM pred
1145
- // for every V_PERM pred of this DS_W.
1146
- DenseMap<MachineInstr *, SUnit *> VMEMLookup;
1147
- SmallVector<SUnit *, 6 > Counted;
1148
- for (; I != E; I++) {
1149
- SUnit *Cand = nullptr ;
1150
- bool MissedAny = false ;
1151
- for (auto &Pred : (*I)->Preds ) {
1152
- if (Pred.getSUnit ()->getInstr ()->getOpcode () != AMDGPU::V_PERM_B32_e64)
1153
- continue ;
1154
1141
1155
- if (Cand && llvm::is_contained (Counted, Cand))
1156
- break ;
1157
-
1158
- for (auto &Succ : Pred.getSUnit ()->Succs ) {
1159
- auto MI = Succ.getSUnit ()->getInstr ();
1160
- if (!TII->isVMEM (*MI) || !MI->mayLoad ())
1142
+ if (!HasDSWCounts) {
1143
+ DSWWithPermCount = DSWithPerms.size ();
1144
+ auto I = DSWithPerms.begin ();
1145
+ auto E = DSWithPerms.end ();
1146
+
1147
+ // Get the count of DS_WRITES with V_PERM predecessors which
1148
+ // have loop carried dependencies (WAR) on the same VMEM_READs.
1149
+ // We consider partial overlap as a miss -- in other words,
1150
+ // for a given DS_W, we only consider another DS_W as matching
1151
+ // if there is a corresponding (in terms of the VMEM_R it uses) V_PERM pred
1152
+ // for every V_PERM pred of this DS_W.
1153
+ DenseMap<MachineInstr *, SUnit *> VMEMLookup;
1154
+ SmallVector<SUnit *, 6 > Counted;
1155
+ for (; I != E; I++) {
1156
+ SUnit *Cand = nullptr ;
1157
+ bool MissedAny = false ;
1158
+ for (auto &Pred : (*I)->Preds ) {
1159
+ if (Pred.getSUnit ()->getInstr ()->getOpcode () != AMDGPU::V_PERM_B32_e64)
1161
1160
continue ;
1162
1161
1163
- if (MissedAny || !VMEMLookup.size ()) {
1164
- MissedAny = true ;
1165
- VMEMLookup[MI] = *I;
1166
- continue ;
1167
- }
1162
+ if (Cand && llvm::is_contained (Counted, Cand))
1163
+ break ;
1168
1164
1169
- if (!VMEMLookup.contains (MI)) {
1170
- MissedAny = true ;
1171
- VMEMLookup[MI] = *I;
1172
- continue ;
1173
- }
1165
+ for (auto &Succ : Pred.getSUnit ()->Succs ) {
1166
+ auto MI = Succ.getSUnit ()->getInstr ();
1167
+ if (!TII->isVMEM (*MI) || !MI->mayLoad ())
1168
+ continue ;
1174
1169
1175
- Cand = VMEMLookup[MI];
1176
- if (llvm::is_contained (Counted, Cand)) {
1177
- MissedAny = true ;
1178
- break ;
1170
+ if (MissedAny || !VMEMLookup.size ()) {
1171
+ MissedAny = true ;
1172
+ VMEMLookup[MI] = *I;
1173
+ continue ;
1174
+ }
1175
+
1176
+ if (!VMEMLookup.contains (MI)) {
1177
+ MissedAny = true ;
1178
+ VMEMLookup[MI] = *I;
1179
+ continue ;
1180
+ }
1181
+
1182
+ Cand = VMEMLookup[MI];
1183
+ if (llvm::is_contained (Counted, Cand)) {
1184
+ MissedAny = true ;
1185
+ break ;
1186
+ }
1179
1187
}
1180
1188
}
1181
- }
1182
- if (!MissedAny && Cand) {
1183
- DSWWithSharedVMEMCount += 2 ;
1184
- Counted.push_back (Cand );
1185
- Counted. push_back (*I);
1189
+ if (!MissedAny && Cand) {
1190
+ DSWWithSharedVMEMCount += 2 ;
1191
+ Counted. push_back (Cand) ;
1192
+ Counted.push_back (*I );
1193
+ }
1186
1194
}
1187
1195
}
1196
+ HasDSWCounts = true ;
1188
1197
1189
1198
assert (DSWWithSharedVMEMCount <= DSWWithPermCount);
1190
1199
SchedGroup *SG;
0 commit comments