Skip to content

Commit 8142a52

Browse files
gtawdziurdz
authored andcommitted
[TritonIntelGPU] Normalize remat cost relative to convert data size
1 parent 56f1fd0 commit 8142a52

File tree

1 file changed

+33
-6
lines changed

1 file changed

+33
-6
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1573,9 +1573,12 @@ void LayoutRematerialization::backwardRematerialization(
15731573
// it to account for extra cost due to synchronisation.
15741574
// FIXME: measure cost of smem load/store and synchronisation on Intel GPUs,
15751575
// and refine this model further. (#5476)
1576-
int64_t convertLayoutCost = 32 * convertLayoutBytes * 2;
1576+
int64_t convertLayoutCost = 32 * convertLayoutBytes;
15771577
int64_t rematerialisationCost = 0;
15781578

1579+
// Track total bytes in slice for normalization
1580+
int64_t totalSliceBytes = 0;
1581+
15791582
// Evaluate single-use status for every operation in slice
15801583
for (Operation *op : sliceOps) {
15811584
auto dialect = op->getDialect();
@@ -1589,7 +1592,9 @@ void LayoutRematerialization::backwardRematerialization(
15891592
} else if (isa<LoadOp>(op) || isa<LocalLoadOp>(op)) {
15901593
// optimistically assume L1-cached:
15911594
for (Value result : op->getResults()) {
1592-
rematerialisationCost += 8 * getByteCount(result);
1595+
int64_t bytes = getByteCount(result);
1596+
totalSliceBytes += bytes;
1597+
rematerialisationCost += 8 * bytes;
15931598
}
15941599
} else if (isa<arith::ArithDialect, math::MathDialect>(dialect)) {
15951600
// this is an arithmetic operation; we distinguish between cheap
@@ -1599,7 +1604,9 @@ void LayoutRematerialization::backwardRematerialization(
15991604
// multiple instructions.
16001605
int64_t multiplier = isExpensiveMathOp(op) ? 8 : 1;
16011606
for (Value result : op->getResults()) {
1602-
rematerialisationCost += multiplier * getByteCount(result);
1607+
int64_t bytes = getByteCount(result);
1608+
totalSliceBytes += bytes;
1609+
rematerialisationCost += multiplier * bytes;
16031610
}
16041611
} else if (isa<ReduceOp>(op)) {
16051612
// Reduce op introduce much cost.
@@ -1612,17 +1619,37 @@ void LayoutRematerialization::backwardRematerialization(
16121619
"slice");
16131620
return;
16141621
}
1622+
// Add input tensor bytes for reduce operations
1623+
for (Value operand : op->getOperands()) {
1624+
if (isa<RankedTensorType>(operand.getType())) {
1625+
totalSliceBytes += getByteCount(operand);
1626+
}
1627+
}
16151628
rematerialisationCost += helper.getIntraWarpSizeWithUniqueData();
16161629
rematerialisationCost += 8 * helper.getInterWarpSizeWithUniqueData();
16171630
}
16181631
}
16191632

1633+
// Normalize rematerialization cost relative to convert layout data size.
1634+
// If slice processes similar amount of data as convert, costs should be
1635+
// comparable.
1636+
double normalizationFactor = 1.0;
1637+
if (totalSliceBytes > 0 && convertLayoutBytes > 0) {
1638+
normalizationFactor =
1639+
static_cast<double>(convertLayoutBytes) / totalSliceBytes;
1640+
// Clamp to avoid extreme values
1641+
normalizationFactor = std::clamp(normalizationFactor, 0.1, 10.0);
1642+
}
1643+
1644+
int64_t normalizedRematerialisationCost =
1645+
static_cast<int64_t>(rematerialisationCost * normalizationFactor);
1646+
16201647
LLVM_DEBUG({
16211648
DBGS() << " convert layout cost: " << convertLayoutCost << "\n";
1622-
DBGS() << " rematerialisation cost: " << rematerialisationCost << "\n";
1649+
DBGS() << " rematerialisation cost: " << normalizedRematerialisationCost
1650+
<< "\n";
16231651
});
1624-
1625-
if (rematerialisationCost > convertLayoutCost) {
1652+
if (normalizedRematerialisationCost > convertLayoutCost) {
16261653
LDBG(" skipped rematerialization due to higher cost");
16271654
return;
16281655
}

0 commit comments

Comments
 (0)