@@ -422,11 +422,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueEventsWait(
422
422
phEventWaitList, phEvent);
423
423
}
424
424
425
- UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch (
426
- ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
427
- const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
428
- const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
429
- const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
425
+ static ur_result_t
426
+ enqueueKernelLaunch (ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel,
427
+ uint32_t workDim, const size_t *pGlobalWorkOffset,
428
+ const size_t *pGlobalWorkSize, const size_t *pLocalWorkSize,
429
+ uint32_t numEventsInWaitList,
430
+ const ur_event_handle_t *phEventWaitList,
431
+ ur_event_handle_t *phEvent, size_t WorkGroupMemory) {
430
432
// Preconditions
431
433
UR_ASSERT (hQueue->getDevice () == hKernel->getProgram ()->getDevice (),
432
434
UR_RESULT_ERROR_INVALID_KERNEL);
@@ -444,6 +446,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
444
446
size_t ThreadsPerBlock[3 ] = {32u , 1u , 1u };
445
447
size_t BlocksPerGrid[3 ] = {1u , 1u , 1u };
446
448
449
+ // Set work group memory so we can compute the whole memory requirement
450
+ if (WorkGroupMemory)
451
+ hKernel->setWorkGroupMemory (WorkGroupMemory);
447
452
uint32_t LocalSize = hKernel->getLocalSize ();
448
453
CUfunction CuFunc = hKernel->get ();
449
454
@@ -503,6 +508,17 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
503
508
return UR_RESULT_SUCCESS;
504
509
}
505
510
511
+ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch (
512
+ ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
513
+ const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
514
+ const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
515
+ const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
516
+ return enqueueKernelLaunch (hQueue, hKernel, workDim, pGlobalWorkOffset,
517
+ pGlobalWorkSize, pLocalWorkSize,
518
+ numEventsInWaitList, phEventWaitList, phEvent,
519
+ /* WorkGroupMemory=*/ 0 );
520
+ }
521
+
506
522
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp (
507
523
ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
508
524
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
@@ -513,8 +529,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
513
529
coop_prop.id = UR_EXP_LAUNCH_PROPERTY_ID_COOPERATIVE;
514
530
coop_prop.value .cooperative = 1 ;
515
531
return urEnqueueKernelLaunchCustomExp (
516
- hQueue, hKernel, workDim, pGlobalWorkSize, pLocalWorkSize, 1 ,
517
- &coop_prop, numEventsInWaitList, phEventWaitList, phEvent);
532
+ hQueue, hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize,
533
+ pLocalWorkSize, 1 , &coop_prop, numEventsInWaitList, phEventWaitList,
534
+ phEvent);
518
535
}
519
536
return urEnqueueKernelLaunch (hQueue, hKernel, workDim, pGlobalWorkOffset,
520
537
pGlobalWorkSize, pLocalWorkSize,
@@ -523,16 +540,29 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
523
540
524
541
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp (
525
542
ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
526
- const size_t *pGlobalWorkSize , const size_t *pLocalWorkSize ,
527
- uint32_t numPropsInLaunchPropList,
543
+ const size_t *pGlobalWorkOffset , const size_t *pGlobalWorkSize ,
544
+ const size_t *pLocalWorkSize, uint32_t numPropsInLaunchPropList,
528
545
const ur_exp_launch_property_t *launchPropList,
529
546
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
530
547
ur_event_handle_t *phEvent) {
531
548
532
- if (numPropsInLaunchPropList == 0 ) {
533
- urEnqueueKernelLaunch (hQueue, hKernel, workDim, nullptr , pGlobalWorkSize,
534
- pLocalWorkSize, numEventsInWaitList, phEventWaitList,
535
- phEvent);
549
+ size_t WorkGroupMemory = [&]() -> size_t {
550
+ const ur_exp_launch_property_t *WorkGroupMemoryProp = std::find_if (
551
+ launchPropList, launchPropList + numPropsInLaunchPropList,
552
+ [](const ur_exp_launch_property_t &Prop) {
553
+ return Prop.id == UR_EXP_LAUNCH_PROPERTY_ID_WORK_GROUP_MEMORY;
554
+ });
555
+ if (WorkGroupMemoryProp != launchPropList + numPropsInLaunchPropList)
556
+ return WorkGroupMemoryProp->value .workgroup_mem_size ;
557
+ return 0 ;
558
+ }();
559
+
560
+ if (numPropsInLaunchPropList == 0 ||
561
+ (WorkGroupMemory && numPropsInLaunchPropList == 1 )) {
562
+ return enqueueKernelLaunch (hQueue, hKernel, workDim, pGlobalWorkOffset,
563
+ pGlobalWorkSize, pLocalWorkSize,
564
+ numEventsInWaitList, phEventWaitList, phEvent,
565
+ WorkGroupMemory);
536
566
}
537
567
#if CUDA_VERSION >= 11080
538
568
// Preconditions
@@ -545,7 +575,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
545
575
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
546
576
}
547
577
548
- std::vector<CUlaunchAttribute> launch_attribute (numPropsInLaunchPropList);
578
+ std::vector<CUlaunchAttribute> launch_attribute;
579
+ launch_attribute.reserve (numPropsInLaunchPropList);
549
580
550
581
// Early exit for zero size kernel
551
582
if (*pGlobalWorkSize == 0 ) {
@@ -558,40 +589,35 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
558
589
size_t ThreadsPerBlock[3 ] = {32u , 1u , 1u };
559
590
size_t BlocksPerGrid[3 ] = {1u , 1u , 1u };
560
591
592
+ // Set work group memory so we can compute the whole memory requirement
593
+ if (WorkGroupMemory)
594
+ hKernel->setWorkGroupMemory (WorkGroupMemory);
561
595
uint32_t LocalSize = hKernel->getLocalSize ();
562
596
CUfunction CuFunc = hKernel->get ();
563
597
564
598
for (uint32_t i = 0 ; i < numPropsInLaunchPropList; i++) {
565
599
switch (launchPropList[i].id ) {
566
600
case UR_EXP_LAUNCH_PROPERTY_ID_IGNORE: {
567
- launch_attribute[i].id = CU_LAUNCH_ATTRIBUTE_IGNORE;
601
+ auto &attr = launch_attribute.emplace_back ();
602
+ attr.id = CU_LAUNCH_ATTRIBUTE_IGNORE;
568
603
break ;
569
604
}
570
605
case UR_EXP_LAUNCH_PROPERTY_ID_CLUSTER_DIMENSION: {
571
-
572
- launch_attribute[i] .id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
606
+ auto &attr = launch_attribute. emplace_back ();
607
+ attr .id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
573
608
// Note that cuda orders from right to left wrt SYCL dimensional order.
574
609
if (workDim == 3 ) {
575
- launch_attribute[i].value .clusterDim .x =
576
- launchPropList[i].value .clusterDim [2 ];
577
- launch_attribute[i].value .clusterDim .y =
578
- launchPropList[i].value .clusterDim [1 ];
579
- launch_attribute[i].value .clusterDim .z =
580
- launchPropList[i].value .clusterDim [0 ];
610
+ attr.value .clusterDim .x = launchPropList[i].value .clusterDim [2 ];
611
+ attr.value .clusterDim .y = launchPropList[i].value .clusterDim [1 ];
612
+ attr.value .clusterDim .z = launchPropList[i].value .clusterDim [0 ];
581
613
} else if (workDim == 2 ) {
582
- launch_attribute[i].value .clusterDim .x =
583
- launchPropList[i].value .clusterDim [1 ];
584
- launch_attribute[i].value .clusterDim .y =
585
- launchPropList[i].value .clusterDim [0 ];
586
- launch_attribute[i].value .clusterDim .z =
587
- launchPropList[i].value .clusterDim [2 ];
614
+ attr.value .clusterDim .x = launchPropList[i].value .clusterDim [1 ];
615
+ attr.value .clusterDim .y = launchPropList[i].value .clusterDim [0 ];
616
+ attr.value .clusterDim .z = launchPropList[i].value .clusterDim [2 ];
588
617
} else {
589
- launch_attribute[i].value .clusterDim .x =
590
- launchPropList[i].value .clusterDim [0 ];
591
- launch_attribute[i].value .clusterDim .y =
592
- launchPropList[i].value .clusterDim [1 ];
593
- launch_attribute[i].value .clusterDim .z =
594
- launchPropList[i].value .clusterDim [2 ];
618
+ attr.value .clusterDim .x = launchPropList[i].value .clusterDim [0 ];
619
+ attr.value .clusterDim .y = launchPropList[i].value .clusterDim [1 ];
620
+ attr.value .clusterDim .z = launchPropList[i].value .clusterDim [2 ];
595
621
}
596
622
597
623
UR_CHECK_ERROR (cuFuncSetAttribute (
@@ -600,9 +626,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
600
626
break ;
601
627
}
602
628
case UR_EXP_LAUNCH_PROPERTY_ID_COOPERATIVE: {
603
- launch_attribute[i].id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE;
604
- launch_attribute[i].value .cooperative =
605
- launchPropList[i].value .cooperative ;
629
+ auto &attr = launch_attribute.emplace_back ();
630
+ attr.id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE;
631
+ attr.value .cooperative = launchPropList[i].value .cooperative ;
632
+ break ;
633
+ }
634
+ case UR_EXP_LAUNCH_PROPERTY_ID_WORK_GROUP_MEMORY: {
606
635
break ;
607
636
}
608
637
default : {
@@ -615,8 +644,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
615
644
// using the standard UR_CHECK_ERROR
616
645
if (ur_result_t Ret =
617
646
setKernelParams (hQueue->getContext (), hQueue->Device , workDim,
618
- nullptr , pGlobalWorkSize, pLocalWorkSize, hKernel ,
619
- CuFunc, ThreadsPerBlock, BlocksPerGrid);
647
+ pGlobalWorkOffset , pGlobalWorkSize, pLocalWorkSize,
648
+ hKernel, CuFunc, ThreadsPerBlock, BlocksPerGrid);
620
649
Ret != UR_RESULT_SUCCESS)
621
650
return Ret;
622
651
@@ -664,7 +693,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
664
693
launch_config.sharedMemBytes = LocalSize;
665
694
launch_config.hStream = CuStream;
666
695
launch_config.attrs = &launch_attribute[0 ];
667
- launch_config.numAttrs = numPropsInLaunchPropList ;
696
+ launch_config.numAttrs = launch_attribute. size () ;
668
697
669
698
UR_CHECK_ERROR (cuLaunchKernelEx (&launch_config, CuFunc,
670
699
const_cast <void **>(ArgIndices.data ()),
0 commit comments