diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td index e2409cbec5fde..6f023f0c5263f 100644 --- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td +++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td @@ -298,14 +298,20 @@ def DestructurableAllocationOpInterface "destructure", (ins "const ::mlir::DestructurableMemorySlot &":$slot, "const ::llvm::SmallPtrSetImpl<::mlir::Attribute> &":$usedIndices, - "::mlir::OpBuilder &":$builder) + "::mlir::OpBuilder &":$builder, + "::mlir::SmallVectorImpl<::mlir::DestructurableAllocationOpInterface> &": + $newAllocators) >, InterfaceMethod<[{ Hook triggered once the destructuring of a slot is complete, meaning the original slot is no longer being refered to and could be deleted. This will only be called for slots declared by this operation. + + Must return a new destructurable allocation op if this hook creates + a new destructurable op, nullopt otherwise. }], - "void", "handleDestructuringComplete", + "::std::optional<::mlir::DestructurableAllocationOpInterface>", + "handleDestructuringComplete", (ins "const ::mlir::DestructurableMemorySlot &":$slot, "::mlir::OpBuilder &":$builder) >, diff --git a/mlir/include/mlir/Transforms/SROA.h b/mlir/include/mlir/Transforms/SROA.h index fa84fb1eae73a..d48f809d30766 100644 --- a/mlir/include/mlir/Transforms/SROA.h +++ b/mlir/include/mlir/Transforms/SROA.h @@ -27,8 +27,10 @@ struct SROAStatistics { llvm::Statistic *maxSubelementAmount = nullptr; }; -/// Attempts to destructure the slots of destructurable allocators. Returns -/// failure if no slot was destructured. +/// Attempts to destructure the slots of destructurable allocators. Iteratively +/// retries the destructuring of all slots as destructuring one slot might +/// enable subsequent destructuring. Returns failure if no slot was +/// destructured. LogicalResult tryToDestructureMemorySlots( ArrayRef allocators, OpBuilder &builder, const DataLayout &dataLayout, diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index 4fdf847a559ce..3f1e5b1773bf7 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -77,10 +77,10 @@ SmallVector LLVM::AllocaOp::getDestructurableSlots() { *destructuredType}}; } -DenseMap -LLVM::AllocaOp::destructure(const DestructurableMemorySlot &slot, - const SmallPtrSetImpl &usedIndices, - OpBuilder &builder) { +DenseMap LLVM::AllocaOp::destructure( + const DestructurableMemorySlot &slot, + const SmallPtrSetImpl &usedIndices, OpBuilder &builder, + SmallVectorImpl &newAllocators) { assert(slot.ptr == getResult()); builder.setInsertionPointAfter(*this); @@ -92,16 +92,19 @@ LLVM::AllocaOp::destructure(const DestructurableMemorySlot &slot, auto subAlloca = builder.create( getLoc(), LLVM::LLVMPointerType::get(getContext()), elemType, getArraySize()); + newAllocators.push_back(subAlloca); slotMap.try_emplace(index, {subAlloca.getResult(), elemType}); } return slotMap; } -void LLVM::AllocaOp::handleDestructuringComplete( +std::optional +LLVM::AllocaOp::handleDestructuringComplete( const DestructurableMemorySlot &slot, OpBuilder &builder) { assert(slot.ptr == getResult()); this->erase(); + return std::nullopt; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp index e30598e6878f4..631dee2d40538 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp @@ -126,10 +126,10 @@ memref::AllocaOp::getDestructurableSlots() { DestructurableMemorySlot{{getMemref(), memrefType}, *destructuredType}}; } -DenseMap -memref::AllocaOp::destructure(const DestructurableMemorySlot &slot, - const SmallPtrSetImpl &usedIndices, - OpBuilder &builder) { +DenseMap memref::AllocaOp::destructure( + const DestructurableMemorySlot &slot, + const SmallPtrSetImpl &usedIndices, OpBuilder &builder, + SmallVectorImpl &newAllocators) { builder.setInsertionPointAfter(*this); DenseMap slotMap; @@ -139,6 +139,7 @@ memref::AllocaOp::destructure(const DestructurableMemorySlot &slot, Type elemType = memrefType.getTypeAtIndex(usedIndex); MemRefType elemPtr = MemRefType::get({}, elemType); auto subAlloca = builder.create(getLoc(), elemPtr); + newAllocators.push_back(subAlloca); slotMap.try_emplace(usedIndex, {subAlloca.getResult(), elemType}); } @@ -146,10 +147,12 @@ memref::AllocaOp::destructure(const DestructurableMemorySlot &slot, return slotMap; } -void memref::AllocaOp::handleDestructuringComplete( +std::optional +memref::AllocaOp::handleDestructuringComplete( const DestructurableMemorySlot &slot, OpBuilder &builder) { assert(slot.ptr == getResult()); this->erase(); + return std::nullopt; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp index 4e28fa687ffd4..67cbade07bc94 100644 --- a/mlir/lib/Transforms/SROA.cpp +++ b/mlir/lib/Transforms/SROA.cpp @@ -132,16 +132,17 @@ computeDestructuringInfo(DestructurableMemorySlot &slot, /// Performs the destructuring of a destructible slot given associated /// destructuring information. The provided slot will be destructured in /// subslots as specified by its allocator. -static void destructureSlot(DestructurableMemorySlot &slot, - DestructurableAllocationOpInterface allocator, - OpBuilder &builder, const DataLayout &dataLayout, - MemorySlotDestructuringInfo &info, - const SROAStatistics &statistics) { +static void destructureSlot( + DestructurableMemorySlot &slot, + DestructurableAllocationOpInterface allocator, OpBuilder &builder, + const DataLayout &dataLayout, MemorySlotDestructuringInfo &info, + SmallVectorImpl &newAllocators, + const SROAStatistics &statistics) { OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToStart(slot.ptr.getParentBlock()); DenseMap subslots = - allocator.destructure(slot, info.usedIndices, builder); + allocator.destructure(slot, info.usedIndices, builder, newAllocators); if (statistics.slotsWithMemoryBenefit && slot.elementPtrs.size() != info.usedIndices.size()) @@ -185,7 +186,11 @@ static void destructureSlot(DestructurableMemorySlot &slot, if (statistics.destructuredAmount) (*statistics.destructuredAmount)++; - allocator.handleDestructuringComplete(slot, builder); + std::optional newAllocator = + allocator.handleDestructuringComplete(slot, builder); + // Add newly created allocators to the worklist for further processing. + if (newAllocator) + newAllocators.push_back(*newAllocator); } LogicalResult mlir::tryToDestructureMemorySlots( @@ -194,16 +199,44 @@ LogicalResult mlir::tryToDestructureMemorySlots( SROAStatistics statistics) { bool destructuredAny = false; - for (DestructurableAllocationOpInterface allocator : allocators) { - for (DestructurableMemorySlot slot : allocator.getDestructurableSlots()) { - std::optional info = - computeDestructuringInfo(slot, dataLayout); - if (!info) - continue; + SmallVector workList(allocators.begin(), + allocators.end()); + SmallVector newWorkList; + newWorkList.reserve(allocators.size()); + // Destructuring a slot can allow for further destructuring of other + // slots, destructuring is tried until no destructuring succeeds. + while (true) { + bool changesInThisRound = false; + + for (DestructurableAllocationOpInterface allocator : workList) { + bool destructuredAnySlot = false; + for (DestructurableMemorySlot slot : allocator.getDestructurableSlots()) { + std::optional info = + computeDestructuringInfo(slot, dataLayout); + if (!info) + continue; - destructureSlot(slot, allocator, builder, dataLayout, *info, statistics); - destructuredAny = true; + destructureSlot(slot, allocator, builder, dataLayout, *info, + newWorkList, statistics); + destructuredAnySlot = true; + + // A break is required, since destructuring a slot may invalidate the + // remaning slots of an allocator. + break; + } + if (!destructuredAnySlot) + newWorkList.push_back(allocator); + changesInThisRound |= destructuredAnySlot; } + + if (!changesInThisRound) + break; + destructuredAny |= changesInThisRound; + + // Swap the vector's backing memory and clear the entries in newWorkList + // afterwards. This ensures that additional heap allocations can be avoided. + workList.swap(newWorkList); + newWorkList.clear(); } return success(destructuredAny); @@ -230,23 +263,16 @@ struct SROA : public impl::SROABase { OpBuilder builder(®ion.front(), region.front().begin()); - // Destructuring a slot can allow for further destructuring of other - // slots, destructuring is tried until no destructuring succeeds. - while (true) { - SmallVector allocators; - // Build a list of allocators to attempt to destructure the slots of. - // TODO: Update list on the fly to avoid repeated visiting of the same - // allocators. - region.walk([&](DestructurableAllocationOpInterface allocator) { - allocators.emplace_back(allocator); - }); - - if (failed(tryToDestructureMemorySlots(allocators, builder, dataLayout, - statistics))) - break; + SmallVector allocators; + // Build a list of allocators to attempt to destructure the slots of. + region.walk([&](DestructurableAllocationOpInterface allocator) { + allocators.emplace_back(allocator); + }); + // Attempt to destructure as many slots as possible. + if (succeeded(tryToDestructureMemorySlots(allocators, builder, dataLayout, + statistics))) changed = true; - } } if (!changed) markAllAnalysesPreserved(); diff --git a/mlir/test/Transforms/sroa.mlir b/mlir/test/Transforms/sroa.mlir new file mode 100644 index 0000000000000..c9e80a6cf8dd1 --- /dev/null +++ b/mlir/test/Transforms/sroa.mlir @@ -0,0 +1,31 @@ +// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(sroa))' --split-input-file | FileCheck %s + +// Verifies that allocators with mutliple slots are handled properly. + +// CHECK-LABEL: func.func @multi_slot_alloca +func.func @multi_slot_alloca() -> (i32, i32) { + %0 = arith.constant 0 : index + %1, %2 = test.multi_slot_alloca : () -> (memref<2xi32>, memref<4xi32>) + // CHECK-COUNT-2: test.multi_slot_alloca : () -> memref + %3 = memref.load %1[%0] {first}: memref<2xi32> + %4 = memref.load %2[%0] {second} : memref<4xi32> + return %3, %4 : i32, i32 +} + +// ----- + +// Verifies that a multi slot allocator can be partially destructured. + +func.func private @consumer(memref<2xi32>) + +// CHECK-LABEL: func.func @multi_slot_alloca_only_second +func.func @multi_slot_alloca_only_second() -> (i32, i32) { + %0 = arith.constant 0 : index + // CHECK: test.multi_slot_alloca : () -> memref<2xi32> + // CHECK: test.multi_slot_alloca : () -> memref + %1, %2 = test.multi_slot_alloca : () -> (memref<2xi32>, memref<4xi32>) + func.call @consumer(%1) : (memref<2xi32>) -> () + %3 = memref.load %1[%0] : memref<2xi32> + %4 = memref.load %2[%0] : memref<4xi32> + return %3, %4 : i32, i32 +} diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp index d22d48b139a04..0b676db18af41 100644 --- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp @@ -1199,22 +1199,20 @@ void TestMultiSlotAlloca::handleBlockArgument(const MemorySlot &slot, // Not relevant for testing. } -std::optional -TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot, - Value defaultValue, - OpBuilder &builder) { - if (defaultValue && defaultValue.use_empty()) - defaultValue.getDefiningOp()->erase(); +/// Creates a new TestMultiSlotAlloca operation, just without the `slot`. +static std::optional +createNewMultiAllocaWithoutSlot(const MemorySlot &slot, OpBuilder &builder, + TestMultiSlotAlloca oldOp) { - if (getNumResults() == 1) { - erase(); + if (oldOp.getNumResults() == 1) { + oldOp.erase(); return std::nullopt; } SmallVector newTypes; SmallVector remainingValues; - for (Value oldResult : getResults()) { + for (Value oldResult : oldOp.getResults()) { if (oldResult == slot.ptr) continue; remainingValues.push_back(oldResult); @@ -1222,12 +1220,68 @@ TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot, } OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPoint(*this); - auto replacement = builder.create(getLoc(), newTypes); + builder.setInsertionPoint(oldOp); + auto replacement = + builder.create(oldOp->getLoc(), newTypes); for (auto [oldResult, newResult] : llvm::zip_equal(remainingValues, replacement.getResults())) oldResult.replaceAllUsesWith(newResult); - erase(); + oldOp.erase(); return replacement; } + +std::optional +TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot, + Value defaultValue, + OpBuilder &builder) { + if (defaultValue && defaultValue.use_empty()) + defaultValue.getDefiningOp()->erase(); + return createNewMultiAllocaWithoutSlot(slot, builder, *this); +} + +SmallVector +TestMultiSlotAlloca::getDestructurableSlots() { + SmallVector slots; + for (Value result : getResults()) { + auto memrefType = cast(result.getType()); + auto destructurable = dyn_cast(memrefType); + if (!destructurable) + continue; + + std::optional> destructuredType = + destructurable.getSubelementIndexMap(); + if (!destructuredType) + continue; + slots.emplace_back( + DestructurableMemorySlot{{result, memrefType}, *destructuredType}); + } + return slots; +} + +DenseMap TestMultiSlotAlloca::destructure( + const DestructurableMemorySlot &slot, + const SmallPtrSetImpl &usedIndices, OpBuilder &builder, + SmallVectorImpl &newAllocators) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointAfter(*this); + + DenseMap slotMap; + + for (Attribute usedIndex : usedIndices) { + Type elemType = slot.elementPtrs.lookup(usedIndex); + MemRefType elemPtr = MemRefType::get({}, elemType); + auto subAlloca = builder.create(getLoc(), elemPtr); + newAllocators.push_back(subAlloca); + slotMap.try_emplace(usedIndex, + {subAlloca.getResult(0), elemType}); + } + + return slotMap; +} + +std::optional +TestMultiSlotAlloca::handleDestructuringComplete( + const DestructurableMemorySlot &slot, OpBuilder &builder) { + return createNewMultiAllocaWithoutSlot(slot, builder, *this); +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index e16ea2407314e..7fc3d22d18958 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -3169,11 +3169,12 @@ def TestOpOptionallyImplementingInterface } //===----------------------------------------------------------------------===// -// Test Mem2Reg +// Test Mem2Reg & SROA //===----------------------------------------------------------------------===// def TestMultiSlotAlloca : TEST_Op<"multi_slot_alloca", - [DeclareOpInterfaceMethods]> { + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let results = (outs Variadic>:$results); let assemblyFormat = "attr-dict `:` functional-type(operands, results)"; }