Skip to content

Commit a16b19c

Browse files
committed
[MLIR][LLVM] Take the alignment attribute into account during inlining.
This is a subset of the full LLVM functionality to detect whether realignment is necessary, conservatively copying byval arguments whenever we cannot prove that the alignment requirement is met. Differential Revision: https://reviews.llvm.org/D147049
1 parent 48cd8b5 commit a16b19c

File tree

2 files changed

+142
-28
lines changed

2 files changed

+142
-28
lines changed

mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp

Lines changed: 79 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -95,44 +95,86 @@ static void moveConstantAllocasToEntryBlock(
9595
}
9696
}
9797

98+
/// Tries to find and return the alignment of the pointer `value` by looking for
99+
/// an alignment attribute on the defining allocation op or function argument.
100+
/// If no such attribute is found, returns 1 (i.e., assume that no alignment is
101+
/// guaranteed).
102+
static unsigned getAlignmentOf(Value value) {
103+
if (Operation *definingOp = value.getDefiningOp()) {
104+
if (auto alloca = dyn_cast<LLVM::AllocaOp>(definingOp))
105+
return alloca.getAlignment().value_or(1);
106+
if (auto addressOf = dyn_cast<LLVM::AddressOfOp>(definingOp))
107+
if (auto global = SymbolTable::lookupNearestSymbolFrom<LLVM::GlobalOp>(
108+
definingOp, addressOf.getGlobalNameAttr()))
109+
return global.getAlignment().value_or(1);
110+
// We don't currently handle this operation; assume no alignment.
111+
return 1;
112+
}
113+
// Since there is no defining op, this is a block argument. Probably this
114+
// comes directly from a function argument, so check that this is the case.
115+
Operation *parentOp = value.getParentBlock()->getParentOp();
116+
if (auto func = dyn_cast<LLVM::LLVMFuncOp>(parentOp)) {
117+
// Use the alignment attribute set for this argument in the parent
118+
// function if it has been set.
119+
auto blockArg = value.cast<BlockArgument>();
120+
if (Attribute alignAttr = func.getArgAttr(
121+
blockArg.getArgNumber(), LLVM::LLVMDialect::getAlignAttrName()))
122+
return cast<IntegerAttr>(alignAttr).getValue().getLimitedValue();
123+
}
124+
// We didn't find anything useful; assume no alignment.
125+
return 1;
126+
}
127+
128+
/// Copies the data from a byval pointer argument into newly alloca'ed memory
129+
/// and returns the value of the alloca.
130+
static Value handleByValArgumentInit(OpBuilder &builder, Location loc,
131+
Value argument, Type elementType,
132+
unsigned elementTypeSize,
133+
unsigned targetAlignment) {
134+
// Allocate the new value on the stack.
135+
Value one = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
136+
builder.getI64IntegerAttr(1));
137+
Value allocaOp = builder.create<LLVM::AllocaOp>(
138+
loc, argument.getType(), elementType, one, targetAlignment);
139+
// Copy the pointee to the newly allocated value.
140+
Value copySize = builder.create<LLVM::ConstantOp>(
141+
loc, builder.getI64Type(), builder.getI64IntegerAttr(elementTypeSize));
142+
Value isVolatile = builder.create<LLVM::ConstantOp>(
143+
loc, builder.getI1Type(), builder.getBoolAttr(false));
144+
builder.create<LLVM::MemcpyOp>(loc, allocaOp, argument, copySize, isVolatile);
145+
return allocaOp;
146+
}
147+
148+
/// Handles a function argument marked with the byval attribute by introducing a
149+
/// memcpy if necessary, either due to the pointee being writeable in the
150+
/// callee, and/or due to an alignment mismatch. `requestedAlignment` specifies
151+
/// the alignment set in the "align" argument attribute (or 1 if no align
152+
/// attribute was set).
98153
static Value handleByValArgument(OpBuilder &builder, Operation *callable,
99-
Value argument,
100-
NamedAttribute byValAttribute) {
154+
Value argument, Type elementType,
155+
unsigned requestedAlignment) {
101156
auto func = cast<LLVM::LLVMFuncOp>(callable);
102157
LLVM::MemoryEffectsAttr memoryEffects = func.getMemoryAttr();
103158
// If there is no memory effects attribute, assume that the function is
104159
// not read-only.
105160
bool isReadOnly = memoryEffects &&
106161
memoryEffects.getArgMem() != LLVM::ModRefInfo::ModRef &&
107162
memoryEffects.getArgMem() != LLVM::ModRefInfo::Mod;
108-
if (isReadOnly)
163+
// Check if there's an alignment mismatch requiring us to copy.
164+
DataLayout dataLayout(callable->getParentOfType<DataLayoutOpInterface>());
165+
unsigned minimumAlignment = dataLayout.getTypeABIAlignment(elementType);
166+
if (isReadOnly && (requestedAlignment <= minimumAlignment ||
167+
getAlignmentOf(argument) >= requestedAlignment))
109168
return argument;
110-
// Resolve the pointee type and its size.
111-
auto ptrType = cast<LLVM::LLVMPointerType>(argument.getType());
112-
Type elementType = cast<TypeAttr>(byValAttribute.getValue()).getValue();
113-
unsigned int typeSize =
114-
DataLayout(callable->getParentOfType<DataLayoutOpInterface>())
115-
.getTypeSize(elementType);
116-
// Allocate the new value on the stack.
117-
Value one = builder.create<LLVM::ConstantOp>(
118-
func.getLoc(), builder.getI64Type(), builder.getI64IntegerAttr(1));
119-
Value allocaOp =
120-
builder.create<LLVM::AllocaOp>(func.getLoc(), ptrType, elementType, one);
121-
// Copy the pointee to the newly allocated value.
122-
Value copySize = builder.create<LLVM::ConstantOp>(
123-
func.getLoc(), builder.getI64Type(), builder.getI64IntegerAttr(typeSize));
124-
Value isVolatile = builder.create<LLVM::ConstantOp>(
125-
func.getLoc(), builder.getI1Type(), builder.getBoolAttr(false));
126-
builder.create<LLVM::MemcpyOp>(func.getLoc(), allocaOp, argument, copySize,
127-
isVolatile);
128-
return allocaOp;
169+
unsigned targetAlignment = std::max(requestedAlignment, minimumAlignment);
170+
return handleByValArgumentInit(builder, func.getLoc(), argument, elementType,
171+
dataLayout.getTypeSize(elementType),
172+
targetAlignment);
129173
}
130174

131175
/// Returns true if the given argument or result attribute is supported by the
132176
/// inliner, false otherwise.
133177
static bool isArgOrResAttrSupported(NamedAttribute attr) {
134-
if (attr.getName() == LLVM::LLVMDialect::getAlignAttrName())
135-
return false;
136178
if (attr.getName() == LLVM::LLVMDialect::getInAllocaAttrName())
137179
return false;
138180
if (attr.getName() == LLVM::LLVMDialect::getNoAliasAttrName())
@@ -289,9 +331,19 @@ struct LLVMInlinerInterface : public DialectInlinerInterface {
289331
Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable,
290332
Value argument, Type targetType,
291333
DictionaryAttr argumentAttrs) const final {
292-
if (auto attr =
293-
argumentAttrs.getNamed(LLVM::LLVMDialect::getByValAttrName()))
294-
return handleByValArgument(builder, callable, argument, *attr);
334+
if (std::optional<NamedAttribute> attr =
335+
argumentAttrs.getNamed(LLVM::LLVMDialect::getByValAttrName())) {
336+
Type elementType = cast<TypeAttr>(attr->getValue()).getValue();
337+
unsigned requestedAlignment = 1;
338+
if (std::optional<NamedAttribute> alignAttr =
339+
argumentAttrs.getNamed(LLVM::LLVMDialect::getAlignAttrName())) {
340+
requestedAlignment = cast<IntegerAttr>(alignAttr->getValue())
341+
.getValue()
342+
.getLimitedValue();
343+
}
344+
return handleByValArgument(builder, callable, argument, elementType,
345+
requestedAlignment);
346+
}
295347
return argument;
296348
}
297349

mlir/test/Dialect/LLVMIR/inlining.mlir

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,68 @@ llvm.func @test_byval_write_only(%ptr : !llvm.ptr) {
399399

400400
// -----
401401

402+
llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read>} {
403+
llvm.return
404+
}
405+
406+
// CHECK-LABEL: llvm.func @test_byval_input_aligned
407+
// CHECK-SAME: %[[UNALIGNED:[a-zA-Z0-9_]+]]: !llvm.ptr
408+
// CHECK-SAME: %[[ALIGNED:[a-zA-Z0-9_]+]]: !llvm.ptr
409+
llvm.func @test_byval_input_aligned(%unaligned : !llvm.ptr, %aligned : !llvm.ptr { llvm.align = 16 }) {
410+
// Make sure only the unaligned input triggers a memcpy.
411+
// CHECK: %[[ALLOCA:.+]] = llvm.alloca %{{.+}} x i16 {alignment = 16
412+
// CHECK: "llvm.intr.memcpy"(%[[ALLOCA]], %[[UNALIGNED]]
413+
llvm.call @aligned_byval_arg(%unaligned) : (!llvm.ptr) -> ()
414+
// CHECK-NOT: memcpy
415+
llvm.call @aligned_byval_arg(%aligned) : (!llvm.ptr) -> ()
416+
llvm.return
417+
}
418+
419+
// -----
420+
421+
llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read>} {
422+
llvm.return
423+
}
424+
425+
// CHECK-LABEL: llvm.func @test_byval_alloca
426+
llvm.func @test_byval_alloca() {
427+
// Make sure only the unaligned alloca triggers a memcpy.
428+
%size = llvm.mlir.constant(1 : i64) : i64
429+
// CHECK: %[[ALLOCA:.+]] = llvm.alloca {{.+}}alignment = 1
430+
// CHECK: "llvm.intr.memcpy"(%{{.+}}, %[[ALLOCA]]
431+
%unaligned = llvm.alloca %size x i16 { alignment = 1 } : (i64) -> !llvm.ptr
432+
llvm.call @aligned_byval_arg(%unaligned) : (!llvm.ptr) -> ()
433+
// CHECK-NOT: memcpy
434+
%aligned = llvm.alloca %size x i16 { alignment = 16 } : (i64) -> !llvm.ptr
435+
llvm.call @aligned_byval_arg(%aligned) : (!llvm.ptr) -> ()
436+
llvm.return
437+
}
438+
439+
// -----
440+
441+
llvm.mlir.global private @unaligned_global(42 : i64) : i64
442+
llvm.mlir.global private @aligned_global(42 : i64) { alignment = 64 } : i64
443+
444+
llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read>} {
445+
llvm.return
446+
}
447+
448+
// CHECK-LABEL: llvm.func @test_byval_global
449+
llvm.func @test_byval_global() {
450+
// Make sure only the unaligned global triggers a memcpy.
451+
// CHECK: %[[UNALIGNED:.+]] = llvm.mlir.addressof @unaligned_global
452+
// CHECK: %[[ALLOCA:.+]] = llvm.alloca
453+
// CHECK: "llvm.intr.memcpy"(%[[ALLOCA]], %[[UNALIGNED]]
454+
// CHECK-NOT: llvm.alloca
455+
%unaligned = llvm.mlir.addressof @unaligned_global : !llvm.ptr
456+
llvm.call @aligned_byval_arg(%unaligned) : (!llvm.ptr) -> ()
457+
%aligned = llvm.mlir.addressof @aligned_global : !llvm.ptr
458+
llvm.call @aligned_byval_arg(%aligned) : (!llvm.ptr) -> ()
459+
llvm.return
460+
}
461+
462+
// -----
463+
402464
llvm.func @ignored_attrs(%ptr : !llvm.ptr { llvm.inreg, llvm.nocapture, llvm.nofree, llvm.preallocated = i32, llvm.returned, llvm.alignstack = 32 : i64, llvm.writeonly, llvm.noundef, llvm.nonnull }, %x : i32 { llvm.zeroext }) -> (!llvm.ptr { llvm.noundef, llvm.inreg, llvm.nonnull }) {
403465
llvm.return %ptr : !llvm.ptr
404466
}
@@ -413,7 +475,7 @@ llvm.func @test_ignored_attrs(%ptr : !llvm.ptr, %x : i32) {
413475

414476
// -----
415477

416-
llvm.func @disallowed_arg_attr(%ptr : !llvm.ptr { llvm.align = 16 : i32 }) {
478+
llvm.func @disallowed_arg_attr(%ptr : !llvm.ptr { llvm.noalias }) {
417479
llvm.return
418480
}
419481

0 commit comments

Comments
 (0)