-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[MLIR][Affine] Make affine fusion MDG API const correct #125994
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@llvm/pr-subscribers-mlir-affine @llvm/pr-subscribers-mlir Author: Uday Bondhugula (bondhugula) ChangesMake affine fusion MDG API const correct. NFC changes otherwise. Full diff: https://github.com/llvm/llvm-project/pull/125994.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
index b1fbf4477428ca2..97cf29ce045ced6 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
@@ -156,10 +156,19 @@ struct MemRefDependenceGraph {
bool init();
// Returns the graph node for 'id'.
- Node *getNode(unsigned id);
+ const Node *getNode(unsigned id) const;
+ Node *getNode(unsigned id) {
+ return const_cast<Node *>(
+ static_cast<const MemRefDependenceGraph *>(this)->getNode(id));
+ }
+ bool hasNode(unsigned id) const { return nodes.contains(id); }
// Returns the graph node for 'forOp'.
- Node *getForOpNode(AffineForOp forOp);
+ const Node *getForOpNode(AffineForOp forOp) const;
+ Node *getForOpNode(AffineForOp forOp) {
+ return const_cast<Node *>(
+ static_cast<const MemRefDependenceGraph *>(this)->getForOpNode(forOp));
+ }
// Adds a node with 'op' to the graph and returns its unique identifier.
unsigned addNode(Operation *op);
@@ -169,12 +178,12 @@ struct MemRefDependenceGraph {
// Returns true if node 'id' writes to any memref which escapes (or is an
// argument to) the block. Returns false otherwise.
- bool writesToLiveInOrEscapingMemrefs(unsigned id);
+ bool writesToLiveInOrEscapingMemrefs(unsigned id) const;
// Returns true iff there is an edge from node 'srcId' to node 'dstId' which
// is for 'value' if non-null, or for any value otherwise. Returns false
// otherwise.
- bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr);
+ bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr) const;
// Adds an edge from node 'srcId' to node 'dstId' for 'value'.
void addEdge(unsigned srcId, unsigned dstId, Value value);
@@ -185,23 +194,25 @@ struct MemRefDependenceGraph {
// Returns true if there is a path in the dependence graph from node 'srcId'
// to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
// operations that the edges connected are expected to be from the same block.
- bool hasDependencePath(unsigned srcId, unsigned dstId);
+ bool hasDependencePath(unsigned srcId, unsigned dstId) const;
// Returns the input edge count for node 'id' and 'memref' from src nodes
// which access 'memref' with a store operation.
- unsigned getIncomingMemRefAccesses(unsigned id, Value memref);
+ unsigned getIncomingMemRefAccesses(unsigned id, Value memref) const;
// Returns the output edge count for node 'id' and 'memref' (if non-null),
// otherwise returns the total output edge count from node 'id'.
- unsigned getOutEdgeCount(unsigned id, Value memref = nullptr);
+ unsigned getOutEdgeCount(unsigned id, Value memref = nullptr) const;
/// Return all nodes which define SSA values used in node 'id'.
- void gatherDefiningNodes(unsigned id, DenseSet<unsigned> &definingNodes);
+ void gatherDefiningNodes(unsigned id,
+ DenseSet<unsigned> &definingNodes) const;
// Computes and returns an insertion point operation, before which the
// the fused <srcId, dstId> loop nest can be inserted while preserving
// dependences. Returns nullptr if no such insertion point is found.
- Operation *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId);
+ Operation *getFusedLoopNestInsertionPoint(unsigned srcId,
+ unsigned dstId) const;
// Updates edge mappings from node 'srcId' to node 'dstId' after fusing them,
// taking into account that:
diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index 9c0b5dbf52d299b..54bb529041e0914 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -187,8 +187,9 @@ static void getEffectedValues(Operation *op, SmallVectorImpl<Value> &values) {
/// Add `op` to MDG creating a new node and adding its memory accesses (affine
/// or non-affine to memrefAccesses (memref -> list of nodes with accesses) map.
-Node *addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
- DenseMap<Value, SetVector<unsigned>> &memrefAccesses) {
+static Node *
+addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
+ DenseMap<Value, SetVector<unsigned>> &memrefAccesses) {
auto &nodes = mdg.nodes;
// Create graph node 'id' to represent top-level 'forOp' and record
// all loads and store accesses it contains.
@@ -358,14 +359,14 @@ bool MemRefDependenceGraph::init() {
}
// Returns the graph node for 'id'.
-Node *MemRefDependenceGraph::getNode(unsigned id) {
+const Node *MemRefDependenceGraph::getNode(unsigned id) const {
auto it = nodes.find(id);
assert(it != nodes.end());
return &it->second;
}
// Returns the graph node for 'forOp'.
-Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) {
+const Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) const {
for (auto &idAndNode : nodes)
if (idAndNode.second.op == forOp)
return &idAndNode.second;
@@ -403,8 +404,8 @@ void MemRefDependenceGraph::removeNode(unsigned id) {
// Returns true if node 'id' writes to any memref which escapes (or is an
// argument to) the block. Returns false otherwise.
-bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) {
- Node *node = getNode(id);
+bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) const {
+ const Node *node = getNode(id);
for (auto *storeOpInst : node->stores) {
auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
auto *op = memref.getDefiningOp();
@@ -424,14 +425,14 @@ bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) {
// is for 'value' if non-null, or for any value otherwise. Returns false
// otherwise.
bool MemRefDependenceGraph::hasEdge(unsigned srcId, unsigned dstId,
- Value value) {
+ Value value) const {
if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
return false;
}
- bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
+ bool hasOutEdge = llvm::any_of(outEdges.lookup(srcId), [=](Edge &edge) {
return edge.id == dstId && (!value || edge.value == value);
});
- bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
+ bool hasInEdge = llvm::any_of(inEdges.lookup(dstId), [=](Edge &edge) {
return edge.id == srcId && (!value || edge.value == value);
});
return hasOutEdge && hasInEdge;
@@ -476,7 +477,8 @@ void MemRefDependenceGraph::removeEdge(unsigned srcId, unsigned dstId,
// Returns true if there is a path in the dependence graph from node 'srcId'
// to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
// operations that the edges connected are expected to be from the same block.
-bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
+bool MemRefDependenceGraph::hasDependencePath(unsigned srcId,
+ unsigned dstId) const {
// Worklist state is: <node-id, next-output-edge-index-to-visit>
SmallVector<std::pair<unsigned, unsigned>, 4> worklist;
worklist.push_back({srcId, 0});
@@ -490,12 +492,12 @@ bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
// Pop and continue if node has no out edges, or if all out edges have
// already been visited.
if (outEdges.count(idAndIndex.first) == 0 ||
- idAndIndex.second == outEdges[idAndIndex.first].size()) {
+ idAndIndex.second == outEdges.lookup(idAndIndex.first).size()) {
worklist.pop_back();
continue;
}
// Get graph edge to traverse.
- Edge edge = outEdges[idAndIndex.first][idAndIndex.second];
+ const Edge edge = outEdges.lookup(idAndIndex.first)[idAndIndex.second];
// Increment next output edge index for 'idAndIndex'.
++idAndIndex.second;
// Add node at 'edge.id' to the worklist. We don't need to consider
@@ -511,25 +513,26 @@ bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
// Returns the input edge count for node 'id' and 'memref' from src nodes
// which access 'memref' with a store operation.
unsigned MemRefDependenceGraph::getIncomingMemRefAccesses(unsigned id,
- Value memref) {
+ Value memref) const {
unsigned inEdgeCount = 0;
- if (inEdges.count(id) > 0)
- for (auto &inEdge : inEdges[id])
- if (inEdge.value == memref) {
- Node *srcNode = getNode(inEdge.id);
- // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
- if (srcNode->getStoreOpCount(memref) > 0)
- ++inEdgeCount;
- }
+ for (const MemRefDependenceGraph::Edge &inEdge : inEdges.lookup(id)) {
+ if (inEdge.value == memref) {
+ const Node *srcNode = getNode(inEdge.id);
+ // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
+ if (srcNode->getStoreOpCount(memref) > 0)
+ ++inEdgeCount;
+ }
+ }
return inEdgeCount;
}
// Returns the output edge count for node 'id' and 'memref' (if non-null),
// otherwise returns the total output edge count from node 'id'.
-unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id, Value memref) {
+unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id,
+ Value memref) const {
unsigned outEdgeCount = 0;
if (outEdges.count(id) > 0)
- for (auto &outEdge : outEdges[id])
+ for (auto &outEdge : outEdges.lookup(id))
if (!memref || outEdge.value == memref)
++outEdgeCount;
return outEdgeCount;
@@ -537,8 +540,8 @@ unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id, Value memref) {
/// Return all nodes which define SSA values used in node 'id'.
void MemRefDependenceGraph::gatherDefiningNodes(
- unsigned id, DenseSet<unsigned> &definingNodes) {
- for (MemRefDependenceGraph::Edge edge : inEdges[id])
+ unsigned id, DenseSet<unsigned> &definingNodes) const {
+ for (MemRefDependenceGraph::Edge edge : inEdges.lookup(id))
// By definition of edge, if the edge value is a non-memref value,
// then the dependence is between a graph node which defines an SSA value
// and another graph node which uses the SSA value.
@@ -551,7 +554,7 @@ void MemRefDependenceGraph::gatherDefiningNodes(
// dependences. Returns nullptr if no such insertion point is found.
Operation *
MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
- unsigned dstId) {
+ unsigned dstId) const {
if (outEdges.count(srcId) == 0)
return getNode(dstId)->op;
@@ -568,13 +571,13 @@ MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
// Build set of insts in range (srcId, dstId) which depend on 'srcId'.
SmallPtrSet<Operation *, 2> srcDepInsts;
- for (auto &outEdge : outEdges[srcId])
+ for (auto &outEdge : outEdges.lookup(srcId))
if (outEdge.id != dstId)
srcDepInsts.insert(getNode(outEdge.id)->op);
// Build set of insts in range (srcId, dstId) on which 'dstId' depends.
SmallPtrSet<Operation *, 2> dstDepInsts;
- for (auto &inEdge : inEdges[dstId])
+ for (auto &inEdge : inEdges.lookup(dstId))
if (inEdge.id != srcId)
dstDepInsts.insert(getNode(inEdge.id)->op);
|
e4f004c
to
fa236b2
Compare
patel-vimal
reviewed
Feb 10, 2025
patel-vimal
reviewed
Feb 10, 2025
patel-vimal
approved these changes
Feb 10, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The change looks good to me.
fa236b2
to
5650e0a
Compare
Make affine fusion MDG API const correct. NFC changes otherwise.
5650e0a
to
191eb97
Compare
Icohedron
pushed a commit
to Icohedron/llvm-project
that referenced
this pull request
Feb 11, 2025
Make affine fusion MDG API const correct. NFC changes otherwise.
joaosaffran
pushed a commit
to joaosaffran/llvm-project
that referenced
this pull request
Feb 14, 2025
Make affine fusion MDG API const correct. NFC changes otherwise.
sivan-shani
pushed a commit
to sivan-shani/llvm-project
that referenced
this pull request
Feb 24, 2025
Make affine fusion MDG API const correct. NFC changes otherwise.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Make affine fusion MDG API const correct. NFC changes otherwise.