@@ -442,34 +442,6 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
442442 return DiagnosedSilenceableFailure::success ();
443443}
444444
445- // / Check if `target` scf loop can be fused into `source` scf loop.
446- // / Applies for scf.for, scf.forall, and scf.parallel.
447- // /
448- // / This simply checks if both loops have the same bounds, steps and mapping.
449- // / No attempt is made at checking that the side effects of `target` and
450- // / `source` are independent of each other.
451- template <typename LoopTy>
452- static bool isLoopWithIdenticalConfiguration (Operation *target,
453- Operation *source) {
454- static_assert (llvm::is_one_of<LoopTy, scf::ForallOp, scf::ForOp,
455- scf::ParallelOp>::value,
456- " applies to only `forall`, `for` and `parallel`" );
457- auto targetOp = dyn_cast<LoopTy>(target);
458- auto sourceOp = dyn_cast<LoopTy>(source);
459- if (!targetOp || !sourceOp)
460- return false ;
461-
462- if constexpr (std::is_same_v<LoopTy, scf::ForallOp>)
463- return targetOp.getMixedLowerBound () == sourceOp.getMixedLowerBound () &&
464- targetOp.getMixedUpperBound () == sourceOp.getMixedUpperBound () &&
465- targetOp.getMixedStep () == sourceOp.getMixedStep () &&
466- targetOp.getMapping () == sourceOp.getMapping ();
467- else
468- return targetOp.getLowerBound () == sourceOp.getLowerBound () &&
469- targetOp.getUpperBound () == sourceOp.getUpperBound () &&
470- targetOp.getStep () == sourceOp.getStep ();
471- }
472-
473445DiagnosedSilenceableFailure
474446transform::LoopFuseSiblingOp::apply (transform::TransformRewriter &rewriter,
475447 transform::TransformResults &results,
@@ -485,29 +457,37 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
485457 << " source handle (got " << llvm::range_size (sourceOps) << " )" ;
486458 }
487459
488- Operation *target = *targetOps.begin ();
489- Operation *source = *sourceOps.begin ();
460+ LoopLikeOpInterface target =
461+ dyn_cast<LoopLikeOpInterface>(*targetOps.begin ());
462+ LoopLikeOpInterface source =
463+ dyn_cast<LoopLikeOpInterface>(*sourceOps.begin ());
464+ if (!target || !source)
465+ return emitSilenceableFailure (target->getLoc ())
466+ << " target or source is not a loop op" ;
490467
491468 // Check if the target and source are siblings.
492469 DiagnosedSilenceableFailure diag = isOpSibling (target, source);
493470 if (!diag.succeeded ())
494471 return diag;
495472
473+ if (!mlir::checkFusionStructuralLegality (target, source))
474+ return emitSilenceableFailure (target->getLoc ())
475+ << " operations cannot be fused" ;
476+
496477 Operation *fusedLoop;
497478 // / TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
498- if (isLoopWithIdenticalConfiguration <scf::ForOp>(target, source)) {
479+ if (isa <scf::ForOp>(target) && isa<scf::ForOp>( source)) {
499480 fusedLoop = fuseIndependentSiblingForLoops (
500481 cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
501- } else if (isLoopWithIdenticalConfiguration <scf::ForallOp>(target, source)) {
482+ } else if (isa <scf::ForallOp>(target) && isa<scf::ForallOp>( source)) {
502483 fusedLoop = fuseIndependentSiblingForallLoops (
503484 cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
504- } else if (isLoopWithIdenticalConfiguration<scf::ParallelOp>(target,
505- source)) {
485+ } else if (isa<scf::ParallelOp>(target) && isa<scf::ParallelOp>(source)) {
506486 fusedLoop = fuseIndependentSiblingParallelLoops (
507487 cast<scf::ParallelOp>(target), cast<scf::ParallelOp>(source), rewriter);
508488 } else
509489 return emitSilenceableFailure (target->getLoc ())
510- << " operations cannot be fused " ;
490+ << " unsupported loop type for fusion " ;
511491
512492 assert (fusedLoop && " failed to fuse operations" );
513493
0 commit comments