@@ -422,13 +422,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueEventsWait(
422
422
phEventWaitList, phEvent);
423
423
}
424
424
425
- static ur_result_t
426
- enqueueKernelLaunch (ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel,
427
- uint32_t workDim, const size_t *pGlobalWorkOffset,
428
- const size_t *pGlobalWorkSize, const size_t *pLocalWorkSize,
429
- uint32_t numEventsInWaitList,
430
- const ur_event_handle_t *phEventWaitList,
431
- ur_event_handle_t *phEvent, size_t WorkGroupMemory) {
425
+ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch (
426
+ ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
427
+ const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
428
+ const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
429
+ const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
432
430
// Preconditions
433
431
UR_ASSERT (hQueue->getDevice () == hKernel->getProgram ()->getDevice (),
434
432
UR_RESULT_ERROR_INVALID_KERNEL);
@@ -446,9 +444,6 @@ enqueueKernelLaunch(ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel,
446
444
size_t ThreadsPerBlock[3 ] = {32u , 1u , 1u };
447
445
size_t BlocksPerGrid[3 ] = {1u , 1u , 1u };
448
446
449
- // Set work group memory so we can compute the whole memory requirement
450
- if (WorkGroupMemory)
451
- hKernel->setWorkGroupMemory (WorkGroupMemory);
452
447
uint32_t LocalSize = hKernel->getLocalSize ();
453
448
CUfunction CuFunc = hKernel->get ();
454
449
@@ -511,17 +506,6 @@ enqueueKernelLaunch(ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel,
511
506
return UR_RESULT_SUCCESS;
512
507
}
513
508
514
- UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch (
515
- ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
516
- const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
517
- const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
518
- const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
519
- return enqueueKernelLaunch (hQueue, hKernel, workDim, pGlobalWorkOffset,
520
- pGlobalWorkSize, pLocalWorkSize,
521
- numEventsInWaitList, phEventWaitList, phEvent,
522
- /* WorkGroupMemory=*/ 0 );
523
- }
524
-
525
509
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp (
526
510
ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
527
511
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
@@ -532,9 +516,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
532
516
coop_prop.id = UR_EXP_LAUNCH_PROPERTY_ID_COOPERATIVE;
533
517
coop_prop.value .cooperative = 1 ;
534
518
return urEnqueueKernelLaunchCustomExp (
535
- hQueue, hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize,
536
- pLocalWorkSize, 1 , &coop_prop, numEventsInWaitList, phEventWaitList,
537
- phEvent);
519
+ hQueue, hKernel, workDim, pGlobalWorkSize, pLocalWorkSize, 1 ,
520
+ &coop_prop, numEventsInWaitList, phEventWaitList, phEvent);
538
521
}
539
522
return urEnqueueKernelLaunch (hQueue, hKernel, workDim, pGlobalWorkOffset,
540
523
pGlobalWorkSize, pLocalWorkSize,
@@ -543,29 +526,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
543
526
544
527
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp (
545
528
ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
546
- const size_t *pGlobalWorkOffset , const size_t *pGlobalWorkSize ,
547
- const size_t *pLocalWorkSize, uint32_t numPropsInLaunchPropList,
529
+ const size_t *pGlobalWorkSize , const size_t *pLocalWorkSize ,
530
+ uint32_t numPropsInLaunchPropList,
548
531
const ur_exp_launch_property_t *launchPropList,
549
532
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
550
533
ur_event_handle_t *phEvent) {
551
534
552
- size_t WorkGroupMemory = [&]() -> size_t {
553
- const ur_exp_launch_property_t *WorkGroupMemoryProp = std::find_if (
554
- launchPropList, launchPropList + numPropsInLaunchPropList,
555
- [](const ur_exp_launch_property_t &Prop) {
556
- return Prop.id == UR_EXP_LAUNCH_PROPERTY_ID_WORK_GROUP_MEMORY;
557
- });
558
- if (WorkGroupMemoryProp != launchPropList + numPropsInLaunchPropList)
559
- return WorkGroupMemoryProp->value .workgroup_mem_size ;
560
- return 0 ;
561
- }();
562
-
563
- if (numPropsInLaunchPropList == 0 ||
564
- (WorkGroupMemory && numPropsInLaunchPropList == 1 )) {
565
- return enqueueKernelLaunch (hQueue, hKernel, workDim, pGlobalWorkOffset,
566
- pGlobalWorkSize, pLocalWorkSize,
567
- numEventsInWaitList, phEventWaitList, phEvent,
568
- WorkGroupMemory);
535
+ if (numPropsInLaunchPropList == 0 ) {
536
+ urEnqueueKernelLaunch (hQueue, hKernel, workDim, nullptr , pGlobalWorkSize,
537
+ pLocalWorkSize, numEventsInWaitList, phEventWaitList,
538
+ phEvent);
569
539
}
570
540
#if CUDA_VERSION >= 11080
571
541
// Preconditions
@@ -578,8 +548,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
578
548
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
579
549
}
580
550
581
- std::vector<CUlaunchAttribute> launch_attribute;
582
- launch_attribute.reserve (numPropsInLaunchPropList);
551
+ std::vector<CUlaunchAttribute> launch_attribute (numPropsInLaunchPropList);
583
552
584
553
// Early exit for zero size kernel
585
554
if (*pGlobalWorkSize == 0 ) {
@@ -592,35 +561,40 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
592
561
size_t ThreadsPerBlock[3 ] = {32u , 1u , 1u };
593
562
size_t BlocksPerGrid[3 ] = {1u , 1u , 1u };
594
563
595
- // Set work group memory so we can compute the whole memory requirement
596
- if (WorkGroupMemory)
597
- hKernel->setWorkGroupMemory (WorkGroupMemory);
598
564
uint32_t LocalSize = hKernel->getLocalSize ();
599
565
CUfunction CuFunc = hKernel->get ();
600
566
601
567
for (uint32_t i = 0 ; i < numPropsInLaunchPropList; i++) {
602
568
switch (launchPropList[i].id ) {
603
569
case UR_EXP_LAUNCH_PROPERTY_ID_IGNORE: {
604
- auto &attr = launch_attribute.emplace_back ();
605
- attr.id = CU_LAUNCH_ATTRIBUTE_IGNORE;
570
+ launch_attribute[i].id = CU_LAUNCH_ATTRIBUTE_IGNORE;
606
571
break ;
607
572
}
608
573
case UR_EXP_LAUNCH_PROPERTY_ID_CLUSTER_DIMENSION: {
609
- auto &attr = launch_attribute. emplace_back ();
610
- attr .id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
574
+
575
+ launch_attribute[i] .id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
611
576
// Note that cuda orders from right to left wrt SYCL dimensional order.
612
577
if (workDim == 3 ) {
613
- attr.value .clusterDim .x = launchPropList[i].value .clusterDim [2 ];
614
- attr.value .clusterDim .y = launchPropList[i].value .clusterDim [1 ];
615
- attr.value .clusterDim .z = launchPropList[i].value .clusterDim [0 ];
578
+ launch_attribute[i].value .clusterDim .x =
579
+ launchPropList[i].value .clusterDim [2 ];
580
+ launch_attribute[i].value .clusterDim .y =
581
+ launchPropList[i].value .clusterDim [1 ];
582
+ launch_attribute[i].value .clusterDim .z =
583
+ launchPropList[i].value .clusterDim [0 ];
616
584
} else if (workDim == 2 ) {
617
- attr.value .clusterDim .x = launchPropList[i].value .clusterDim [1 ];
618
- attr.value .clusterDim .y = launchPropList[i].value .clusterDim [0 ];
619
- attr.value .clusterDim .z = launchPropList[i].value .clusterDim [2 ];
585
+ launch_attribute[i].value .clusterDim .x =
586
+ launchPropList[i].value .clusterDim [1 ];
587
+ launch_attribute[i].value .clusterDim .y =
588
+ launchPropList[i].value .clusterDim [0 ];
589
+ launch_attribute[i].value .clusterDim .z =
590
+ launchPropList[i].value .clusterDim [2 ];
620
591
} else {
621
- attr.value .clusterDim .x = launchPropList[i].value .clusterDim [0 ];
622
- attr.value .clusterDim .y = launchPropList[i].value .clusterDim [1 ];
623
- attr.value .clusterDim .z = launchPropList[i].value .clusterDim [2 ];
592
+ launch_attribute[i].value .clusterDim .x =
593
+ launchPropList[i].value .clusterDim [0 ];
594
+ launch_attribute[i].value .clusterDim .y =
595
+ launchPropList[i].value .clusterDim [1 ];
596
+ launch_attribute[i].value .clusterDim .z =
597
+ launchPropList[i].value .clusterDim [2 ];
624
598
}
625
599
626
600
UR_CHECK_ERROR (cuFuncSetAttribute (
@@ -629,12 +603,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
629
603
break ;
630
604
}
631
605
case UR_EXP_LAUNCH_PROPERTY_ID_COOPERATIVE: {
632
- auto &attr = launch_attribute.emplace_back ();
633
- attr.id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE;
634
- attr.value .cooperative = launchPropList[i].value .cooperative ;
635
- break ;
636
- }
637
- case UR_EXP_LAUNCH_PROPERTY_ID_WORK_GROUP_MEMORY: {
606
+ launch_attribute[i].id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE;
607
+ launch_attribute[i].value .cooperative =
608
+ launchPropList[i].value .cooperative ;
638
609
break ;
639
610
}
640
611
default : {
@@ -647,8 +618,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
647
618
// using the standard UR_CHECK_ERROR
648
619
if (ur_result_t Ret =
649
620
setKernelParams (hQueue->getContext (), hQueue->Device , workDim,
650
- pGlobalWorkOffset , pGlobalWorkSize, pLocalWorkSize,
651
- hKernel, CuFunc, ThreadsPerBlock, BlocksPerGrid);
621
+ nullptr , pGlobalWorkSize, pLocalWorkSize, hKernel ,
622
+ CuFunc, ThreadsPerBlock, BlocksPerGrid);
652
623
Ret != UR_RESULT_SUCCESS)
653
624
return Ret;
654
625
@@ -696,7 +667,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
696
667
launch_config.sharedMemBytes = LocalSize;
697
668
launch_config.hStream = CuStream;
698
669
launch_config.attrs = &launch_attribute[0 ];
699
- launch_config.numAttrs = launch_attribute. size () ;
670
+ launch_config.numAttrs = numPropsInLaunchPropList ;
700
671
701
672
UR_CHECK_ERROR (cuLaunchKernelEx (&launch_config, CuFunc,
702
673
const_cast <void **>(ArgIndices.data ()),
0 commit comments