Skip to content

Commit ed48280

Browse files
mbrkusaninpetar-avramovicpiotrAMD
authored andcommitted
[AMDGPU] Add GFX12 WMMA and SWMMAC instructions (#77795)
Co-authored-by: Petar Avramovic <[email protected]> Co-authored-by: Piotr Sobczak <[email protected]>
1 parent aa4cb0e commit ed48280

File tree

65 files changed

+17708
-111
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+17708
-111
lines changed

clang/include/clang/Basic/BuiltinsAMDGPU.def

+62
Original file line numberDiff line numberDiff line change
@@ -436,5 +436,67 @@ TARGET_BUILTIN(__builtin_amdgcn_global_load_tr_i32, "ii*1", "nc", "gfx12-insts,w
436436
TARGET_BUILTIN(__builtin_amdgcn_global_load_tr_v4i16, "V4sV4s*1", "nc", "gfx12-insts,wavefrontsize64")
437437
TARGET_BUILTIN(__builtin_amdgcn_global_load_tr_v4f16, "V4hV4h*1", "nc", "gfx12-insts,wavefrontsize64")
438438

439+
//===----------------------------------------------------------------------===//
440+
// WMMA builtins.
441+
// Postfix w32 indicates the builtin requires wavefront size of 32.
442+
// Postfix w64 indicates the builtin requires wavefront size of 64.
443+
//
444+
// Some of these are very similar to their GFX11 counterparts, but they don't
445+
// require replication of the A,B matrices, so they use fewer vector elements.
446+
// Therefore, we add an "_gfx12" suffix to distinguish them from the existing
447+
// builtins.
448+
//===----------------------------------------------------------------------===//
449+
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12, "V8fV8hV8hV8f", "nc", "gfx12-insts,wavefrontsize32")
450+
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12, "V8fV8sV8sV8f", "nc", "gfx12-insts,wavefrontsize32")
451+
TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12, "V8hV8hV8hV8h", "nc", "gfx12-insts,wavefrontsize32")
452+
TARGET_BUILTIN(__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12, "V8sV8sV8sV8s", "nc", "gfx12-insts,wavefrontsize32")
453+
TARGET_BUILTIN(__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12, "V8iIbV2iIbV2iV8iIb", "nc", "gfx12-insts,wavefrontsize32")
454+
TARGET_BUILTIN(__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32_gfx12, "V8iIbiIbiV8iIb", "nc", "gfx12-insts,wavefrontsize32")
455+
// These are gfx12-only, but for consistency with the other WMMA variants we're
456+
// keeping the "_gfx12" suffix.
457+
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12, "V8fV2iV2iV8f", "nc", "gfx12-insts,wavefrontsize32")
458+
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12, "V8fV2iV2iV8f", "nc", "gfx12-insts,wavefrontsize32")
459+
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12, "V8fV2iV2iV8f", "nc", "gfx12-insts,wavefrontsize32")
460+
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12, "V8fV2iV2iV8f", "nc", "gfx12-insts,wavefrontsize32")
461+
TARGET_BUILTIN(__builtin_amdgcn_wmma_i32_16x16x32_iu4_w32_gfx12, "V8iIbV2iIbV2iV8iIb", "nc", "gfx12-insts,wavefrontsize32")
462+
463+
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12, "V4fV4hV4hV4f", "nc", "gfx12-insts,wavefrontsize64")
464+
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12, "V4fV4sV4sV4f", "nc", "gfx12-insts,wavefrontsize64")
465+
TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12, "V4hV4hV4hV4h", "nc", "gfx12-insts,wavefrontsize64")
466+
TARGET_BUILTIN(__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64_gfx12, "V4sV4sV4sV4s", "nc", "gfx12-insts,wavefrontsize64")
467+
TARGET_BUILTIN(__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64_gfx12, "V4iIbiIbiV4iIb", "nc", "gfx12-insts,wavefrontsize64")
468+
TARGET_BUILTIN(__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64_gfx12, "V4iIbiIbiV4iIb", "nc", "gfx12-insts,wavefrontsize64")
469+
// These are gfx12-only, but for consistency with the other WMMA variants we're
470+
// keeping the "_gfx12" suffix.
471+
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w64_gfx12, "V4fiiV4f", "nc", "gfx12-insts,wavefrontsize64")
472+
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w64_gfx12, "V4fiiV4f", "nc", "gfx12-insts,wavefrontsize64")
473+
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w64_gfx12, "V4fiiV4f", "nc", "gfx12-insts,wavefrontsize64")
474+
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w64_gfx12, "V4fiiV4f", "nc", "gfx12-insts,wavefrontsize64")
475+
TARGET_BUILTIN(__builtin_amdgcn_wmma_i32_16x16x32_iu4_w64_gfx12, "V4iIbiIbiV4iIb", "nc", "gfx12-insts,wavefrontsize64")
476+
477+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32, "V8fV8hV16hV8fs", "nc", "gfx12-insts,wavefrontsize32")
478+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32, "V8fV8sV16sV8fs", "nc", "gfx12-insts,wavefrontsize32")
479+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32, "V8hV8hV16hV8hs", "nc", "gfx12-insts,wavefrontsize32")
480+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32, "V8sV8sV16sV8ss", "nc", "gfx12-insts,wavefrontsize32")
481+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32, "V8iIbV2iIbV4iV8isIb", "nc", "gfx12-insts,wavefrontsize32")
482+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w32, "V8iIbiIbV2iV8isIb", "nc", "gfx12-insts,wavefrontsize32")
483+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w32, "V8iIbV2iIbV4iV8isIb", "nc", "gfx12-insts,wavefrontsize32")
484+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32, "V8fV2iV4iV8fs", "nc", "gfx12-insts,wavefrontsize32")
485+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32, "V8fV2iV4iV8fs", "nc", "gfx12-insts,wavefrontsize32")
486+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32, "V8fV2iV4iV8fs", "nc", "gfx12-insts,wavefrontsize32")
487+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32, "V8fV2iV4iV8fs", "nc", "gfx12-insts,wavefrontsize32")
488+
489+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_f16_w64, "V4fV4hV8hV4fs", "nc", "gfx12-insts,wavefrontsize64")
490+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w64, "V4fV4sV8sV4fs", "nc", "gfx12-insts,wavefrontsize64")
491+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f16_16x16x32_f16_w64, "V4hV4hV8hV4hs", "nc", "gfx12-insts,wavefrontsize64")
492+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w64, "V4sV4sV8sV4ss", "nc", "gfx12-insts,wavefrontsize64")
493+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w64, "V4iIbiIbV2iV4isIb", "nc", "gfx12-insts,wavefrontsize64")
494+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w64, "V4iIbiIbiV4isIb", "nc", "gfx12-insts,wavefrontsize64")
495+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w64, "V4iIbiIbV2iV4isIb", "nc", "gfx12-insts,wavefrontsize64")
496+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w64, "V4fiV2iV4fs", "nc", "gfx12-insts,wavefrontsize64")
497+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w64, "V4fiV2iV4fs", "nc", "gfx12-insts,wavefrontsize64")
498+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w64, "V4fiV2iV4fs", "nc", "gfx12-insts,wavefrontsize64")
499+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w64, "V4fiV2iV4fs", "nc", "gfx12-insts,wavefrontsize64")
500+
439501
#undef BUILTIN
440502
#undef TARGET_BUILTIN

clang/lib/CodeGen/CGBuiltin.cpp

+164-13
Original file line numberDiff line numberDiff line change
@@ -18279,65 +18279,216 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
1827918279
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32:
1828018280
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64:
1828118281
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32:
18282-
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64: {
18282+
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64:
18283+
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12:
18284+
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64_gfx12:
18285+
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12:
18286+
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12:
18287+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12:
18288+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12:
18289+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12:
18290+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12:
18291+
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32_gfx12:
18292+
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64_gfx12:
18293+
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12:
18294+
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64_gfx12:
18295+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12:
18296+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w64_gfx12:
18297+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12:
18298+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w64_gfx12:
18299+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12:
18300+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w64_gfx12:
18301+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12:
18302+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w64_gfx12:
18303+
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w32_gfx12:
18304+
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w64_gfx12:
18305+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32:
18306+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w64:
18307+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32:
18308+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w64:
18309+
case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32:
18310+
case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w64:
18311+
case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32:
18312+
case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w64:
18313+
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32:
18314+
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w64:
18315+
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w32:
18316+
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w64:
18317+
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w32:
18318+
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w64:
18319+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32:
18320+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w64:
18321+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32:
18322+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w64:
18323+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32:
18324+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w64:
18325+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32:
18326+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w64: {
1828318327

1828418328
// These operations perform a matrix multiplication and accumulation of
1828518329
// the form:
1828618330
// D = A * B + C
18287-
// The return type always matches the type of matrix C.
18288-
unsigned ArgForMatchingRetType;
18331+
// We need to specify one type for matrices AB and one for matrices CD.
18332+
// Sparse matrix operations can have different types for A and B as well as
18333+
// an additional type for sparsity index.
18334+
// Destination type should be put before types used for source operands.
18335+
SmallVector<unsigned, 2> ArgsForMatchingMatrixTypes;
18336+
// On GFX12, the intrinsics with 16-bit accumulator use a packed layout.
18337+
// There is no need for the variable opsel argument, so always set it to
18338+
// "false".
18339+
bool AppendFalseForOpselArg = false;
1828918340
unsigned BuiltinWMMAOp;
1829018341

1829118342
switch (BuiltinID) {
1829218343
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w32:
1829318344
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w64:
18294-
ArgForMatchingRetType = 2;
18345+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12:
18346+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12:
18347+
ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
1829518348
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_f16;
1829618349
break;
1829718350
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32:
1829818351
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64:
18299-
ArgForMatchingRetType = 2;
18352+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12:
18353+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12:
18354+
ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
1830018355
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_bf16;
1830118356
break;
18357+
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12:
18358+
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12:
18359+
AppendFalseForOpselArg = true;
18360+
LLVM_FALLTHROUGH;
1830218361
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w32:
1830318362
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w64:
18304-
ArgForMatchingRetType = 2;
18363+
ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
1830518364
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f16_16x16x16_f16;
1830618365
break;
18366+
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12:
18367+
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64_gfx12:
18368+
AppendFalseForOpselArg = true;
18369+
LLVM_FALLTHROUGH;
1830718370
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32:
1830818371
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64:
18309-
ArgForMatchingRetType = 2;
18372+
ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
1831018373
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_bf16_16x16x16_bf16;
1831118374
break;
1831218375
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_tied_w32:
1831318376
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_tied_w64:
18314-
ArgForMatchingRetType = 2;
18377+
ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
1831518378
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f16_16x16x16_f16_tied;
1831618379
break;
1831718380
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_tied_w32:
1831818381
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_tied_w64:
18319-
ArgForMatchingRetType = 2;
18382+
ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
1832018383
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_bf16_16x16x16_bf16_tied;
1832118384
break;
1832218385
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32:
1832318386
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64:
18324-
ArgForMatchingRetType = 4;
18387+
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12:
18388+
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64_gfx12:
18389+
ArgsForMatchingMatrixTypes = {4, 1}; // CD, AB
1832518390
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x16_iu8;
1832618391
break;
1832718392
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32:
1832818393
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64:
18329-
ArgForMatchingRetType = 4;
18394+
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32_gfx12:
18395+
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64_gfx12:
18396+
ArgsForMatchingMatrixTypes = {4, 1}; // CD, AB
1833018397
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x16_iu4;
1833118398
break;
18399+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12:
18400+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w64_gfx12:
18401+
ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18402+
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_fp8_fp8;
18403+
break;
18404+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12:
18405+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w64_gfx12:
18406+
ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18407+
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_fp8_bf8;
18408+
break;
18409+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12:
18410+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w64_gfx12:
18411+
ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18412+
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_bf8_fp8;
18413+
break;
18414+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12:
18415+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w64_gfx12:
18416+
ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18417+
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_bf8_bf8;
18418+
break;
18419+
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w32_gfx12:
18420+
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w64_gfx12:
18421+
ArgsForMatchingMatrixTypes = {4, 1}; // CD, AB
18422+
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x32_iu4;
18423+
break;
18424+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32:
18425+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w64:
18426+
ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18427+
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_f16;
18428+
break;
18429+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32:
18430+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w64:
18431+
ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18432+
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_bf16;
18433+
break;
18434+
case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32:
18435+
case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w64:
18436+
ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18437+
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f16_16x16x32_f16;
18438+
break;
18439+
case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32:
18440+
case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w64:
18441+
ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18442+
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_bf16_16x16x32_bf16;
18443+
break;
18444+
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32:
18445+
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w64:
18446+
ArgsForMatchingMatrixTypes = {4, 1, 3, 5}; // CD, A, B, Index
18447+
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_i32_16x16x32_iu8;
18448+
break;
18449+
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w32:
18450+
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w64:
18451+
ArgsForMatchingMatrixTypes = {4, 1, 3, 5}; // CD, A, B, Index
18452+
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_i32_16x16x32_iu4;
18453+
break;
18454+
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w32:
18455+
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w64:
18456+
ArgsForMatchingMatrixTypes = {4, 1, 3, 5}; // CD, A, B, Index
18457+
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_i32_16x16x64_iu4;
18458+
break;
18459+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32:
18460+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w64:
18461+
ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18462+
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_fp8_fp8;
18463+
break;
18464+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32:
18465+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w64:
18466+
ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18467+
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_fp8_bf8;
18468+
break;
18469+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32:
18470+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w64:
18471+
ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18472+
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_bf8_fp8;
18473+
break;
18474+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32:
18475+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w64:
18476+
ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18477+
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_bf8_bf8;
18478+
break;
1833218479
}
1833318480

1833418481
SmallVector<Value *, 6> Args;
1833518482
for (int i = 0, e = E->getNumArgs(); i != e; ++i)
1833618483
Args.push_back(EmitScalarExpr(E->getArg(i)));
18484+
if (AppendFalseForOpselArg)
18485+
Args.push_back(Builder.getFalse());
1833718486

18338-
Function *F = CGM.getIntrinsic(BuiltinWMMAOp,
18339-
{Args[ArgForMatchingRetType]->getType()});
18487+
SmallVector<llvm::Type *, 6> ArgTypes;
18488+
for (auto ArgIdx : ArgsForMatchingMatrixTypes)
18489+
ArgTypes.push_back(Args[ArgIdx]->getType());
1834018490

18491+
Function *F = CGM.getIntrinsic(BuiltinWMMAOp, ArgTypes);
1834118492
return Builder.CreateCall(F, Args);
1834218493
}
1834318494

0 commit comments

Comments
 (0)