Skip to content

[MLIR][Analysis] Consolidate topological sort utilities #92563

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions mlir/include/mlir/Analysis/SliceAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,6 @@ SetVector<Operation *>
getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions = {},
const ForwardSliceOptions &forwardSliceOptions = {});

/// Multi-root DAG topological sort.
/// Performs a topological sort of the Operation in the `toSort` SetVector.
/// Returns a topologically sorted SetVector.
SetVector<Operation *> topologicalSort(const SetVector<Operation *> &toSort);

/// Utility to match a generic reduction given a list of iteration-carried
/// arguments, `iterCarriedArgs` and the position of the potential reduction
/// argument within the list, `redPos`. If a reduction is matched, returns the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H
#define MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H
#ifndef MLIR_ANALYSIS_TOPOLOGICALSORTUTILS_H
#define MLIR_ANALYSIS_TOPOLOGICALSORTUTILS_H

#include "mlir/IR/Block.h"

Expand Down Expand Up @@ -104,6 +104,14 @@ bool computeTopologicalSorting(
MutableArrayRef<Operation *> ops,
function_ref<bool(Value, Operation *)> isOperandReady = nullptr);

/// Gets a list of blocks that is sorted according to dominance. This sort is
/// stable.
SetVector<Block *> getBlocksSortedByDominance(Region &region);

/// Sorts all operations in `toSort` topologically while also considering region
/// semantics. Does not support multi-sets.
SetVector<Operation *> topologicalSort(const SetVector<Operation *> &toSort);

} // end namespace mlir

#endif // MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H
#endif // MLIR_ANALYSIS_TOPOLOGICALSORTUTILS_H
4 changes: 0 additions & 4 deletions mlir/include/mlir/Transforms/RegionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,6 @@ LogicalResult eraseUnreachableBlocks(RewriterBase &rewriter,
LogicalResult runRegionDCE(RewriterBase &rewriter,
MutableArrayRef<Region> regions);

/// Get a list of blocks that is sorted according to dominance. This sort is
/// stable.
SetVector<Block *> getBlocksSortedByDominance(Region &region);

} // namespace mlir

#endif // MLIR_TRANSFORMS_REGIONUTILS_H_
2 changes: 2 additions & 0 deletions mlir/lib/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ set(LLVM_OPTIONAL_SOURCES
Liveness.cpp
CFGLoopInfo.cpp
SliceAnalysis.cpp
TopologicalSortUtils.cpp

AliasAnalysis/LocalAliasAnalysis.cpp

Expand All @@ -28,6 +29,7 @@ add_mlir_library(MLIRAnalysis
Liveness.cpp
CFGLoopInfo.cpp
SliceAnalysis.cpp
TopologicalSortUtils.cpp

AliasAnalysis/LocalAliasAnalysis.cpp

Expand Down
59 changes: 2 additions & 57 deletions mlir/lib/Analysis/SliceAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
//===----------------------------------------------------------------------===//

#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/LLVM.h"
Expand Down Expand Up @@ -164,62 +165,6 @@ mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
return topologicalSort(slice);
}

namespace {
/// DFS post-order implementation that maintains a global count to work across
/// multiple invocations, to help implement topological sort on multi-root DAGs.
/// We traverse all operations but only record the ones that appear in
/// `toSort` for the final result.
struct DFSState {
DFSState(const SetVector<Operation *> &set) : toSort(set), seen() {}
const SetVector<Operation *> &toSort;
SmallVector<Operation *, 16> topologicalCounts;
DenseSet<Operation *> seen;
};
} // namespace

static void dfsPostorder(Operation *root, DFSState *state) {
SmallVector<Operation *> queue(1, root);
std::vector<Operation *> ops;
while (!queue.empty()) {
Operation *current = queue.pop_back_val();
ops.push_back(current);
for (Operation *op : current->getUsers())
queue.push_back(op);
for (Region &region : current->getRegions()) {
for (Operation &op : region.getOps())
queue.push_back(&op);
}
}

for (Operation *op : llvm::reverse(ops)) {
if (state->seen.insert(op).second && state->toSort.count(op) > 0)
state->topologicalCounts.push_back(op);
}
}

SetVector<Operation *>
mlir::topologicalSort(const SetVector<Operation *> &toSort) {
if (toSort.empty()) {
return toSort;
}

// Run from each root with global count and `seen` set.
DFSState state(toSort);
for (auto *s : toSort) {
assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
dfsPostorder(s, &state);
}

// Reorder and return.
SetVector<Operation *> res;
for (auto it = state.topologicalCounts.rbegin(),
eit = state.topologicalCounts.rend();
it != eit; ++it) {
res.insert(*it);
}
return res;
}

/// Returns true if `value` (transitively) depends on iteration-carried values
/// of the given `ancestorOp`.
static bool dependsOnCarriedVals(Value value,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
//===- TopologicalSortUtils.h - Topological sort utilities ------*- C++ -*-===//
//===- TopologicalSortUtils.cpp - Topological sort utilities --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Transforms/TopologicalSortUtils.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/RegionGraphTraits.h"

#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/SetVector.h"

using namespace mlir;

Expand Down Expand Up @@ -146,3 +151,135 @@ bool mlir::computeTopologicalSorting(

return allOpsScheduled;
}

SetVector<Block *> mlir::getBlocksSortedByDominance(Region &region) {
// For each block that has not been visited yet (i.e. that has no
// predecessors), add it to the list as well as its successors.
SetVector<Block *> blocks;
for (Block &b : region) {
if (blocks.count(&b) == 0) {
llvm::ReversePostOrderTraversal<Block *> traversal(&b);
blocks.insert(traversal.begin(), traversal.end());
}
}
assert(blocks.size() == region.getBlocks().size() &&
"some blocks are not sorted");

return blocks;
}

namespace {
class TopoSortHelper {
public:
explicit TopoSortHelper(const SetVector<Operation *> &toSort)
: toSort(toSort) {}

/// Executes the topological sort of the operations this instance was
/// constructed with. This function will destroy the internal state of the
/// instance.
SetVector<Operation *> sort() {
Comment on lines +178 to +179
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// constructed with. This function will destroy the internal state of the
/// instance.
/// constructed with.

nit: We could implement this but currently it is not implemented?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a purely internal class that is not accessible from the outside. I therefore decided to not add clearing logic.

if (toSort.size() <= 1) {
// Note: Creates a copy on purpose.
return toSort;
}

// First, find the root region to start the traversal through the IR. This
// additionally enriches the internal caches with all relevant ancestor
// regions and blocks.
Region *rootRegion = findCommonAncestorRegion();
assert(rootRegion && "expected all ops to have a common ancestor");

// Sort all elements in `toSort` by traversing the IR in the appropriate
// order.
SetVector<Operation *> result = topoSortRegion(*rootRegion);
assert(result.size() == toSort.size() &&
"expected all operations to be present in the result");
return result;
}

private:
/// Computes the closest common ancestor region of all operations in `toSort`.
Region *findCommonAncestorRegion() {
// Map to count the number of times a region was encountered.
DenseMap<Region *, size_t> regionCounts;
size_t expectedCount = toSort.size();

// Walk the region tree for each operation towards the root and add to the
// region count.
Region *res = nullptr;
for (Operation *op : toSort) {
Region *current = op->getParentRegion();
// Store the block as an ancestor block.
ancestorBlocks.insert(op->getBlock());
while (current) {
// Insert or update the count and compare it.
if (++regionCounts[current] == expectedCount) {
res = current;
break;
}
ancestorBlocks.insert(current->getParentOp()->getBlock());
current = current->getParentRegion();
}
}
auto firstRange = llvm::make_first_range(regionCounts);
ancestorRegions.insert(firstRange.begin(), firstRange.end());
return res;
}

/// Performs the dominance respecting IR walk to collect the topological order
/// of the operation to sort.
SetVector<Operation *> topoSortRegion(Region &rootRegion) {
using StackT = PointerUnion<Region *, Block *, Operation *>;

SetVector<Operation *> result;
// Stack that stores the different IR constructs to traverse.
SmallVector<StackT> stack;
stack.push_back(&rootRegion);

// Traverse the IR in a dominance respecting pre-order walk.
while (!stack.empty()) {
StackT current = stack.pop_back_val();
if (auto *region = dyn_cast<Region *>(current)) {
// A region's blocks need to be traversed in dominance order.
SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(*region);
for (Block *block : llvm::reverse(sortedBlocks)) {
// Only add blocks to the stack that are ancestors of the operations
// to sort.
if (ancestorBlocks.contains(block))
stack.push_back(block);
}
continue;
}

if (auto *block = dyn_cast<Block *>(current)) {
// Add all of the blocks operations to the stack.
for (Operation &op : llvm::reverse(*block))
stack.push_back(&op);
continue;
}

auto *op = cast<Operation *>(current);
if (toSort.contains(op))
result.insert(op);

// Add all the subregions that are ancestors of the operations to sort.
for (Region &subRegion : op->getRegions())
if (ancestorRegions.contains(&subRegion))
stack.push_back(&subRegion);
}
return result;
}

/// Operations to sort.
const SetVector<Operation *> &toSort;
/// Set containing all the ancestor regions of the operations to sort.
DenseSet<Region *> ancestorRegions;
/// Set containing all the ancestor blocks of the operations to sort.
DenseSet<Block *> ancestorBlocks;
};
} // namespace

SetVector<Operation *>
mlir::topologicalSort(const SetVector<Operation *> &toSort) {
return TopoSortHelper(toSort).sort();
}
1 change: 1 addition & 0 deletions mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <type_traits>

#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "mlir/Dialect/Affine/LoopFusionUtils.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/Utils.h"
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/Analysis/Liveness.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
//===----------------------------------------------------------------------===//

#include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Target/LLVMIR/Dialect/OpenMPCommon.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "mlir/Transforms/RegionUtils.h"

#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "AttrKindDetail.h"
#include "DebugTranslation.h"
#include "LoopAnnotationTranslation.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
Expand All @@ -33,7 +34,6 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
#include "mlir/Transforms/RegionUtils.h"

#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/SetVector.h"
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Transforms/Mem2Reg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
#include "mlir/Transforms/Mem2Reg.h"
#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/GenericIteratedDominanceFrontier.h"

Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Transforms/SROA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/Transforms/SROA.h"
#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Transforms/Passes.h"

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Transforms/TopologicalSort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

#include "mlir/Transforms/Passes.h"

#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/RegionKindInterface.h"
#include "mlir/Transforms/TopologicalSortUtils.h"

namespace mlir {
#define GEN_PASS_DEF_TOPOLOGICALSORT
Expand Down
Loading
Loading