Skip to content

Commit 3359806

Browse files
authored
[mlir][LLVM][MemRef] Lower assume_alignment with operand bundles (#117800)
Now that LLVM allows a operand bundle on assume calls to directly specify alignment assumptions, change the lowering of memref.assume_alignment to use that feature instead of the ptrtoint method. This makes LLVM's job easier and prevents issues when dealing with cases where ptrtoint isn't a desired operation (like those with poiner provenance)
1 parent 3a8b28f commit 3359806

File tree

3 files changed

+18
-34
lines changed

3 files changed

+18
-34
lines changed

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -192,22 +192,15 @@ struct AssumeAlignmentOpLowering
192192
Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{},
193193
rewriter);
194194

195-
// Emit llvm.assume(memref & (alignment - 1) == 0).
196-
//
197-
// This relies on LLVM's CSE optimization (potentially after SROA), since
198-
// after CSE all memref instances should get de-duplicated into the same
199-
// pointer SSA value.
200-
MemRefDescriptor memRefDescriptor(memref);
201-
auto intPtrType =
202-
getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
203-
Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
204-
Value mask =
205-
createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
206-
Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
207-
rewriter.create<LLVM::AssumeOp>(
208-
loc, rewriter.create<LLVM::ICmpOp>(
209-
loc, LLVM::ICmpPredicate::eq,
210-
rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
195+
// Emit llvm.assume(true) ["align"(memref, alignment)].
196+
// This is more direct than ptrtoint-based checks, is explicitly supported,
197+
// and works with non-integral address spaces.
198+
Value trueCond =
199+
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getBoolAttr(true));
200+
Value alignmentConst =
201+
createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
202+
rewriter.create<LLVM::AssumeOp>(loc, trueCond, LLVM::AssumeAlignTag(), ptr,
203+
alignmentConst);
211204

212205
rewriter.eraseOp(op);
213206
return success();

mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -675,10 +675,7 @@ func.func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf
675675
// CHECK: %[[ALIGNED_PTR:.*]] = llvm.extractvalue %[[DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
676676
// CHECK: %[[OFFSET:.*]] = llvm.extractvalue %[[DESC]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
677677
// CHECK: %[[BUFF_ADDR:.*]] = llvm.getelementptr %[[ALIGNED_PTR]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
678-
// CHECK: %[[INT_TO_PTR:.*]] = llvm.ptrtoint %[[BUFF_ADDR]] : !llvm.ptr to i64
679-
// CHECK: %[[AND:.*]] = llvm.and %[[INT_TO_PTR]], {{.*}} : i64
680-
// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[AND]], {{.*}} : i64
681-
// CHECK: llvm.intr.assume %[[CMP]] : i1
678+
// CHECK: llvm.intr.assume %{{.*}} ["align"(%[[BUFF_ADDR]], %{{.*}} : !llvm.ptr, i64)] : i1
682679
// CHECK: %[[LD_ADDR:.*]] = llvm.getelementptr %[[BUFF_ADDR]][%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
683680
// CHECK: %[[VAL:.*]] = llvm.load %[[LD_ADDR]] : !llvm.ptr -> f32
684681
// CHECK: return %[[VAL]] : f32

mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,9 @@ func.func @subview(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>, %arg0 : in
155155
// CHECK-LABEL: func @assume_alignment(
156156
func.func @assume_alignment(%0 : memref<4x4xf16>) {
157157
// CHECK: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
158-
// CHECK-NEXT: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : i64
159-
// CHECK-NEXT: %[[MASK:.*]] = llvm.mlir.constant(15 : index) : i64
160-
// CHECK-NEXT: %[[INT:.*]] = llvm.ptrtoint %[[PTR]] : !llvm.ptr to i64
161-
// CHECK-NEXT: %[[MASKED_PTR:.*]] = llvm.and %[[INT]], %[[MASK:.*]] : i64
162-
// CHECK-NEXT: %[[CONDITION:.*]] = llvm.icmp "eq" %[[MASKED_PTR]], %[[ZERO]] : i64
163-
// CHECK-NEXT: llvm.intr.assume %[[CONDITION]] : i1
158+
// CHECK-NEXT: %[[TRUE:.*]] = llvm.mlir.constant(true) : i1
159+
// CHECK-NEXT: %[[ALIGN:.*]] = llvm.mlir.constant(16 : index) : i64
160+
// CHECK-NEXT: llvm.intr.assume %[[TRUE]] ["align"(%[[PTR]], %[[ALIGN]] : !llvm.ptr, i64)] : i1
164161
memref.assume_alignment %0, 16 : memref<4x4xf16>
165162
return
166163
}
@@ -172,12 +169,9 @@ func.func @assume_alignment_w_offset(%0 : memref<4x4xf16, strided<[?, ?], offset
172169
// CHECK-DAG: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
173170
// CHECK-DAG: %[[OFFSET:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
174171
// CHECK-DAG: %[[BUFF_ADDR:.*]] = llvm.getelementptr %[[PTR]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, f16
175-
// CHECK-DAG: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : i64
176-
// CHECK-DAG: %[[MASK:.*]] = llvm.mlir.constant(15 : index) : i64
177-
// CHECK-NEXT: %[[INT:.*]] = llvm.ptrtoint %[[BUFF_ADDR]] : !llvm.ptr to i64
178-
// CHECK-NEXT: %[[MASKED_PTR:.*]] = llvm.and %[[INT]], %[[MASK:.*]] : i64
179-
// CHECK-NEXT: %[[CONDITION:.*]] = llvm.icmp "eq" %[[MASKED_PTR]], %[[ZERO]] : i64
180-
// CHECK-NEXT: llvm.intr.assume %[[CONDITION]] : i1
172+
// CHECK-DAG: %[[TRUE:.*]] = llvm.mlir.constant(true) : i1
173+
// CHECK-DAG: %[[ALIGN:.*]] = llvm.mlir.constant(16 : index) : i64
174+
// CHECK-NEXT: llvm.intr.assume %[[TRUE]] ["align"(%[[BUFF_ADDR]], %[[ALIGN]] : !llvm.ptr, i64)] : i1
181175
memref.assume_alignment %0, 16 : memref<4x4xf16, strided<[?, ?], offset: ?>>
182176
return
183177
}
@@ -410,7 +404,7 @@ func.func @atomic_rmw_with_offset(%I : memref<10xi32, strided<[1], offset: 5>>,
410404
// CHECK-SAME: %[[ARG2:.+]]: index
411405
// CHECK-DAG: %[[MEMREF_STRUCT:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<10xi32, strided<[1], offset: 5>> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
412406
// CHECK-DAG: %[[INDEX:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i64
413-
// CHECK: %[[BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_STRUCT]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
407+
// CHECK: %[[BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_STRUCT]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
414408
// CHECK: %[[OFFSET:.+]] = llvm.mlir.constant(5 : index) : i64
415409
// CHECK: %[[OFFSET_PTR:.+]] = llvm.getelementptr %[[BASE_PTR]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
416410
// CHECK: %[[PTR:.+]] = llvm.getelementptr %[[OFFSET_PTR]][%[[INDEX]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
@@ -601,7 +595,7 @@ func.func @extract_aligned_pointer_as_index(%m: memref<?xf32>) -> index {
601595
// CHECK-LABEL: func @extract_aligned_pointer_as_index_unranked
602596
func.func @extract_aligned_pointer_as_index_unranked(%m: memref<*xf32>) -> index {
603597
%0 = memref.extract_aligned_pointer_as_index %m: memref<*xf32> -> index
604-
// CHECK: %[[PTR:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(i64, ptr)>
598+
// CHECK: %[[PTR:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(i64, ptr)>
605599
// CHECK: %[[ALIGNED_FIELD:.*]] = llvm.getelementptr %[[PTR]][1] : (!llvm.ptr) -> !llvm.ptr, !llvm.ptr
606600
// CHECK: %[[ALIGNED_PTR:.*]] = llvm.load %[[ALIGNED_FIELD]] : !llvm.ptr -> !llvm.ptr
607601
// CHECK: %[[I64:.*]] = llvm.ptrtoint %[[ALIGNED_PTR]] : !llvm.ptr to i64

0 commit comments

Comments
 (0)