Skip to content

Commit 21cef8a

Browse files
[MLIR][NVVM] Add support for tcgen05.{ld, st} (#130728)
This commit adds support for tcgen05.{ld, st} to the NVVM Dialect with tests under tcgen05-ld.mlir and tcgen05-st.mlir respectively
1 parent 7811075 commit 21cef8a

File tree

5 files changed

+1018
-0
lines changed

5 files changed

+1018
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2929,6 +2929,208 @@ def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp"> {
29292929
}];
29302930
}
29312931

2932+
//===----------------------------------------------------------------------===//
2933+
// NVVM tcgen05 LdSt Shape Attr
2934+
//===----------------------------------------------------------------------===//
2935+
2936+
def Tcgen05LdStShape16x64b: I32EnumAttrCase<"SHAPE_16X64B", 0, "shape_16x64b">;
2937+
def Tcgen05LdStShape16x128b: I32EnumAttrCase<"SHAPE_16X128B", 1, "shape_16x128b">;
2938+
def Tcgen05LdStShape16x256b: I32EnumAttrCase<"SHAPE_16X256B", 2, "shape_16x256b">;
2939+
def Tcgen05LdStShape32x32b: I32EnumAttrCase<"SHAPE_32X32B", 3, "shape_32x32b">;
2940+
def Tcgen05LdStShape16x32bx2: I32EnumAttrCase<"SHAPE_16X32BX2", 4, "shape_16x32bx2">;
2941+
2942+
def Tcgen05LdStShape: I32EnumAttr<
2943+
"Tcgen05LdStShape",
2944+
"",
2945+
[Tcgen05LdStShape16x64b, Tcgen05LdStShape16x128b, Tcgen05LdStShape16x256b,
2946+
Tcgen05LdStShape32x32b, Tcgen05LdStShape16x32bx2]
2947+
> {
2948+
let cppNamespace = "::mlir::NVVM";
2949+
let genSpecializedAttr = 0;
2950+
}
2951+
2952+
def Tcgen05LdStShapeAttr: EnumAttr<NVVM_Dialect, Tcgen05LdStShape, "tcgen05_ldst_shape"> {
2953+
let assemblyFormat = "`<` $value `>`";
2954+
}
2955+
2956+
//===----------------------------------------------------------------------===//
2957+
// NVVM tcgen05.ld Op
2958+
//===----------------------------------------------------------------------===//
2959+
2960+
def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld"> {
2961+
let summary = "tensor memory load instructions";
2962+
let arguments = (ins
2963+
// Attributes
2964+
UnitAttr:$pack,
2965+
Tcgen05LdStShapeAttr:$shape,
2966+
// Arguments
2967+
LLVM_PointerTensor:$tmemAddr,
2968+
Optional<I64>:$offset
2969+
);
2970+
2971+
let results = (outs AnyTypeOf<[I32, VectorOfLengthAndType<
2972+
[2, 4, 8, 16, 32, 64, 128], [I32]>]>:$res);
2973+
2974+
let assemblyFormat = [{
2975+
$tmemAddr (`,` $offset^)? (`pack` $pack^)? attr-dict `:` type($res)
2976+
}];
2977+
2978+
let description = [{
2979+
Instruction `tcgen05.ld` asynchronously loads data from the Tensor Memory at
2980+
the location specified by the 32-bit address operand `tmemAddr` into the
2981+
destination register `res`, collectively across all threads of the warps.
2982+
2983+
The `shape` and the `num` attribute together determines the total
2984+
dimension of the data which is loaded from the Tensor Memory. The `shape`
2985+
attribute indicates the base dimension of data to be accessed as described
2986+
in the Data Movement Shape. The `num` attribute indicates the repeat
2987+
factor on the base dimension resulting in the total dimension of the data
2988+
that is accessed.
2989+
2990+
The shape `16x32bx2` performs two accesses into Tensor Memory of the shape
2991+
`16x32b`. The base address of the first access is specified by `tmemAddr`
2992+
and the base address of the second access is specified by
2993+
`tmemAddr + offset`, where `offset` is an immediate argument.
2994+
2995+
The unit attribute `pack` can be used to pack two 16-bit
2996+
elements from adjacent columns into a single 32-bit element during the load.
2997+
2998+
The following table describes the size of the vector for various combinations
2999+
of `num` and `shape` attributes
3000+
|=====================================================================|
3001+
| num/shape | 16x32bx2/16x64b/32x32b | 16x128b | 16x256b |
3002+
|=====================================================================|
3003+
| x1 | 1 | 2 | 4 |
3004+
| x2 | 2 | 4 | 8 |
3005+
| x4 | 4 | 8 | 16 |
3006+
| x8 | 8 | 16 | 32 |
3007+
| x16 | 16 | 32 | 64 |
3008+
| x32 | 32 | 64 | 128 |
3009+
| x64 | 64 | 128 | NA |
3010+
| x128 | 128 | NA | NA |
3011+
|=====================================================================|
3012+
3013+
Example:
3014+
```mlir
3015+
nvvm.tcgen05.ld %tmemAddr, %offset pack {
3016+
shape = #nvvm.tcgen05_ldst_shape<shape_16x32bx2>,
3017+
} : <2xi32>
3018+
```
3019+
3020+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st)
3021+
}];
3022+
3023+
let hasVerifier = 1;
3024+
3025+
string llvmBuilder = [{
3026+
llvm::LLVMContext &Context = moduleTranslation.getLLVMContext();
3027+
auto Pack = llvm::ConstantInt::get(Context, llvm::APInt(1, $pack));
3028+
3029+
unsigned num = $_resultType->isVectorTy()
3030+
? llvm::cast<llvm::VectorType>($_resultType)
3031+
->getElementCount()
3032+
.getFixedValue()
3033+
: 1;
3034+
3035+
auto ID = getTcgen05LdIntrinsicID($shape, num);
3036+
if (ID == llvm::Intrinsic::not_intrinsic)
3037+
llvm::report_fatal_error("unknow intrinsic signature for tcgen05.ld");
3038+
3039+
if ($offset)
3040+
$res = createIntrinsicCall(builder, ID, {$tmemAddr, $offset, Pack});
3041+
else
3042+
$res = createIntrinsicCall(builder, ID, {$tmemAddr, Pack});
3043+
}];
3044+
}
3045+
3046+
//===----------------------------------------------------------------------===//
3047+
// NVVM tcgen05.st Op
3048+
//===----------------------------------------------------------------------===//
3049+
3050+
def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st"> {
3051+
let summary = "tensor memory store instructions";
3052+
let arguments = (ins
3053+
// Attributes
3054+
UnitAttr:$unpack,
3055+
Tcgen05LdStShapeAttr:$shape,
3056+
// Arguments
3057+
LLVM_PointerTensor:$tmemAddr,
3058+
AnyTypeOf<[I32, VectorOfLengthAndType<
3059+
[2, 4, 8, 16, 32, 64, 128], [I32]>]>:$val,
3060+
Optional<I64>:$offset
3061+
);
3062+
3063+
let assemblyFormat = [{
3064+
$tmemAddr `,` $val (`,` $offset^)? (`unpack` $unpack^)? attr-dict `:` type($val)
3065+
}];
3066+
3067+
let description = [{
3068+
Instruction `tcgen05.st` asynchronously stores data from the source register `r`
3069+
into the Tensor Memory at the location specified by the 32-bit address operand
3070+
`tmemAddr`, collectively across all threads of the warps.
3071+
3072+
The `shape` and the `num` attribute together determines the total dimension of
3073+
the data which is stored to the Tensor Memory. The `shape` indicates the base
3074+
dimension of data to be accessed. The `num` attribute indicates the repeat
3075+
factor on the base dimension resulting in the total dimension of the data that
3076+
is accessed.
3077+
3078+
The shape `16x32bx2` performs two accesses into Tensor Memory of the shape
3079+
`16x32b`. The base address of the first access is specified by `tmemAddr`
3080+
and the base address of the second access is specified by
3081+
`tmemAddr + offset`, where `offset` is an immediate argument.
3082+
3083+
The unit attribute `unpack` can be used to unpack a 32-bit element
3084+
in the register into two 16-bit elements and store them in adjacent columns.
3085+
3086+
The following table describes the size of the vector for various combinations
3087+
of `num` and `shape` attributes
3088+
|=====================================================================|
3089+
| num/shape | 16x32bx2/16x64b/32x32b | 16x128b | 16x256b |
3090+
|=====================================================================|
3091+
| x1 | 1 | 2 | 4 |
3092+
| x2 | 2 | 4 | 8 |
3093+
| x4 | 4 | 8 | 16 |
3094+
| x8 | 8 | 16 | 32 |
3095+
| x16 | 16 | 32 | 64 |
3096+
| x32 | 32 | 64 | 128 |
3097+
| x64 | 64 | 128 | NA |
3098+
| x128 | 128 | NA | NA |
3099+
|=====================================================================|
3100+
3101+
Example:
3102+
```mlir
3103+
nvvm.tcgen05.st %tmemAddr, %val, %offset unpack {
3104+
shape = #nvvm.tcgen05_ldst_shape<shape_16x32bx2>,
3105+
} : <2xi32>
3106+
```
3107+
3108+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st)
3109+
}];
3110+
3111+
string llvmBuilder = [{
3112+
llvm::LLVMContext &Context = moduleTranslation.getLLVMContext();
3113+
auto Unpack = llvm::ConstantInt::get(Context, llvm::APInt(1, $unpack));
3114+
3115+
auto valTy = $val->getType();
3116+
uint32_t num = valTy->isVectorTy() ? llvm::cast<llvm::VectorType>(valTy)
3117+
->getElementCount()
3118+
.getFixedValue()
3119+
: 1;
3120+
3121+
auto ID = getTcgen05StIntrinsicID($shape, num);
3122+
if (ID == llvm::Intrinsic::not_intrinsic)
3123+
llvm::report_fatal_error("unknow intrinsic signature for tcgen05.st");
3124+
3125+
if ($offset)
3126+
createIntrinsicCall(builder, ID, {$tmemAddr, $offset, $val, Unpack});
3127+
else
3128+
createIntrinsicCall(builder, ID, {$tmemAddr, $val, Unpack});
3129+
}];
3130+
3131+
let hasVerifier = 1;
3132+
}
3133+
29323134
//===----------------------------------------------------------------------===//
29333135
// NVVM target attribute.
29343136
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "llvm/IR/Function.h"
3636
#include "llvm/IR/Type.h"
3737
#include "llvm/Support/Casting.h"
38+
#include "llvm/Support/FormatVariadic.h"
3839
#include "llvm/Support/SourceMgr.h"
3940
#include "llvm/Support/raw_ostream.h"
4041
#include <cassert>
@@ -1387,6 +1388,51 @@ llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
13871388
llvm_unreachable("Invalid shape in tcgen05 cp Op");
13881389
}
13891390

1391+
// Returns the valid vector length for a given shape and vector length, the
1392+
// function models the table mentioned in the tcgen05.{ld, st} Op description
1393+
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape Shape,
1394+
unsigned VecLen) {
1395+
if (Shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
1396+
return VecLen >= 2;
1397+
if (Shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
1398+
return VecLen >= 4;
1399+
return true;
1400+
}
1401+
1402+
LogicalResult Tcgen05LdOp::verify() {
1403+
LogicalResult Result = success();
1404+
if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
1405+
Result = emitError("shape 16x32bx2 requires offset argument");
1406+
1407+
auto ResTy = getRes().getType();
1408+
unsigned ResLen = isa<VectorType>(ResTy)
1409+
? llvm::cast<VectorType>(ResTy).getNumElements()
1410+
: 1;
1411+
if (!isValidVectorLength(getShape(), ResLen))
1412+
Result = emitError(llvm::formatv("invalid result type length {0} for shape "
1413+
"{1} in tcgen05.ld Op",
1414+
ResLen, stringifyEnum(getShape())));
1415+
1416+
return Result;
1417+
}
1418+
1419+
LogicalResult Tcgen05StOp::verify() {
1420+
LogicalResult Result = success();
1421+
if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
1422+
Result = emitError("shape 16x32bx2 requires offset argument");
1423+
1424+
auto ValTy = getVal().getType();
1425+
unsigned ValLen = isa<VectorType>(ValTy)
1426+
? llvm::cast<VectorType>(ValTy).getNumElements()
1427+
: 1;
1428+
if (!isValidVectorLength(getShape(), ValLen))
1429+
Result = emitError(llvm::formatv("invalid input length {0} for shape "
1430+
"{1} in tcgen05.st Op",
1431+
ValLen, stringifyEnum(getShape())));
1432+
1433+
return Result;
1434+
}
1435+
13901436
/// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
13911437
/// have ConstantRangeAttr.
13921438
static void nvvmInferResultRanges(Operation *op, Value result,

mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,112 @@ static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy,
170170
llvm_unreachable("Unsupported proxy kinds");
171171
}
172172

173+
#define TCGEN05LD(SHAPE, NUM) llvm::Intrinsic::nvvm_tcgen05_ld_##SHAPE##_##NUM
174+
175+
static llvm::Intrinsic::ID
176+
getTcgen05LdIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num) {
177+
llvm::Intrinsic::ID Shape16x64b[] = {
178+
TCGEN05LD(16x64b, x1), TCGEN05LD(16x64b, x2), TCGEN05LD(16x64b, x4),
179+
TCGEN05LD(16x64b, x8), TCGEN05LD(16x64b, x16), TCGEN05LD(16x64b, x32),
180+
TCGEN05LD(16x64b, x64), TCGEN05LD(16x64b, x128),
181+
};
182+
183+
llvm::Intrinsic::ID Shape16x128b[] = {
184+
TCGEN05LD(16x128b, x1), TCGEN05LD(16x128b, x2), TCGEN05LD(16x128b, x4),
185+
TCGEN05LD(16x128b, x8), TCGEN05LD(16x128b, x16), TCGEN05LD(16x128b, x32),
186+
TCGEN05LD(16x128b, x64),
187+
};
188+
189+
llvm::Intrinsic::ID Shape16x256b[] = {
190+
TCGEN05LD(16x256b, x1), TCGEN05LD(16x256b, x2), TCGEN05LD(16x256b, x4),
191+
TCGEN05LD(16x256b, x8), TCGEN05LD(16x256b, x16), TCGEN05LD(16x256b, x32),
192+
};
193+
194+
llvm::Intrinsic::ID Shape16x32bx2[] = {
195+
TCGEN05LD(16x32bx2, x1), TCGEN05LD(16x32bx2, x2),
196+
TCGEN05LD(16x32bx2, x4), TCGEN05LD(16x32bx2, x8),
197+
TCGEN05LD(16x32bx2, x16), TCGEN05LD(16x32bx2, x32),
198+
TCGEN05LD(16x32bx2, x64), TCGEN05LD(16x32bx2, x128),
199+
};
200+
201+
llvm::Intrinsic::ID Shape32x32b[] = {
202+
TCGEN05LD(32x32b, x1), TCGEN05LD(32x32b, x2), TCGEN05LD(32x32b, x4),
203+
TCGEN05LD(32x32b, x8), TCGEN05LD(32x32b, x16), TCGEN05LD(32x32b, x32),
204+
TCGEN05LD(32x32b, x64), TCGEN05LD(32x32b, x128),
205+
};
206+
207+
// `num` contains the length of vector and log2 of `num` returns the index
208+
// into the shape array
209+
unsigned Idx = std::log2(num);
210+
211+
switch (shape) {
212+
case NVVM::Tcgen05LdStShape::SHAPE_16X64B:
213+
return Shape16x64b[Idx];
214+
case NVVM::Tcgen05LdStShape::SHAPE_16X128B:
215+
return Shape16x128b[Idx - 1];
216+
case NVVM::Tcgen05LdStShape::SHAPE_16X256B:
217+
return Shape16x256b[Idx - 2];
218+
case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
219+
return Shape32x32b[Idx];
220+
case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
221+
return Shape16x32bx2[Idx];
222+
}
223+
llvm_unreachable("unhandled tcgen05.ld lowering");
224+
}
225+
226+
#define TCGEN05ST(SHAPE, NUM) llvm::Intrinsic::nvvm_tcgen05_st_##SHAPE##_##NUM
227+
228+
static llvm::Intrinsic::ID
229+
getTcgen05StIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num) {
230+
llvm::Intrinsic::ID Shape16x64b[] = {
231+
TCGEN05ST(16x64b, x1), TCGEN05ST(16x64b, x2), TCGEN05ST(16x64b, x4),
232+
TCGEN05ST(16x64b, x8), TCGEN05ST(16x64b, x16), TCGEN05ST(16x64b, x32),
233+
TCGEN05ST(16x64b, x64), TCGEN05ST(16x64b, x128),
234+
};
235+
236+
llvm::Intrinsic::ID Shape16x128b[] = {
237+
TCGEN05ST(16x128b, x1), TCGEN05ST(16x128b, x2), TCGEN05ST(16x128b, x4),
238+
TCGEN05ST(16x128b, x8), TCGEN05ST(16x128b, x16), TCGEN05ST(16x128b, x32),
239+
TCGEN05ST(16x128b, x64),
240+
};
241+
242+
llvm::Intrinsic::ID Shape16x256b[] = {
243+
TCGEN05ST(16x256b, x1), TCGEN05ST(16x256b, x2), TCGEN05ST(16x256b, x4),
244+
TCGEN05ST(16x256b, x8), TCGEN05ST(16x256b, x16), TCGEN05ST(16x256b, x32),
245+
};
246+
247+
llvm::Intrinsic::ID Shape16x32bx2[] = {
248+
TCGEN05ST(16x32bx2, x1), TCGEN05ST(16x32bx2, x2),
249+
TCGEN05ST(16x32bx2, x4), TCGEN05ST(16x32bx2, x8),
250+
TCGEN05ST(16x32bx2, x16), TCGEN05ST(16x32bx2, x32),
251+
TCGEN05ST(16x32bx2, x64), TCGEN05ST(16x32bx2, x128),
252+
};
253+
254+
llvm::Intrinsic::ID Shape32x32b[] = {
255+
TCGEN05ST(32x32b, x1), TCGEN05ST(32x32b, x2), TCGEN05ST(32x32b, x4),
256+
TCGEN05ST(32x32b, x8), TCGEN05ST(32x32b, x16), TCGEN05ST(32x32b, x32),
257+
TCGEN05ST(32x32b, x64), TCGEN05ST(32x32b, x128),
258+
};
259+
260+
// `num` contains the length of vector and log2 of `num` returns the index
261+
// into the shape array
262+
unsigned Idx = std::log2(num);
263+
264+
switch (shape) {
265+
case NVVM::Tcgen05LdStShape::SHAPE_16X64B:
266+
return Shape16x64b[Idx];
267+
case NVVM::Tcgen05LdStShape::SHAPE_16X128B:
268+
return Shape16x128b[Idx - 1];
269+
case NVVM::Tcgen05LdStShape::SHAPE_16X256B:
270+
return Shape16x256b[Idx - 2];
271+
case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
272+
return Shape32x32b[Idx];
273+
case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
274+
return Shape16x32bx2[Idx];
275+
}
276+
llvm_unreachable("unhandled tcgen05.st lowering");
277+
}
278+
173279
namespace {
174280
/// Implementation of the dialect interface that converts operations belonging
175281
/// to the NVVM dialect to LLVM IR.

0 commit comments

Comments
 (0)