From c64c0b7708a363960eacfe66e631fbef7ef5bb00 Mon Sep 17 00:00:00 2001 From: Christian Ulmann Date: Fri, 17 May 2024 11:16:56 +0000 Subject: [PATCH 1/7] first impl of the new topo sort --- mlir/include/mlir/Analysis/SliceAnalysis.h | 1 + mlir/lib/Analysis/SliceAnalysis.cpp | 126 +++++++++----- mlir/test/Analysis/test-topoligical-sort.mlir | 53 ++++-- mlir/test/Dialect/Affine/slicing-utils.mlir | 160 +++++++++--------- mlir/test/lib/Analysis/TestSlice.cpp | 28 ++- 5 files changed, 211 insertions(+), 157 deletions(-) diff --git a/mlir/include/mlir/Analysis/SliceAnalysis.h b/mlir/include/mlir/Analysis/SliceAnalysis.h index d5cdf72c3889f..19571fc1946be 100644 --- a/mlir/include/mlir/Analysis/SliceAnalysis.h +++ b/mlir/include/mlir/Analysis/SliceAnalysis.h @@ -226,6 +226,7 @@ getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions = {}, /// Multi-root DAG topological sort. /// Performs a topological sort of the Operation in the `toSort` SetVector. /// Returns a topologically sorted SetVector. +/// Does not support multi-sets. SetVector topologicalSort(const SetVector &toSort); /// Utility to match a generic reduction given a list of iteration-carried diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index 26fe8e3dc0819..f93183749dfd2 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -11,10 +11,13 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/IR/Block.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/RegionGraphTraits.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Support/LLVM.h" +#include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" @@ -164,60 +167,95 @@ mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions, return topologicalSort(slice); } -namespace { -/// DFS post-order implementation that maintains a global count to work across -/// multiple invocations, to help implement topological sort on multi-root DAGs. -/// We traverse all operations but only record the ones that appear in -/// `toSort` for the final result. -struct DFSState { - DFSState(const SetVector &set) : toSort(set), seen() {} - const SetVector &toSort; - SmallVector topologicalCounts; - DenseSet seen; -}; -} // namespace - -static void dfsPostorder(Operation *root, DFSState *state) { - SmallVector queue(1, root); - std::vector ops; - while (!queue.empty()) { - Operation *current = queue.pop_back_val(); - ops.push_back(current); - for (Operation *op : current->getUsers()) - queue.push_back(op); - for (Region ®ion : current->getRegions()) { - for (Operation &op : region.getOps()) - queue.push_back(&op); +/// TODO: deduplicate +static SetVector getTopologicallySortedBlocks(Region ®ion) { + // For each block that has not been visited yet (i.e. that has no + // predecessors), add it to the list as well as its successors. + SetVector blocks; + for (Block &b : region) { + if (blocks.count(&b) == 0) { + llvm::ReversePostOrderTraversal traversal(&b); + blocks.insert(traversal.begin(), traversal.end()); } } + assert(blocks.size() == region.getBlocks().size() && + "some blocks are not sorted"); - for (Operation *op : llvm::reverse(ops)) { - if (state->seen.insert(op).second && state->toSort.count(op) > 0) - state->topologicalCounts.push_back(op); + return blocks; +} + +/// Computes the common ancestor region of all operations in `ops`. Remembers +/// all the traversed regions in `traversedRegions`. +static Region *findCommonParentRegion(const SetVector &ops, + DenseSet &traversedRegions) { + // Map to count the number of times a region was encountered. + llvm::DenseMap regionCounts; + size_t expectedCount = ops.size(); + + // Walk the region tree for each operation towards the root and add to the + // region count. + Region *res = nullptr; + for (Operation *op : ops) { + Region *current = op->getParentRegion(); + while (current) { + // Insert or get the count. + auto it = regionCounts.try_emplace(current, 0).first; + size_t count = ++it->getSecond(); + if (count == expectedCount) { + res = current; + break; + } + current = current->getParentRegion(); + } + } + auto firstRange = llvm::make_first_range(regionCounts); + traversedRegions.insert(firstRange.begin(), firstRange.end()); + return res; +} + +/// Topologically traverses `region` and insers all encountered operations in +/// `toSort` into the result. Recursively traverses regions when they are +/// present in `relevantRegions`. +static void topoSortRegion(Region ®ion, + const DenseSet &relevantRegions, + const SetVector &toSort, + SetVector &result) { + SetVector sortedBlocks = getTopologicallySortedBlocks(region); + for (Block *block : sortedBlocks) { + for (Operation &op : *block) { + if (toSort.contains(&op)) + result.insert(&op); + for (Region &subRegion : op.getRegions()) { + // Skip regions that do not contain operations from `toSort`. + if (!relevantRegions.contains(®ion)) + continue; + topoSortRegion(subRegion, relevantRegions, toSort, result); + } + } } } SetVector mlir::topologicalSort(const SetVector &toSort) { - if (toSort.empty()) { + if (toSort.size() <= 1) return toSort; - } - // Run from each root with global count and `seen` set. - DFSState state(toSort); - for (auto *s : toSort) { - assert(toSort.count(s) == 1 && "NYI: multi-sets not supported"); - dfsPostorder(s, &state); - } - - // Reorder and return. - SetVector res; - for (auto it = state.topologicalCounts.rbegin(), - eit = state.topologicalCounts.rend(); - it != eit; ++it) { - res.insert(*it); - } - return res; + assert(llvm::all_of(toSort, + [&](Operation *op) { return toSort.count(op) == 1; }) && + "expected only unique set entries"); + + // First, find the root region to start the recursive traversal through the + // IR. + DenseSet relevantRegions; + Region *rootRegion = findCommonParentRegion(toSort, relevantRegions); + assert(rootRegion && "expected all ops to have a common ancestor"); + + // Sort all element in `toSort` by recursively traversing the IR. + SetVector result; + topoSortRegion(*rootRegion, relevantRegions, toSort, result); + assert(result.size() == toSort.size() && + "expected all operations to be present in the result"); + return result; } /// Returns true if `value` (transitively) depends on iteration-carried values diff --git a/mlir/test/Analysis/test-topoligical-sort.mlir b/mlir/test/Analysis/test-topoligical-sort.mlir index 8608586402055..150aff854fc8f 100644 --- a/mlir/test/Analysis/test-topoligical-sort.mlir +++ b/mlir/test/Analysis/test-topoligical-sort.mlir @@ -1,21 +1,38 @@ -// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(test-print-topological-sort))" 2>&1 | FileCheck %s +// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(test-print-topological-sort))" --split-input-file | FileCheck %s -// CHECK-LABEL: Testing : region -// CHECK: arith.addi {{.*}} : index -// CHECK-NEXT: scf.for -// CHECK: } {__test_sort_original_idx__ = 2 : i64} -// CHECK-NEXT: arith.addi {{.*}} : i32 -// CHECK-NEXT: arith.subi {{.*}} : i32 -func.func @region( - %arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index, - %arg4 : i32, %arg5 : i32, %arg6 : i32, - %buffer : memref) { - %0 = arith.addi %arg4, %arg5 {__test_sort_original_idx__ = 0} : i32 - %idx = arith.addi %arg0, %arg1 {__test_sort_original_idx__ = 3} : index - scf.for %arg7 = %idx to %arg2 step %arg3 { - %2 = arith.addi %0, %arg5 : i32 - %3 = arith.subi %2, %arg6 {__test_sort_original_idx__ = 1} : i32 - memref.store %3, %buffer[] : memref - } {__test_sort_original_idx__ = 2} +// CHECK-LABEL: single_element +func.func @single_element() { + // CHECK: test_sort_index = 0 + return {test_to_sort} +} + +// ----- + +// CHECK-LABEL: @simple_region +func.func @simple_region(%cond: i1) { + // CHECK: test_sort_index = 0 + %0 = arith.constant {test_to_sort} 42 : i32 + scf.if %cond { + %1 = arith.addi %0, %0 : i32 + // CHECK: test_sort_index = 2 + %2 = arith.subi %0, %1 {test_to_sort} : i32 + // CHECK: test_sort_index = 1 + } {test_to_sort} + return +} + +// ----- + +// CHECK-LABEL: @multi_region +func.func @multi_region(%cond: i1) { + scf.if %cond { + // CHECK: test_sort_index = 0 + %0 = arith.constant {test_to_sort} 42 : i32 + } + + scf.if %cond { + // CHECK: test_sort_index = 1 + %0 = arith.constant {test_to_sort} 24 : i32 + } return } diff --git a/mlir/test/Dialect/Affine/slicing-utils.mlir b/mlir/test/Dialect/Affine/slicing-utils.mlir index 74379978fdf8c..0848a924b9d96 100644 --- a/mlir/test/Dialect/Affine/slicing-utils.mlir +++ b/mlir/test/Dialect/Affine/slicing-utils.mlir @@ -28,15 +28,15 @@ func.func @slicing_test() { // BWD: matched: %[[v1:.*]] {{.*}} backward static slice: // // FWDBWD: matched: %[[v1:.*]] {{.*}} static slice: - // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4 - // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3 - // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6 - // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2 - // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1 - // FWDBWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5 - // FWDBWD-DAG: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8 - // FWDBWD-DAG: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7 - // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9 + // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1 + // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2 + // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3 + // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4 + // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5 + // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6 + // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7 + // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8 + // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9 %1 = "slicing-test-op" () : () -> i1 @@ -49,15 +49,15 @@ func.func @slicing_test() { // BWD: matched: %[[v2:.*]] {{.*}} backward static slice: // // FWDBWD-NEXT: matched: %[[v2:.*]] {{.*}} static slice: - // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4 - // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3 - // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6 - // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2 - // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1 - // FWDBWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5 - // FWDBWD-DAG: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8 - // FWDBWD-DAG: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7 - // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9 + // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1 + // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2 + // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3 + // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4 + // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5 + // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6 + // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7 + // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8 + // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9 %2 = "slicing-test-op" () : () -> i2 @@ -69,15 +69,15 @@ func.func @slicing_test() { // BWD: matched: %[[v3:.*]] {{.*}} backward static slice: // // FWDBWD-NEXT: matched: %[[v3:.*]] {{.*}} static slice: - // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2 - // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1 - // FWDBWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5 - // FWDBWD-NEXT: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7 - // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4 - // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3 - // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6 - // FWDBWD-NEXT: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8 - // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9 + // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1 + // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2 + // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3 + // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4 + // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5 + // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6 + // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7 + // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8 + // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9 %3 = "slicing-test-op" () : () -> i3 @@ -89,15 +89,15 @@ func.func @slicing_test() { // BWD: matched: %[[v4:.*]] {{.*}} backward static slice: // // FWDBWD-NEXT: matched: %[[v4:.*]] {{.*}} static slice: - // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2 - // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1 - // FWDBWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5 - // FWDBWD-NEXT: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7 - // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4 - // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3 - // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6 - // FWDBWD-NEXT: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8 - // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9 + // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1 + // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2 + // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3 + // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4 + // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5 + // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6 + // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7 + // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8 + // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9 %4 = "slicing-test-op" () : () -> i4 @@ -111,15 +111,15 @@ func.func @slicing_test() { // BWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2 // // FWDBWD-NEXT: matched: %[[v5:.*]] {{.*}} static slice: - // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4 - // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3 - // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6 - // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2 - // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1 - // FWDBWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5 - // FWDBWD-DAG: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8 - // FWDBWD-DAG: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7 - // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9 + // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1 + // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2 + // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3 + // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4 + // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5 + // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6 + // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7 + // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8 + // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9 %5 = "slicing-test-op" (%1, %2) : (i1, i2) -> i5 @@ -132,15 +132,15 @@ func.func @slicing_test() { // BWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4 // // FWDBWD-NEXT: matched: %[[v6:.*]] {{.*}} static slice: - // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2 - // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1 - // FWDBWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5 - // FWDBWD-NEXT: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7 - // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4 - // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3 - // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6 - // FWDBWD-NEXT: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8 - // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9 + // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1 + // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2 + // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3 + // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4 + // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5 + // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6 + // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7 + // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8 + // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9 %6 = "slicing-test-op" (%3, %4) : (i3, i4) -> i6 @@ -153,15 +153,15 @@ func.func @slicing_test() { // BWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5 // // FWDBWD-NEXT: matched: %[[v7:.*]] {{.*}} static slice: - // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4 - // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3 - // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6 - // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2 - // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1 + // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1 + // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2 + // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3 + // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4 // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5 - // FWDBWD-DAG: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8 - // FWDBWD-DAG: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7 - // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9 + // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6 + // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7 + // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8 + // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9 %7 = "slicing-test-op" (%1, %5) : (i1, i5) -> i7 @@ -177,15 +177,15 @@ func.func @slicing_test() { // BWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6 // // FWDBWD-NEXT: matched: %[[v8:.*]] {{.*}} static slice: - // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4 - // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3 - // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6 - // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2 - // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1 - // FWDBWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5 - // FWDBWD-DAG: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8 - // FWDBWD-DAG: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7 - // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9 + // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1 + // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2 + // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3 + // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4 + // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5 + // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6 + // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7 + // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8 + // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9 %8 = "slicing-test-op" (%5, %6) : (i5, i6) -> i8 @@ -202,15 +202,15 @@ func.func @slicing_test() { // BWD-NEXT: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8 // // FWDBWD-NEXT: matched: %[[v9:.*]] {{.*}} static slice: - // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4 - // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3 - // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6 - // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2 - // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1 - // FWDBWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5 - // FWDBWD-DAG: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8 - // FWDBWD-DAG: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7 - // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9 + // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1 + // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2 + // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3 + // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4 + // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5 + // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6 + // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7 + // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8 + // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9 %9 = "slicing-test-op" (%7, %8) : (i7, i8) -> i9 diff --git a/mlir/test/lib/Analysis/TestSlice.cpp b/mlir/test/lib/Analysis/TestSlice.cpp index b445febde5971..06c41d8c4a110 100644 --- a/mlir/test/lib/Analysis/TestSlice.cpp +++ b/mlir/test/lib/Analysis/TestSlice.cpp @@ -1,4 +1,4 @@ -//===------------- TestSlice.cpp - Test slice related analisis ------------===// +//===- TestSlice.cpp - Test slice related analisis ------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -7,12 +7,14 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Pass/Pass.h" using namespace mlir; -static const StringLiteral kOrderMarker = "__test_sort_original_idx__"; +static const StringLiteral kToSortMark = "test_to_sort"; +static const StringLiteral kOrderIndex = "test_sort_index"; namespace { @@ -26,20 +28,16 @@ struct TestTopologicalSortPass return "Print operations in topological order"; } void runOnOperation() override { - std::map ops; - getOperation().walk([&ops](Operation *op) { - if (auto originalOrderAttr = op->getAttrOfType(kOrderMarker)) - ops[originalOrderAttr.getInt()] = op; + SetVector toSort; + getOperation().walk([&](Operation *op) { + if (op->hasAttrOfType(kToSortMark)) + toSort.insert(op); }); - SetVector sortedOp; - for (auto op : ops) - sortedOp.insert(op.second); - sortedOp = topologicalSort(sortedOp); - llvm::errs() << "Testing : " << getOperation().getName() << "\n"; - for (Operation *op : sortedOp) { - op->print(llvm::errs()); - llvm::errs() << "\n"; - } + + auto i32Type = IntegerType::get(&getContext(), 32); + SetVector sortedOps = topologicalSort(toSort); + for (auto [index, op] : llvm::enumerate(sortedOps)) + op->setAttr(kOrderIndex, IntegerAttr::get(i32Type, index)); } }; From 5e181bd00cbe1cab023e496f00d3340fbf31cdbe Mon Sep 17 00:00:00 2001 From: Christian Ulmann Date: Fri, 17 May 2024 11:42:08 +0000 Subject: [PATCH 2/7] move topo utils to analysis --- .../mlir/{Transforms => Analysis}/TopologicalSortUtils.h | 6 +++--- mlir/lib/Analysis/CMakeLists.txt | 2 ++ .../{Transforms/Utils => Analysis}/TopologicalSortUtils.cpp | 4 ++-- mlir/lib/Transforms/TopologicalSort.cpp | 2 +- mlir/lib/Transforms/Utils/CMakeLists.txt | 1 - mlir/lib/Transforms/Utils/RegionUtils.cpp | 2 +- mlir/lib/Transforms/ViewOpGraph.cpp | 2 +- mlir/test/{Transforms => Analysis}/test-toposort.mlir | 0 mlir/test/lib/Analysis/CMakeLists.txt | 1 + .../lib/{Transforms => Analysis}/TestTopologicalSort.cpp | 2 +- mlir/test/lib/Transforms/CMakeLists.txt | 1 - utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 1 - 12 files changed, 12 insertions(+), 12 deletions(-) rename mlir/include/mlir/{Transforms => Analysis}/TopologicalSortUtils.h (95%) rename mlir/lib/{Transforms/Utils => Analysis}/TopologicalSortUtils.cpp (97%) rename mlir/test/{Transforms => Analysis}/test-toposort.mlir (100%) rename mlir/test/lib/{Transforms => Analysis}/TestTopologicalSort.cpp (98%) diff --git a/mlir/include/mlir/Transforms/TopologicalSortUtils.h b/mlir/include/mlir/Analysis/TopologicalSortUtils.h similarity index 95% rename from mlir/include/mlir/Transforms/TopologicalSortUtils.h rename to mlir/include/mlir/Analysis/TopologicalSortUtils.h index 74e44b1dc485d..fb9441db119fd 100644 --- a/mlir/include/mlir/Transforms/TopologicalSortUtils.h +++ b/mlir/include/mlir/Analysis/TopologicalSortUtils.h @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H -#define MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H +#ifndef MLIR_ANALYSIS_TOPOLOGICALSORTUTILS_H +#define MLIR_ANALYSIS_TOPOLOGICALSORTUTILS_H #include "mlir/IR/Block.h" @@ -106,4 +106,4 @@ bool computeTopologicalSorting( } // end namespace mlir -#endif // MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H +#endif // MLIR_ANALYSIS_TOPOLOGICALSORTUTILS_H diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt index 005814ddbec79..38d8415d81c72 100644 --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -6,6 +6,7 @@ set(LLVM_OPTIONAL_SOURCES Liveness.cpp CFGLoopInfo.cpp SliceAnalysis.cpp + TopologicalSortUtils.cpp AliasAnalysis/LocalAliasAnalysis.cpp @@ -28,6 +29,7 @@ add_mlir_library(MLIRAnalysis Liveness.cpp CFGLoopInfo.cpp SliceAnalysis.cpp + TopologicalSortUtils.cpp AliasAnalysis/LocalAliasAnalysis.cpp diff --git a/mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp b/mlir/lib/Analysis/TopologicalSortUtils.cpp similarity index 97% rename from mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp rename to mlir/lib/Analysis/TopologicalSortUtils.cpp index f3a9d217f2c98..4281beacee89e 100644 --- a/mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp +++ b/mlir/lib/Analysis/TopologicalSortUtils.cpp @@ -1,4 +1,4 @@ -//===- TopologicalSortUtils.h - Topological sort utilities ------*- C++ -*-===// +//===- TopologicalSortUtils.cpp - Topological sort utilities --------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Transforms/TopologicalSortUtils.h" +#include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/IR/OpDefinition.h" using namespace mlir; diff --git a/mlir/lib/Transforms/TopologicalSort.cpp b/mlir/lib/Transforms/TopologicalSort.cpp index 1219968fb3692..528f6ef676020 100644 --- a/mlir/lib/Transforms/TopologicalSort.cpp +++ b/mlir/lib/Transforms/TopologicalSort.cpp @@ -8,8 +8,8 @@ #include "mlir/Transforms/Passes.h" +#include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/IR/RegionKindInterface.h" -#include "mlir/Transforms/TopologicalSortUtils.h" namespace mlir { #define GEN_PASS_DEF_TOPOLOGICALSORT diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt index d6aac0e2da4f5..b5788c679edc4 100644 --- a/mlir/lib/Transforms/Utils/CMakeLists.txt +++ b/mlir/lib/Transforms/Utils/CMakeLists.txt @@ -10,7 +10,6 @@ add_mlir_library(MLIRTransformUtils LoopInvariantCodeMotionUtils.cpp OneToNTypeConversion.cpp RegionUtils.cpp - TopologicalSortUtils.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp index 192f59b353295..b6a6dea5fe9a0 100644 --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Transforms/RegionUtils.h" +#include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/IR/Block.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Operation.h" @@ -15,7 +16,6 @@ #include "mlir/IR/Value.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Transforms/TopologicalSortUtils.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/PostOrderIterator.h" diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp index c2eb2b893cea4..b3c0a06c96fea 100644 --- a/mlir/lib/Transforms/ViewOpGraph.cpp +++ b/mlir/lib/Transforms/ViewOpGraph.cpp @@ -8,12 +8,12 @@ #include "mlir/Transforms/ViewOpGraph.h" +#include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/IR/Block.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/IndentedOstream.h" -#include "mlir/Transforms/TopologicalSortUtils.h" #include "llvm/Support/Format.h" #include "llvm/Support/GraphWriter.h" #include diff --git a/mlir/test/Transforms/test-toposort.mlir b/mlir/test/Analysis/test-toposort.mlir similarity index 100% rename from mlir/test/Transforms/test-toposort.mlir rename to mlir/test/Analysis/test-toposort.mlir diff --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt index d168888c1e71e..7c6b31ae8b73e 100644 --- a/mlir/test/lib/Analysis/CMakeLists.txt +++ b/mlir/test/lib/Analysis/CMakeLists.txt @@ -10,6 +10,7 @@ add_mlir_library(MLIRTestAnalysis TestMemRefDependenceCheck.cpp TestMemRefStrideCalculation.cpp TestSlice.cpp + TestTopologicalSort.cpp DataFlow/TestDeadCodeAnalysis.cpp DataFlow/TestDenseBackwardDataFlowAnalysis.cpp diff --git a/mlir/test/lib/Transforms/TestTopologicalSort.cpp b/mlir/test/lib/Analysis/TestTopologicalSort.cpp similarity index 98% rename from mlir/test/lib/Transforms/TestTopologicalSort.cpp rename to mlir/test/lib/Analysis/TestTopologicalSort.cpp index 3b110c7126200..c7e0206b2a4d7 100644 --- a/mlir/test/lib/Transforms/TestTopologicalSort.cpp +++ b/mlir/test/lib/Analysis/TestTopologicalSort.cpp @@ -6,10 +6,10 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/TopologicalSortUtils.h" using namespace mlir; diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt index a849b7ebd29e2..975a41ac3d5fe 100644 --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -26,7 +26,6 @@ add_mlir_library(MLIRTestTransforms TestInlining.cpp TestIntRangeInference.cpp TestMakeIsolatedFromAbove.cpp - TestTopologicalSort.cpp ${MLIRTestTransformsPDLSrc} EXCLUDE_FROM_LIBMLIR diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index fc449e9010ae4..971c851a5f89f 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -7597,7 +7597,6 @@ cc_library( "include/mlir/Transforms/LoopInvariantCodeMotionUtils.h", "include/mlir/Transforms/OneToNTypeConversion.h", "include/mlir/Transforms/RegionUtils.h", - "include/mlir/Transforms/TopologicalSortUtils.h", ], includes = ["include"], deps = [ From 277ed0ac9f4e2b8963defea5c54e2f0868103a37 Mon Sep 17 00:00:00 2001 From: Christian Ulmann Date: Fri, 17 May 2024 12:02:50 +0000 Subject: [PATCH 3/7] migrate block topo sort --- .../mlir/Analysis/TopologicalSortUtils.h | 4 ++++ mlir/include/mlir/Transforms/RegionUtils.h | 4 ---- mlir/lib/Analysis/SliceAnalysis.cpp | 23 ++----------------- mlir/lib/Analysis/TopologicalSortUtils.cpp | 21 +++++++++++++++++ .../ArmSME/Transforms/TileAllocation.cpp | 1 + .../OpenACC/OpenACCToLLVMIRTranslation.cpp | 2 +- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 1 + mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 2 +- mlir/lib/Transforms/Mem2Reg.cpp | 2 +- mlir/lib/Transforms/Utils/RegionUtils.cpp | 17 -------------- 10 files changed, 32 insertions(+), 45 deletions(-) diff --git a/mlir/include/mlir/Analysis/TopologicalSortUtils.h b/mlir/include/mlir/Analysis/TopologicalSortUtils.h index fb9441db119fd..c2bc15ad3143f 100644 --- a/mlir/include/mlir/Analysis/TopologicalSortUtils.h +++ b/mlir/include/mlir/Analysis/TopologicalSortUtils.h @@ -104,6 +104,10 @@ bool computeTopologicalSorting( MutableArrayRef ops, function_ref isOperandReady = nullptr); +/// Get a list of blocks that is sorted according to dominance. This sort is +/// stable. +SetVector getBlocksSortedByDominance(Region ®ion); + } // end namespace mlir #endif // MLIR_ANALYSIS_TOPOLOGICALSORTUTILS_H diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h index f65d0d44eef42..06eebff201d1b 100644 --- a/mlir/include/mlir/Transforms/RegionUtils.h +++ b/mlir/include/mlir/Transforms/RegionUtils.h @@ -87,10 +87,6 @@ LogicalResult eraseUnreachableBlocks(RewriterBase &rewriter, LogicalResult runRegionDCE(RewriterBase &rewriter, MutableArrayRef regions); -/// Get a list of blocks that is sorted according to dominance. This sort is -/// stable. -SetVector getBlocksSortedByDominance(Region ®ion); - } // namespace mlir #endif // MLIR_TRANSFORMS_REGIONUTILS_H_ diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index f93183749dfd2..37dae769dbc6d 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -11,13 +11,11 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/IR/Block.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Operation.h" -#include "mlir/IR/RegionGraphTraits.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Support/LLVM.h" -#include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" @@ -167,23 +165,6 @@ mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions, return topologicalSort(slice); } -/// TODO: deduplicate -static SetVector getTopologicallySortedBlocks(Region ®ion) { - // For each block that has not been visited yet (i.e. that has no - // predecessors), add it to the list as well as its successors. - SetVector blocks; - for (Block &b : region) { - if (blocks.count(&b) == 0) { - llvm::ReversePostOrderTraversal traversal(&b); - blocks.insert(traversal.begin(), traversal.end()); - } - } - assert(blocks.size() == region.getBlocks().size() && - "some blocks are not sorted"); - - return blocks; -} - /// Computes the common ancestor region of all operations in `ops`. Remembers /// all the traversed regions in `traversedRegions`. static Region *findCommonParentRegion(const SetVector &ops, @@ -220,7 +201,7 @@ static void topoSortRegion(Region ®ion, const DenseSet &relevantRegions, const SetVector &toSort, SetVector &result) { - SetVector sortedBlocks = getTopologicallySortedBlocks(region); + SetVector sortedBlocks = getBlocksSortedByDominance(region); for (Block *block : sortedBlocks) { for (Operation &op : *block) { if (toSort.contains(&op)) diff --git a/mlir/lib/Analysis/TopologicalSortUtils.cpp b/mlir/lib/Analysis/TopologicalSortUtils.cpp index 4281beacee89e..f5e16e9a91fe5 100644 --- a/mlir/lib/Analysis/TopologicalSortUtils.cpp +++ b/mlir/lib/Analysis/TopologicalSortUtils.cpp @@ -7,7 +7,12 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/TopologicalSortUtils.h" +#include "mlir/IR/Block.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/RegionGraphTraits.h" + +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SetVector.h" using namespace mlir; @@ -146,3 +151,19 @@ bool mlir::computeTopologicalSorting( return allOpsScheduled; } + +SetVector mlir::getBlocksSortedByDominance(Region ®ion) { + // For each block that has not been visited yet (i.e. that has no + // predecessors), add it to the list as well as its successors. + SetVector blocks; + for (Block &b : region) { + if (blocks.count(&b) == 0) { + llvm::ReversePostOrderTraversal traversal(&b); + blocks.insert(traversal.begin(), traversal.end()); + } + } + assert(blocks.size() == region.getBlocks().size() && + "some blocks are not sorted"); + + return blocks; +} diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp index acbbbe9932e19..733e758b43907 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp @@ -46,6 +46,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/Liveness.h" +#include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Dialect/ArmSME/IR/ArmSME.h" #include "mlir/Dialect/ArmSME/Transforms/Passes.h" #include "mlir/Dialect/ArmSME/Transforms/Transforms.h" diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp index eeda245ce969f..d9cf85e4aecab 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h" +#include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/IR/BuiltinOps.h" @@ -19,7 +20,6 @@ #include "mlir/Support/LLVM.h" #include "mlir/Target/LLVMIR/Dialect/OpenMPCommon.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" -#include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 34b6903f8da07..9d125b7f11809 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// #include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h" +#include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/OpenMP/OpenMPInterfaces.h" diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index cf3257c8b9b87..1ec0736ec08bf 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -16,6 +16,7 @@ #include "AttrKindDetail.h" #include "DebugTranslation.h" #include "LoopAnnotationTranslation.h" +#include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h" @@ -33,7 +34,6 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" #include "mlir/Target/LLVMIR/TypeToLLVM.h" -#include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SetVector.h" diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp index e2e240ad865ce..a452cc3fae8ac 100644 --- a/mlir/lib/Transforms/Mem2Reg.cpp +++ b/mlir/lib/Transforms/Mem2Reg.cpp @@ -9,6 +9,7 @@ #include "mlir/Transforms/Mem2Reg.h" #include "mlir/Analysis/DataLayoutAnalysis.h" #include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/PatternMatch.h" @@ -16,7 +17,6 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Transforms/Passes.h" -#include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/GenericIteratedDominanceFrontier.h" diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp index b6a6dea5fe9a0..b5e641d39fc0a 100644 --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -19,7 +19,6 @@ #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/PostOrderIterator.h" -#include "llvm/ADT/SmallSet.h" #include @@ -836,19 +835,3 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter, return success(eliminatedBlocks || eliminatedOpsOrArgs || mergedIdenticalBlocks); } - -SetVector mlir::getBlocksSortedByDominance(Region ®ion) { - // For each block that has not been visited yet (i.e. that has no - // predecessors), add it to the list as well as its successors. - SetVector blocks; - for (Block &b : region) { - if (blocks.count(&b) == 0) { - llvm::ReversePostOrderTraversal traversal(&b); - blocks.insert(traversal.begin(), traversal.end()); - } - } - assert(blocks.size() == region.getBlocks().size() && - "some blocks are not sorted"); - - return blocks; -} From a269b086c64d009d2f3af2891b02c5daca7457a4 Mon Sep 17 00:00:00 2001 From: Christian Ulmann Date: Fri, 17 May 2024 13:56:09 +0000 Subject: [PATCH 4/7] move all topo sorts into the utils file --- mlir/include/mlir/Analysis/SliceAnalysis.h | 6 -- .../mlir/Analysis/TopologicalSortUtils.h | 4 + mlir/lib/Analysis/SliceAnalysis.cpp | 74 ------------------- mlir/lib/Analysis/TopologicalSortUtils.cpp | 74 +++++++++++++++++++ .../Conversion/VectorToGPU/VectorToGPU.cpp | 1 + .../Dialect/Affine/Utils/LoopFusionUtils.cpp | 1 + mlir/lib/Transforms/SROA.cpp | 1 + mlir/test/lib/Analysis/TestSlice.cpp | 2 +- 8 files changed, 82 insertions(+), 81 deletions(-) diff --git a/mlir/include/mlir/Analysis/SliceAnalysis.h b/mlir/include/mlir/Analysis/SliceAnalysis.h index 19571fc1946be..99279fdfe427c 100644 --- a/mlir/include/mlir/Analysis/SliceAnalysis.h +++ b/mlir/include/mlir/Analysis/SliceAnalysis.h @@ -223,12 +223,6 @@ SetVector getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions = {}, const ForwardSliceOptions &forwardSliceOptions = {}); -/// Multi-root DAG topological sort. -/// Performs a topological sort of the Operation in the `toSort` SetVector. -/// Returns a topologically sorted SetVector. -/// Does not support multi-sets. -SetVector topologicalSort(const SetVector &toSort); - /// Utility to match a generic reduction given a list of iteration-carried /// arguments, `iterCarriedArgs` and the position of the potential reduction /// argument within the list, `redPos`. If a reduction is matched, returns the diff --git a/mlir/include/mlir/Analysis/TopologicalSortUtils.h b/mlir/include/mlir/Analysis/TopologicalSortUtils.h index c2bc15ad3143f..7aabc5ee457c0 100644 --- a/mlir/include/mlir/Analysis/TopologicalSortUtils.h +++ b/mlir/include/mlir/Analysis/TopologicalSortUtils.h @@ -108,6 +108,10 @@ bool computeTopologicalSorting( /// stable. SetVector getBlocksSortedByDominance(Region ®ion); +/// Sorts all operation in `toSort` topologically while also region semantics. +/// Does not support multi-sets. +SetVector topologicalSort(const SetVector &toSort); + } // end namespace mlir #endif // MLIR_ANALYSIS_TOPOLOGICALSORTUTILS_H diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index 37dae769dbc6d..2b1cf411ceeee 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -165,80 +165,6 @@ mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions, return topologicalSort(slice); } -/// Computes the common ancestor region of all operations in `ops`. Remembers -/// all the traversed regions in `traversedRegions`. -static Region *findCommonParentRegion(const SetVector &ops, - DenseSet &traversedRegions) { - // Map to count the number of times a region was encountered. - llvm::DenseMap regionCounts; - size_t expectedCount = ops.size(); - - // Walk the region tree for each operation towards the root and add to the - // region count. - Region *res = nullptr; - for (Operation *op : ops) { - Region *current = op->getParentRegion(); - while (current) { - // Insert or get the count. - auto it = regionCounts.try_emplace(current, 0).first; - size_t count = ++it->getSecond(); - if (count == expectedCount) { - res = current; - break; - } - current = current->getParentRegion(); - } - } - auto firstRange = llvm::make_first_range(regionCounts); - traversedRegions.insert(firstRange.begin(), firstRange.end()); - return res; -} - -/// Topologically traverses `region` and insers all encountered operations in -/// `toSort` into the result. Recursively traverses regions when they are -/// present in `relevantRegions`. -static void topoSortRegion(Region ®ion, - const DenseSet &relevantRegions, - const SetVector &toSort, - SetVector &result) { - SetVector sortedBlocks = getBlocksSortedByDominance(region); - for (Block *block : sortedBlocks) { - for (Operation &op : *block) { - if (toSort.contains(&op)) - result.insert(&op); - for (Region &subRegion : op.getRegions()) { - // Skip regions that do not contain operations from `toSort`. - if (!relevantRegions.contains(®ion)) - continue; - topoSortRegion(subRegion, relevantRegions, toSort, result); - } - } - } -} - -SetVector -mlir::topologicalSort(const SetVector &toSort) { - if (toSort.size() <= 1) - return toSort; - - assert(llvm::all_of(toSort, - [&](Operation *op) { return toSort.count(op) == 1; }) && - "expected only unique set entries"); - - // First, find the root region to start the recursive traversal through the - // IR. - DenseSet relevantRegions; - Region *rootRegion = findCommonParentRegion(toSort, relevantRegions); - assert(rootRegion && "expected all ops to have a common ancestor"); - - // Sort all element in `toSort` by recursively traversing the IR. - SetVector result; - topoSortRegion(*rootRegion, relevantRegions, toSort, result); - assert(result.size() == toSort.size() && - "expected all operations to be present in the result"); - return result; -} - /// Returns true if `value` (transitively) depends on iteration-carried values /// of the given `ancestorOp`. static bool dependsOnCarriedVals(Value value, diff --git a/mlir/lib/Analysis/TopologicalSortUtils.cpp b/mlir/lib/Analysis/TopologicalSortUtils.cpp index f5e16e9a91fe5..2fc1bd582ef47 100644 --- a/mlir/lib/Analysis/TopologicalSortUtils.cpp +++ b/mlir/lib/Analysis/TopologicalSortUtils.cpp @@ -167,3 +167,77 @@ SetVector mlir::getBlocksSortedByDominance(Region ®ion) { return blocks; } + +/// Computes the common ancestor region of all operations in `ops`. Remembers +/// all the traversed regions in `traversedRegions`. +static Region *findCommonParentRegion(const SetVector &ops, + DenseSet &traversedRegions) { + // Map to count the number of times a region was encountered. + llvm::DenseMap regionCounts; + size_t expectedCount = ops.size(); + + // Walk the region tree for each operation towards the root and add to the + // region count. + Region *res = nullptr; + for (Operation *op : ops) { + Region *current = op->getParentRegion(); + while (current) { + // Insert or get the count. + auto it = regionCounts.try_emplace(current, 0).first; + size_t count = ++it->getSecond(); + if (count == expectedCount) { + res = current; + break; + } + current = current->getParentRegion(); + } + } + auto firstRange = llvm::make_first_range(regionCounts); + traversedRegions.insert(firstRange.begin(), firstRange.end()); + return res; +} + +/// Topologically traverses `region` and insers all encountered operations in +/// `toSort` into the result. Recursively traverses regions when they are +/// present in `relevantRegions`. +static void topoSortRegion(Region ®ion, + const DenseSet &relevantRegions, + const SetVector &toSort, + SetVector &result) { + SetVector sortedBlocks = getBlocksSortedByDominance(region); + for (Block *block : sortedBlocks) { + for (Operation &op : *block) { + if (toSort.contains(&op)) + result.insert(&op); + for (Region &subRegion : op.getRegions()) { + // Skip regions that do not contain operations from `toSort`. + if (!relevantRegions.contains(®ion)) + continue; + topoSortRegion(subRegion, relevantRegions, toSort, result); + } + } + } +} + +SetVector +mlir::topologicalSort(const SetVector &toSort) { + if (toSort.size() <= 1) + return toSort; + + assert(llvm::all_of(toSort, + [&](Operation *op) { return toSort.count(op) == 1; }) && + "expected only unique set entries"); + + // First, find the root region to start the recursive traversal through the + // IR. + DenseSet relevantRegions; + Region *rootRegion = findCommonParentRegion(toSort, relevantRegions); + assert(rootRegion && "expected all ops to have a common ancestor"); + + // Sort all element in `toSort` by recursively traversing the IR. + SetVector result; + topoSortRegion(*rootRegion, relevantRegions, toSort, result); + assert(result.size() == toSort.size() && + "expected all operations to be present in the result"); + return result; +} diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp index 332f0a2eecfcf..4496c2bc5fe8b 100644 --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -15,6 +15,7 @@ #include #include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" diff --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp index 84ae4b52dcf4e..7f3e43d0b4cd3 100644 --- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Affine/LoopFusionUtils.h" #include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h" #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" #include "mlir/Dialect/Affine/Analysis/Utils.h" diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp index 67cbade07bc94..39f7256fb789d 100644 --- a/mlir/lib/Transforms/SROA.cpp +++ b/mlir/lib/Transforms/SROA.cpp @@ -9,6 +9,7 @@ #include "mlir/Transforms/SROA.h" #include "mlir/Analysis/DataLayoutAnalysis.h" #include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Transforms/Passes.h" diff --git a/mlir/test/lib/Analysis/TestSlice.cpp b/mlir/test/lib/Analysis/TestSlice.cpp index 06c41d8c4a110..fc367c07ad863 100644 --- a/mlir/test/lib/Analysis/TestSlice.cpp +++ b/mlir/test/lib/Analysis/TestSlice.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Pass/Pass.h" From 9ea5fb929db8c4566abca69f5940a30b13de148d Mon Sep 17 00:00:00 2001 From: Christian Ulmann Date: Tue, 21 May 2024 09:01:26 +0000 Subject: [PATCH 5/7] address nit comments --- .../mlir/Analysis/TopologicalSortUtils.h | 6 ++-- mlir/lib/Analysis/TopologicalSortUtils.cpp | 34 ++++++++----------- mlir/lib/Transforms/SROA.cpp | 9 +++-- mlir/test/lib/Analysis/TestSlice.cpp | 3 +- 4 files changed, 26 insertions(+), 26 deletions(-) diff --git a/mlir/include/mlir/Analysis/TopologicalSortUtils.h b/mlir/include/mlir/Analysis/TopologicalSortUtils.h index 7aabc5ee457c0..ee98cd8cb380e 100644 --- a/mlir/include/mlir/Analysis/TopologicalSortUtils.h +++ b/mlir/include/mlir/Analysis/TopologicalSortUtils.h @@ -104,12 +104,12 @@ bool computeTopologicalSorting( MutableArrayRef ops, function_ref isOperandReady = nullptr); -/// Get a list of blocks that is sorted according to dominance. This sort is +/// Gets a list of blocks that is sorted according to dominance. This sort is /// stable. SetVector getBlocksSortedByDominance(Region ®ion); -/// Sorts all operation in `toSort` topologically while also region semantics. -/// Does not support multi-sets. +/// Sorts all operations in `toSort` topologically while also considering region +/// semantics. Does not support multi-sets. SetVector topologicalSort(const SetVector &toSort); } // end namespace mlir diff --git a/mlir/lib/Analysis/TopologicalSortUtils.cpp b/mlir/lib/Analysis/TopologicalSortUtils.cpp index 2fc1bd582ef47..94c403c07385b 100644 --- a/mlir/lib/Analysis/TopologicalSortUtils.cpp +++ b/mlir/lib/Analysis/TopologicalSortUtils.cpp @@ -168,12 +168,12 @@ SetVector mlir::getBlocksSortedByDominance(Region ®ion) { return blocks; } -/// Computes the common ancestor region of all operations in `ops`. Remembers -/// all the traversed regions in `traversedRegions`. -static Region *findCommonParentRegion(const SetVector &ops, - DenseSet &traversedRegions) { +/// Computes the closest common ancestor region of all operations in `ops`. +/// Remembers all the traversed regions in `traversedRegions`. +static Region *findCommonAncestorRegion(const SetVector &ops, + DenseSet &traversedRegions) { // Map to count the number of times a region was encountered. - llvm::DenseMap regionCounts; + DenseMap regionCounts; size_t expectedCount = ops.size(); // Walk the region tree for each operation towards the root and add to the @@ -182,10 +182,8 @@ static Region *findCommonParentRegion(const SetVector &ops, for (Operation *op : ops) { Region *current = op->getParentRegion(); while (current) { - // Insert or get the count. - auto it = regionCounts.try_emplace(current, 0).first; - size_t count = ++it->getSecond(); - if (count == expectedCount) { + // Insert or update the count and compare it. + if (++regionCounts[current] == expectedCount) { res = current; break; } @@ -197,11 +195,11 @@ static Region *findCommonParentRegion(const SetVector &ops, return res; } -/// Topologically traverses `region` and insers all encountered operations in +/// Topologically traverses `region` and inserts all encountered operations in /// `toSort` into the result. Recursively traverses regions when they are /// present in `relevantRegions`. static void topoSortRegion(Region ®ion, - const DenseSet &relevantRegions, + const DenseSet &ancestorRegions, const SetVector &toSort, SetVector &result) { SetVector sortedBlocks = getBlocksSortedByDominance(region); @@ -211,9 +209,9 @@ static void topoSortRegion(Region ®ion, result.insert(&op); for (Region &subRegion : op.getRegions()) { // Skip regions that do not contain operations from `toSort`. - if (!relevantRegions.contains(®ion)) + if (!ancestorRegions.contains(®ion)) continue; - topoSortRegion(subRegion, relevantRegions, toSort, result); + topoSortRegion(subRegion, ancestorRegions, toSort, result); } } } @@ -224,19 +222,15 @@ mlir::topologicalSort(const SetVector &toSort) { if (toSort.size() <= 1) return toSort; - assert(llvm::all_of(toSort, - [&](Operation *op) { return toSort.count(op) == 1; }) && - "expected only unique set entries"); - // First, find the root region to start the recursive traversal through the // IR. - DenseSet relevantRegions; - Region *rootRegion = findCommonParentRegion(toSort, relevantRegions); + DenseSet ancestorRegions; + Region *rootRegion = findCommonAncestorRegion(toSort, ancestorRegions); assert(rootRegion && "expected all ops to have a common ancestor"); // Sort all element in `toSort` by recursively traversing the IR. SetVector result; - topoSortRegion(*rootRegion, relevantRegions, toSort, result); + topoSortRegion(*rootRegion, ancestorRegions, toSort, result); assert(result.size() == toSort.size() && "expected all operations to be present in the result"); return result; diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp index 39f7256fb789d..9a7f4db2afe00 100644 --- a/mlir/lib/Transforms/SROA.cpp +++ b/mlir/lib/Transforms/SROA.cpp @@ -108,14 +108,19 @@ computeDestructuringInfo(DestructurableMemorySlot &slot, // An operation that has blocking uses must be promoted. If it is not // promotable, destructuring must fail. - if (!promotable) + if (!promotable) { + // user->emitError() << "not promotable"; return {}; + } SmallVector newBlockingUses; // If the operation decides it cannot deal with removing the blocking uses, // destructuring must fail. - if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses, dataLayout)) + if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses, + dataLayout)) { + // promotable->emitError() << "not removable"; return {}; + } // Then, register any new blocking uses for coming operations. for (OpOperand *blockingUse : newBlockingUses) { diff --git a/mlir/test/lib/Analysis/TestSlice.cpp b/mlir/test/lib/Analysis/TestSlice.cpp index fc367c07ad863..7e8320dbf3ec3 100644 --- a/mlir/test/lib/Analysis/TestSlice.cpp +++ b/mlir/test/lib/Analysis/TestSlice.cpp @@ -25,7 +25,8 @@ struct TestTopologicalSortPass StringRef getArgument() const final { return "test-print-topological-sort"; } StringRef getDescription() const final { - return "Print operations in topological order"; + return "Sorts operations topologically and attaches attributes with their " + "corresponding index in the ordering to them"; } void runOnOperation() override { SetVector toSort; From a6a8f71f0372c8e0c9fa6c6a68234b35d0a7e273 Mon Sep 17 00:00:00 2001 From: Christian Ulmann Date: Tue, 21 May 2024 15:36:22 +0000 Subject: [PATCH 6/7] rewrite into iterative algorithm --- mlir/lib/Analysis/TopologicalSortUtils.cpp | 160 +++++++++++++-------- 1 file changed, 104 insertions(+), 56 deletions(-) diff --git a/mlir/lib/Analysis/TopologicalSortUtils.cpp b/mlir/lib/Analysis/TopologicalSortUtils.cpp index 94c403c07385b..7e640bedee642 100644 --- a/mlir/lib/Analysis/TopologicalSortUtils.cpp +++ b/mlir/lib/Analysis/TopologicalSortUtils.cpp @@ -168,70 +168,118 @@ SetVector mlir::getBlocksSortedByDominance(Region ®ion) { return blocks; } -/// Computes the closest common ancestor region of all operations in `ops`. -/// Remembers all the traversed regions in `traversedRegions`. -static Region *findCommonAncestorRegion(const SetVector &ops, - DenseSet &traversedRegions) { - // Map to count the number of times a region was encountered. - DenseMap regionCounts; - size_t expectedCount = ops.size(); - - // Walk the region tree for each operation towards the root and add to the - // region count. - Region *res = nullptr; - for (Operation *op : ops) { - Region *current = op->getParentRegion(); - while (current) { - // Insert or update the count and compare it. - if (++regionCounts[current] == expectedCount) { - res = current; - break; +namespace { +class TopoSortHelper { +public: + explicit TopoSortHelper(const SetVector &toSort) + : toSort(toSort) {} + + /// Executes the topological sort of the operations this instance was + /// constructed with. This function will destroy the internal state of the + /// instance. + SetVector sort() { + if (toSort.size() <= 1) + // Note: Creates a copy on purpose. + return toSort; + + // First, find the root region to start the traversal through the IR. This + // additionally enriches the internal caches with all relevant ancestor + // regions and blocks. + Region *rootRegion = findCommonAncestorRegion(); + assert(rootRegion && "expected all ops to have a common ancestor"); + + // Sort all element in `toSort` by traversing the IR in the appropriate + // order. + SetVector result = topoSortRegion(*rootRegion); + assert(result.size() == toSort.size() && + "expected all operations to be present in the result"); + return result; + } + +private: + /// Computes the closest common ancestor region of all operations in `toSort`. + /// Remembers all the traversed regions in `ancestorRegions`. + Region *findCommonAncestorRegion() { + // Map to count the number of times a region was encountered. + DenseMap regionCounts; + size_t expectedCount = toSort.size(); + + // Walk the region tree for each operation towards the root and add to the + // region count. + Region *res = nullptr; + for (Operation *op : toSort) { + Region *current = op->getParentRegion(); + // Store the block as an ancestor block. + ancestorBlocks.insert(op->getBlock()); + while (current) { + + // Insert or update the count and compare it. + if (++regionCounts[current] == expectedCount) { + res = current; + break; + } + ancestorBlocks.insert(current->getParentOp()->getBlock()); + current = current->getParentRegion(); } - current = current->getParentRegion(); } + auto firstRange = llvm::make_first_range(regionCounts); + ancestorRegions.insert(firstRange.begin(), firstRange.end()); + return res; } - auto firstRange = llvm::make_first_range(regionCounts); - traversedRegions.insert(firstRange.begin(), firstRange.end()); - return res; -} -/// Topologically traverses `region` and inserts all encountered operations in -/// `toSort` into the result. Recursively traverses regions when they are -/// present in `relevantRegions`. -static void topoSortRegion(Region ®ion, - const DenseSet &ancestorRegions, - const SetVector &toSort, - SetVector &result) { - SetVector sortedBlocks = getBlocksSortedByDominance(region); - for (Block *block : sortedBlocks) { - for (Operation &op : *block) { - if (toSort.contains(&op)) - result.insert(&op); - for (Region &subRegion : op.getRegions()) { - // Skip regions that do not contain operations from `toSort`. - if (!ancestorRegions.contains(®ion)) - continue; - topoSortRegion(subRegion, ancestorRegions, toSort, result); + /// Performs the dominance respecting IR walk to collect the topological order + /// of the operation to sort. + SetVector topoSortRegion(Region &rootRegion) { + using StackT = PointerUnion; + + SetVector result; + // Stack that stores the different IR constructs to traverse. + SmallVector stack; + stack.push_back(&rootRegion); + + // Traverse the IR in a dominance respecting pre-order walk. + while (!stack.empty()) { + StackT current = stack.pop_back_val(); + if (auto *region = dyn_cast(current)) { + // A region's blocks need to be traversed in dominance order. + SetVector sortedBlocks = getBlocksSortedByDominance(*region); + for (Block *block : llvm::reverse(sortedBlocks)) + // Only add blocks to the stack that are ancestors of the operations + // to sort. + if (ancestorBlocks.contains(block)) + stack.push_back(block); + continue; } + + if (auto *block = dyn_cast(current)) { + // Add all of the blocks operations to the stack. + for (Operation &op : llvm::reverse(*block)) + stack.push_back(&op); + continue; + } + + auto *op = cast(current); + if (toSort.contains(op)) + result.insert(op); + + // Add all the subregions that are ancestors of the operations to sort. + for (Region &subRegion : op->getRegions()) + if (ancestorRegions.contains(&subRegion)) + stack.push_back(&subRegion); } + return result; } -} + + /// Operations to sort. + const SetVector &toSort; + /// Set containing all the ancestor regions of the operations to sort. + DenseSet ancestorRegions; + /// Set containing all the ancestor blocks of the operations to sort. + DenseSet ancestorBlocks; +}; +} // namespace SetVector mlir::topologicalSort(const SetVector &toSort) { - if (toSort.size() <= 1) - return toSort; - - // First, find the root region to start the recursive traversal through the - // IR. - DenseSet ancestorRegions; - Region *rootRegion = findCommonAncestorRegion(toSort, ancestorRegions); - assert(rootRegion && "expected all ops to have a common ancestor"); - - // Sort all element in `toSort` by recursively traversing the IR. - SetVector result; - topoSortRegion(*rootRegion, ancestorRegions, toSort, result); - assert(result.size() == toSort.size() && - "expected all operations to be present in the result"); - return result; + return TopoSortHelper(toSort).sort(); } From ffce4ce1501466a00736f182b939aa3d9a76456a Mon Sep 17 00:00:00 2001 From: Christian Ulmann Date: Wed, 22 May 2024 05:41:22 +0000 Subject: [PATCH 7/7] address additional review comments --- mlir/lib/Analysis/TopologicalSortUtils.cpp | 10 +++++----- mlir/lib/Transforms/SROA.cpp | 9 ++------- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Analysis/TopologicalSortUtils.cpp b/mlir/lib/Analysis/TopologicalSortUtils.cpp index 7e640bedee642..c406960fdecc3 100644 --- a/mlir/lib/Analysis/TopologicalSortUtils.cpp +++ b/mlir/lib/Analysis/TopologicalSortUtils.cpp @@ -178,9 +178,10 @@ class TopoSortHelper { /// constructed with. This function will destroy the internal state of the /// instance. SetVector sort() { - if (toSort.size() <= 1) + if (toSort.size() <= 1) { // Note: Creates a copy on purpose. return toSort; + } // First, find the root region to start the traversal through the IR. This // additionally enriches the internal caches with all relevant ancestor @@ -188,7 +189,7 @@ class TopoSortHelper { Region *rootRegion = findCommonAncestorRegion(); assert(rootRegion && "expected all ops to have a common ancestor"); - // Sort all element in `toSort` by traversing the IR in the appropriate + // Sort all elements in `toSort` by traversing the IR in the appropriate // order. SetVector result = topoSortRegion(*rootRegion); assert(result.size() == toSort.size() && @@ -198,7 +199,6 @@ class TopoSortHelper { private: /// Computes the closest common ancestor region of all operations in `toSort`. - /// Remembers all the traversed regions in `ancestorRegions`. Region *findCommonAncestorRegion() { // Map to count the number of times a region was encountered. DenseMap regionCounts; @@ -212,7 +212,6 @@ class TopoSortHelper { // Store the block as an ancestor block. ancestorBlocks.insert(op->getBlock()); while (current) { - // Insert or update the count and compare it. if (++regionCounts[current] == expectedCount) { res = current; @@ -243,11 +242,12 @@ class TopoSortHelper { if (auto *region = dyn_cast(current)) { // A region's blocks need to be traversed in dominance order. SetVector sortedBlocks = getBlocksSortedByDominance(*region); - for (Block *block : llvm::reverse(sortedBlocks)) + for (Block *block : llvm::reverse(sortedBlocks)) { // Only add blocks to the stack that are ancestors of the operations // to sort. if (ancestorBlocks.contains(block)) stack.push_back(block); + } continue; } diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp index 9a7f4db2afe00..39f7256fb789d 100644 --- a/mlir/lib/Transforms/SROA.cpp +++ b/mlir/lib/Transforms/SROA.cpp @@ -108,19 +108,14 @@ computeDestructuringInfo(DestructurableMemorySlot &slot, // An operation that has blocking uses must be promoted. If it is not // promotable, destructuring must fail. - if (!promotable) { - // user->emitError() << "not promotable"; + if (!promotable) return {}; - } SmallVector newBlockingUses; // If the operation decides it cannot deal with removing the blocking uses, // destructuring must fail. - if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses, - dataLayout)) { - // promotable->emitError() << "not removable"; + if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses, dataLayout)) return {}; - } // Then, register any new blocking uses for coming operations. for (OpOperand *blockingUse : newBlockingUses) {