@@ -261,10 +261,8 @@ loopScheduling(scf::ForOp forOp,
261261 return 1 ;
262262 };
263263
264- std::optional<int64_t > ubConstant =
265- getConstantIntValue (forOp.getUpperBound ());
266- std::optional<int64_t > lbConstant =
267- getConstantIntValue (forOp.getLowerBound ());
264+ std::optional<int64_t > ubConstant = getConstantIntValue (forOp.getUpperBound ());
265+ std::optional<int64_t > lbConstant = getConstantIntValue (forOp.getLowerBound ());
268266 DenseMap<Operation *, unsigned > opCycles;
269267 std::map<unsigned , std::vector<Operation *>> wrappedSchedule;
270268 for (Operation &op : forOp.getBody ()->getOperations ()) {
@@ -449,6 +447,113 @@ void transform::TakeAssumedBranchOp::getEffects(
449447// LoopFuseSiblingOp
450448// ===----------------------------------------------------------------------===//
451449
450+ // / Check if `target` and `source` are siblings, in the context that `target`
451+ // / is being fused into `source`.
452+ // /
453+ // / This is a simple check that just checks if both operations are in the same
454+ // / block and some checks to ensure that the fused IR does not violate
455+ // / dominance.
456+ static DiagnosedSilenceableFailure isOpSibling (Operation *target,
457+ Operation *source) {
458+ // Check if both operations are same.
459+ if (target == source)
460+ return emitSilenceableFailure (source)
461+ << " target and source need to be different loops" ;
462+
463+ // Check if both operations are in the same block.
464+ if (target->getBlock () != source->getBlock ())
465+ return emitSilenceableFailure (source)
466+ << " target and source are not in the same block" ;
467+
468+ // Check if fusion will violate dominance.
469+ DominanceInfo domInfo (source);
470+ if (target->isBeforeInBlock (source)) {
471+ // Since `target` is before `source`, all users of results of `target`
472+ // need to be dominated by `source`.
473+ for (Operation *user : target->getUsers ()) {
474+ if (!domInfo.properlyDominates (source, user, /* enclosingOpOk=*/ false )) {
475+ return emitSilenceableFailure (target)
476+ << " user of results of target should be properly dominated by "
477+ " source" ;
478+ }
479+ }
480+ } else {
481+ // Since `target` is after `source`, all values used by `target` need
482+ // to dominate `source`.
483+
484+ // Check if operands of `target` are dominated by `source`.
485+ for (Value operand : target->getOperands ()) {
486+ Operation *operandOp = operand.getDefiningOp ();
487+ // Operands without defining operations are block arguments. When `target`
488+ // and `source` occur in the same block, these operands dominate `source`.
489+ if (!operandOp)
490+ continue ;
491+
492+ // Operand's defining operation should properly dominate `source`.
493+ if (!domInfo.properlyDominates (operandOp, source,
494+ /* enclosingOpOk=*/ false ))
495+ return emitSilenceableFailure (target)
496+ << " operands of target should be properly dominated by source" ;
497+ }
498+
499+ // Check if values used by `target` are dominated by `source`.
500+ bool failed = false ;
501+ OpOperand *failedValue = nullptr ;
502+ visitUsedValuesDefinedAbove (target->getRegions (), [&](OpOperand *operand) {
503+ Operation *operandOp = operand->get ().getDefiningOp ();
504+ if (operandOp && !domInfo.properlyDominates (operandOp, source,
505+ /* enclosingOpOk=*/ false )) {
506+ // `operand` is not an argument of an enclosing block and the defining
507+ // op of `operand` is outside `target` but does not dominate `source`.
508+ failed = true ;
509+ failedValue = operand;
510+ }
511+ });
512+
513+ if (failed)
514+ return emitSilenceableFailure (failedValue->getOwner ())
515+ << " values used inside regions of target should be properly "
516+ " dominated by source" ;
517+ }
518+
519+ return DiagnosedSilenceableFailure::success ();
520+ }
521+
522+ // / Check if `target` scf.forall can be fused into `source` scf.forall.
523+ // /
524+ // / This simply checks if both loops have the same bounds, steps and mapping.
525+ // / No attempt is made at checking that the side effects of `target` and
526+ // / `source` are independent of each other.
527+ static bool isForallWithIdenticalConfiguration (Operation *target,
528+ Operation *source) {
529+ auto targetOp = dyn_cast<scf::ForallOp>(target);
530+ auto sourceOp = dyn_cast<scf::ForallOp>(source);
531+ if (!targetOp || !sourceOp)
532+ return false ;
533+
534+ return targetOp.getMixedLowerBound () == sourceOp.getMixedLowerBound () &&
535+ targetOp.getMixedUpperBound () == sourceOp.getMixedUpperBound () &&
536+ targetOp.getMixedStep () == sourceOp.getMixedStep () &&
537+ targetOp.getMapping () == sourceOp.getMapping ();
538+ }
539+
540+ // / Check if `target` scf.for can be fused into `source` scf.for.
541+ // /
542+ // / This simply checks if both loops have the same bounds and steps. No attempt
543+ // / is made at checking that the side effects of `target` and `source` are
544+ // / independent of each other.
545+ static bool isForWithIdenticalConfiguration (Operation *target,
546+ Operation *source) {
547+ auto targetOp = dyn_cast<scf::ForOp>(target);
548+ auto sourceOp = dyn_cast<scf::ForOp>(source);
549+ if (!targetOp || !sourceOp)
550+ return false ;
551+
552+ return targetOp.getLowerBound () == sourceOp.getLowerBound () &&
553+ targetOp.getUpperBound () == sourceOp.getUpperBound () &&
554+ targetOp.getStep () == sourceOp.getStep ();
555+ }
556+
452557DiagnosedSilenceableFailure
453558transform::LoopFuseSiblingOp::apply (transform::TransformRewriter &rewriter,
454559 transform::TransformResults &results,
@@ -464,32 +569,25 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
464569 << " source handle (got " << llvm::range_size (sourceOps) << " )" ;
465570 }
466571
467- auto target = dyn_cast<LoopLikeOpInterface>(*targetOps.begin ());
468- auto source = dyn_cast<LoopLikeOpInterface>(*sourceOps.begin ());
469- if (!target || !source)
470- return emitSilenceableFailure (target->getLoc ())
471- << " target or source is not a loop op" ;
572+ Operation *target = *targetOps.begin ();
573+ Operation *source = *sourceOps.begin ();
472574
473- // Check if loops can be fused
474- Diagnostic diag (target. getLoc (), DiagnosticSeverity::Error );
475- if (!mlir::checkFusionStructuralLegality (target, source, diag))
476- return DiagnosedSilenceableFailure::silenceableFailure ( std::move ( diag)) ;
575+ // Check if the target and source are siblings.
576+ DiagnosedSilenceableFailure diag = isOpSibling (target, source );
577+ if (!diag. succeeded ( ))
578+ return diag;
477579
478580 Operation *fusedLoop;
479- // TODO: Support fusion for loop-like ops besides scf.for, scf.forall
480- // and scf.parallel.
481- if (isa<scf::ForOp>(target) && isa<scf::ForOp>(source)) {
581+ // / TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
582+ if (isForWithIdenticalConfiguration (target, source)) {
482583 fusedLoop = fuseIndependentSiblingForLoops (
483584 cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
484- } else if (isa<scf::ForallOp> (target) && isa<scf::ForallOp>( source)) {
585+ } else if (isForallWithIdenticalConfiguration (target, source)) {
485586 fusedLoop = fuseIndependentSiblingForallLoops (
486587 cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
487- } else if (isa<scf::ParallelOp>(target) && isa<scf::ParallelOp>(source)) {
488- fusedLoop = fuseIndependentSiblingParallelLoops (
489- cast<scf::ParallelOp>(target), cast<scf::ParallelOp>(source), rewriter);
490588 } else
491589 return emitSilenceableFailure (target->getLoc ())
492- << " unsupported loop type for fusion " ;
590+ << " operations cannot be fused " ;
493591
494592 assert (fusedLoop && " failed to fuse operations" );
495593
0 commit comments