Skip to content

Commit aad27bf

Browse files
authored
Add non-temporal support for LLVM masked loads (#104598)
This PR is adding non-temporal support to masked load as a `UnitAttr` attribute. Non temporal load is quite an important feature for masked loads to make the intrinsic usable from high level compilers
1 parent fc6300a commit aad27bf

File tree

3 files changed

+16
-6
lines changed

3 files changed

+16
-6
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -842,22 +842,27 @@ def LLVM_GetActiveLaneMaskOp
842842
/// Create a call to Masked Load intrinsic.
843843
def LLVM_MaskedLoadOp : LLVM_OneResultIntrOp<"masked.load"> {
844844
let arguments = (ins LLVM_AnyPointer:$data, LLVM_VectorOf<I1>:$mask,
845-
Variadic<LLVM_AnyVector>:$pass_thru, I32Attr:$alignment);
845+
Variadic<LLVM_AnyVector>:$pass_thru, I32Attr:$alignment,
846+
UnitAttr:$nontemporal);
846847
let results = (outs LLVM_AnyVector:$res);
847848
let assemblyFormat =
848849
"operands attr-dict `:` functional-type(operands, results)";
849850

850851
string llvmBuilder = [{
851-
$res = $pass_thru.empty() ? builder.CreateMaskedLoad(
852+
auto *inst = $pass_thru.empty() ? builder.CreateMaskedLoad(
852853
$_resultType, $data, llvm::Align($alignment), $mask) :
853854
builder.CreateMaskedLoad(
854855
$_resultType, $data, llvm::Align($alignment), $mask, $pass_thru[0]);
855-
}];
856+
$res = inst;
857+
}] #setNonTemporalMetadataCode;
856858
string mlirBuilder = [{
859+
auto *intrinInst = dyn_cast<llvm::IntrinsicInst>(inst);
860+
bool nontemporal = intrinInst->hasMetadata(llvm::LLVMContext::MD_nontemporal);
857861
$res = $_builder.create<LLVM::MaskedLoadOp>($_location,
858-
$_resultType, $data, $mask, $pass_thru, $_int_attr($alignment));
862+
$_resultType, $data, $mask, $pass_thru, $_int_attr($alignment),
863+
nontemporal ? $_builder.getUnitAttr() : nullptr);
859864
}];
860-
list<int> llvmArgIndices = [0, 2, 3, 1];
865+
list<int> llvmArgIndices = [0, 2, 3, 1, -1];
861866
}
862867

863868
/// Create a call to Masked Store intrinsic.

mlir/test/Target/LLVMIR/Import/intrinsic.ll

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,8 @@ define void @masked_load_store_intrinsics(ptr %vec, <7 x i1> %mask) {
409409
%1 = call <7 x float> @llvm.masked.load.v7f32.p0(ptr %vec, i32 1, <7 x i1> %mask, <7 x float> undef)
410410
; CHECK: %[[VAL2:.+]] = llvm.intr.masked.load %[[VEC]], %[[MASK]], %[[VAL1]] {alignment = 4 : i32}
411411
%2 = call <7 x float> @llvm.masked.load.v7f32.p0(ptr %vec, i32 4, <7 x i1> %mask, <7 x float> %1)
412+
; CHECK: %[[VAL3:.+]] = llvm.intr.masked.load %[[VEC]], %[[MASK]], %[[VAL1]] {alignment = 4 : i32, nontemporal}
413+
%3 = call <7 x float> @llvm.masked.load.v7f32.p0(ptr %vec, i32 4, <7 x i1> %mask, <7 x float> %1), !nontemporal !{i32 1}
412414
; CHECK: llvm.intr.masked.store %[[VAL2]], %[[VEC]], %[[MASK]] {alignment = 8 : i32}
413415
; CHECK-SAME: vector<7xf32>, vector<7xi1> into !llvm.ptr
414416
call void @llvm.masked.store.v7f32.p0(<7 x float> %2, ptr %vec, i32 8, <7 x i1> %mask)

mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,8 +417,11 @@ llvm.func @masked_load_store_intrinsics(%A: !llvm.ptr, %mask: vector<7xi1>) {
417417
// CHECK: call <7 x float> @llvm.masked.load.v7f32.p0(ptr %{{.*}}, i32 1, <7 x i1> %{{.*}}, <7 x float> poison)
418418
%a = llvm.intr.masked.load %A, %mask { alignment = 1: i32} :
419419
(!llvm.ptr, vector<7xi1>) -> vector<7xf32>
420+
// CHECK: call <7 x float> @llvm.masked.load.v7f32.p0(ptr %{{.*}}, i32 1, <7 x i1> %{{.*}}, <7 x float> poison), !nontemporal !1
421+
%b = llvm.intr.masked.load %A, %mask { alignment = 1: i32, nontemporal} :
422+
(!llvm.ptr, vector<7xi1>) -> vector<7xf32>
420423
// CHECK: call <7 x float> @llvm.masked.load.v7f32.p0(ptr %{{.*}}, i32 1, <7 x i1> %{{.*}}, <7 x float> %{{.*}})
421-
%b = llvm.intr.masked.load %A, %mask, %a { alignment = 1: i32} :
424+
%c = llvm.intr.masked.load %A, %mask, %a { alignment = 1: i32} :
422425
(!llvm.ptr, vector<7xi1>, vector<7xf32>) -> vector<7xf32>
423426
// CHECK: call void @llvm.masked.store.v7f32.p0(<7 x float> %{{.*}}, ptr %0, i32 {{.*}}, <7 x i1> %{{.*}})
424427
llvm.intr.masked.store %b, %A, %mask { alignment = 1: i32} :

0 commit comments

Comments
 (0)