@@ -1184,13 +1184,39 @@ static bool equalIterationSpaces(scf::ParallelOp firstPloop,
11841184 matchOperands (firstPloop.getStep (), secondPloop.getStep ());
11851185}
11861186
1187+ // ===----------------------------------------------------------------------===//
1188+ // Fusion related helpers
1189+ // ===----------------------------------------------------------------------===//
1190+
11871191// / Verify there are no nested ParallelOps.
11881192static bool hasNestedParallelOp (scf::ParallelOp ploop) {
11891193 auto walkResult = ploop.getBody ()->walk (
11901194 [](scf::ParallelOp) { return WalkResult::interrupt (); });
11911195 return walkResult.wasInterrupted ();
11921196}
11931197
1198+ template <typename LoopTy>
1199+ static bool checkFusionStructuralLegality (Operation *target,
1200+ Operation *source) {
1201+ static_assert (llvm::is_one_of<LoopTy, scf::ForallOp, scf::ForOp,
1202+ scf::ParallelOp>::value,
1203+ " applies to only `forall`, `for` and `parallel`" );
1204+ auto targetOp = dyn_cast<LoopTy>(target);
1205+ auto sourceOp = dyn_cast<LoopTy>(source);
1206+ if (!targetOp || !sourceOp)
1207+ return false ;
1208+
1209+ if constexpr (std::is_same_v<LoopTy, scf::ForallOp>)
1210+ return targetOp.getMixedLowerBound () == sourceOp.getMixedLowerBound () &&
1211+ targetOp.getMixedUpperBound () == sourceOp.getMixedUpperBound () &&
1212+ targetOp.getMixedStep () == sourceOp.getMixedStep () &&
1213+ targetOp.getMapping () == sourceOp.getMapping ();
1214+ else
1215+ return targetOp.getLowerBound () == sourceOp.getLowerBound () &&
1216+ targetOp.getUpperBound () == sourceOp.getUpperBound () &&
1217+ targetOp.getStep () == sourceOp.getStep ();
1218+ }
1219+
11941220static bool isFusionLegal (scf::ParallelOp firstPloop,
11951221 scf::ParallelOp secondPloop,
11961222 const IRMapping &firstToSecondPloopIndices,
0 commit comments