@@ -2944,6 +2944,57 @@ namespace dpct
2944
2944
using shared_memory = detail::device_memory<T, shared, Dimension>;
2945
2945
2946
2946
2947
+ template <typename T,
2948
+ sycl::access::address_space addressSpace =
2949
+ sycl::access::address_space::global_space,
2950
+ sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
2951
+ sycl::memory_scope memoryScope = sycl::memory_scope::device>
2952
+ inline T atomic_fetch_add(T *addr, T operand) {
2953
+ auto atm =
2954
+ sycl::atomic_ref<T, memoryOrder, memoryScope, addressSpace>(addr[0]);
2955
+ return atm.fetch_add(operand);
2956
+ }
2957
+
2958
+ template <sycl::access::address_space addressSpace =
2959
+ sycl::access::address_space::global_space,
2960
+ sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
2961
+ sycl::memory_scope memoryScope = sycl::memory_scope::device,
2962
+ typename T1, typename T2>
2963
+ inline T1 atomic_fetch_add(T1 *addr, T2 operand) {
2964
+ auto atm =
2965
+ sycl::atomic_ref<T1, memoryOrder, memoryScope, addressSpace>(addr[0]);
2966
+ return atm.fetch_add(operand);
2967
+ }
2968
+
2969
+ template <typename T, sycl::access::address_space addressSpace =
2970
+ sycl::access::address_space::global_space>
2971
+ inline T atomic_fetch_add(T *addr, T operand,
2972
+ sycl::memory_order memoryOrder) {
2973
+ switch (memoryOrder) {
2974
+ case sycl::memory_order::relaxed:
2975
+ return atomic_fetch_add<T, addressSpace, sycl::memory_order::relaxed,
2976
+ sycl::memory_scope::device>(addr, operand);
2977
+ case sycl::memory_order::acq_rel:
2978
+ return atomic_fetch_add<T, addressSpace, sycl::memory_order::acq_rel,
2979
+ sycl::memory_scope::device>(addr, operand);
2980
+ case sycl::memory_order::seq_cst:
2981
+ return atomic_fetch_add<T, addressSpace, sycl::memory_order::seq_cst,
2982
+ sycl::memory_scope::device>(addr, operand);
2983
+ default:
2984
+ assert(false && "Invalid memory_order for atomics. Valid memory_order for "
2985
+ "atomics are: sycl::memory_order::relaxed, "
2986
+ "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!");
2987
+ }
2988
+ }
2989
+
2990
+ template <sycl::access::address_space addressSpace =
2991
+ sycl::access::address_space::global_space,
2992
+ typename T1, typename T2>
2993
+ inline T1 atomic_fetch_add(T1 *addr, T2 operand,
2994
+ sycl::memory_order memoryOrder) {
2995
+ atomic_fetch_add<T1, addressSpace>(addr, operand, memoryOrder);
2996
+ }
2997
+
2947
2998
} // COPY from DPCT head files
2948
2999
2949
3000
#define GGML_COMMON_DECL_SYCL
@@ -3060,6 +3111,7 @@ void ggml_sycl_get_device_description(int device, char * description, size_t d
3060
3111
bool ggml_backend_is_sycl(ggml_backend_t backend);
3061
3112
int ggml_backend_sycl_get_device(ggml_backend_t backend);
3062
3113
int get_main_device();
3114
+ static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer);
3063
3115
void print_ggml_tensor(const char*name, struct ggml_tensor *src);
3064
3116
void log_tensor_with_cnt(const char* name, struct ggml_tensor * src, int stop_cnt);
3065
3117
@@ -15459,22 +15511,86 @@ static void ggml_sycl_mul_mat_id_sycl(ggml_tensor * dst) {
15459
15511
}
15460
15512
#endif
15461
15513
15514
+ struct mmid_row_mapping {
15515
+ int32_t i1;
15516
+ int32_t i2;
15517
+ };
15518
+
15519
+ __dpct_inline__ static void k_copy_src1_to_contiguous(
15520
+ const char *__restrict__ src1_original, char *__restrict__ src1_contiguous,
15521
+ int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping,
15522
+ const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
15523
+ int64_t ne11, int64_t ne10, size_t nb11, size_t nb12,
15524
+ const sycl::nd_item<3> &item_ct1, int &src1_row) {
15525
+ int32_t iid1 = item_ct1.get_group(2);
15526
+ int32_t id = item_ct1.get_group(1);
15527
+
15528
+ const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
15529
+
15530
+ if (row_id_i != i02) {
15531
+ return;
15532
+ }
15533
+
15534
+ const int64_t i11 = id % ne11;
15535
+ const int64_t i12 = iid1;
15536
+
15537
+ if (item_ct1.get_local_id(2) == 0) {
15538
+ src1_row =
15539
+ dpct::atomic_fetch_add<sycl::access::address_space::generic_space>(
15540
+ cur_src1_row, 1);
15541
+ row_mapping[src1_row] = {id, iid1};
15542
+ }
15543
+ /*
15544
+ DPCT1065:194: Consider replacing sycl::nd_item::barrier() with
15545
+ sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better
15546
+ performance if there is no access to global memory.
15547
+ */
15548
+ item_ct1.barrier();
15549
+
15550
+ const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
15551
+ float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
15552
+
15553
+ #pragma unroll
15554
+ for (int i = item_ct1.get_local_id(2); i < ne10;
15555
+ i += item_ct1.get_local_range(2)) {
15556
+ src1_row_contiguous[i] = src1_row_original[i];
15557
+ }
15558
+ }
15559
+
15560
+ __dpct_inline__ static void k_copy_dst_from_contiguous(
15561
+ char *__restrict__ dst_original, const char *__restrict__ dst_contiguous,
15562
+ const mmid_row_mapping *__restrict__ row_mapping, int64_t ne0, size_t nb1,
15563
+ size_t nb2, const sycl::nd_item<3> &item_ct1) {
15564
+ int32_t i = item_ct1.get_group(2);
15565
+
15566
+ const int32_t i1 = row_mapping[i].i1;
15567
+ const int32_t i2 = row_mapping[i].i2;
15568
+
15569
+ const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
15570
+ float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
15571
+
15572
+ #pragma unroll
15573
+ for (int j = item_ct1.get_local_id(2); j < ne0;
15574
+ j += item_ct1.get_local_range(2)) {
15575
+ dst_row_original[j] = dst_row_contiguous[j];
15576
+ }
15577
+ }
15578
+
15462
15579
static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
15463
15580
const ggml_tensor *src1,
15464
15581
ggml_tensor *dst) try {
15465
- GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT &&
15466
- "mul_mat_id does not support split buffers");
15582
+ GGML_ASSERT(!ggml_backend_buffer_is_sycl_split( src0->buffer) && "mul_mat_id does not support split buffers");
15583
+
15467
15584
const ggml_tensor *ids = dst->src[2];
15468
- const dpct::queue_ptr stream = g_syclStreams[g_main_device][0];
15585
+ GGML_TENSOR_BINARY_OP_LOCALS
15469
15586
15470
- const size_t nb11 = src1->nb[1];
15471
- const size_t nb1 = dst->nb[1];
15587
+ const dpct::queue_ptr stream = g_syclStreams[g_main_device][0];
15472
15588
15473
- const int32_t id = ((int32_t *)dst->op_params)[0] ;
15474
- const int32_t n_as = src0 ->ne[2 ];
15589
+ const int64_t n_as = ne02 ;
15590
+ const int64_t n_ids = ids ->ne[0 ];
15475
15591
15476
15592
std::vector<char> ids_host(ggml_nbytes(ids));
15477
- const char *ids_dev = (const char *)ids->data;
15593
+ const char * ids_dev = (const char *) ids->data;
15478
15594
15479
15595
SYCL_CHECK(CHECK_TRY_ERROR(
15480
15596
stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
@@ -15514,24 +15630,40 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
15514
15630
15515
15631
src0_row.ne[2] = 1;
15516
15632
src0_row.ne[3] = 1;
15517
- src0_row.nb[3] = src0->nb[2];
15518
-
15519
- if (src1->ne[1] == 1) {
15520
- for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
15521
- const int32_t row_id =
15522
- *(const int32_t *)(ids_host.data() + i01 * ids->nb[1] +
15523
- id * ids->nb[0]);
15524
-
15525
- GGML_ASSERT(row_id >= 0 && row_id < n_as);
15633
+ src0_row.nb[3] = nb02;
15634
+
15635
+ src1_row.ne[1] = 1;
15636
+ src1_row.ne[2] = 1;
15637
+ src1_row.ne[3] = 1;
15638
+ src1_row.nb[2] = nb11;
15639
+ src1_row.nb[3] = nb11;
15640
+
15641
+ dst_row.ne[1] = 1;
15642
+ dst_row.ne[2] = 1;
15643
+ dst_row.ne[3] = 1;
15644
+ dst_row.nb[2] = nb1;
15645
+ dst_row.nb[3] = nb1;
15646
+ if (ne12 == 1) {
15647
+ for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
15648
+ for (int64_t id = 0; id < n_ids; id++) {
15649
+ const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
15650
+ GGML_ASSERT(i02 >= 0 && i02 < n_as);
15651
+
15652
+ const int64_t i11 = id % ne11;
15653
+ const int64_t i12 = iid1;
15654
+
15655
+ const int64_t i1 = id;
15656
+ const int64_t i2 = i12;
15526
15657
15527
15658
src0_row_extra.data_device[g_main_device] =
15528
- src0_original + row_id * src0->nb[2] ;
15659
+ src0_original + i02*nb02 ;
15529
15660
src1_row_extra.data_device[g_main_device] =
15530
- src1_original + i01 * src1->nb[1] ;
15661
+ src1_original + + i11*nb11 + i12*nb12 ;
15531
15662
dst_row_extra.data_device[g_main_device] =
15532
- dst_original + i01 * dst->nb[1] ;
15663
+ dst_original + i1*nb1 + i2*nb2 ;
15533
15664
15534
15665
ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row);
15666
+ }
15535
15667
}
15536
15668
} else {
15537
15669
sycl_pool_alloc<char> src1_contiguous(sizeof(float)*ggml_nelements(src1));
@@ -15540,64 +15672,98 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
15540
15672
src1_row_extra.data_device[g_main_device] = src1_contiguous.get();
15541
15673
dst_row_extra.data_device[g_main_device] = dst_contiguous.get();
15542
15674
15543
- for (int32_t row_id = 0; row_id < n_as; ++row_id ) {
15675
+ for (int64_t i02 = 0; i02 < n_as; i02++ ) {
15544
15676
int64_t num_src1_rows = 0;
15545
- for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
15546
- const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
15677
+ for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
15678
+ for (int64_t id = 0; id < n_ids; id++) {
15679
+ const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
15547
15680
15548
- if (row_id_i != row_id) {
15549
- continue;
15550
- }
15681
+ GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
15551
15682
15552
- GGML_ASSERT(row_id >= 0 && row_id < n_as);
15683
+ if (row_id_i != i02) {
15684
+ continue;
15685
+ }
15553
15686
15554
- SYCL_CHECK(CHECK_TRY_ERROR(
15555
- stream->memcpy(src1_contiguous.get() + num_src1_rows * nb11,
15556
- src1_original + i01 * nb11, nb11)));
15557
- num_src1_rows++;
15687
+ num_src1_rows++;
15688
+ }
15558
15689
}
15559
15690
15560
15691
if (num_src1_rows == 0) {
15561
15692
continue;
15562
15693
}
15563
15694
15564
- src0_row_extra.data_device[g_main_device] =
15565
- src0_original + row_id * src0->nb[2];
15566
15695
15696
+ sycl_pool_alloc<int> dev_cur_src1_row(1);
15697
+ sycl_pool_alloc<mmid_row_mapping> dev_row_mapping(num_src1_rows);
15698
+ SYCL_CHECK(CHECK_TRY_ERROR(
15699
+ stream->memset(dev_cur_src1_row.get(), 0, sizeof(int))));
15700
+
15701
+ {
15702
+ sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u));
15703
+ sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
15704
+ stream->submit([&](sycl::handler &cgh) {
15705
+ sycl::local_accessor<int, 0> src1_row_acc(cgh);
15706
+
15707
+ char *__restrict src1_contiguous_get =
15708
+ src1_contiguous.get();
15709
+ int *__restrict dev_cur_src1_row_get =
15710
+ dev_cur_src1_row.get();
15711
+ mmid_row_mapping *__restrict dev_row_mapping_get =
15712
+ dev_row_mapping.get();
15713
+ size_t ids_nb_ct6 = ids->nb[1];
15714
+ size_t ids_nb_ct7 = ids->nb[0];
15715
+
15716
+ cgh.parallel_for(
15717
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
15718
+ [=](sycl::nd_item<3> item_ct1) {
15719
+ k_copy_src1_to_contiguous(
15720
+ src1_original, src1_contiguous_get,
15721
+ dev_cur_src1_row_get,
15722
+ dev_row_mapping_get, ids_dev, i02,
15723
+ ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12,
15724
+ item_ct1, src1_row_acc);
15725
+ });
15726
+ });
15727
+ }
15728
+
15729
+ src0_row_extra.data_device[g_main_device] = src0_original + i02*nb02;
15730
+
15731
+ GGML_ASSERT(nb11 == sizeof(float)*ne10);
15732
+ GGML_ASSERT(nb1 == sizeof(float)*ne0);
15567
15733
src1_row.ne[1] = num_src1_rows;
15568
- dst_row.ne[1] = num_src1_rows;
15569
15734
15570
15735
src1_row.nb[1] = nb11;
15571
15736
src1_row.nb[2] = num_src1_rows*nb11;
15572
15737
src1_row.nb[3] = num_src1_rows*nb11;
15573
15738
15739
+ dst_row.ne[1] = num_src1_rows;
15574
15740
dst_row.nb[1] = nb1;
15575
15741
dst_row.nb[2] = num_src1_rows*nb1;
15576
15742
dst_row.nb[3] = num_src1_rows*nb1;
15577
15743
15578
15744
ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row);
15579
15745
15580
- num_src1_rows = 0;
15581
- for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
15582
- const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
15583
-
15584
- if (row_id_i != row_id) {
15585
- continue;
15586
- }
15587
-
15588
- GGML_ASSERT(row_id >= 0 && row_id < n_as);
15589
-
15590
- SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(
15591
- dst_original + i01 * nb1,
15592
- dst_contiguous.get() + num_src1_rows * nb1, nb1)));
15593
- num_src1_rows++;
15746
+ {
15747
+ sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u));
15748
+ sycl::range<3> grid_dims(1, 1, num_src1_rows);
15749
+ stream->submit([&](sycl::handler &cgh) {
15750
+ const char *__restrict dst_contiguous_get =
15751
+ dst_contiguous.get();
15752
+ const mmid_row_mapping *__restrict dev_row_mapping_get =
15753
+ dev_row_mapping.get();
15754
+
15755
+ cgh.parallel_for(
15756
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
15757
+ [=](sycl::nd_item<3> item_ct1) {
15758
+ k_copy_dst_from_contiguous(dst_original,
15759
+ dst_contiguous_get,
15760
+ dev_row_mapping_get,
15761
+ ne0, nb1, nb2, item_ct1);
15762
+ });
15763
+ });
15594
15764
}
15595
15765
}
15596
15766
}
15597
-
15598
- if (dst->backend == GGML_BACKEND_TYPE_CPU) {
15599
- SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
15600
- }
15601
15767
}
15602
15768
catch (sycl::exception const &exc) {
15603
15769
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -16580,10 +16746,9 @@ GGML_CALL static const char * ggml_backend_sycl_split_buffer_get_name(ggml_backe
16580
16746
UNUSED(buffer);
16581
16747
}
16582
16748
16583
- // unused at the moment
16584
- //static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) {
16585
- // return buffer->iface.get_name == ggml_backend_sycl_split_buffer_get_name;
16586
- //}
16749
+ static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) {
16750
+ return buffer->iface.get_name == ggml_backend_sycl_split_buffer_get_name;
16751
+ }
16587
16752
16588
16753
GGML_CALL static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
16589
16754
ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
0 commit comments