Skip to content

[MLIR][Affine] Make affine fusion MDG API const correct #125994

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 10, 2025

Conversation

bondhugula
Copy link
Contributor

Make affine fusion MDG API const correct. NFC changes otherwise.

@llvmbot
Copy link
Member

llvmbot commented Feb 6, 2025

@llvm/pr-subscribers-mlir-affine

@llvm/pr-subscribers-mlir

Author: Uday Bondhugula (bondhugula)

Changes

Make affine fusion MDG API const correct. NFC changes otherwise.


Full diff: https://github.com/llvm/llvm-project/pull/125994.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Affine/Analysis/Utils.h (+20-9)
  • (modified) mlir/lib/Dialect/Affine/Analysis/Utils.cpp (+31-28)
diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
index b1fbf4477428ca2..97cf29ce045ced6 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
@@ -156,10 +156,19 @@ struct MemRefDependenceGraph {
   bool init();
 
   // Returns the graph node for 'id'.
-  Node *getNode(unsigned id);
+  const Node *getNode(unsigned id) const;
+  Node *getNode(unsigned id) {
+    return const_cast<Node *>(
+        static_cast<const MemRefDependenceGraph *>(this)->getNode(id));
+  }
+  bool hasNode(unsigned id) const { return nodes.contains(id); }
 
   // Returns the graph node for 'forOp'.
-  Node *getForOpNode(AffineForOp forOp);
+  const Node *getForOpNode(AffineForOp forOp) const;
+  Node *getForOpNode(AffineForOp forOp) {
+    return const_cast<Node *>(
+        static_cast<const MemRefDependenceGraph *>(this)->getForOpNode(forOp));
+  }
 
   // Adds a node with 'op' to the graph and returns its unique identifier.
   unsigned addNode(Operation *op);
@@ -169,12 +178,12 @@ struct MemRefDependenceGraph {
 
   // Returns true if node 'id' writes to any memref which escapes (or is an
   // argument to) the block. Returns false otherwise.
-  bool writesToLiveInOrEscapingMemrefs(unsigned id);
+  bool writesToLiveInOrEscapingMemrefs(unsigned id) const;
 
   // Returns true iff there is an edge from node 'srcId' to node 'dstId' which
   // is for 'value' if non-null, or for any value otherwise. Returns false
   // otherwise.
-  bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr);
+  bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr) const;
 
   // Adds an edge from node 'srcId' to node 'dstId' for 'value'.
   void addEdge(unsigned srcId, unsigned dstId, Value value);
@@ -185,23 +194,25 @@ struct MemRefDependenceGraph {
   // Returns true if there is a path in the dependence graph from node 'srcId'
   // to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
   // operations that the edges connected are expected to be from the same block.
-  bool hasDependencePath(unsigned srcId, unsigned dstId);
+  bool hasDependencePath(unsigned srcId, unsigned dstId) const;
 
   // Returns the input edge count for node 'id' and 'memref' from src nodes
   // which access 'memref' with a store operation.
-  unsigned getIncomingMemRefAccesses(unsigned id, Value memref);
+  unsigned getIncomingMemRefAccesses(unsigned id, Value memref) const;
 
   // Returns the output edge count for node 'id' and 'memref' (if non-null),
   // otherwise returns the total output edge count from node 'id'.
-  unsigned getOutEdgeCount(unsigned id, Value memref = nullptr);
+  unsigned getOutEdgeCount(unsigned id, Value memref = nullptr) const;
 
   /// Return all nodes which define SSA values used in node 'id'.
-  void gatherDefiningNodes(unsigned id, DenseSet<unsigned> &definingNodes);
+  void gatherDefiningNodes(unsigned id,
+                           DenseSet<unsigned> &definingNodes) const;
 
   // Computes and returns an insertion point operation, before which the
   // the fused <srcId, dstId> loop nest can be inserted while preserving
   // dependences. Returns nullptr if no such insertion point is found.
-  Operation *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId);
+  Operation *getFusedLoopNestInsertionPoint(unsigned srcId,
+                                            unsigned dstId) const;
 
   // Updates edge mappings from node 'srcId' to node 'dstId' after fusing them,
   // taking into account that:
diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index 9c0b5dbf52d299b..54bb529041e0914 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -187,8 +187,9 @@ static void getEffectedValues(Operation *op, SmallVectorImpl<Value> &values) {
 
 /// Add `op` to MDG creating a new node and adding its memory accesses (affine
 /// or non-affine to memrefAccesses (memref -> list of nodes with accesses) map.
-Node *addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
-                   DenseMap<Value, SetVector<unsigned>> &memrefAccesses) {
+static Node *
+addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
+             DenseMap<Value, SetVector<unsigned>> &memrefAccesses) {
   auto &nodes = mdg.nodes;
   // Create graph node 'id' to represent top-level 'forOp' and record
   // all loads and store accesses it contains.
@@ -358,14 +359,14 @@ bool MemRefDependenceGraph::init() {
 }
 
 // Returns the graph node for 'id'.
-Node *MemRefDependenceGraph::getNode(unsigned id) {
+const Node *MemRefDependenceGraph::getNode(unsigned id) const {
   auto it = nodes.find(id);
   assert(it != nodes.end());
   return &it->second;
 }
 
 // Returns the graph node for 'forOp'.
-Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) {
+const Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) const {
   for (auto &idAndNode : nodes)
     if (idAndNode.second.op == forOp)
       return &idAndNode.second;
@@ -403,8 +404,8 @@ void MemRefDependenceGraph::removeNode(unsigned id) {
 
 // Returns true if node 'id' writes to any memref which escapes (or is an
 // argument to) the block. Returns false otherwise.
-bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) {
-  Node *node = getNode(id);
+bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) const {
+  const Node *node = getNode(id);
   for (auto *storeOpInst : node->stores) {
     auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
     auto *op = memref.getDefiningOp();
@@ -424,14 +425,14 @@ bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) {
 // is for 'value' if non-null, or for any value otherwise. Returns false
 // otherwise.
 bool MemRefDependenceGraph::hasEdge(unsigned srcId, unsigned dstId,
-                                    Value value) {
+                                    Value value) const {
   if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
     return false;
   }
-  bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
+  bool hasOutEdge = llvm::any_of(outEdges.lookup(srcId), [=](Edge &edge) {
     return edge.id == dstId && (!value || edge.value == value);
   });
-  bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
+  bool hasInEdge = llvm::any_of(inEdges.lookup(dstId), [=](Edge &edge) {
     return edge.id == srcId && (!value || edge.value == value);
   });
   return hasOutEdge && hasInEdge;
@@ -476,7 +477,8 @@ void MemRefDependenceGraph::removeEdge(unsigned srcId, unsigned dstId,
 // Returns true if there is a path in the dependence graph from node 'srcId'
 // to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
 // operations that the edges connected are expected to be from the same block.
-bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
+bool MemRefDependenceGraph::hasDependencePath(unsigned srcId,
+                                              unsigned dstId) const {
   // Worklist state is: <node-id, next-output-edge-index-to-visit>
   SmallVector<std::pair<unsigned, unsigned>, 4> worklist;
   worklist.push_back({srcId, 0});
@@ -490,12 +492,12 @@ bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
     // Pop and continue if node has no out edges, or if all out edges have
     // already been visited.
     if (outEdges.count(idAndIndex.first) == 0 ||
-        idAndIndex.second == outEdges[idAndIndex.first].size()) {
+        idAndIndex.second == outEdges.lookup(idAndIndex.first).size()) {
       worklist.pop_back();
       continue;
     }
     // Get graph edge to traverse.
-    Edge edge = outEdges[idAndIndex.first][idAndIndex.second];
+    const Edge edge = outEdges.lookup(idAndIndex.first)[idAndIndex.second];
     // Increment next output edge index for 'idAndIndex'.
     ++idAndIndex.second;
     // Add node at 'edge.id' to the worklist. We don't need to consider
@@ -511,25 +513,26 @@ bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
 // Returns the input edge count for node 'id' and 'memref' from src nodes
 // which access 'memref' with a store operation.
 unsigned MemRefDependenceGraph::getIncomingMemRefAccesses(unsigned id,
-                                                          Value memref) {
+                                                          Value memref) const {
   unsigned inEdgeCount = 0;
-  if (inEdges.count(id) > 0)
-    for (auto &inEdge : inEdges[id])
-      if (inEdge.value == memref) {
-        Node *srcNode = getNode(inEdge.id);
-        // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
-        if (srcNode->getStoreOpCount(memref) > 0)
-          ++inEdgeCount;
-      }
+  for (const MemRefDependenceGraph::Edge &inEdge : inEdges.lookup(id)) {
+    if (inEdge.value == memref) {
+      const Node *srcNode = getNode(inEdge.id);
+      // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
+      if (srcNode->getStoreOpCount(memref) > 0)
+        ++inEdgeCount;
+    }
+  }
   return inEdgeCount;
 }
 
 // Returns the output edge count for node 'id' and 'memref' (if non-null),
 // otherwise returns the total output edge count from node 'id'.
-unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id, Value memref) {
+unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id,
+                                                Value memref) const {
   unsigned outEdgeCount = 0;
   if (outEdges.count(id) > 0)
-    for (auto &outEdge : outEdges[id])
+    for (auto &outEdge : outEdges.lookup(id))
       if (!memref || outEdge.value == memref)
         ++outEdgeCount;
   return outEdgeCount;
@@ -537,8 +540,8 @@ unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id, Value memref) {
 
 /// Return all nodes which define SSA values used in node 'id'.
 void MemRefDependenceGraph::gatherDefiningNodes(
-    unsigned id, DenseSet<unsigned> &definingNodes) {
-  for (MemRefDependenceGraph::Edge edge : inEdges[id])
+    unsigned id, DenseSet<unsigned> &definingNodes) const {
+  for (MemRefDependenceGraph::Edge edge : inEdges.lookup(id))
     // By definition of edge, if the edge value is a non-memref value,
     // then the dependence is between a graph node which defines an SSA value
     // and another graph node which uses the SSA value.
@@ -551,7 +554,7 @@ void MemRefDependenceGraph::gatherDefiningNodes(
 // dependences. Returns nullptr if no such insertion point is found.
 Operation *
 MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
-                                                      unsigned dstId) {
+                                                      unsigned dstId) const {
   if (outEdges.count(srcId) == 0)
     return getNode(dstId)->op;
 
@@ -568,13 +571,13 @@ MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
 
   // Build set of insts in range (srcId, dstId) which depend on 'srcId'.
   SmallPtrSet<Operation *, 2> srcDepInsts;
-  for (auto &outEdge : outEdges[srcId])
+  for (auto &outEdge : outEdges.lookup(srcId))
     if (outEdge.id != dstId)
       srcDepInsts.insert(getNode(outEdge.id)->op);
 
   // Build set of insts in range (srcId, dstId) on which 'dstId' depends.
   SmallPtrSet<Operation *, 2> dstDepInsts;
-  for (auto &inEdge : inEdges[dstId])
+  for (auto &inEdge : inEdges.lookup(dstId))
     if (inEdge.id != srcId)
       dstDepInsts.insert(getNode(inEdge.id)->op);
 

@bondhugula bondhugula force-pushed the uday/mdg_const_correct branch from e4f004c to fa236b2 Compare February 9, 2025 23:38
Copy link
Contributor

@patel-vimal patel-vimal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change looks good to me.

@bondhugula bondhugula force-pushed the uday/mdg_const_correct branch from fa236b2 to 5650e0a Compare February 10, 2025 15:44
Make affine fusion MDG API const correct. NFC changes otherwise.
@bondhugula bondhugula force-pushed the uday/mdg_const_correct branch from 5650e0a to 191eb97 Compare February 10, 2025 15:48
@bondhugula bondhugula merged commit 001ba42 into llvm:main Feb 10, 2025
8 checks passed
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
Make affine fusion MDG API const correct. NFC changes otherwise.
joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
Make affine fusion MDG API const correct. NFC changes otherwise.
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
Make affine fusion MDG API const correct. NFC changes otherwise.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants