Skip to content

Commit 9606655

Browse files
[mlir][Transforms] Fix use-after-free when accessing replaced block args (#83646)
This commit fixes a bug in a dialect conversion. Currently, when a block is replaced via a signature conversion, the block is erased during the "commit" phase. This is problematic because the block arguments may still be referenced internal data structures of the dialect conversion (`mapping`). Blocks should be treated same as ops: they should be erased during the "cleanup" phase. Note: The test case fails without this fix when running with ASAN, but may pass when running without ASAN.
1 parent 61c283d commit 9606655

File tree

3 files changed

+34
-12
lines changed

3 files changed

+34
-12
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -207,13 +207,13 @@ class IRRewrite {
207207
/// Roll back the rewrite. Operations may be erased during rollback.
208208
virtual void rollback() = 0;
209209

210-
/// Commit the rewrite. Operations may be unlinked from their blocks during
211-
/// the commit phase, but they must not be erased yet. This is because
212-
/// internal dialect conversion state (such as `mapping`) may still be using
213-
/// them. Operations must be erased during cleanup.
210+
/// Commit the rewrite. Operations/blocks may be unlinked during the commit
211+
/// phase, but they must not be erased yet. This is because internal dialect
212+
/// conversion state (such as `mapping`) may still be using them. Operations/
213+
/// blocks must be erased during cleanup.
214214
virtual void commit() {}
215215

216-
/// Cleanup operations. Cleanup is called after commit.
216+
/// Cleanup operations/blocks. Cleanup is called after commit.
217217
virtual void cleanup() {}
218218

219219
Kind getKind() const { return kind; }
@@ -280,9 +280,9 @@ class CreateBlockRewrite : public BlockRewrite {
280280
};
281281

282282
/// Erasure of a block. Block erasures are partially reflected in the IR. Erased
283-
/// blocks are immediately unlinked, but only erased when the rewrite is
284-
/// committed. This makes it easier to rollback a block erasure: the block is
285-
/// simply inserted into its original location.
283+
/// blocks are immediately unlinked, but only erased during cleanup. This makes
284+
/// it easier to rollback a block erasure: the block is simply inserted into its
285+
/// original location.
286286
class EraseBlockRewrite : public BlockRewrite {
287287
public:
288288
EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
@@ -295,7 +295,8 @@ class EraseBlockRewrite : public BlockRewrite {
295295
}
296296

297297
~EraseBlockRewrite() override {
298-
assert(!block && "rewrite was neither rolled back nor committed");
298+
assert(!block &&
299+
"rewrite was neither rolled back nor committed/cleaned up");
299300
}
300301

301302
void rollback() override {
@@ -310,7 +311,7 @@ class EraseBlockRewrite : public BlockRewrite {
310311
block = nullptr;
311312
}
312313

313-
void commit() override {
314+
void cleanup() override {
314315
// Erase the block.
315316
assert(block && "expected block");
316317
assert(block->empty() && "expected empty block");
@@ -438,6 +439,8 @@ class BlockTypeConversionRewrite : public BlockRewrite {
438439

439440
void commit() override;
440441

442+
void cleanup() override;
443+
441444
void rollback() override;
442445

443446
private:
@@ -985,7 +988,9 @@ void BlockTypeConversionRewrite::commit() {
985988
rewriterImpl.mapping.lookupOrDefault(castValue, origArg.getType()));
986989
}
987990
}
991+
}
988992

993+
void BlockTypeConversionRewrite::cleanup() {
989994
assert(origBlock->empty() && "expected empty block");
990995
origBlock->dropAllDefinedValueUses();
991996
delete origBlock;
@@ -1483,6 +1488,11 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
14831488
Block *block, Region *previous, Region::iterator previousIt) {
14841489
assert(!wasOpReplaced(block->getParentOp()) &&
14851490
"attempting to insert into a region within a replaced/erased op");
1491+
LLVM_DEBUG({
1492+
logger.startLine() << "** Insert Block into : '"
1493+
<< block->getParentOp()->getName() << "'("
1494+
<< block->getParentOp() << ")\n";
1495+
});
14861496

14871497
if (!previous) {
14881498
// This is a newly created block.

mlir/test/Transforms/test-legalizer.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,3 +346,15 @@ func.func @test_properties_rollback() {
346346
{modify_inplace}
347347
"test.return"() : () -> ()
348348
}
349+
350+
// -----
351+
352+
// CHECK: func.func @use_of_replaced_bbarg(
353+
// CHECK-SAME: %[[arg0:.*]]: f64)
354+
// CHECK: "test.valid"(%[[arg0]])
355+
func.func @use_of_replaced_bbarg(%arg0: i64) {
356+
%0 = "test.op_with_region_fold"(%arg0) ({
357+
"foo.op_with_region_terminator"() : () -> ()
358+
}) : (i64) -> (i64)
359+
"test.invalid"(%0) : (i64) -> ()
360+
}

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,8 +1270,8 @@ def TestOpWithRegionFoldNoMemoryEffect : TEST_Op<
12701270

12711271
// Op for testing folding of outer op with inner ops.
12721272
def TestOpWithRegionFold : TEST_Op<"op_with_region_fold"> {
1273-
let arguments = (ins I32:$operand);
1274-
let results = (outs I32);
1273+
let arguments = (ins AnyType:$operand);
1274+
let results = (outs AnyType);
12751275
let regions = (region SizedRegion<1>:$region);
12761276
let hasFolder = 1;
12771277
}

0 commit comments

Comments
 (0)