Skip to content

Commit fa236b2

Browse files
committed
[MLIR][Affine] Make affine fusion MDG API const correct
Make affine fusion MDG API const correct. NFC changes otherwise.
1 parent 44f638f commit fa236b2

File tree

2 files changed

+62
-47
lines changed
  • mlir

2 files changed

+62
-47
lines changed

mlir/include/mlir/Dialect/Affine/Analysis/Utils.h

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,11 @@ struct MemRefDependenceGraph {
139139

140140
// Map from node id to Node.
141141
DenseMap<unsigned, Node> nodes;
142-
// Map from node id to list of input edges.
142+
// Map from node id to list of input edges. The absence of an entry for a key
143+
// is also equivalent to the absence of any edges.
143144
DenseMap<unsigned, SmallVector<Edge, 2>> inEdges;
144-
// Map from node id to list of output edges.
145+
// Map from node id to list of output edges. The absence of an entry for a
146+
// node is also equivalent to the absence of any edges.
145147
DenseMap<unsigned, SmallVector<Edge, 2>> outEdges;
146148
// Map from memref to a count on the dependence edges associated with that
147149
// memref.
@@ -156,10 +158,19 @@ struct MemRefDependenceGraph {
156158
bool init();
157159

158160
// Returns the graph node for 'id'.
159-
Node *getNode(unsigned id);
161+
const Node *getNode(unsigned id) const;
162+
Node *getNode(unsigned id) {
163+
return const_cast<Node *>(
164+
static_cast<const MemRefDependenceGraph *>(this)->getNode(id));
165+
}
166+
bool hasNode(unsigned id) const { return nodes.contains(id); }
160167

161168
// Returns the graph node for 'forOp'.
162-
Node *getForOpNode(AffineForOp forOp);
169+
const Node *getForOpNode(AffineForOp forOp) const;
170+
Node *getForOpNode(AffineForOp forOp) {
171+
return const_cast<Node *>(
172+
static_cast<const MemRefDependenceGraph *>(this)->getForOpNode(forOp));
173+
}
163174

164175
// Adds a node with 'op' to the graph and returns its unique identifier.
165176
unsigned addNode(Operation *op);
@@ -169,12 +180,12 @@ struct MemRefDependenceGraph {
169180

170181
// Returns true if node 'id' writes to any memref which escapes (or is an
171182
// argument to) the block. Returns false otherwise.
172-
bool writesToLiveInOrEscapingMemrefs(unsigned id);
183+
bool writesToLiveInOrEscapingMemrefs(unsigned id) const;
173184

174185
// Returns true iff there is an edge from node 'srcId' to node 'dstId' which
175186
// is for 'value' if non-null, or for any value otherwise. Returns false
176187
// otherwise.
177-
bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr);
188+
bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr) const;
178189

179190
// Adds an edge from node 'srcId' to node 'dstId' for 'value'.
180191
void addEdge(unsigned srcId, unsigned dstId, Value value);
@@ -185,23 +196,25 @@ struct MemRefDependenceGraph {
185196
// Returns true if there is a path in the dependence graph from node 'srcId'
186197
// to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
187198
// operations that the edges connected are expected to be from the same block.
188-
bool hasDependencePath(unsigned srcId, unsigned dstId);
199+
bool hasDependencePath(unsigned srcId, unsigned dstId) const;
189200

190201
// Returns the input edge count for node 'id' and 'memref' from src nodes
191202
// which access 'memref' with a store operation.
192-
unsigned getIncomingMemRefAccesses(unsigned id, Value memref);
203+
unsigned getIncomingMemRefAccesses(unsigned id, Value memref) const;
193204

194205
// Returns the output edge count for node 'id' and 'memref' (if non-null),
195206
// otherwise returns the total output edge count from node 'id'.
196-
unsigned getOutEdgeCount(unsigned id, Value memref = nullptr);
207+
unsigned getOutEdgeCount(unsigned id, Value memref = nullptr) const;
197208

198209
/// Return all nodes which define SSA values used in node 'id'.
199-
void gatherDefiningNodes(unsigned id, DenseSet<unsigned> &definingNodes);
210+
void gatherDefiningNodes(unsigned id,
211+
DenseSet<unsigned> &definingNodes) const;
200212

201213
// Computes and returns an insertion point operation, before which the
202214
// the fused <srcId, dstId> loop nest can be inserted while preserving
203215
// dependences. Returns nullptr if no such insertion point is found.
204-
Operation *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId);
216+
Operation *getFusedLoopNestInsertionPoint(unsigned srcId,
217+
unsigned dstId) const;
205218

206219
// Updates edge mappings from node 'srcId' to node 'dstId' after fusing them,
207220
// taking into account that:

mlir/lib/Dialect/Affine/Analysis/Utils.cpp

Lines changed: 38 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,9 @@ static void getEffectedValues(Operation *op, SmallVectorImpl<Value> &values) {
187187

188188
/// Add `op` to MDG creating a new node and adding its memory accesses (affine
189189
/// or non-affine to memrefAccesses (memref -> list of nodes with accesses) map.
190-
Node *addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
191-
DenseMap<Value, SetVector<unsigned>> &memrefAccesses) {
190+
static Node *
191+
addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
192+
DenseMap<Value, SetVector<unsigned>> &memrefAccesses) {
192193
auto &nodes = mdg.nodes;
193194
// Create graph node 'id' to represent top-level 'forOp' and record
194195
// all loads and store accesses it contains.
@@ -358,14 +359,14 @@ bool MemRefDependenceGraph::init() {
358359
}
359360

360361
// Returns the graph node for 'id'.
361-
Node *MemRefDependenceGraph::getNode(unsigned id) {
362+
const Node *MemRefDependenceGraph::getNode(unsigned id) const {
362363
auto it = nodes.find(id);
363364
assert(it != nodes.end());
364365
return &it->second;
365366
}
366367

367368
// Returns the graph node for 'forOp'.
368-
Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) {
369+
const Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) const {
369370
for (auto &idAndNode : nodes)
370371
if (idAndNode.second.op == forOp)
371372
return &idAndNode.second;
@@ -389,7 +390,7 @@ void MemRefDependenceGraph::removeNode(unsigned id) {
389390
}
390391
}
391392
// Remove each edge in 'outEdges[id]'.
392-
if (outEdges.count(id) > 0) {
393+
if (outEdges.contains(id)) {
393394
SmallVector<Edge, 2> oldOutEdges = outEdges[id];
394395
for (auto &outEdge : oldOutEdges) {
395396
removeEdge(id, outEdge.id, outEdge.value);
@@ -403,8 +404,8 @@ void MemRefDependenceGraph::removeNode(unsigned id) {
403404

404405
// Returns true if node 'id' writes to any memref which escapes (or is an
405406
// argument to) the block. Returns false otherwise.
406-
bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) {
407-
Node *node = getNode(id);
407+
bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) const {
408+
const Node *node = getNode(id);
408409
for (auto *storeOpInst : node->stores) {
409410
auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
410411
auto *op = memref.getDefiningOp();
@@ -424,14 +425,14 @@ bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) {
424425
// is for 'value' if non-null, or for any value otherwise. Returns false
425426
// otherwise.
426427
bool MemRefDependenceGraph::hasEdge(unsigned srcId, unsigned dstId,
427-
Value value) {
428-
if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
428+
Value value) const {
429+
if (!outEdges.contains(srcId) || !inEdges.contains(dstId)) {
429430
return false;
430431
}
431-
bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
432+
bool hasOutEdge = llvm::any_of(outEdges.lookup(srcId), [=](const Edge &edge) {
432433
return edge.id == dstId && (!value || edge.value == value);
433434
});
434-
bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
435+
bool hasInEdge = llvm::any_of(inEdges.lookup(dstId), [=](const Edge &edge) {
435436
return edge.id == srcId && (!value || edge.value == value);
436437
});
437438
return hasOutEdge && hasInEdge;
@@ -476,7 +477,8 @@ void MemRefDependenceGraph::removeEdge(unsigned srcId, unsigned dstId,
476477
// Returns true if there is a path in the dependence graph from node 'srcId'
477478
// to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
478479
// operations that the edges connected are expected to be from the same block.
479-
bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
480+
bool MemRefDependenceGraph::hasDependencePath(unsigned srcId,
481+
unsigned dstId) const {
480482
// Worklist state is: <node-id, next-output-edge-index-to-visit>
481483
SmallVector<std::pair<unsigned, unsigned>, 4> worklist;
482484
worklist.push_back({srcId, 0});
@@ -489,13 +491,13 @@ bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
489491
return true;
490492
// Pop and continue if node has no out edges, or if all out edges have
491493
// already been visited.
492-
if (outEdges.count(idAndIndex.first) == 0 ||
493-
idAndIndex.second == outEdges[idAndIndex.first].size()) {
494+
if (!outEdges.contains(idAndIndex.first) ||
495+
idAndIndex.second == outEdges.lookup(idAndIndex.first).size()) {
494496
worklist.pop_back();
495497
continue;
496498
}
497499
// Get graph edge to traverse.
498-
Edge edge = outEdges[idAndIndex.first][idAndIndex.second];
500+
const Edge edge = outEdges.lookup(idAndIndex.first)[idAndIndex.second];
499501
// Increment next output edge index for 'idAndIndex'.
500502
++idAndIndex.second;
501503
// Add node at 'edge.id' to the worklist. We don't need to consider
@@ -511,34 +513,34 @@ bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
511513
// Returns the input edge count for node 'id' and 'memref' from src nodes
512514
// which access 'memref' with a store operation.
513515
unsigned MemRefDependenceGraph::getIncomingMemRefAccesses(unsigned id,
514-
Value memref) {
516+
Value memref) const {
515517
unsigned inEdgeCount = 0;
516-
if (inEdges.count(id) > 0)
517-
for (auto &inEdge : inEdges[id])
518-
if (inEdge.value == memref) {
519-
Node *srcNode = getNode(inEdge.id);
520-
// Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
521-
if (srcNode->getStoreOpCount(memref) > 0)
522-
++inEdgeCount;
523-
}
518+
for (const Edge &inEdge : inEdges.lookup(id)) {
519+
if (inEdge.value == memref) {
520+
const Node *srcNode = getNode(inEdge.id);
521+
// Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
522+
if (srcNode->getStoreOpCount(memref) > 0)
523+
++inEdgeCount;
524+
}
525+
}
524526
return inEdgeCount;
525527
}
526528

527529
// Returns the output edge count for node 'id' and 'memref' (if non-null),
528530
// otherwise returns the total output edge count from node 'id'.
529-
unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id, Value memref) {
531+
unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id,
532+
Value memref) const {
530533
unsigned outEdgeCount = 0;
531-
if (outEdges.count(id) > 0)
532-
for (auto &outEdge : outEdges[id])
533-
if (!memref || outEdge.value == memref)
534-
++outEdgeCount;
534+
for (const auto &outEdge : outEdges.lookup(id))
535+
if (!memref || outEdge.value == memref)
536+
++outEdgeCount;
535537
return outEdgeCount;
536538
}
537539

538540
/// Return all nodes which define SSA values used in node 'id'.
539541
void MemRefDependenceGraph::gatherDefiningNodes(
540-
unsigned id, DenseSet<unsigned> &definingNodes) {
541-
for (MemRefDependenceGraph::Edge edge : inEdges[id])
542+
unsigned id, DenseSet<unsigned> &definingNodes) const {
543+
for (const Edge &edge : inEdges.lookup(id))
542544
// By definition of edge, if the edge value is a non-memref value,
543545
// then the dependence is between a graph node which defines an SSA value
544546
// and another graph node which uses the SSA value.
@@ -551,8 +553,8 @@ void MemRefDependenceGraph::gatherDefiningNodes(
551553
// dependences. Returns nullptr if no such insertion point is found.
552554
Operation *
553555
MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
554-
unsigned dstId) {
555-
if (outEdges.count(srcId) == 0)
556+
unsigned dstId) const {
557+
if (!outEdges.contains(srcId))
556558
return getNode(dstId)->op;
557559

558560
// Skip if there is any defining node of 'dstId' that depends on 'srcId'.
@@ -568,13 +570,13 @@ MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
568570

569571
// Build set of insts in range (srcId, dstId) which depend on 'srcId'.
570572
SmallPtrSet<Operation *, 2> srcDepInsts;
571-
for (auto &outEdge : outEdges[srcId])
573+
for (auto &outEdge : outEdges.lookup(srcId))
572574
if (outEdge.id != dstId)
573575
srcDepInsts.insert(getNode(outEdge.id)->op);
574576

575577
// Build set of insts in range (srcId, dstId) on which 'dstId' depends.
576578
SmallPtrSet<Operation *, 2> dstDepInsts;
577-
for (auto &inEdge : inEdges[dstId])
579+
for (auto &inEdge : inEdges.lookup(dstId))
578580
if (inEdge.id != srcId)
579581
dstDepInsts.insert(getNode(inEdge.id)->op);
580582

@@ -634,7 +636,7 @@ void MemRefDependenceGraph::updateEdges(unsigned srcId, unsigned dstId,
634636
SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
635637
for (auto &inEdge : oldInEdges) {
636638
// Add edge from 'inEdge.id' to 'dstId' if it's not a private memref.
637-
if (privateMemRefs.count(inEdge.value) == 0)
639+
if (!privateMemRefs.contains(inEdge.value))
638640
addEdge(inEdge.id, dstId, inEdge.value);
639641
}
640642
}

0 commit comments

Comments
 (0)