Skip to content

Commit 220cdf9

Browse files
[mlir] Add requiresReplacedValues and visitReplacedValues to PromotableOpInterface (#86792)
Add `requiresReplacedValues` and `visitReplacedValues` methods to `PromotableOpInterface`. These methods allow `PromotableOpInterface` ops to transforms definitions mutated by a `store`. This change is necessary to correctly handle the promotion of `LLVM_DbgDeclareOp`. --------- Co-authored-by: Théo Degioanni <[email protected]>
1 parent b9ec4ab commit 220cdf9

File tree

5 files changed

+82
-10
lines changed

5 files changed

+82
-10
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -562,8 +562,10 @@ class LLVM_DbgIntrOp<string name, string argName, list<Trait> traits = []>
562562
}];
563563
}
564564

565-
def LLVM_DbgDeclareOp : LLVM_DbgIntrOp<"dbg.declare", "addr",
566-
[DeclareOpInterfaceMethods<PromotableOpInterface>]> {
565+
def LLVM_DbgDeclareOp : LLVM_DbgIntrOp<"dbg.declare", "addr", [
566+
DeclareOpInterfaceMethods<PromotableOpInterface, [
567+
"requiresReplacedValues", "visitReplacedValues"
568+
]>]> {
567569
let summary = "Describes how the address relates to a source language variable.";
568570
let arguments = (ins
569571
LLVM_AnyPointer:$addr,

mlir/include/mlir/Interfaces/MemorySlotInterfaces.td

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,36 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> {
229229
(ins "const ::llvm::SmallPtrSetImpl<mlir::OpOperand *> &":$blockingUses,
230230
"::mlir::RewriterBase &":$rewriter)
231231
>,
232+
InterfaceMethod<[{
233+
This method allows the promoted operation to visit the SSA values used
234+
in place of the memory slot once the promotion process of the memory
235+
slot is complete.
236+
237+
If this method returns true, the `visitReplacedValues` method on this
238+
operation will be called after the main mutation stage finishes
239+
(i.e., after all ops have been processed with `removeBlockingUses`).
240+
241+
Operations should only the replaced values if the intended
242+
transformation applies to all the replaced values. Furthermore, replaced
243+
values must not be deleted.
244+
}], "bool", "requiresReplacedValues", (ins), [{}],
245+
[{ return false; }]
246+
>,
247+
InterfaceMethod<[{
248+
Transforms the IR using the SSA values that replaced the memory slot.
249+
250+
This method will only be called after all blocking uses have been
251+
scheduled for removal and if `requiresReplacedValues` returned
252+
true.
253+
254+
The rewriter is located after the promotable operation on call. All IR
255+
mutations must happen through the rewriter. During the transformation,
256+
*no operation should be deleted*.
257+
}],
258+
"void", "visitReplacedValues",
259+
(ins "::llvm::ArrayRef<std::pair<::mlir::Operation*, ::mlir::Value>>":$mutatedDefs,
260+
"::mlir::RewriterBase &":$rewriter), [{}], [{ return; }]
261+
>,
232262
];
233263
}
234264

mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,6 @@ bool LLVM::StoreOp::canUsesBeRemoved(
168168
DeletionKind LLVM::StoreOp::removeBlockingUses(
169169
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
170170
RewriterBase &rewriter, Value reachingDefinition) {
171-
// `canUsesBeRemoved` checked this blocking use must be the stored slot
172-
// pointer.
173-
for (Operation *user : slot.ptr.getUsers())
174-
if (auto declareOp = dyn_cast<LLVM::DbgDeclareOp>(user))
175-
rewriter.create<LLVM::DbgValueOp>(declareOp->getLoc(), getValue(),
176-
declareOp.getVarInfo(),
177-
declareOp.getLocationExpr());
178171
return DeletionKind::Delete;
179172
}
180173

@@ -407,6 +400,18 @@ DeletionKind LLVM::DbgValueOp::removeBlockingUses(
407400
return DeletionKind::Keep;
408401
}
409402

403+
bool LLVM::DbgDeclareOp::requiresReplacedValues() { return true; }
404+
405+
void LLVM::DbgDeclareOp::visitReplacedValues(
406+
ArrayRef<std::pair<Operation *, Value>> definitions,
407+
RewriterBase &rewriter) {
408+
for (auto [op, value] : definitions) {
409+
rewriter.setInsertionPointAfter(op);
410+
rewriter.create<LLVM::DbgValueOp>(getLoc(), value, getVarInfo(),
411+
getLocationExpr());
412+
}
413+
}
414+
410415
//===----------------------------------------------------------------------===//
411416
// Interfaces for GEPOp
412417
//===----------------------------------------------------------------------===//

mlir/lib/Transforms/Mem2Reg.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ class MemorySlotPromoter {
202202
/// Contains the reaching definition at this operation. Reaching definitions
203203
/// are only computed for promotable memory operations with blocking uses.
204204
DenseMap<PromotableMemOpInterface, Value> reachingDefs;
205+
DenseMap<PromotableMemOpInterface, Value> replacedValuesMap;
205206
DominanceInfo &dominance;
206207
MemorySlotPromotionInfo info;
207208
const Mem2RegStatistics &statistics;
@@ -438,6 +439,7 @@ Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
438439
assert(stored && "a memory operation storing to a slot must provide a "
439440
"new definition of the slot");
440441
reachingDef = stored;
442+
replacedValuesMap[memOp] = stored;
441443
}
442444
}
443445
}
@@ -552,6 +554,10 @@ void MemorySlotPromoter::removeBlockingUses() {
552554
dominanceSort(usersToRemoveUses, *slot.ptr.getParentBlock()->getParent());
553555

554556
llvm::SmallVector<Operation *> toErase;
557+
// List of all replaced values in the slot.
558+
llvm::SmallVector<std::pair<Operation *, Value>> replacedValuesList;
559+
// Ops to visit with the `visitReplacedValues` method.
560+
llvm::SmallVector<PromotableOpInterface> toVisit;
555561
for (Operation *toPromote : llvm::reverse(usersToRemoveUses)) {
556562
if (auto toPromoteMemOp = dyn_cast<PromotableMemOpInterface>(toPromote)) {
557563
Value reachingDef = reachingDefs.lookup(toPromoteMemOp);
@@ -565,7 +571,9 @@ void MemorySlotPromoter::removeBlockingUses() {
565571
slot, info.userToBlockingUses[toPromote], rewriter,
566572
reachingDef) == DeletionKind::Delete)
567573
toErase.push_back(toPromote);
568-
574+
if (toPromoteMemOp.storesTo(slot))
575+
if (Value replacedValue = replacedValuesMap[toPromoteMemOp])
576+
replacedValuesList.push_back({toPromoteMemOp, replacedValue});
569577
continue;
570578
}
571579

@@ -574,6 +582,12 @@ void MemorySlotPromoter::removeBlockingUses() {
574582
if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote],
575583
rewriter) == DeletionKind::Delete)
576584
toErase.push_back(toPromote);
585+
if (toPromoteBasic.requiresReplacedValues())
586+
toVisit.push_back(toPromoteBasic);
587+
}
588+
for (PromotableOpInterface op : toVisit) {
589+
rewriter.setInsertionPointAfter(op);
590+
op.visitReplacedValues(replacedValuesList, rewriter);
577591
}
578592

579593
for (Operation *toEraseOp : toErase)

mlir/test/Dialect/LLVMIR/mem2reg-dbginfo.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,27 @@ llvm.func @basic_store_load(%arg0: i64) -> i64 {
2929
llvm.return %2 : i64
3030
}
3131

32+
// CHECK-LABEL: llvm.func @multiple_store_load
33+
llvm.func @multiple_store_load(%arg0: i64) -> i64 {
34+
%0 = llvm.mlir.constant(1 : i32) : i32
35+
// CHECK-NOT: = llvm.alloca
36+
%1 = llvm.alloca %0 x i64 {alignment = 8 : i64} : (i32) -> !llvm.ptr
37+
// CHECK-NOT: llvm.intr.dbg.declare
38+
llvm.intr.dbg.declare #di_local_variable = %1 : !llvm.ptr
39+
// CHECK-NOT: llvm.store
40+
llvm.store %arg0, %1 {alignment = 4 : i64} : i64, !llvm.ptr
41+
// CHECK-NOT: llvm.intr.dbg.declare
42+
llvm.intr.dbg.declare #di_local_variable = %1 : !llvm.ptr
43+
// CHECK: llvm.intr.dbg.value #[[$VAR]] = %[[LOADED:.*]] : i64
44+
// CHECK: llvm.intr.dbg.value #[[$VAR]] = %[[LOADED]] : i64
45+
// CHECK-NOT: llvm.intr.dbg.value
46+
// CHECK-NOT: llvm.intr.dbg.declare
47+
// CHECK-NOT: llvm.store
48+
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i64
49+
// CHECK: llvm.return %[[LOADED]] : i64
50+
llvm.return %2 : i64
51+
}
52+
3253
// CHECK-LABEL: llvm.func @block_argument_value
3354
// CHECK-SAME: (%[[ARG0:.*]]: i64, {{.*}})
3455
llvm.func @block_argument_value(%arg0: i64, %arg1: i1) -> i64 {

0 commit comments

Comments
 (0)