Skip to content

Commit 516d6ed

Browse files
authored
[mlir][gpu] Add optional attributes of kernelModule and kernelFunc for outlining kernels. (#118861)
Adding optional attributes so we can specify the kernel function names and the kernel module names generated.
1 parent cd74eba commit 516d6ed

File tree

3 files changed

+157
-8
lines changed

3 files changed

+157
-8
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -803,7 +803,9 @@ def GPU_LaunchOp : GPU_Op<"launch", [
803803
Optional<Index>:$clusterSizeX,
804804
Optional<Index>:$clusterSizeY,
805805
Optional<Index>:$clusterSizeZ,
806-
Optional<I32>:$dynamicSharedMemorySize)>,
806+
Optional<I32>:$dynamicSharedMemorySize,
807+
OptionalAttr<SymbolRefAttr>:$kernelFunc,
808+
OptionalAttr<SymbolRefAttr>:$kernelModule)>,
807809
Results<(outs Optional<GPU_AsyncToken>:$asyncToken)> {
808810
let summary = "GPU kernel launch operation";
809811

@@ -837,6 +839,9 @@ def GPU_LaunchOp : GPU_Op<"launch", [
837839
- a variadic number of Workgroup memory attributions.
838840
- a variadic number of Private memory attributions.
839841

842+
The `kernelFunc` and `kernelModule` attributes are optional and specifies
843+
the kernel name and a module in which the kernel should be outlined.
844+
840845
Syntax:
841846

842847
```

mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -364,17 +364,23 @@ class GpuKernelOutliningPass
364364
Block::iterator insertPt(func->getNextNode());
365365
auto funcWalkResult = func.walk([&](gpu::LaunchOp op) {
366366
SetVector<Value> operands;
367-
std::string kernelFnName =
368-
Twine(op->getParentOfType<SymbolOpInterface>().getName(), "_kernel")
369-
.str();
367+
std::string kernelFnName;
368+
if (op.getKernelFunc()) {
369+
kernelFnName = op.getKernelFunc()->getRootReference().str();
370+
} else {
371+
kernelFnName =
372+
Twine(op->getParentOfType<SymbolOpInterface>().getName(),
373+
"_kernel")
374+
.str();
375+
}
370376

371377
gpu::GPUFuncOp outlinedFunc =
372378
outlineKernelFuncImpl(op, kernelFnName, operands);
373379

374380
// Create nested module and insert outlinedFunc. The module will
375381
// originally get the same name as the function, but may be renamed on
376382
// insertion into the parent module.
377-
auto kernelModule = createKernelModule(outlinedFunc, symbolTable);
383+
auto kernelModule = createKernelModule(op, outlinedFunc, symbolTable);
378384
symbolTable.insert(kernelModule, insertPt);
379385

380386
// Potentially changes signature, pulling in constants.
@@ -395,16 +401,32 @@ class GpuKernelOutliningPass
395401

396402
private:
397403
/// Returns a gpu.module containing kernelFunc and all callees (recursive).
398-
gpu::GPUModuleOp createKernelModule(gpu::GPUFuncOp kernelFunc,
404+
gpu::GPUModuleOp createKernelModule(gpu::LaunchOp gpuLaunchOp,
405+
gpu::GPUFuncOp kernelFunc,
399406
const SymbolTable &parentSymbolTable) {
400407
// TODO: This code cannot use an OpBuilder because it must be inserted into
401408
// a SymbolTable by the caller. SymbolTable needs to be refactored to
402409
// prevent manual building of Ops with symbols in code using SymbolTables
403410
// and then this needs to use the OpBuilder.
404411
auto *context = getOperation().getContext();
405412
OpBuilder builder(context);
406-
auto kernelModule = builder.create<gpu::GPUModuleOp>(kernelFunc.getLoc(),
407-
kernelFunc.getName());
413+
std::string kernelModuleName;
414+
gpu::GPUModuleOp kernelModule;
415+
if (gpuLaunchOp.getKernelModule()) {
416+
kernelModuleName =
417+
gpuLaunchOp.getKernelModule()->getRootReference().str();
418+
kernelModule =
419+
parentSymbolTable.lookup<gpu::GPUModuleOp>(kernelModuleName);
420+
} else {
421+
kernelModuleName = kernelFunc.getName();
422+
}
423+
424+
// Check if the module already exists in the symbol table
425+
if (!kernelModule) {
426+
// If not found, create a new GPU module
427+
kernelModule = builder.create<gpu::GPUModuleOp>(kernelFunc.getLoc(),
428+
kernelModuleName);
429+
}
408430

409431
// If a valid data layout spec was provided, attach it to the kernel module.
410432
// Otherwise, the default data layout will be used.

mlir/test/Dialect/GPU/outlining.mlir

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,3 +508,125 @@ func.func @launch_cluster() {
508508
// CHECK-NEXT: "some_op"(%[[CID]], %[[BID]], %[[BDIM]]) : (index, index, index) -> ()
509509
// CHECK-NEXT: = memref.load %[[KERNEL_ARG1]][%[[TID]]] : memref<?xf32, 1>
510510

511+
// -----
512+
// This test tests the two optional attributes kernelModule and kernelFunc for gpu.launch
513+
// CHECK-LABEL: func.func @testKernelAttributes()
514+
// CHECK: gpu.launch_func @test_module::@test_kernel_func blocks in (%[[GRID_X:.*]], %[[GRID_Y:.*]], %[[GRID_Z:.*]]) threads in (%[[BLOCK_X:.*]], %[[BLOCK_Y:.*]], %[[BLOCK_Z:.*]])
515+
// CHECK: gpu.module @test_module
516+
// CHECK: gpu.func @test_kernel_func()
517+
func.func @testKernelAttributes() {
518+
%gDimX = arith.constant 8 : index
519+
%gDimY = arith.constant 12 : index
520+
%gDimZ = arith.constant 16 : index
521+
%bDimX = arith.constant 32 : index
522+
%bDimY = arith.constant 16 : index
523+
%bDimZ = arith.constant 8 : index
524+
525+
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %gDimX, %grid_y = %gDimY, %grid_z = %gDimZ)
526+
threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ) {
527+
"some_op"(%bx, %tx) : (index, index) -> ()
528+
gpu.terminator
529+
} {kernelModule = @test_module, kernelFunc = @test_kernel_func}
530+
return
531+
}
532+
533+
// -----
534+
// This test tests the two optional attributes kernelModule and kernelFunc for gpu.launch, when kernelModule already exists.
535+
536+
// CHECK-LABEL: gpu.module @existing_module
537+
// CHECK: gpu.func @test_kernel_func()
538+
// CHECK: gpu.func @test_kernel_func_0()
539+
// CHECK-NOT: gpu.module @testExistingModule_kernel
540+
// CHECK-NOT: gpu.func @testExistingModule_kernel()
541+
// CHECK: func.func @testExistingModule()
542+
// CHECK: gpu.launch_func @existing_module::@test_kernel_func_0 blocks in (%[[GRID_X:.*]], %[[GRID_Y:.*]], %[[GRID_Z:.*]]) threads in (%[[BLOCK_X:.*]], %[[BLOCK_Y:.*]], %[[BLOCK_Z:.*]])
543+
544+
gpu.module @existing_module {
545+
gpu.func @test_kernel_func() {
546+
gpu.return
547+
}
548+
}
549+
550+
func.func @testExistingModule() {
551+
%gDimX = arith.constant 8 : index
552+
%gDimY = arith.constant 12 : index
553+
%gDimZ = arith.constant 16 : index
554+
%bDimX = arith.constant 32 : index
555+
%bDimY = arith.constant 16 : index
556+
%bDimZ = arith.constant 8 : index
557+
558+
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %gDimX, %grid_y = %gDimY, %grid_z = %gDimZ)
559+
threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ) {
560+
"some_op"(%bx, %tx) : (index, index) -> ()
561+
gpu.terminator
562+
} {kernelModule = @existing_module, kernelFunc = @test_kernel_func}
563+
return
564+
}
565+
566+
// -----
567+
// This test tests the optional attribute kernelModule for gpu.launch.
568+
// CHECK-LABEL: func.func @testKernelModuleOnly()
569+
// CHECK: gpu.launch_func @test_module::@testKernelModuleOnly_kernel blocks in (%[[GRID_X:.*]], %[[GRID_Y:.*]], %[[GRID_Z:.*]]) threads in (%[[BLOCK_X:.*]], %[[BLOCK_Y:.*]], %[[BLOCK_Z:.*]])
570+
// CHECK: gpu.module @test_module
571+
// CHECK: gpu.func @testKernelModuleOnly_kernel()
572+
func.func @testKernelModuleOnly() {
573+
%gDimX = arith.constant 8 : index
574+
%gDimY = arith.constant 12 : index
575+
%gDimZ = arith.constant 16 : index
576+
%bDimX = arith.constant 32 : index
577+
%bDimY = arith.constant 16 : index
578+
%bDimZ = arith.constant 8 : index
579+
580+
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %gDimX, %grid_y = %gDimY, %grid_z = %gDimZ)
581+
threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ) {
582+
"some_op"(%bx, %tx) : (index, index) -> ()
583+
gpu.terminator
584+
} {kernelModule = @test_module}
585+
return
586+
}
587+
588+
// -----
589+
// This test tests the optional attribute kernelFunc for gpu.launch.
590+
// CHECK-LABEL: func.func @testKernelFuncOnly()
591+
// CHECK: gpu.launch_func @test_kernel_func::@test_kernel_func blocks in (%[[GRID_X:.*]], %[[GRID_Y:.*]], %[[GRID_Z:.*]]) threads in (%[[BLOCK_X:.*]], %[[BLOCK_Y:.*]], %[[BLOCK_Z:.*]])
592+
593+
// CHECK: gpu.module @test_kernel_func
594+
// CHECK: gpu.func @test_kernel_func()
595+
func.func @testKernelFuncOnly() {
596+
%gDimX = arith.constant 8 : index
597+
%gDimY = arith.constant 12 : index
598+
%gDimZ = arith.constant 16 : index
599+
%bDimX = arith.constant 32 : index
600+
%bDimY = arith.constant 16 : index
601+
%bDimZ = arith.constant 8 : index
602+
603+
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %gDimX, %grid_y = %gDimY, %grid_z = %gDimZ)
604+
threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ) {
605+
"some_op"(%bx, %tx) : (index, index) -> ()
606+
gpu.terminator
607+
} {kernelFunc = @test_kernel_func}
608+
return
609+
}
610+
611+
// -----
612+
// This test tests gpu.launch when optional attributes kernelModule and kernelFunc are not specified.
613+
// CHECK-LABEL: func.func @testNoAttributes()
614+
// CHECK: gpu.launch_func @testNoAttributes_kernel::@testNoAttributes_kernel blocks in (%[[GRID_X:.*]], %[[GRID_Y:.*]], %[[GRID_Z:.*]]) threads in (%[[BLOCK_X:.*]], %[[BLOCK_Y:.*]], %[[BLOCK_Z:.*]])
615+
616+
// CHECK: gpu.module @testNoAttributes_kernel
617+
// CHECK: gpu.func @testNoAttributes_kernel()
618+
func.func @testNoAttributes() {
619+
%gDimX = arith.constant 8 : index
620+
%gDimY = arith.constant 12 : index
621+
%gDimZ = arith.constant 16 : index
622+
%bDimX = arith.constant 32 : index
623+
%bDimY = arith.constant 16 : index
624+
%bDimZ = arith.constant 8 : index
625+
626+
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %gDimX, %grid_y = %gDimY, %grid_z = %gDimZ)
627+
threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY, %block_z = %bDimZ) {
628+
"some_op"(%bx, %tx) : (index, index) -> ()
629+
gpu.terminator
630+
}
631+
return
632+
}

0 commit comments

Comments
 (0)