@@ -528,43 +528,57 @@ pi_result cuda_piPlatformsGet(pi_uint32 num_entries, pi_platform *platforms,
528
528
pi_uint32 *num_platforms) {
529
529
530
530
try {
531
- static constexpr pi_uint32 numPlatforms = 1 ;
531
+ static std::once_flag initFlag;
532
+ static pi_uint32 numPlatforms = 1 ;
533
+ static _pi_platform platformId;
532
534
533
- if (num_platforms != nullptr ) {
534
- *num_platforms = numPlatforms;
535
+ if (num_entries == 0 and platforms != nullptr ) {
536
+ return PI_INVALID_VALUE;
537
+ }
538
+ if (platforms == nullptr and num_platforms == nullptr ) {
539
+ return PI_INVALID_VALUE;
535
540
}
536
541
537
542
pi_result err = PI_SUCCESS;
538
543
539
- if (platforms != nullptr ) {
540
-
541
- assert (num_entries != 0 );
542
-
543
- static std::once_flag initFlag;
544
- static _pi_platform platformId;
545
- std::call_once (
546
- initFlag,
547
- [](pi_result &err) {
548
- err = PI_CHECK_ERROR (cuInit (0 ));
549
-
550
- int numDevices = 0 ;
551
- err = PI_CHECK_ERROR (cuDeviceGetCount (&numDevices));
544
+ std::call_once (
545
+ initFlag,
546
+ [](pi_result &err) {
547
+ if (cuInit (0 ) != CUDA_SUCCESS) {
548
+ numPlatforms = 0 ;
549
+ return ;
550
+ }
551
+ int numDevices = 0 ;
552
+ err = PI_CHECK_ERROR (cuDeviceGetCount (&numDevices));
553
+ if (numDevices == 0 ) {
554
+ numPlatforms = 0 ;
555
+ return ;
556
+ }
557
+ try {
552
558
platformId.devices_ .reserve (numDevices);
553
- try {
554
- for (int i = 0 ; i < numDevices; ++i) {
555
- CUdevice device;
556
- err = PI_CHECK_ERROR (cuDeviceGet (&device, i));
557
- platformId.devices_ .emplace_back (
558
- new _pi_device{device, &platformId});
559
- }
560
- } catch (...) {
561
- // Clear and rethrow to allow retry
562
- platformId.devices_ .clear ();
563
- throw ;
559
+ for (int i = 0 ; i < numDevices; ++i) {
560
+ CUdevice device;
561
+ err = PI_CHECK_ERROR (cuDeviceGet (&device, i));
562
+ platformId.devices_ .emplace_back (
563
+ new _pi_device{device, &platformId});
564
564
}
565
- },
566
- err);
565
+ } catch (const std::bad_alloc &) {
566
+ // Signal out-of-memory situation
567
+ platformId.devices_ .clear ();
568
+ err = PI_OUT_OF_HOST_MEMORY;
569
+ } catch (...) {
570
+ // Clear and rethrow to allow retry
571
+ platformId.devices_ .clear ();
572
+ throw ;
573
+ }
574
+ },
575
+ err);
567
576
577
+ if (num_platforms != nullptr ) {
578
+ *num_platforms = numPlatforms;
579
+ }
580
+
581
+ if (platforms != nullptr ) {
568
582
*platforms = &platformId;
569
583
}
570
584
0 commit comments