diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h index b1fbf4477428c..7355509807b4d 100644 --- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h @@ -139,9 +139,11 @@ struct MemRefDependenceGraph { // Map from node id to Node. DenseMap nodes; - // Map from node id to list of input edges. + // Map from node id to list of input edges. The absence of an entry for a key + // is also equivalent to the absence of any edges. DenseMap> inEdges; - // Map from node id to list of output edges. + // Map from node id to list of output edges. The absence of an entry for a + // node is also equivalent to the absence of any edges. DenseMap> outEdges; // Map from memref to a count on the dependence edges associated with that // memref. @@ -156,10 +158,21 @@ struct MemRefDependenceGraph { bool init(); // Returns the graph node for 'id'. - Node *getNode(unsigned id); + const Node *getNode(unsigned id) const; + Node *getNode(unsigned id) { + return const_cast( + static_cast(this)->getNode(id)); + } + + // Returns true if the graph has node with ID `id`. + bool hasNode(unsigned id) const { return nodes.contains(id); } // Returns the graph node for 'forOp'. - Node *getForOpNode(AffineForOp forOp); + const Node *getForOpNode(AffineForOp forOp) const; + Node *getForOpNode(AffineForOp forOp) { + return const_cast( + static_cast(this)->getForOpNode(forOp)); + } // Adds a node with 'op' to the graph and returns its unique identifier. unsigned addNode(Operation *op); @@ -169,12 +182,12 @@ struct MemRefDependenceGraph { // Returns true if node 'id' writes to any memref which escapes (or is an // argument to) the block. Returns false otherwise. - bool writesToLiveInOrEscapingMemrefs(unsigned id); + bool writesToLiveInOrEscapingMemrefs(unsigned id) const; // Returns true iff there is an edge from node 'srcId' to node 'dstId' which // is for 'value' if non-null, or for any value otherwise. Returns false // otherwise. - bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr); + bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr) const; // Adds an edge from node 'srcId' to node 'dstId' for 'value'. void addEdge(unsigned srcId, unsigned dstId, Value value); @@ -185,23 +198,25 @@ struct MemRefDependenceGraph { // Returns true if there is a path in the dependence graph from node 'srcId' // to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the // operations that the edges connected are expected to be from the same block. - bool hasDependencePath(unsigned srcId, unsigned dstId); + bool hasDependencePath(unsigned srcId, unsigned dstId) const; // Returns the input edge count for node 'id' and 'memref' from src nodes // which access 'memref' with a store operation. - unsigned getIncomingMemRefAccesses(unsigned id, Value memref); + unsigned getIncomingMemRefAccesses(unsigned id, Value memref) const; // Returns the output edge count for node 'id' and 'memref' (if non-null), // otherwise returns the total output edge count from node 'id'. - unsigned getOutEdgeCount(unsigned id, Value memref = nullptr); + unsigned getOutEdgeCount(unsigned id, Value memref = nullptr) const; /// Return all nodes which define SSA values used in node 'id'. - void gatherDefiningNodes(unsigned id, DenseSet &definingNodes); + void gatherDefiningNodes(unsigned id, + DenseSet &definingNodes) const; // Computes and returns an insertion point operation, before which the // the fused loop nest can be inserted while preserving // dependences. Returns nullptr if no such insertion point is found. - Operation *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId); + Operation *getFusedLoopNestInsertionPoint(unsigned srcId, + unsigned dstId) const; // Updates edge mappings from node 'srcId' to node 'dstId' after fusing them, // taking into account that: diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp index 9c0b5dbf52d29..92e7667ff2c72 100644 --- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp @@ -187,8 +187,9 @@ static void getEffectedValues(Operation *op, SmallVectorImpl &values) { /// Add `op` to MDG creating a new node and adding its memory accesses (affine /// or non-affine to memrefAccesses (memref -> list of nodes with accesses) map. -Node *addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg, - DenseMap> &memrefAccesses) { +static Node * +addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg, + DenseMap> &memrefAccesses) { auto &nodes = mdg.nodes; // Create graph node 'id' to represent top-level 'forOp' and record // all loads and store accesses it contains. @@ -358,14 +359,14 @@ bool MemRefDependenceGraph::init() { } // Returns the graph node for 'id'. -Node *MemRefDependenceGraph::getNode(unsigned id) { +const Node *MemRefDependenceGraph::getNode(unsigned id) const { auto it = nodes.find(id); assert(it != nodes.end()); return &it->second; } // Returns the graph node for 'forOp'. -Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) { +const Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) const { for (auto &idAndNode : nodes) if (idAndNode.second.op == forOp) return &idAndNode.second; @@ -389,7 +390,7 @@ void MemRefDependenceGraph::removeNode(unsigned id) { } } // Remove each edge in 'outEdges[id]'. - if (outEdges.count(id) > 0) { + if (outEdges.contains(id)) { SmallVector oldOutEdges = outEdges[id]; for (auto &outEdge : oldOutEdges) { removeEdge(id, outEdge.id, outEdge.value); @@ -403,8 +404,8 @@ void MemRefDependenceGraph::removeNode(unsigned id) { // Returns true if node 'id' writes to any memref which escapes (or is an // argument to) the block. Returns false otherwise. -bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) { - Node *node = getNode(id); +bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) const { + const Node *node = getNode(id); for (auto *storeOpInst : node->stores) { auto memref = cast(storeOpInst).getMemRef(); auto *op = memref.getDefiningOp(); @@ -424,14 +425,14 @@ bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) { // is for 'value' if non-null, or for any value otherwise. Returns false // otherwise. bool MemRefDependenceGraph::hasEdge(unsigned srcId, unsigned dstId, - Value value) { - if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) { + Value value) const { + if (!outEdges.contains(srcId) || !inEdges.contains(dstId)) { return false; } - bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) { + bool hasOutEdge = llvm::any_of(outEdges.lookup(srcId), [=](const Edge &edge) { return edge.id == dstId && (!value || edge.value == value); }); - bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) { + bool hasInEdge = llvm::any_of(inEdges.lookup(dstId), [=](const Edge &edge) { return edge.id == srcId && (!value || edge.value == value); }); return hasOutEdge && hasInEdge; @@ -476,7 +477,8 @@ void MemRefDependenceGraph::removeEdge(unsigned srcId, unsigned dstId, // Returns true if there is a path in the dependence graph from node 'srcId' // to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the // operations that the edges connected are expected to be from the same block. -bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) { +bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, + unsigned dstId) const { // Worklist state is: SmallVector, 4> worklist; worklist.push_back({srcId, 0}); @@ -489,13 +491,13 @@ bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) { return true; // Pop and continue if node has no out edges, or if all out edges have // already been visited. - if (outEdges.count(idAndIndex.first) == 0 || - idAndIndex.second == outEdges[idAndIndex.first].size()) { + if (!outEdges.contains(idAndIndex.first) || + idAndIndex.second == outEdges.lookup(idAndIndex.first).size()) { worklist.pop_back(); continue; } // Get graph edge to traverse. - Edge edge = outEdges[idAndIndex.first][idAndIndex.second]; + const Edge edge = outEdges.lookup(idAndIndex.first)[idAndIndex.second]; // Increment next output edge index for 'idAndIndex'. ++idAndIndex.second; // Add node at 'edge.id' to the worklist. We don't need to consider @@ -511,34 +513,34 @@ bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) { // Returns the input edge count for node 'id' and 'memref' from src nodes // which access 'memref' with a store operation. unsigned MemRefDependenceGraph::getIncomingMemRefAccesses(unsigned id, - Value memref) { + Value memref) const { unsigned inEdgeCount = 0; - if (inEdges.count(id) > 0) - for (auto &inEdge : inEdges[id]) - if (inEdge.value == memref) { - Node *srcNode = getNode(inEdge.id); - // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref' - if (srcNode->getStoreOpCount(memref) > 0) - ++inEdgeCount; - } + for (const Edge &inEdge : inEdges.lookup(id)) { + if (inEdge.value == memref) { + const Node *srcNode = getNode(inEdge.id); + // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref' + if (srcNode->getStoreOpCount(memref) > 0) + ++inEdgeCount; + } + } return inEdgeCount; } // Returns the output edge count for node 'id' and 'memref' (if non-null), // otherwise returns the total output edge count from node 'id'. -unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id, Value memref) { +unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id, + Value memref) const { unsigned outEdgeCount = 0; - if (outEdges.count(id) > 0) - for (auto &outEdge : outEdges[id]) - if (!memref || outEdge.value == memref) - ++outEdgeCount; + for (const auto &outEdge : outEdges.lookup(id)) + if (!memref || outEdge.value == memref) + ++outEdgeCount; return outEdgeCount; } /// Return all nodes which define SSA values used in node 'id'. void MemRefDependenceGraph::gatherDefiningNodes( - unsigned id, DenseSet &definingNodes) { - for (MemRefDependenceGraph::Edge edge : inEdges[id]) + unsigned id, DenseSet &definingNodes) const { + for (const Edge &edge : inEdges.lookup(id)) // By definition of edge, if the edge value is a non-memref value, // then the dependence is between a graph node which defines an SSA value // and another graph node which uses the SSA value. @@ -551,8 +553,8 @@ void MemRefDependenceGraph::gatherDefiningNodes( // dependences. Returns nullptr if no such insertion point is found. Operation * MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId, - unsigned dstId) { - if (outEdges.count(srcId) == 0) + unsigned dstId) const { + if (!outEdges.contains(srcId)) return getNode(dstId)->op; // Skip if there is any defining node of 'dstId' that depends on 'srcId'. @@ -568,13 +570,13 @@ MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId, // Build set of insts in range (srcId, dstId) which depend on 'srcId'. SmallPtrSet srcDepInsts; - for (auto &outEdge : outEdges[srcId]) + for (auto &outEdge : outEdges.lookup(srcId)) if (outEdge.id != dstId) srcDepInsts.insert(getNode(outEdge.id)->op); // Build set of insts in range (srcId, dstId) on which 'dstId' depends. SmallPtrSet dstDepInsts; - for (auto &inEdge : inEdges[dstId]) + for (auto &inEdge : inEdges.lookup(dstId)) if (inEdge.id != srcId) dstDepInsts.insert(getNode(inEdge.id)->op); @@ -634,7 +636,7 @@ void MemRefDependenceGraph::updateEdges(unsigned srcId, unsigned dstId, SmallVector oldInEdges = inEdges[srcId]; for (auto &inEdge : oldInEdges) { // Add edge from 'inEdge.id' to 'dstId' if it's not a private memref. - if (privateMemRefs.count(inEdge.value) == 0) + if (!privateMemRefs.contains(inEdge.value)) addEdge(inEdge.id, dstId, inEdge.value); } } diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp index c22ec213be95c..13915aee83f98 100644 --- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp @@ -78,13 +78,13 @@ struct LoopFusion : public affine::impl::AffineLoopFusionBase { static bool canRemoveSrcNodeAfterFusion( unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice, Operation *fusedLoopInsPoint, const DenseSet &escapingMemRefs, - MemRefDependenceGraph *mdg) { + const MemRefDependenceGraph &mdg) { - Operation *dstNodeOp = mdg->getNode(dstId)->op; + Operation *dstNodeOp = mdg.getNode(dstId)->op; bool hasOutDepsAfterFusion = false; - for (auto &outEdge : mdg->outEdges[srcId]) { - Operation *depNodeOp = mdg->getNode(outEdge.id)->op; + for (auto &outEdge : mdg.outEdges.lookup(srcId)) { + Operation *depNodeOp = mdg.getNode(outEdge.id)->op; // Skip dependence with dstOp since it will be removed after fusion. if (depNodeOp == dstNodeOp) continue; @@ -134,22 +134,23 @@ static bool canRemoveSrcNodeAfterFusion( /// held if the 'mdg' is reused from a previous fusion step or if the node /// creation order changes in the future to support more advance cases. // TODO: Move this to a loop fusion utility once 'mdg' is also moved. -static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg, +static void getProducerCandidates(unsigned dstId, + const MemRefDependenceGraph &mdg, SmallVectorImpl &srcIdCandidates) { // Skip if no input edges along which to fuse. - if (mdg->inEdges.count(dstId) == 0) + if (mdg.inEdges.count(dstId) == 0) return; // Gather memrefs from loads in 'dstId'. - auto *dstNode = mdg->getNode(dstId); + auto *dstNode = mdg.getNode(dstId); DenseSet consumedMemrefs; for (Operation *load : dstNode->loads) consumedMemrefs.insert(cast(load).getMemRef()); // Traverse 'dstId' incoming edges and gather the nodes that contain a store // to one of the consumed memrefs. - for (auto &srcEdge : mdg->inEdges[dstId]) { - auto *srcNode = mdg->getNode(srcEdge.id); + for (const auto &srcEdge : mdg.inEdges.lookup(dstId)) { + const auto *srcNode = mdg.getNode(srcEdge.id); // Skip if 'srcNode' is not a loop nest. if (!isa(srcNode->op)) continue; @@ -169,10 +170,10 @@ static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg, /// producer-consumer dependence between 'srcId' and 'dstId'. static void gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId, - MemRefDependenceGraph *mdg, + const MemRefDependenceGraph &mdg, DenseSet &producerConsumerMemrefs) { - auto *dstNode = mdg->getNode(dstId); - auto *srcNode = mdg->getNode(srcId); + auto *dstNode = mdg.getNode(dstId); + auto *srcNode = mdg.getNode(srcId); gatherProducerConsumerMemrefs(srcNode->stores, dstNode->loads, producerConsumerMemrefs); } @@ -214,14 +215,14 @@ static bool isEscapingMemref(Value memref, Block *block) { /// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id' /// that escape the block or are accessed in a non-affine way. -static void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg, +static void gatherEscapingMemrefs(unsigned id, const MemRefDependenceGraph &mdg, DenseSet &escapingMemRefs) { - auto *node = mdg->getNode(id); + auto *node = mdg.getNode(id); for (Operation *storeOp : node->stores) { auto memref = cast(storeOp).getMemRef(); if (escapingMemRefs.count(memref)) continue; - if (isEscapingMemref(memref, &mdg->block)) + if (isEscapingMemref(memref, &mdg.block)) escapingMemRefs.insert(memref); } } @@ -787,7 +788,7 @@ struct GreedyFusion { // in 'srcIdCandidates'. dstNodeChanged = false; SmallVector srcIdCandidates; - getProducerCandidates(dstId, mdg, srcIdCandidates); + getProducerCandidates(dstId, *mdg, srcIdCandidates); for (unsigned srcId : llvm::reverse(srcIdCandidates)) { // Get 'srcNode' from which to attempt fusion into 'dstNode'. @@ -802,7 +803,7 @@ struct GreedyFusion { continue; DenseSet producerConsumerMemrefs; - gatherProducerConsumerMemrefs(srcId, dstId, mdg, + gatherProducerConsumerMemrefs(srcId, dstId, *mdg, producerConsumerMemrefs); // Skip if 'srcNode' out edge count on any memref is greater than @@ -817,7 +818,7 @@ struct GreedyFusion { // block (e.g., memref block arguments, returned memrefs, // memrefs passed to function calls, etc.). DenseSet srcEscapingMemRefs; - gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs); + gatherEscapingMemrefs(srcNode->id, *mdg, srcEscapingMemRefs); // Compute an operation list insertion point for the fused loop // nest which preserves dependences. @@ -911,7 +912,7 @@ struct GreedyFusion { // insertion point. bool removeSrcNode = canRemoveSrcNodeAfterFusion( srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs, - mdg); + *mdg); DenseSet privateMemrefs; for (Value memref : producerConsumerMemrefs) {