@@ -1573,9 +1573,12 @@ void LayoutRematerialization::backwardRematerialization(
15731573 // it to account for extra cost due to synchronisation.
15741574 // FIXME: measure cost of smem load/store and synchronisation on Intel GPUs,
15751575 // and refine this model further. (#5476)
1576- int64_t convertLayoutCost = 32 * convertLayoutBytes * 2 ;
1576+ int64_t convertLayoutCost = 32 * convertLayoutBytes;
15771577 int64_t rematerialisationCost = 0 ;
15781578
1579+ // Track total bytes in slice for normalization
1580+ int64_t totalSliceBytes = 0 ;
1581+
15791582 // Evaluate single-use status for every operation in slice
15801583 for (Operation *op : sliceOps) {
15811584 auto dialect = op->getDialect ();
@@ -1589,7 +1592,9 @@ void LayoutRematerialization::backwardRematerialization(
15891592 } else if (isa<LoadOp>(op) || isa<LocalLoadOp>(op)) {
15901593 // optimistically assume L1-cached:
15911594 for (Value result : op->getResults ()) {
1592- rematerialisationCost += 8 * getByteCount (result);
1595+ int64_t bytes = getByteCount (result);
1596+ totalSliceBytes += bytes;
1597+ rematerialisationCost += 8 * bytes;
15931598 }
15941599 } else if (isa<arith::ArithDialect, math::MathDialect>(dialect)) {
15951600 // this is an arithmetic operation; we distinguish between cheap
@@ -1599,7 +1604,9 @@ void LayoutRematerialization::backwardRematerialization(
15991604 // multiple instructions.
16001605 int64_t multiplier = isExpensiveMathOp (op) ? 8 : 1 ;
16011606 for (Value result : op->getResults ()) {
1602- rematerialisationCost += multiplier * getByteCount (result);
1607+ int64_t bytes = getByteCount (result);
1608+ totalSliceBytes += bytes;
1609+ rematerialisationCost += multiplier * bytes;
16031610 }
16041611 } else if (isa<ReduceOp>(op)) {
16051612 // Reduce op introduce much cost.
@@ -1612,17 +1619,37 @@ void LayoutRematerialization::backwardRematerialization(
16121619 " slice" );
16131620 return ;
16141621 }
1622+ // Add input tensor bytes for reduce operations
1623+ for (Value operand : op->getOperands ()) {
1624+ if (isa<RankedTensorType>(operand.getType ())) {
1625+ totalSliceBytes += getByteCount (operand);
1626+ }
1627+ }
16151628 rematerialisationCost += helper.getIntraWarpSizeWithUniqueData ();
16161629 rematerialisationCost += 8 * helper.getInterWarpSizeWithUniqueData ();
16171630 }
16181631 }
16191632
1633+ // Normalize rematerialization cost relative to convert layout data size.
1634+ // If slice processes similar amount of data as convert, costs should be
1635+ // comparable.
1636+ double normalizationFactor = 1.0 ;
1637+ if (totalSliceBytes > 0 && convertLayoutBytes > 0 ) {
1638+ normalizationFactor =
1639+ static_cast <double >(convertLayoutBytes) / totalSliceBytes;
1640+ // Clamp to avoid extreme values
1641+ normalizationFactor = std::clamp (normalizationFactor, 0.1 , 10.0 );
1642+ }
1643+
1644+ int64_t normalizedRematerialisationCost =
1645+ static_cast <int64_t >(rematerialisationCost * normalizationFactor);
1646+
16201647 LLVM_DEBUG ({
16211648 DBGS () << " convert layout cost: " << convertLayoutCost << " \n " ;
1622- DBGS () << " rematerialisation cost: " << rematerialisationCost << " \n " ;
1649+ DBGS () << " rematerialisation cost: " << normalizedRematerialisationCost
1650+ << " \n " ;
16231651 });
1624-
1625- if (rematerialisationCost > convertLayoutCost) {
1652+ if (normalizedRematerialisationCost > convertLayoutCost) {
16261653 LDBG (" skipped rematerialization due to higher cost" );
16271654 return ;
16281655 }
0 commit comments