Skip to content

[MLIR][Affine] Make affine fusion MDG API const correct #125994

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,11 @@ struct MemRefDependenceGraph {

// Map from node id to Node.
DenseMap<unsigned, Node> nodes;
// Map from node id to list of input edges.
// Map from node id to list of input edges. The absence of an entry for a key
// is also equivalent to the absence of any edges.
DenseMap<unsigned, SmallVector<Edge, 2>> inEdges;
// Map from node id to list of output edges.
// Map from node id to list of output edges. The absence of an entry for a
// node is also equivalent to the absence of any edges.
DenseMap<unsigned, SmallVector<Edge, 2>> outEdges;
// Map from memref to a count on the dependence edges associated with that
// memref.
Expand All @@ -156,10 +158,21 @@ struct MemRefDependenceGraph {
bool init();

// Returns the graph node for 'id'.
Node *getNode(unsigned id);
const Node *getNode(unsigned id) const;
Node *getNode(unsigned id) {
return const_cast<Node *>(
static_cast<const MemRefDependenceGraph *>(this)->getNode(id));
}

// Returns true if the graph has node with ID `id`.
bool hasNode(unsigned id) const { return nodes.contains(id); }

// Returns the graph node for 'forOp'.
Node *getForOpNode(AffineForOp forOp);
const Node *getForOpNode(AffineForOp forOp) const;
Node *getForOpNode(AffineForOp forOp) {
return const_cast<Node *>(
static_cast<const MemRefDependenceGraph *>(this)->getForOpNode(forOp));
}

// Adds a node with 'op' to the graph and returns its unique identifier.
unsigned addNode(Operation *op);
Expand All @@ -169,12 +182,12 @@ struct MemRefDependenceGraph {

// Returns true if node 'id' writes to any memref which escapes (or is an
// argument to) the block. Returns false otherwise.
bool writesToLiveInOrEscapingMemrefs(unsigned id);
bool writesToLiveInOrEscapingMemrefs(unsigned id) const;

// Returns true iff there is an edge from node 'srcId' to node 'dstId' which
// is for 'value' if non-null, or for any value otherwise. Returns false
// otherwise.
bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr);
bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr) const;

// Adds an edge from node 'srcId' to node 'dstId' for 'value'.
void addEdge(unsigned srcId, unsigned dstId, Value value);
Expand All @@ -185,23 +198,25 @@ struct MemRefDependenceGraph {
// Returns true if there is a path in the dependence graph from node 'srcId'
// to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
// operations that the edges connected are expected to be from the same block.
bool hasDependencePath(unsigned srcId, unsigned dstId);
bool hasDependencePath(unsigned srcId, unsigned dstId) const;

// Returns the input edge count for node 'id' and 'memref' from src nodes
// which access 'memref' with a store operation.
unsigned getIncomingMemRefAccesses(unsigned id, Value memref);
unsigned getIncomingMemRefAccesses(unsigned id, Value memref) const;

// Returns the output edge count for node 'id' and 'memref' (if non-null),
// otherwise returns the total output edge count from node 'id'.
unsigned getOutEdgeCount(unsigned id, Value memref = nullptr);
unsigned getOutEdgeCount(unsigned id, Value memref = nullptr) const;

/// Return all nodes which define SSA values used in node 'id'.
void gatherDefiningNodes(unsigned id, DenseSet<unsigned> &definingNodes);
void gatherDefiningNodes(unsigned id,
DenseSet<unsigned> &definingNodes) const;

// Computes and returns an insertion point operation, before which the
// the fused <srcId, dstId> loop nest can be inserted while preserving
// dependences. Returns nullptr if no such insertion point is found.
Operation *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId);
Operation *getFusedLoopNestInsertionPoint(unsigned srcId,
unsigned dstId) const;

// Updates edge mappings from node 'srcId' to node 'dstId' after fusing them,
// taking into account that:
Expand Down
74 changes: 38 additions & 36 deletions mlir/lib/Dialect/Affine/Analysis/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,9 @@ static void getEffectedValues(Operation *op, SmallVectorImpl<Value> &values) {

/// Add `op` to MDG creating a new node and adding its memory accesses (affine
/// or non-affine to memrefAccesses (memref -> list of nodes with accesses) map.
Node *addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
DenseMap<Value, SetVector<unsigned>> &memrefAccesses) {
static Node *
addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
DenseMap<Value, SetVector<unsigned>> &memrefAccesses) {
auto &nodes = mdg.nodes;
// Create graph node 'id' to represent top-level 'forOp' and record
// all loads and store accesses it contains.
Expand Down Expand Up @@ -358,14 +359,14 @@ bool MemRefDependenceGraph::init() {
}

// Returns the graph node for 'id'.
Node *MemRefDependenceGraph::getNode(unsigned id) {
const Node *MemRefDependenceGraph::getNode(unsigned id) const {
auto it = nodes.find(id);
assert(it != nodes.end());
return &it->second;
}

// Returns the graph node for 'forOp'.
Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) {
const Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) const {
for (auto &idAndNode : nodes)
if (idAndNode.second.op == forOp)
return &idAndNode.second;
Expand All @@ -389,7 +390,7 @@ void MemRefDependenceGraph::removeNode(unsigned id) {
}
}
// Remove each edge in 'outEdges[id]'.
if (outEdges.count(id) > 0) {
if (outEdges.contains(id)) {
SmallVector<Edge, 2> oldOutEdges = outEdges[id];
for (auto &outEdge : oldOutEdges) {
removeEdge(id, outEdge.id, outEdge.value);
Expand All @@ -403,8 +404,8 @@ void MemRefDependenceGraph::removeNode(unsigned id) {

// Returns true if node 'id' writes to any memref which escapes (or is an
// argument to) the block. Returns false otherwise.
bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) {
Node *node = getNode(id);
bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) const {
const Node *node = getNode(id);
for (auto *storeOpInst : node->stores) {
auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
auto *op = memref.getDefiningOp();
Expand All @@ -424,14 +425,14 @@ bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) {
// is for 'value' if non-null, or for any value otherwise. Returns false
// otherwise.
bool MemRefDependenceGraph::hasEdge(unsigned srcId, unsigned dstId,
Value value) {
if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
Value value) const {
if (!outEdges.contains(srcId) || !inEdges.contains(dstId)) {
return false;
}
bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
bool hasOutEdge = llvm::any_of(outEdges.lookup(srcId), [=](const Edge &edge) {
return edge.id == dstId && (!value || edge.value == value);
});
bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
bool hasInEdge = llvm::any_of(inEdges.lookup(dstId), [=](const Edge &edge) {
return edge.id == srcId && (!value || edge.value == value);
});
return hasOutEdge && hasInEdge;
Expand Down Expand Up @@ -476,7 +477,8 @@ void MemRefDependenceGraph::removeEdge(unsigned srcId, unsigned dstId,
// Returns true if there is a path in the dependence graph from node 'srcId'
// to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
// operations that the edges connected are expected to be from the same block.
bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
bool MemRefDependenceGraph::hasDependencePath(unsigned srcId,
unsigned dstId) const {
// Worklist state is: <node-id, next-output-edge-index-to-visit>
SmallVector<std::pair<unsigned, unsigned>, 4> worklist;
worklist.push_back({srcId, 0});
Expand All @@ -489,13 +491,13 @@ bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
return true;
// Pop and continue if node has no out edges, or if all out edges have
// already been visited.
if (outEdges.count(idAndIndex.first) == 0 ||
idAndIndex.second == outEdges[idAndIndex.first].size()) {
if (!outEdges.contains(idAndIndex.first) ||
idAndIndex.second == outEdges.lookup(idAndIndex.first).size()) {
worklist.pop_back();
continue;
}
// Get graph edge to traverse.
Edge edge = outEdges[idAndIndex.first][idAndIndex.second];
const Edge edge = outEdges.lookup(idAndIndex.first)[idAndIndex.second];
// Increment next output edge index for 'idAndIndex'.
++idAndIndex.second;
// Add node at 'edge.id' to the worklist. We don't need to consider
Expand All @@ -511,34 +513,34 @@ bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
// Returns the input edge count for node 'id' and 'memref' from src nodes
// which access 'memref' with a store operation.
unsigned MemRefDependenceGraph::getIncomingMemRefAccesses(unsigned id,
Value memref) {
Value memref) const {
unsigned inEdgeCount = 0;
if (inEdges.count(id) > 0)
for (auto &inEdge : inEdges[id])
if (inEdge.value == memref) {
Node *srcNode = getNode(inEdge.id);
// Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
if (srcNode->getStoreOpCount(memref) > 0)
++inEdgeCount;
}
for (const Edge &inEdge : inEdges.lookup(id)) {
if (inEdge.value == memref) {
const Node *srcNode = getNode(inEdge.id);
// Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
if (srcNode->getStoreOpCount(memref) > 0)
++inEdgeCount;
}
}
return inEdgeCount;
}

// Returns the output edge count for node 'id' and 'memref' (if non-null),
// otherwise returns the total output edge count from node 'id'.
unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id, Value memref) {
unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id,
Value memref) const {
unsigned outEdgeCount = 0;
if (outEdges.count(id) > 0)
for (auto &outEdge : outEdges[id])
if (!memref || outEdge.value == memref)
++outEdgeCount;
for (const auto &outEdge : outEdges.lookup(id))
if (!memref || outEdge.value == memref)
++outEdgeCount;
return outEdgeCount;
}

/// Return all nodes which define SSA values used in node 'id'.
void MemRefDependenceGraph::gatherDefiningNodes(
unsigned id, DenseSet<unsigned> &definingNodes) {
for (MemRefDependenceGraph::Edge edge : inEdges[id])
unsigned id, DenseSet<unsigned> &definingNodes) const {
for (const Edge &edge : inEdges.lookup(id))
// By definition of edge, if the edge value is a non-memref value,
// then the dependence is between a graph node which defines an SSA value
// and another graph node which uses the SSA value.
Expand All @@ -551,8 +553,8 @@ void MemRefDependenceGraph::gatherDefiningNodes(
// dependences. Returns nullptr if no such insertion point is found.
Operation *
MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
unsigned dstId) {
if (outEdges.count(srcId) == 0)
unsigned dstId) const {
if (!outEdges.contains(srcId))
return getNode(dstId)->op;

// Skip if there is any defining node of 'dstId' that depends on 'srcId'.
Expand All @@ -568,13 +570,13 @@ MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,

// Build set of insts in range (srcId, dstId) which depend on 'srcId'.
SmallPtrSet<Operation *, 2> srcDepInsts;
for (auto &outEdge : outEdges[srcId])
for (auto &outEdge : outEdges.lookup(srcId))
if (outEdge.id != dstId)
srcDepInsts.insert(getNode(outEdge.id)->op);

// Build set of insts in range (srcId, dstId) on which 'dstId' depends.
SmallPtrSet<Operation *, 2> dstDepInsts;
for (auto &inEdge : inEdges[dstId])
for (auto &inEdge : inEdges.lookup(dstId))
if (inEdge.id != srcId)
dstDepInsts.insert(getNode(inEdge.id)->op);

Expand Down Expand Up @@ -634,7 +636,7 @@ void MemRefDependenceGraph::updateEdges(unsigned srcId, unsigned dstId,
SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
for (auto &inEdge : oldInEdges) {
// Add edge from 'inEdge.id' to 'dstId' if it's not a private memref.
if (privateMemRefs.count(inEdge.value) == 0)
if (!privateMemRefs.contains(inEdge.value))
addEdge(inEdge.id, dstId, inEdge.value);
}
}
Expand Down
39 changes: 20 additions & 19 deletions mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,13 @@ struct LoopFusion : public affine::impl::AffineLoopFusionBase<LoopFusion> {
static bool canRemoveSrcNodeAfterFusion(
unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice,
Operation *fusedLoopInsPoint, const DenseSet<Value> &escapingMemRefs,
MemRefDependenceGraph *mdg) {
const MemRefDependenceGraph &mdg) {

Operation *dstNodeOp = mdg->getNode(dstId)->op;
Operation *dstNodeOp = mdg.getNode(dstId)->op;
bool hasOutDepsAfterFusion = false;

for (auto &outEdge : mdg->outEdges[srcId]) {
Operation *depNodeOp = mdg->getNode(outEdge.id)->op;
for (auto &outEdge : mdg.outEdges.lookup(srcId)) {
Operation *depNodeOp = mdg.getNode(outEdge.id)->op;
// Skip dependence with dstOp since it will be removed after fusion.
if (depNodeOp == dstNodeOp)
continue;
Expand Down Expand Up @@ -134,22 +134,23 @@ static bool canRemoveSrcNodeAfterFusion(
/// held if the 'mdg' is reused from a previous fusion step or if the node
/// creation order changes in the future to support more advance cases.
// TODO: Move this to a loop fusion utility once 'mdg' is also moved.
static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg,
static void getProducerCandidates(unsigned dstId,
const MemRefDependenceGraph &mdg,
SmallVectorImpl<unsigned> &srcIdCandidates) {
// Skip if no input edges along which to fuse.
if (mdg->inEdges.count(dstId) == 0)
if (mdg.inEdges.count(dstId) == 0)
return;

// Gather memrefs from loads in 'dstId'.
auto *dstNode = mdg->getNode(dstId);
auto *dstNode = mdg.getNode(dstId);
DenseSet<Value> consumedMemrefs;
for (Operation *load : dstNode->loads)
consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef());

// Traverse 'dstId' incoming edges and gather the nodes that contain a store
// to one of the consumed memrefs.
for (auto &srcEdge : mdg->inEdges[dstId]) {
auto *srcNode = mdg->getNode(srcEdge.id);
for (const auto &srcEdge : mdg.inEdges.lookup(dstId)) {
const auto *srcNode = mdg.getNode(srcEdge.id);
// Skip if 'srcNode' is not a loop nest.
if (!isa<AffineForOp>(srcNode->op))
continue;
Expand All @@ -169,10 +170,10 @@ static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg,
/// producer-consumer dependence between 'srcId' and 'dstId'.
static void
gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
MemRefDependenceGraph *mdg,
const MemRefDependenceGraph &mdg,
DenseSet<Value> &producerConsumerMemrefs) {
auto *dstNode = mdg->getNode(dstId);
auto *srcNode = mdg->getNode(srcId);
auto *dstNode = mdg.getNode(dstId);
auto *srcNode = mdg.getNode(srcId);
gatherProducerConsumerMemrefs(srcNode->stores, dstNode->loads,
producerConsumerMemrefs);
}
Expand Down Expand Up @@ -214,14 +215,14 @@ static bool isEscapingMemref(Value memref, Block *block) {

/// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id'
/// that escape the block or are accessed in a non-affine way.
static void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
static void gatherEscapingMemrefs(unsigned id, const MemRefDependenceGraph &mdg,
DenseSet<Value> &escapingMemRefs) {
auto *node = mdg->getNode(id);
auto *node = mdg.getNode(id);
for (Operation *storeOp : node->stores) {
auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
if (escapingMemRefs.count(memref))
continue;
if (isEscapingMemref(memref, &mdg->block))
if (isEscapingMemref(memref, &mdg.block))
escapingMemRefs.insert(memref);
}
}
Expand Down Expand Up @@ -787,7 +788,7 @@ struct GreedyFusion {
// in 'srcIdCandidates'.
dstNodeChanged = false;
SmallVector<unsigned, 16> srcIdCandidates;
getProducerCandidates(dstId, mdg, srcIdCandidates);
getProducerCandidates(dstId, *mdg, srcIdCandidates);

for (unsigned srcId : llvm::reverse(srcIdCandidates)) {
// Get 'srcNode' from which to attempt fusion into 'dstNode'.
Expand All @@ -802,7 +803,7 @@ struct GreedyFusion {
continue;

DenseSet<Value> producerConsumerMemrefs;
gatherProducerConsumerMemrefs(srcId, dstId, mdg,
gatherProducerConsumerMemrefs(srcId, dstId, *mdg,
producerConsumerMemrefs);

// Skip if 'srcNode' out edge count on any memref is greater than
Expand All @@ -817,7 +818,7 @@ struct GreedyFusion {
// block (e.g., memref block arguments, returned memrefs,
// memrefs passed to function calls, etc.).
DenseSet<Value> srcEscapingMemRefs;
gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs);
gatherEscapingMemrefs(srcNode->id, *mdg, srcEscapingMemRefs);

// Compute an operation list insertion point for the fused loop
// nest which preserves dependences.
Expand Down Expand Up @@ -911,7 +912,7 @@ struct GreedyFusion {
// insertion point.
bool removeSrcNode = canRemoveSrcNodeAfterFusion(
srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,
mdg);
*mdg);

DenseSet<Value> privateMemrefs;
for (Value memref : producerConsumerMemrefs) {
Expand Down