@@ -422,11 +422,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueEventsWait(
422
422
phEventWaitList, phEvent);
423
423
}
424
424
425
- UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch (
426
- ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
427
- const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
428
- const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
429
- const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
425
+ static ur_result_t
426
+ enqueueKernelLaunch (ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel,
427
+ uint32_t workDim, const size_t *pGlobalWorkOffset,
428
+ const size_t *pGlobalWorkSize, const size_t *pLocalWorkSize,
429
+ uint32_t numEventsInWaitList,
430
+ const ur_event_handle_t *phEventWaitList,
431
+ ur_event_handle_t *phEvent, size_t WorkGroupMemory) {
430
432
// Preconditions
431
433
UR_ASSERT (hQueue->getDevice () == hKernel->getProgram ()->getDevice (),
432
434
UR_RESULT_ERROR_INVALID_KERNEL);
@@ -444,6 +446,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
444
446
size_t ThreadsPerBlock[3 ] = {32u , 1u , 1u };
445
447
size_t BlocksPerGrid[3 ] = {1u , 1u , 1u };
446
448
449
+ // Set work group memory so we can compute the whole memory requirement
450
+ if (WorkGroupMemory)
451
+ hKernel->setWorkGroupMemory (WorkGroupMemory);
447
452
uint32_t LocalSize = hKernel->getLocalSize ();
448
453
CUfunction CuFunc = hKernel->get ();
449
454
@@ -506,6 +511,17 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
506
511
return UR_RESULT_SUCCESS;
507
512
}
508
513
514
+ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch (
515
+ ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
516
+ const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
517
+ const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
518
+ const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
519
+ return enqueueKernelLaunch (hQueue, hKernel, workDim, pGlobalWorkOffset,
520
+ pGlobalWorkSize, pLocalWorkSize,
521
+ numEventsInWaitList, phEventWaitList, phEvent,
522
+ /* WorkGroupMemory=*/ 0 );
523
+ }
524
+
509
525
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp (
510
526
ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
511
527
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
@@ -516,8 +532,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
516
532
coop_prop.id = UR_EXP_LAUNCH_PROPERTY_ID_COOPERATIVE;
517
533
coop_prop.value .cooperative = 1 ;
518
534
return urEnqueueKernelLaunchCustomExp (
519
- hQueue, hKernel, workDim, pGlobalWorkSize, pLocalWorkSize, 1 ,
520
- &coop_prop, numEventsInWaitList, phEventWaitList, phEvent);
535
+ hQueue, hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize,
536
+ pLocalWorkSize, 1 , &coop_prop, numEventsInWaitList, phEventWaitList,
537
+ phEvent);
521
538
}
522
539
return urEnqueueKernelLaunch (hQueue, hKernel, workDim, pGlobalWorkOffset,
523
540
pGlobalWorkSize, pLocalWorkSize,
@@ -526,16 +543,29 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
526
543
527
544
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp (
528
545
ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
529
- const size_t *pGlobalWorkSize , const size_t *pLocalWorkSize ,
530
- uint32_t numPropsInLaunchPropList,
546
+ const size_t *pGlobalWorkOffset , const size_t *pGlobalWorkSize ,
547
+ const size_t *pLocalWorkSize, uint32_t numPropsInLaunchPropList,
531
548
const ur_exp_launch_property_t *launchPropList,
532
549
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
533
550
ur_event_handle_t *phEvent) {
534
551
535
- if (numPropsInLaunchPropList == 0 ) {
536
- urEnqueueKernelLaunch (hQueue, hKernel, workDim, nullptr , pGlobalWorkSize,
537
- pLocalWorkSize, numEventsInWaitList, phEventWaitList,
538
- phEvent);
552
+ size_t WorkGroupMemory = [&]() -> size_t {
553
+ const ur_exp_launch_property_t *WorkGroupMemoryProp = std::find_if (
554
+ launchPropList, launchPropList + numPropsInLaunchPropList,
555
+ [](const ur_exp_launch_property_t &Prop) {
556
+ return Prop.id == UR_EXP_LAUNCH_PROPERTY_ID_WORK_GROUP_MEMORY;
557
+ });
558
+ if (WorkGroupMemoryProp != launchPropList + numPropsInLaunchPropList)
559
+ return WorkGroupMemoryProp->value .workgroup_mem_size ;
560
+ return 0 ;
561
+ }();
562
+
563
+ if (numPropsInLaunchPropList == 0 ||
564
+ (WorkGroupMemory && numPropsInLaunchPropList == 1 )) {
565
+ return enqueueKernelLaunch (hQueue, hKernel, workDim, pGlobalWorkOffset,
566
+ pGlobalWorkSize, pLocalWorkSize,
567
+ numEventsInWaitList, phEventWaitList, phEvent,
568
+ WorkGroupMemory);
539
569
}
540
570
#if CUDA_VERSION >= 11080
541
571
// Preconditions
@@ -548,7 +578,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
548
578
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
549
579
}
550
580
551
- std::vector<CUlaunchAttribute> launch_attribute (numPropsInLaunchPropList);
581
+ std::vector<CUlaunchAttribute> launch_attribute;
582
+ launch_attribute.reserve (numPropsInLaunchPropList);
552
583
553
584
// Early exit for zero size kernel
554
585
if (*pGlobalWorkSize == 0 ) {
@@ -561,40 +592,35 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
561
592
size_t ThreadsPerBlock[3 ] = {32u , 1u , 1u };
562
593
size_t BlocksPerGrid[3 ] = {1u , 1u , 1u };
563
594
595
+ // Set work group memory so we can compute the whole memory requirement
596
+ if (WorkGroupMemory)
597
+ hKernel->setWorkGroupMemory (WorkGroupMemory);
564
598
uint32_t LocalSize = hKernel->getLocalSize ();
565
599
CUfunction CuFunc = hKernel->get ();
566
600
567
601
for (uint32_t i = 0 ; i < numPropsInLaunchPropList; i++) {
568
602
switch (launchPropList[i].id ) {
569
603
case UR_EXP_LAUNCH_PROPERTY_ID_IGNORE: {
570
- launch_attribute[i].id = CU_LAUNCH_ATTRIBUTE_IGNORE;
604
+ auto &attr = launch_attribute.emplace_back ();
605
+ attr.id = CU_LAUNCH_ATTRIBUTE_IGNORE;
571
606
break ;
572
607
}
573
608
case UR_EXP_LAUNCH_PROPERTY_ID_CLUSTER_DIMENSION: {
574
-
575
- launch_attribute[i] .id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
609
+ auto &attr = launch_attribute. emplace_back ();
610
+ attr .id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
576
611
// Note that cuda orders from right to left wrt SYCL dimensional order.
577
612
if (workDim == 3 ) {
578
- launch_attribute[i].value .clusterDim .x =
579
- launchPropList[i].value .clusterDim [2 ];
580
- launch_attribute[i].value .clusterDim .y =
581
- launchPropList[i].value .clusterDim [1 ];
582
- launch_attribute[i].value .clusterDim .z =
583
- launchPropList[i].value .clusterDim [0 ];
613
+ attr.value .clusterDim .x = launchPropList[i].value .clusterDim [2 ];
614
+ attr.value .clusterDim .y = launchPropList[i].value .clusterDim [1 ];
615
+ attr.value .clusterDim .z = launchPropList[i].value .clusterDim [0 ];
584
616
} else if (workDim == 2 ) {
585
- launch_attribute[i].value .clusterDim .x =
586
- launchPropList[i].value .clusterDim [1 ];
587
- launch_attribute[i].value .clusterDim .y =
588
- launchPropList[i].value .clusterDim [0 ];
589
- launch_attribute[i].value .clusterDim .z =
590
- launchPropList[i].value .clusterDim [2 ];
617
+ attr.value .clusterDim .x = launchPropList[i].value .clusterDim [1 ];
618
+ attr.value .clusterDim .y = launchPropList[i].value .clusterDim [0 ];
619
+ attr.value .clusterDim .z = launchPropList[i].value .clusterDim [2 ];
591
620
} else {
592
- launch_attribute[i].value .clusterDim .x =
593
- launchPropList[i].value .clusterDim [0 ];
594
- launch_attribute[i].value .clusterDim .y =
595
- launchPropList[i].value .clusterDim [1 ];
596
- launch_attribute[i].value .clusterDim .z =
597
- launchPropList[i].value .clusterDim [2 ];
621
+ attr.value .clusterDim .x = launchPropList[i].value .clusterDim [0 ];
622
+ attr.value .clusterDim .y = launchPropList[i].value .clusterDim [1 ];
623
+ attr.value .clusterDim .z = launchPropList[i].value .clusterDim [2 ];
598
624
}
599
625
600
626
UR_CHECK_ERROR (cuFuncSetAttribute (
@@ -603,9 +629,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
603
629
break ;
604
630
}
605
631
case UR_EXP_LAUNCH_PROPERTY_ID_COOPERATIVE: {
606
- launch_attribute[i].id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE;
607
- launch_attribute[i].value .cooperative =
608
- launchPropList[i].value .cooperative ;
632
+ auto &attr = launch_attribute.emplace_back ();
633
+ attr.id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE;
634
+ attr.value .cooperative = launchPropList[i].value .cooperative ;
635
+ break ;
636
+ }
637
+ case UR_EXP_LAUNCH_PROPERTY_ID_WORK_GROUP_MEMORY: {
609
638
break ;
610
639
}
611
640
default : {
@@ -618,8 +647,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
618
647
// using the standard UR_CHECK_ERROR
619
648
if (ur_result_t Ret =
620
649
setKernelParams (hQueue->getContext (), hQueue->Device , workDim,
621
- nullptr , pGlobalWorkSize, pLocalWorkSize, hKernel ,
622
- CuFunc, ThreadsPerBlock, BlocksPerGrid);
650
+ pGlobalWorkOffset , pGlobalWorkSize, pLocalWorkSize,
651
+ hKernel, CuFunc, ThreadsPerBlock, BlocksPerGrid);
623
652
Ret != UR_RESULT_SUCCESS)
624
653
return Ret;
625
654
@@ -667,7 +696,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
667
696
launch_config.sharedMemBytes = LocalSize;
668
697
launch_config.hStream = CuStream;
669
698
launch_config.attrs = &launch_attribute[0 ];
670
- launch_config.numAttrs = numPropsInLaunchPropList ;
699
+ launch_config.numAttrs = launch_attribute. size () ;
671
700
672
701
UR_CHECK_ERROR (cuLaunchKernelEx (&launch_config, CuFunc,
673
702
const_cast <void **>(ArgIndices.data ()),
0 commit comments