Skip to content

Commit b00e0c1

Browse files
authored
[MLIR][Analysis] Consolidate topological sort utilities (#92563)
This PR attempts to consolidate the different topological sort utilities into one place. It adds them to the analysis folder because the `SliceAnalysis` uses some of these. There are now two different sorting strategies: 1. Sort only according to SSA use-def chains 2. Sort while taking regions into account. This requires a much more elaborate traversal and cannot be applied on graph regions that easily. This additionally reimplements the region aware topological sorting because the previous implementation had an exponential space complexity. I'm open to suggestions on how to combine this further or how to fuse the test passes.
1 parent 874a5da commit b00e0c1

26 files changed

+298
-213
lines changed

mlir/include/mlir/Analysis/SliceAnalysis.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -223,11 +223,6 @@ SetVector<Operation *>
223223
getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions = {},
224224
const ForwardSliceOptions &forwardSliceOptions = {});
225225

226-
/// Multi-root DAG topological sort.
227-
/// Performs a topological sort of the Operation in the `toSort` SetVector.
228-
/// Returns a topologically sorted SetVector.
229-
SetVector<Operation *> topologicalSort(const SetVector<Operation *> &toSort);
230-
231226
/// Utility to match a generic reduction given a list of iteration-carried
232227
/// arguments, `iterCarriedArgs` and the position of the potential reduction
233228
/// argument within the list, `redPos`. If a reduction is matched, returns the

mlir/include/mlir/Transforms/TopologicalSortUtils.h renamed to mlir/include/mlir/Analysis/TopologicalSortUtils.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#ifndef MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H
10-
#define MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H
9+
#ifndef MLIR_ANALYSIS_TOPOLOGICALSORTUTILS_H
10+
#define MLIR_ANALYSIS_TOPOLOGICALSORTUTILS_H
1111

1212
#include "mlir/IR/Block.h"
1313

@@ -104,6 +104,14 @@ bool computeTopologicalSorting(
104104
MutableArrayRef<Operation *> ops,
105105
function_ref<bool(Value, Operation *)> isOperandReady = nullptr);
106106

107+
/// Gets a list of blocks that is sorted according to dominance. This sort is
108+
/// stable.
109+
SetVector<Block *> getBlocksSortedByDominance(Region &region);
110+
111+
/// Sorts all operations in `toSort` topologically while also considering region
112+
/// semantics. Does not support multi-sets.
113+
SetVector<Operation *> topologicalSort(const SetVector<Operation *> &toSort);
114+
107115
} // end namespace mlir
108116

109-
#endif // MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H
117+
#endif // MLIR_ANALYSIS_TOPOLOGICALSORTUTILS_H

mlir/include/mlir/Transforms/RegionUtils.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,6 @@ LogicalResult eraseUnreachableBlocks(RewriterBase &rewriter,
8787
LogicalResult runRegionDCE(RewriterBase &rewriter,
8888
MutableArrayRef<Region> regions);
8989

90-
/// Get a list of blocks that is sorted according to dominance. This sort is
91-
/// stable.
92-
SetVector<Block *> getBlocksSortedByDominance(Region &region);
93-
9490
} // namespace mlir
9591

9692
#endif // MLIR_TRANSFORMS_REGIONUTILS_H_

mlir/lib/Analysis/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ set(LLVM_OPTIONAL_SOURCES
66
Liveness.cpp
77
CFGLoopInfo.cpp
88
SliceAnalysis.cpp
9+
TopologicalSortUtils.cpp
910

1011
AliasAnalysis/LocalAliasAnalysis.cpp
1112

@@ -28,6 +29,7 @@ add_mlir_library(MLIRAnalysis
2829
Liveness.cpp
2930
CFGLoopInfo.cpp
3031
SliceAnalysis.cpp
32+
TopologicalSortUtils.cpp
3133

3234
AliasAnalysis/LocalAliasAnalysis.cpp
3335

mlir/lib/Analysis/SliceAnalysis.cpp

Lines changed: 2 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "mlir/Analysis/SliceAnalysis.h"
14-
#include "mlir/IR/BuiltinOps.h"
14+
#include "mlir/Analysis/TopologicalSortUtils.h"
15+
#include "mlir/IR/Block.h"
1516
#include "mlir/IR/Operation.h"
1617
#include "mlir/Interfaces/SideEffectInterfaces.h"
1718
#include "mlir/Support/LLVM.h"
@@ -164,62 +165,6 @@ mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
164165
return topologicalSort(slice);
165166
}
166167

167-
namespace {
168-
/// DFS post-order implementation that maintains a global count to work across
169-
/// multiple invocations, to help implement topological sort on multi-root DAGs.
170-
/// We traverse all operations but only record the ones that appear in
171-
/// `toSort` for the final result.
172-
struct DFSState {
173-
DFSState(const SetVector<Operation *> &set) : toSort(set), seen() {}
174-
const SetVector<Operation *> &toSort;
175-
SmallVector<Operation *, 16> topologicalCounts;
176-
DenseSet<Operation *> seen;
177-
};
178-
} // namespace
179-
180-
static void dfsPostorder(Operation *root, DFSState *state) {
181-
SmallVector<Operation *> queue(1, root);
182-
std::vector<Operation *> ops;
183-
while (!queue.empty()) {
184-
Operation *current = queue.pop_back_val();
185-
ops.push_back(current);
186-
for (Operation *op : current->getUsers())
187-
queue.push_back(op);
188-
for (Region &region : current->getRegions()) {
189-
for (Operation &op : region.getOps())
190-
queue.push_back(&op);
191-
}
192-
}
193-
194-
for (Operation *op : llvm::reverse(ops)) {
195-
if (state->seen.insert(op).second && state->toSort.count(op) > 0)
196-
state->topologicalCounts.push_back(op);
197-
}
198-
}
199-
200-
SetVector<Operation *>
201-
mlir::topologicalSort(const SetVector<Operation *> &toSort) {
202-
if (toSort.empty()) {
203-
return toSort;
204-
}
205-
206-
// Run from each root with global count and `seen` set.
207-
DFSState state(toSort);
208-
for (auto *s : toSort) {
209-
assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
210-
dfsPostorder(s, &state);
211-
}
212-
213-
// Reorder and return.
214-
SetVector<Operation *> res;
215-
for (auto it = state.topologicalCounts.rbegin(),
216-
eit = state.topologicalCounts.rend();
217-
it != eit; ++it) {
218-
res.insert(*it);
219-
}
220-
return res;
221-
}
222-
223168
/// Returns true if `value` (transitively) depends on iteration-carried values
224169
/// of the given `ancestorOp`.
225170
static bool dependsOnCarriedVals(Value value,

mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp renamed to mlir/lib/Analysis/TopologicalSortUtils.cpp

Lines changed: 139 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
1-
//===- TopologicalSortUtils.h - Topological sort utilities ------*- C++ -*-===//
1+
//===- TopologicalSortUtils.cpp - Topological sort utilities --------------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#include "mlir/Transforms/TopologicalSortUtils.h"
9+
#include "mlir/Analysis/TopologicalSortUtils.h"
10+
#include "mlir/IR/Block.h"
1011
#include "mlir/IR/OpDefinition.h"
12+
#include "mlir/IR/RegionGraphTraits.h"
13+
14+
#include "llvm/ADT/PostOrderIterator.h"
15+
#include "llvm/ADT/SetVector.h"
1116

1217
using namespace mlir;
1318

@@ -146,3 +151,135 @@ bool mlir::computeTopologicalSorting(
146151

147152
return allOpsScheduled;
148153
}
154+
155+
SetVector<Block *> mlir::getBlocksSortedByDominance(Region &region) {
156+
// For each block that has not been visited yet (i.e. that has no
157+
// predecessors), add it to the list as well as its successors.
158+
SetVector<Block *> blocks;
159+
for (Block &b : region) {
160+
if (blocks.count(&b) == 0) {
161+
llvm::ReversePostOrderTraversal<Block *> traversal(&b);
162+
blocks.insert(traversal.begin(), traversal.end());
163+
}
164+
}
165+
assert(blocks.size() == region.getBlocks().size() &&
166+
"some blocks are not sorted");
167+
168+
return blocks;
169+
}
170+
171+
namespace {
172+
class TopoSortHelper {
173+
public:
174+
explicit TopoSortHelper(const SetVector<Operation *> &toSort)
175+
: toSort(toSort) {}
176+
177+
/// Executes the topological sort of the operations this instance was
178+
/// constructed with. This function will destroy the internal state of the
179+
/// instance.
180+
SetVector<Operation *> sort() {
181+
if (toSort.size() <= 1) {
182+
// Note: Creates a copy on purpose.
183+
return toSort;
184+
}
185+
186+
// First, find the root region to start the traversal through the IR. This
187+
// additionally enriches the internal caches with all relevant ancestor
188+
// regions and blocks.
189+
Region *rootRegion = findCommonAncestorRegion();
190+
assert(rootRegion && "expected all ops to have a common ancestor");
191+
192+
// Sort all elements in `toSort` by traversing the IR in the appropriate
193+
// order.
194+
SetVector<Operation *> result = topoSortRegion(*rootRegion);
195+
assert(result.size() == toSort.size() &&
196+
"expected all operations to be present in the result");
197+
return result;
198+
}
199+
200+
private:
201+
/// Computes the closest common ancestor region of all operations in `toSort`.
202+
Region *findCommonAncestorRegion() {
203+
// Map to count the number of times a region was encountered.
204+
DenseMap<Region *, size_t> regionCounts;
205+
size_t expectedCount = toSort.size();
206+
207+
// Walk the region tree for each operation towards the root and add to the
208+
// region count.
209+
Region *res = nullptr;
210+
for (Operation *op : toSort) {
211+
Region *current = op->getParentRegion();
212+
// Store the block as an ancestor block.
213+
ancestorBlocks.insert(op->getBlock());
214+
while (current) {
215+
// Insert or update the count and compare it.
216+
if (++regionCounts[current] == expectedCount) {
217+
res = current;
218+
break;
219+
}
220+
ancestorBlocks.insert(current->getParentOp()->getBlock());
221+
current = current->getParentRegion();
222+
}
223+
}
224+
auto firstRange = llvm::make_first_range(regionCounts);
225+
ancestorRegions.insert(firstRange.begin(), firstRange.end());
226+
return res;
227+
}
228+
229+
/// Performs the dominance respecting IR walk to collect the topological order
230+
/// of the operation to sort.
231+
SetVector<Operation *> topoSortRegion(Region &rootRegion) {
232+
using StackT = PointerUnion<Region *, Block *, Operation *>;
233+
234+
SetVector<Operation *> result;
235+
// Stack that stores the different IR constructs to traverse.
236+
SmallVector<StackT> stack;
237+
stack.push_back(&rootRegion);
238+
239+
// Traverse the IR in a dominance respecting pre-order walk.
240+
while (!stack.empty()) {
241+
StackT current = stack.pop_back_val();
242+
if (auto *region = dyn_cast<Region *>(current)) {
243+
// A region's blocks need to be traversed in dominance order.
244+
SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(*region);
245+
for (Block *block : llvm::reverse(sortedBlocks)) {
246+
// Only add blocks to the stack that are ancestors of the operations
247+
// to sort.
248+
if (ancestorBlocks.contains(block))
249+
stack.push_back(block);
250+
}
251+
continue;
252+
}
253+
254+
if (auto *block = dyn_cast<Block *>(current)) {
255+
// Add all of the blocks operations to the stack.
256+
for (Operation &op : llvm::reverse(*block))
257+
stack.push_back(&op);
258+
continue;
259+
}
260+
261+
auto *op = cast<Operation *>(current);
262+
if (toSort.contains(op))
263+
result.insert(op);
264+
265+
// Add all the subregions that are ancestors of the operations to sort.
266+
for (Region &subRegion : op->getRegions())
267+
if (ancestorRegions.contains(&subRegion))
268+
stack.push_back(&subRegion);
269+
}
270+
return result;
271+
}
272+
273+
/// Operations to sort.
274+
const SetVector<Operation *> &toSort;
275+
/// Set containing all the ancestor regions of the operations to sort.
276+
DenseSet<Region *> ancestorRegions;
277+
/// Set containing all the ancestor blocks of the operations to sort.
278+
DenseSet<Block *> ancestorBlocks;
279+
};
280+
} // namespace
281+
282+
SetVector<Operation *>
283+
mlir::topologicalSort(const SetVector<Operation *> &toSort) {
284+
return TopoSortHelper(toSort).sort();
285+
}

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <type_traits>
1616

1717
#include "mlir/Analysis/SliceAnalysis.h"
18+
#include "mlir/Analysis/TopologicalSortUtils.h"
1819
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1920
#include "mlir/Dialect/Arith/IR/Arith.h"
2021
#include "mlir/Dialect/GPU/IR/GPUDialect.h"

mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "mlir/Dialect/Affine/LoopFusionUtils.h"
1414
#include "mlir/Analysis/SliceAnalysis.h"
15+
#include "mlir/Analysis/TopologicalSortUtils.h"
1516
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
1617
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
1718
#include "mlir/Dialect/Affine/Analysis/Utils.h"

mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
//===----------------------------------------------------------------------===//
4747

4848
#include "mlir/Analysis/Liveness.h"
49+
#include "mlir/Analysis/TopologicalSortUtils.h"
4950
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
5051
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
5152
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"

mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
//===----------------------------------------------------------------------===//
1313

1414
#include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h"
15+
#include "mlir/Analysis/TopologicalSortUtils.h"
1516
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1617
#include "mlir/Dialect/OpenACC/OpenACC.h"
1718
#include "mlir/IR/BuiltinOps.h"
1819
#include "mlir/IR/Operation.h"
1920
#include "mlir/Support/LLVM.h"
2021
#include "mlir/Target/LLVMIR/Dialect/OpenMPCommon.h"
2122
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
22-
#include "mlir/Transforms/RegionUtils.h"
2323

2424
#include "llvm/ADT/TypeSwitch.h"
2525
#include "llvm/Frontend/OpenMP/OMPConstants.h"

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313
#include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
14+
#include "mlir/Analysis/TopologicalSortUtils.h"
1415
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1516
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
1617
#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "AttrKindDetail.h"
1717
#include "DebugTranslation.h"
1818
#include "LoopAnnotationTranslation.h"
19+
#include "mlir/Analysis/TopologicalSortUtils.h"
1920
#include "mlir/Dialect/DLTI/DLTI.h"
2021
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
2122
#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
@@ -33,7 +34,6 @@
3334
#include "mlir/Support/LogicalResult.h"
3435
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
3536
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
36-
#include "mlir/Transforms/RegionUtils.h"
3737

3838
#include "llvm/ADT/PostOrderIterator.h"
3939
#include "llvm/ADT/SetVector.h"

mlir/lib/Transforms/Mem2Reg.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99
#include "mlir/Transforms/Mem2Reg.h"
1010
#include "mlir/Analysis/DataLayoutAnalysis.h"
1111
#include "mlir/Analysis/SliceAnalysis.h"
12+
#include "mlir/Analysis/TopologicalSortUtils.h"
1213
#include "mlir/IR/Builders.h"
1314
#include "mlir/IR/Dominance.h"
1415
#include "mlir/IR/PatternMatch.h"
1516
#include "mlir/IR/Value.h"
1617
#include "mlir/Interfaces/ControlFlowInterfaces.h"
1718
#include "mlir/Interfaces/MemorySlotInterfaces.h"
1819
#include "mlir/Transforms/Passes.h"
19-
#include "mlir/Transforms/RegionUtils.h"
2020
#include "llvm/ADT/STLExtras.h"
2121
#include "llvm/Support/GenericIteratedDominanceFrontier.h"
2222

mlir/lib/Transforms/SROA.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Transforms/SROA.h"
1010
#include "mlir/Analysis/DataLayoutAnalysis.h"
1111
#include "mlir/Analysis/SliceAnalysis.h"
12+
#include "mlir/Analysis/TopologicalSortUtils.h"
1213
#include "mlir/Interfaces/MemorySlotInterfaces.h"
1314
#include "mlir/Transforms/Passes.h"
1415

mlir/lib/Transforms/TopologicalSort.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
#include "mlir/Transforms/Passes.h"
1010

11+
#include "mlir/Analysis/TopologicalSortUtils.h"
1112
#include "mlir/IR/RegionKindInterface.h"
12-
#include "mlir/Transforms/TopologicalSortUtils.h"
1313

1414
namespace mlir {
1515
#define GEN_PASS_DEF_TOPOLOGICALSORT

0 commit comments

Comments
 (0)