@@ -337,6 +337,7 @@ namespace dpct
337
337
}
338
338
size_t get_global_mem_size() const { return _global_mem_size; }
339
339
size_t get_local_mem_size() const { return _local_mem_size; }
340
+ size_t get_max_mem_alloc_size() const { return _max_mem_alloc_size; }
340
341
/// Returns the maximum clock rate of device's global memory in kHz. If
341
342
/// compiler does not support this API then returns default value 3200000 kHz.
342
343
unsigned int get_memory_clock_rate() const { return _memory_clock_rate; }
@@ -398,6 +399,10 @@ namespace dpct
398
399
{
399
400
_local_mem_size = local_mem_size;
400
401
}
402
+ void set_max_mem_alloc_size(size_t max_mem_alloc_size)
403
+ {
404
+ _max_mem_alloc_size = max_mem_alloc_size;
405
+ }
401
406
void set_max_work_group_size(int max_work_group_size)
402
407
{
403
408
_max_work_group_size = max_work_group_size;
@@ -465,6 +470,7 @@ namespace dpct
465
470
int _max_register_size_per_work_group;
466
471
size_t _global_mem_size;
467
472
size_t _local_mem_size;
473
+ size_t _max_mem_alloc_size;
468
474
size_t _max_nd_range_size[3];
469
475
int _max_nd_range_size_i[3];
470
476
uint32_t _device_id;
@@ -516,6 +522,7 @@ namespace dpct
516
522
dev.get_info<sycl::info::device::max_work_group_size>());
517
523
prop.set_global_mem_size(dev.get_info<sycl::info::device::global_mem_size>());
518
524
prop.set_local_mem_size(dev.get_info<sycl::info::device::local_mem_size>());
525
+ prop.set_max_mem_alloc_size(dev.get_info<sycl::info::device::max_mem_alloc_size>());
519
526
520
527
#if (defined(SYCL_EXT_INTEL_DEVICE_INFO) && SYCL_EXT_INTEL_DEVICE_INFO >= 6)
521
528
if (dev.has(sycl::aspect::ext_intel_memory_clock_rate))
@@ -644,6 +651,11 @@ namespace dpct
644
651
return get_device_info().get_global_mem_size();
645
652
}
646
653
654
+ size_t get_max_mem_alloc_size() const
655
+ {
656
+ return get_device_info().get_max_mem_alloc_size();
657
+ }
658
+
647
659
/// Get the number of bytes of free and total memory on the SYCL device.
648
660
/// \param [out] free_memory The number of bytes of free memory on the SYCL device.
649
661
/// \param [out] total_memory The number of bytes of total memory on the SYCL device.
@@ -11311,10 +11323,10 @@ void ggml_init_sycl() try {
11311
11323
GGML_ASSERT(g_all_sycl_device_count <= GGML_SYCL_MAX_DEVICES);
11312
11324
int64_t total_vram = 0;
11313
11325
11314
- #if defined(GGML_SYCL_FP16 )
11315
- fprintf(stderr, "%s: GGML_SYCL_FP16 : yes\n", __func__);
11326
+ #if defined(GGML_SYCL_F16 )
11327
+ fprintf(stderr, "%s: GGML_SYCL_F16 : yes\n", __func__);
11316
11328
#else
11317
- fprintf(stderr, "%s: GGML_SYCL_FP16 : no\n", __func__);
11329
+ fprintf(stderr, "%s: GGML_SYCL_F16 : no\n", __func__);
11318
11330
#endif
11319
11331
11320
11332
@@ -14788,6 +14800,12 @@ static size_t ggml_backend_sycl_buffer_type_get_alignment(ggml_backend_buffer_ty
14788
14800
UNUSED(buft);
14789
14801
}
14790
14802
14803
+ static size_t ggml_backend_sycl_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
14804
+ return dpct::get_current_device().get_max_mem_alloc_size();
14805
+
14806
+ UNUSED(buft);
14807
+ }
14808
+
14791
14809
static size_t ggml_backend_sycl_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
14792
14810
int64_t row_low = 0;
14793
14811
int64_t row_high = ggml_nrows(tensor);
@@ -14818,7 +14836,7 @@ static ggml_backend_buffer_type_i ggml_backend_sycl_buffer_type_interface = {
14818
14836
/* .get_name = */ ggml_backend_sycl_buffer_type_name,
14819
14837
/* .alloc_buffer = */ ggml_backend_sycl_buffer_type_alloc_buffer,
14820
14838
/* .get_alignment = */ ggml_backend_sycl_buffer_type_get_alignment,
14821
- /* .get_max_size = */ NULL, // TODO: return device.maxBufferLength
14839
+ /* .get_max_size = */ ggml_backend_sycl_buffer_type_get_max_size,
14822
14840
/* .get_alloc_size = */ ggml_backend_sycl_buffer_type_get_alloc_size,
14823
14841
/* .supports_backend = */ ggml_backend_sycl_buffer_type_supports_backend,
14824
14842
/* .is_host = */ nullptr,
0 commit comments