From a35b0535ce7a9002d4738d3a07d21fdd7179f25a Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Wed, 5 Jun 2024 15:14:43 +0000 Subject: [PATCH 1/4] [ROCDL] Add the global.atomic.fadd intrinsic in ROCDL --- mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 17 +++++++++++++++-- mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp | 20 ++++++++++++++++++++ mlir/test/Target/LLVMIR/rocdl.mlir | 9 +++++++++ 3 files changed, 44 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 1dabf5d7979b7..c8d4e4c03486e 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -165,7 +165,7 @@ def ROCDL_BallotOp : let summary = "Vote across thread group"; let description = [{ - Ballot provides a bit mask containing the 1-bit predicate value from each lane. + Ballot provides a bit mask containing the 1-bit predicate value from each lane. The nth bit of the result contains the 1 bit contributed by the nth warp lane. }]; @@ -516,7 +516,7 @@ def ROCDL_RawBufferAtomicCmpSwap : } //===---------------------------------------------------------------------===// -// MI-100 and MI-200 buffer atomic floating point add intrinsic +// MI-100, MI-200 and MI-300 global/buffer atomic floating point add intrinsic def ROCDL_RawBufferAtomicFAddOp : ROCDL_Op<"raw.buffer.atomic.fadd">, @@ -534,6 +534,19 @@ def ROCDL_RawBufferAtomicFAddOp : let hasCustomAssemblyFormat = 1; } +def ROCDL_GlobalAtomicFAddOp : + ROCDL_Op<"global.atomic.fadd">, + Arguments<(ins LLVM_Type:$ptr, + LLVM_Type:$vdata)>{ + string llvmBuilder = [{ + auto vdataType = moduleTranslation.convertType(op.getVdata().getType()); + auto ptrType = moduleTranslation.convertType(op.getPtr().getType()); + createIntrinsicCall(builder, + llvm::Intrinsic::amdgcn_global_atomic_fadd, {$ptr, $vdata}, {vdataType, ptrType, vdataType}); + }]; + let hasCustomAssemblyFormat = 1; +} + //===---------------------------------------------------------------------===// // Buffer atomic floating point max intrinsic. GFX9 does not support fp32. diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp index 65b770ae32610..34ebdb2ffd3d0 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp @@ -157,6 +157,26 @@ void RawBufferAtomicFAddOp::print(mlir::OpAsmPrinter &p) { p << " " << getOperands() << " : " << getVdata().getType(); } +// ::= +// `llvm.amdgcn.global.atomic.fadd.* %vdata, %ptr +ParseResult GlobalAtomicFAddOp::parse(OpAsmParser &parser, + OperationState &result) { + SmallVector ops; + Type type; + if (parser.parseOperandList(ops, 2) || parser.parseColonType(type)) + return failure(); + + auto ptrType = LLVM::LLVMPointerType::get(parser.getContext()); + if (parser.resolveOperands(ops, {ptrType, type}, parser.getNameLoc(), + result.operands)) + return failure(); + return success(); +} + +void GlobalAtomicFAddOp::print(mlir::OpAsmPrinter &p) { + p << " " << getOperands() << " : " << getVdata().getType(); +} + // ::= // `llvm.amdgcn.raw.buffer.atomic.fmax.* %vdata, %rsrc, %offset, // %soffset, %aux : result_type` diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index ce6b56d48437a..9d22b80748e14 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -494,6 +494,15 @@ llvm.func @rocdl.raw.buffer.atomic.f32(%rsrc : vector<4xi32>, llvm.return } +// CHECK-LABEL: rocdl.global.atomic +llvm.func @rocdl.global.atomic(%vdata0 : f32, %vdata1 : vector<2xf16>, %ptr : !llvm.ptr) { + // CHECK: call float @llvm.amdgcn.global.atomic.fadd.f32.p0.f32(ptr %{{.*}}, float %{{.*}} + rocdl.global.atomic.fadd %ptr, %vdata0: f32 + // CHECK: call <2 x half> @llvm.amdgcn.global.atomic.fadd.v2f16.p0.v2f16(ptr %{{.*}}, <2 x half> %{{.*}}) + rocdl.global.atomic.fadd %ptr, %vdata1: vector<2xf16> + llvm.return +} + llvm.func @rocdl.raw.buffer.atomic.i32(%rsrc : vector<4xi32>, %offset : i32, %soffset : i32, %vdata1 : i32) { From 9d9cc3d4961a00e418c6fe3640b16ad473181727 Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Wed, 5 Jun 2024 16:00:51 +0000 Subject: [PATCH 2/4] Address review feedback --- mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 7 ++++--- mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp | 20 -------------------- mlir/test/Target/LLVMIR/rocdl.mlir | 6 +++--- 3 files changed, 7 insertions(+), 26 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index c8d4e4c03486e..deadd6caeb7e2 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -342,6 +342,7 @@ def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1] //===---------------------------------------------------------------------===// def ROCDLBufferRsrc : LLVM_PointerInAddressSpace<8>; +def ROCDLGlobalPtr: LLVM_PointerInAddressSpace<1>; def ROCDL_MakeBufferRsrcOp : ROCDL_IntrOp<"make.buffer.rsrc", [], [0], [Pure], 1>, @@ -516,7 +517,7 @@ def ROCDL_RawBufferAtomicCmpSwap : } //===---------------------------------------------------------------------===// -// MI-100, MI-200 and MI-300 global/buffer atomic floating point add intrinsic +// gfx9x global/buffer atomic floating point add intrinsics def ROCDL_RawBufferAtomicFAddOp : ROCDL_Op<"raw.buffer.atomic.fadd">, @@ -536,7 +537,7 @@ def ROCDL_RawBufferAtomicFAddOp : def ROCDL_GlobalAtomicFAddOp : ROCDL_Op<"global.atomic.fadd">, - Arguments<(ins LLVM_Type:$ptr, + Arguments<(ins ROCDLGlobalPtr:$ptr, LLVM_Type:$vdata)>{ string llvmBuilder = [{ auto vdataType = moduleTranslation.convertType(op.getVdata().getType()); @@ -544,7 +545,7 @@ def ROCDL_GlobalAtomicFAddOp : createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_global_atomic_fadd, {$ptr, $vdata}, {vdataType, ptrType, vdataType}); }]; - let hasCustomAssemblyFormat = 1; + let assemblyFormat = "operands attr-dict `:` type($vdata)"; } //===---------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp index 34ebdb2ffd3d0..65b770ae32610 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp @@ -157,26 +157,6 @@ void RawBufferAtomicFAddOp::print(mlir::OpAsmPrinter &p) { p << " " << getOperands() << " : " << getVdata().getType(); } -// ::= -// `llvm.amdgcn.global.atomic.fadd.* %vdata, %ptr -ParseResult GlobalAtomicFAddOp::parse(OpAsmParser &parser, - OperationState &result) { - SmallVector ops; - Type type; - if (parser.parseOperandList(ops, 2) || parser.parseColonType(type)) - return failure(); - - auto ptrType = LLVM::LLVMPointerType::get(parser.getContext()); - if (parser.resolveOperands(ops, {ptrType, type}, parser.getNameLoc(), - result.operands)) - return failure(); - return success(); -} - -void GlobalAtomicFAddOp::print(mlir::OpAsmPrinter &p) { - p << " " << getOperands() << " : " << getVdata().getType(); -} - // ::= // `llvm.amdgcn.raw.buffer.atomic.fmax.* %vdata, %rsrc, %offset, // %soffset, %aux : result_type` diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index 9d22b80748e14..c940d01a0a614 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -495,10 +495,10 @@ llvm.func @rocdl.raw.buffer.atomic.f32(%rsrc : vector<4xi32>, } // CHECK-LABEL: rocdl.global.atomic -llvm.func @rocdl.global.atomic(%vdata0 : f32, %vdata1 : vector<2xf16>, %ptr : !llvm.ptr) { - // CHECK: call float @llvm.amdgcn.global.atomic.fadd.f32.p0.f32(ptr %{{.*}}, float %{{.*}} +llvm.func @rocdl.global.atomic(%vdata0 : f32, %vdata1 : vector<2xf16>, %ptr : !llvm.ptr<1>) { + // CHECK: call float @llvm.amdgcn.global.atomic.fadd.f32.p1.f32(ptr addrspace(1) %{{.*}}, float %{{.*}} rocdl.global.atomic.fadd %ptr, %vdata0: f32 - // CHECK: call <2 x half> @llvm.amdgcn.global.atomic.fadd.v2f16.p0.v2f16(ptr %{{.*}}, <2 x half> %{{.*}}) + // CHECK: call <2 x half> @llvm.amdgcn.global.atomic.fadd.v2f16.p1.v2f16(ptr addrspace(1) %{{.*}}, <2 x half> %{{.*}}) rocdl.global.atomic.fadd %ptr, %vdata1: vector<2xf16> llvm.return } From d984a8367a3bc2814bfe3adbcb46bd600601949c Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Thu, 6 Jun 2024 11:32:03 +0000 Subject: [PATCH 3/4] Change the intrinsic tblgen to return the old value from fadd --- mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index deadd6caeb7e2..dd8c2f1a27090 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -535,17 +535,12 @@ def ROCDL_RawBufferAtomicFAddOp : let hasCustomAssemblyFormat = 1; } -def ROCDL_GlobalAtomicFAddOp : - ROCDL_Op<"global.atomic.fadd">, - Arguments<(ins ROCDLGlobalPtr:$ptr, - LLVM_Type:$vdata)>{ - string llvmBuilder = [{ - auto vdataType = moduleTranslation.convertType(op.getVdata().getType()); - auto ptrType = moduleTranslation.convertType(op.getPtr().getType()); - createIntrinsicCall(builder, - llvm::Intrinsic::amdgcn_global_atomic_fadd, {$ptr, $vdata}, {vdataType, ptrType, vdataType}); - }]; - let assemblyFormat = "operands attr-dict `:` type($vdata)"; +def ROCDL_GlobalAtomicFAddOp: + ROCDL_IntrOp<"global.atomic.fadd", + [0], [0, 1], [AllTypesMatch<["res", "vdata"]>], 1>, + Arguments<(ins ROCDLGlobalPtr:$ptr, + LLVM_Type:$vdata)>{ + let assemblyFormat = "operands attr-dict `:` type($res)"; } //===---------------------------------------------------------------------===// From 6eec7787bcc7bd2679a27f3090b8de6c2e7c572c Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Fri, 7 Jun 2024 08:45:32 +0000 Subject: [PATCH 4/4] Address review feedback - 2 --- mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index dd8c2f1a27090..e656ce8f62313 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -165,7 +165,7 @@ def ROCDL_BallotOp : let summary = "Vote across thread group"; let description = [{ - Ballot provides a bit mask containing the 1-bit predicate value from each lane. + Ballot provides a bit mask containing the 1-bit predicate value from each lane. The nth bit of the result contains the 1 bit contributed by the nth warp lane. }];