Skip to content

Commit b24ec52

Browse files
masahirawnhenry
andauthored
[Coalesce] Fix the default order to be row major (#5707)
Taking over triton-lang/triton#4914 due to an inactivity As discussed there, when there are multiple "contiguity of 1" in the `contiguity` array, doing argsort on it means that the resulting `order` becomes ascending for those elements. In the unit test, `order = [2, 1, 0]` becomes `[0, 1, 2]`, which is odd. This convention seems arbitrary, so it is better to pick the row-major ordering by default in such case to be consistent with the rest of code. The current convention is "correct", but we get an additional `convert_layout`. Moreover, this order is inherited to the SMEM allocated during SWP, which could be problematic for other ops. For example, in my case I was getting the order `[4, 0, 1, 2, 3]` in SMEM for 5D blocked scales because only the innermost axis had a contiguity 4 while the rest were 1. @ThomasRaoux @pawelszczerbuk @Jokeren @rawnhenry --------- Co-authored-by: Rawn Henry <[email protected]> Co-authored-by: Masahiro Masuda <[email protected]>
1 parent e854fcd commit b24ec52

File tree

5 files changed

+25
-8
lines changed

5 files changed

+25
-8
lines changed

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,11 @@ SmallVector<unsigned, 3> mmaVersionToInstrShape(int version,
3333
// Return true if the Load uses block pointer.
3434
bool isLoadFromTensorPtr(triton::LoadOp op);
3535

36-
// Return an array of indices enumerating the elements of 'arr' in descending
37-
// order (so that result[i] is the index of the i-th largest element of 'arr')
38-
SmallVector<unsigned, 4> argSort(const SmallVector<int64_t> &arr);
36+
// Gets the order of a tensor from its contiguity. Places the dimensions with
37+
// the largest contiguity as the inner most dimension. If the contiguity is
38+
// all ones, returns the order {dim - 1, dim - 2, ..., 0}
39+
SmallVector<unsigned, 4>
40+
getOrderFromContiguity(const SmallVector<int64_t> &contiguity);
3941

4042
// Return the operand used to access the memory in the operation
4143
Value getMemAccessPtr(Operation *op);

lib/Dialect/TritonGPU/Transforms/Coalesce.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase<CoalescePass> {
3838
});
3939

4040
auto contiguity = axisInfoAnalysis.getAxisInfo(ptr)->getContiguity();
41-
SmallVector<unsigned> order = argSort(contiguity);
41+
SmallVector<unsigned> order = getOrderFromContiguity(contiguity);
4242
LDBG("order=[" << triton::join(order, ", ") << "]");
4343

4444
auto matchesShape = [&refTensorType](const Value &val) {
@@ -55,8 +55,8 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase<CoalescePass> {
5555
Value val = getMemAccessPtr(use);
5656
if (!val || !matchesShape(val) || memAccessesSameOrder.contains(use))
5757
continue;
58-
auto currOrder =
59-
argSort(axisInfoAnalysis.getAxisInfo(val)->getContiguity());
58+
auto currOrder = getOrderFromContiguity(
59+
axisInfoAnalysis.getAxisInfo(val)->getContiguity());
6060
if (order == currOrder) {
6161
LDBG("multi-root-slice: insert to memAccessesSameOrder " << *use);
6262
memAccessesSameOrder.insert(use);

lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ getBlockedEncoding(tt::LoadOp loadOp, tt::ModuleAxisInfoAnalysis &axisInfo) {
341341
int threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(mod);
342342
tt::AxisInfo::DimVectorT contiguity =
343343
axisInfo.getAxisInfo(src)->getContiguity();
344-
SmallVector<unsigned> order = argSort(contiguity);
344+
SmallVector<unsigned> order = getOrderFromContiguity(contiguity);
345345
unsigned currPerThread = getNumElementsPerThread(loadOp, order, axisInfo);
346346
SmallVector<unsigned> sizePerThread(order.size(), 1);
347347
sizePerThread[order[0]] = currPerThread;

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,11 @@ bool isLoadFromTensorPtr(triton::LoadOp op) {
8989
return mlir::triton::isTensorPointerType(op.getPtr().getType());
9090
}
9191

92-
SmallVector<unsigned, 4> argSort(const SmallVector<int64_t> &arr) {
92+
SmallVector<unsigned, 4>
93+
getOrderFromContiguity(const SmallVector<int64_t> &arr) {
9394
SmallVector<unsigned, 4> ret(arr.size());
9495
std::iota(ret.begin(), ret.end(), 0);
96+
std::reverse(ret.begin(), ret.end());
9597
std::stable_sort(ret.begin(), ret.end(),
9698
[&](unsigned x, unsigned y) { return arr[x] > arr[y]; });
9799
return ret;

test/TritonGPU/coalesce.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,16 @@ module {
160160
tt.return
161161
}
162162
}
163+
164+
// -----
165+
#blocked = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [2, 4, 4], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}>
166+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
167+
tt.func public @load_3D_contig_1(%arg: !tt.ptr<i8> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
168+
%50 = tt.splat %arg : !tt.ptr<i8> -> tensor<32x4x4x!tt.ptr<i8>, #blocked>
169+
// This checks that the pass picks the row-major ordering by default for elements with contiguity 1.
170+
// CHECK: #blocked = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [2, 4, 4], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}>
171+
// CHECK: tt.load %1 : tensor<32x4x4x!tt.ptr<i8>, #blocked>
172+
%108 = tt.load %50 : tensor<32x4x4x!tt.ptr<i8>, #blocked>
173+
tt.return
174+
}
175+
}

0 commit comments

Comments
 (0)