@@ -18279,65 +18279,216 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
18279
18279
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32:
18280
18280
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64:
18281
18281
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32:
18282
- case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64: {
18282
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64:
18283
+ case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12:
18284
+ case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64_gfx12:
18285
+ case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12:
18286
+ case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12:
18287
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12:
18288
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12:
18289
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12:
18290
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12:
18291
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32_gfx12:
18292
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64_gfx12:
18293
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12:
18294
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64_gfx12:
18295
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12:
18296
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w64_gfx12:
18297
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12:
18298
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w64_gfx12:
18299
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12:
18300
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w64_gfx12:
18301
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12:
18302
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w64_gfx12:
18303
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w32_gfx12:
18304
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w64_gfx12:
18305
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32:
18306
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w64:
18307
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32:
18308
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w64:
18309
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32:
18310
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w64:
18311
+ case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32:
18312
+ case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w64:
18313
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32:
18314
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w64:
18315
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w32:
18316
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w64:
18317
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w32:
18318
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w64:
18319
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32:
18320
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w64:
18321
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32:
18322
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w64:
18323
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32:
18324
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w64:
18325
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32:
18326
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w64: {
18283
18327
18284
18328
// These operations perform a matrix multiplication and accumulation of
18285
18329
// the form:
18286
18330
// D = A * B + C
18287
- // The return type always matches the type of matrix C.
18288
- unsigned ArgForMatchingRetType;
18331
+ // We need to specify one type for matrices AB and one for matrices CD.
18332
+ // Sparse matrix operations can have different types for A and B as well as
18333
+ // an additional type for sparsity index.
18334
+ // Destination type should be put before types used for source operands.
18335
+ SmallVector<unsigned, 2> ArgsForMatchingMatrixTypes;
18336
+ // On GFX12, the intrinsics with 16-bit accumulator use a packed layout.
18337
+ // There is no need for the variable opsel argument, so always set it to
18338
+ // "false".
18339
+ bool AppendFalseForOpselArg = false;
18289
18340
unsigned BuiltinWMMAOp;
18290
18341
18291
18342
switch (BuiltinID) {
18292
18343
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w32:
18293
18344
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w64:
18294
- ArgForMatchingRetType = 2;
18345
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12:
18346
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12:
18347
+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18295
18348
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_f16;
18296
18349
break;
18297
18350
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32:
18298
18351
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64:
18299
- ArgForMatchingRetType = 2;
18352
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12:
18353
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12:
18354
+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18300
18355
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_bf16;
18301
18356
break;
18357
+ case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12:
18358
+ case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12:
18359
+ AppendFalseForOpselArg = true;
18360
+ LLVM_FALLTHROUGH;
18302
18361
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w32:
18303
18362
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w64:
18304
- ArgForMatchingRetType = 2;
18363
+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18305
18364
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f16_16x16x16_f16;
18306
18365
break;
18366
+ case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12:
18367
+ case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64_gfx12:
18368
+ AppendFalseForOpselArg = true;
18369
+ LLVM_FALLTHROUGH;
18307
18370
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32:
18308
18371
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64:
18309
- ArgForMatchingRetType = 2;
18372
+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18310
18373
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_bf16_16x16x16_bf16;
18311
18374
break;
18312
18375
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_tied_w32:
18313
18376
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_tied_w64:
18314
- ArgForMatchingRetType = 2;
18377
+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18315
18378
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f16_16x16x16_f16_tied;
18316
18379
break;
18317
18380
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_tied_w32:
18318
18381
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_tied_w64:
18319
- ArgForMatchingRetType = 2;
18382
+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18320
18383
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_bf16_16x16x16_bf16_tied;
18321
18384
break;
18322
18385
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32:
18323
18386
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64:
18324
- ArgForMatchingRetType = 4;
18387
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12:
18388
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64_gfx12:
18389
+ ArgsForMatchingMatrixTypes = {4, 1}; // CD, AB
18325
18390
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x16_iu8;
18326
18391
break;
18327
18392
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32:
18328
18393
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64:
18329
- ArgForMatchingRetType = 4;
18394
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32_gfx12:
18395
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64_gfx12:
18396
+ ArgsForMatchingMatrixTypes = {4, 1}; // CD, AB
18330
18397
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x16_iu4;
18331
18398
break;
18399
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12:
18400
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w64_gfx12:
18401
+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18402
+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_fp8_fp8;
18403
+ break;
18404
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12:
18405
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w64_gfx12:
18406
+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18407
+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_fp8_bf8;
18408
+ break;
18409
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12:
18410
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w64_gfx12:
18411
+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18412
+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_bf8_fp8;
18413
+ break;
18414
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12:
18415
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w64_gfx12:
18416
+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18417
+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_bf8_bf8;
18418
+ break;
18419
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w32_gfx12:
18420
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w64_gfx12:
18421
+ ArgsForMatchingMatrixTypes = {4, 1}; // CD, AB
18422
+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x32_iu4;
18423
+ break;
18424
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32:
18425
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w64:
18426
+ ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18427
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_f16;
18428
+ break;
18429
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32:
18430
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w64:
18431
+ ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18432
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_bf16;
18433
+ break;
18434
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32:
18435
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w64:
18436
+ ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18437
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f16_16x16x32_f16;
18438
+ break;
18439
+ case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32:
18440
+ case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w64:
18441
+ ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18442
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_bf16_16x16x32_bf16;
18443
+ break;
18444
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32:
18445
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w64:
18446
+ ArgsForMatchingMatrixTypes = {4, 1, 3, 5}; // CD, A, B, Index
18447
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_i32_16x16x32_iu8;
18448
+ break;
18449
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w32:
18450
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w64:
18451
+ ArgsForMatchingMatrixTypes = {4, 1, 3, 5}; // CD, A, B, Index
18452
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_i32_16x16x32_iu4;
18453
+ break;
18454
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w32:
18455
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w64:
18456
+ ArgsForMatchingMatrixTypes = {4, 1, 3, 5}; // CD, A, B, Index
18457
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_i32_16x16x64_iu4;
18458
+ break;
18459
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32:
18460
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w64:
18461
+ ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18462
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_fp8_fp8;
18463
+ break;
18464
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32:
18465
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w64:
18466
+ ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18467
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_fp8_bf8;
18468
+ break;
18469
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32:
18470
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w64:
18471
+ ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18472
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_bf8_fp8;
18473
+ break;
18474
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32:
18475
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w64:
18476
+ ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18477
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_bf8_bf8;
18478
+ break;
18332
18479
}
18333
18480
18334
18481
SmallVector<Value *, 6> Args;
18335
18482
for (int i = 0, e = E->getNumArgs(); i != e; ++i)
18336
18483
Args.push_back(EmitScalarExpr(E->getArg(i)));
18484
+ if (AppendFalseForOpselArg)
18485
+ Args.push_back(Builder.getFalse());
18337
18486
18338
- Function *F = CGM.getIntrinsic(BuiltinWMMAOp,
18339
- {Args[ArgForMatchingRetType]->getType()});
18487
+ SmallVector<llvm::Type *, 6> ArgTypes;
18488
+ for (auto ArgIdx : ArgsForMatchingMatrixTypes)
18489
+ ArgTypes.push_back(Args[ArgIdx]->getType());
18340
18490
18491
+ Function *F = CGM.getIntrinsic(BuiltinWMMAOp, ArgTypes);
18341
18492
return Builder.CreateCall(F, Args);
18342
18493
}
18343
18494
0 commit comments