diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp index 448962be09..4f7af698fc 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp @@ -1573,9 +1573,12 @@ void LayoutRematerialization::backwardRematerialization( // it to account for extra cost due to synchronisation. // FIXME: measure cost of smem load/store and synchronisation on Intel GPUs, // and refine this model further. (#5476) - int64_t convertLayoutCost = 32 * convertLayoutBytes * 2; + int64_t convertLayoutCost = 32 * convertLayoutBytes; int64_t rematerialisationCost = 0; + // Track total bytes in slice for normalization + int64_t totalSliceBytes = 0; + // Evaluate single-use status for every operation in slice for (Operation *op : sliceOps) { auto dialect = op->getDialect(); @@ -1589,7 +1592,9 @@ void LayoutRematerialization::backwardRematerialization( } else if (isa(op) || isa(op)) { // optimistically assume L1-cached: for (Value result : op->getResults()) { - rematerialisationCost += 8 * getByteCount(result); + int64_t bytes = getByteCount(result); + totalSliceBytes += bytes; + rematerialisationCost += 8 * bytes; } } else if (isa(dialect)) { // this is an arithmetic operation; we distinguish between cheap @@ -1599,7 +1604,9 @@ void LayoutRematerialization::backwardRematerialization( // multiple instructions. int64_t multiplier = isExpensiveMathOp(op) ? 8 : 1; for (Value result : op->getResults()) { - rematerialisationCost += multiplier * getByteCount(result); + int64_t bytes = getByteCount(result); + totalSliceBytes += bytes; + rematerialisationCost += multiplier * bytes; } } else if (isa(op)) { // Reduce op introduce much cost. @@ -1612,17 +1619,37 @@ void LayoutRematerialization::backwardRematerialization( "slice"); return; } + // Add input tensor bytes for reduce operations + for (Value operand : op->getOperands()) { + if (isa(operand.getType())) { + totalSliceBytes += getByteCount(operand); + } + } rematerialisationCost += helper.getIntraWarpSizeWithUniqueData(); rematerialisationCost += 8 * helper.getInterWarpSizeWithUniqueData(); } } + // Normalize rematerialization cost relative to convert layout data size. + // If slice processes similar amount of data as convert, costs should be + // comparable. + double normalizationFactor = 1.0; + if (totalSliceBytes > 0 && convertLayoutBytes > 0) { + normalizationFactor = + static_cast(convertLayoutBytes) / totalSliceBytes; + // Clamp to avoid extreme values + normalizationFactor = std::clamp(normalizationFactor, 0.1, 10.0); + } + + int64_t normalizedRematerialisationCost = + static_cast(rematerialisationCost * normalizationFactor); + LLVM_DEBUG({ DBGS() << " convert layout cost: " << convertLayoutCost << "\n"; - DBGS() << " rematerialisation cost: " << rematerialisationCost << "\n"; + DBGS() << " rematerialisation cost: " << normalizedRematerialisationCost + << "\n"; }); - - if (rematerialisationCost > convertLayoutCost) { + if (normalizedRematerialisationCost > convertLayoutCost) { LDBG(" skipped rematerialization due to higher cost"); return; }