Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1573,9 +1573,12 @@ void LayoutRematerialization::backwardRematerialization(
// it to account for extra cost due to synchronisation.
// FIXME: measure cost of smem load/store and synchronisation on Intel GPUs,
// and refine this model further. (#5476)
int64_t convertLayoutCost = 32 * convertLayoutBytes * 2;
int64_t convertLayoutCost = 32 * convertLayoutBytes;
int64_t rematerialisationCost = 0;

// Track total bytes in slice for normalization
int64_t totalSliceBytes = 0;

// Evaluate single-use status for every operation in slice
for (Operation *op : sliceOps) {
auto dialect = op->getDialect();
Expand All @@ -1589,7 +1592,9 @@ void LayoutRematerialization::backwardRematerialization(
} else if (isa<LoadOp>(op) || isa<LocalLoadOp>(op)) {
// optimistically assume L1-cached:
for (Value result : op->getResults()) {
rematerialisationCost += 8 * getByteCount(result);
int64_t bytes = getByteCount(result);
totalSliceBytes += bytes;
rematerialisationCost += 8 * bytes;
}
} else if (isa<arith::ArithDialect, math::MathDialect>(dialect)) {
// this is an arithmetic operation; we distinguish between cheap
Expand All @@ -1599,7 +1604,9 @@ void LayoutRematerialization::backwardRematerialization(
// multiple instructions.
int64_t multiplier = isExpensiveMathOp(op) ? 8 : 1;
for (Value result : op->getResults()) {
rematerialisationCost += multiplier * getByteCount(result);
int64_t bytes = getByteCount(result);
totalSliceBytes += bytes;
rematerialisationCost += multiplier * bytes;
}
} else if (isa<ReduceOp>(op)) {
// Reduce op introduce much cost.
Expand All @@ -1612,17 +1619,37 @@ void LayoutRematerialization::backwardRematerialization(
"slice");
return;
}
// Add input tensor bytes for reduce operations
for (Value operand : op->getOperands()) {
if (isa<RankedTensorType>(operand.getType())) {
totalSliceBytes += getByteCount(operand);
}
}
rematerialisationCost += helper.getIntraWarpSizeWithUniqueData();
rematerialisationCost += 8 * helper.getInterWarpSizeWithUniqueData();
}
}

// Normalize rematerialization cost relative to convert layout data size.
// If slice processes similar amount of data as convert, costs should be
// comparable.
double normalizationFactor = 1.0;
if (totalSliceBytes > 0 && convertLayoutBytes > 0) {
normalizationFactor =
static_cast<double>(convertLayoutBytes) / totalSliceBytes;
// Clamp to avoid extreme values
normalizationFactor = std::clamp(normalizationFactor, 0.1, 10.0);
}

int64_t normalizedRematerialisationCost =
static_cast<int64_t>(rematerialisationCost * normalizationFactor);

LLVM_DEBUG({
DBGS() << " convert layout cost: " << convertLayoutCost << "\n";
DBGS() << " rematerialisation cost: " << rematerialisationCost << "\n";
DBGS() << " rematerialisation cost: " << normalizedRematerialisationCost
<< "\n";
});

if (rematerialisationCost > convertLayoutCost) {
if (normalizedRematerialisationCost > convertLayoutCost) {
LDBG(" skipped rematerialization due to higher cost");
return;
}
Expand Down
Loading