Skip to content

Commit e2b0650

Browse files
authored
[SYCL]fix ggml_sycl_mul_mat_id() to match the change of api (#7436)
* fix mul_mat_id to match the change of api * rm comment * rm unused or duplicated code, rename as review comment
1 parent 0548a41 commit e2b0650

File tree

1 file changed

+221
-56
lines changed

1 file changed

+221
-56
lines changed

ggml-sycl.cpp

Lines changed: 221 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2944,6 +2944,57 @@ namespace dpct
29442944
using shared_memory = detail::device_memory<T, shared, Dimension>;
29452945

29462946

2947+
template <typename T,
2948+
sycl::access::address_space addressSpace =
2949+
sycl::access::address_space::global_space,
2950+
sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
2951+
sycl::memory_scope memoryScope = sycl::memory_scope::device>
2952+
inline T atomic_fetch_add(T *addr, T operand) {
2953+
auto atm =
2954+
sycl::atomic_ref<T, memoryOrder, memoryScope, addressSpace>(addr[0]);
2955+
return atm.fetch_add(operand);
2956+
}
2957+
2958+
template <sycl::access::address_space addressSpace =
2959+
sycl::access::address_space::global_space,
2960+
sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
2961+
sycl::memory_scope memoryScope = sycl::memory_scope::device,
2962+
typename T1, typename T2>
2963+
inline T1 atomic_fetch_add(T1 *addr, T2 operand) {
2964+
auto atm =
2965+
sycl::atomic_ref<T1, memoryOrder, memoryScope, addressSpace>(addr[0]);
2966+
return atm.fetch_add(operand);
2967+
}
2968+
2969+
template <typename T, sycl::access::address_space addressSpace =
2970+
sycl::access::address_space::global_space>
2971+
inline T atomic_fetch_add(T *addr, T operand,
2972+
sycl::memory_order memoryOrder) {
2973+
switch (memoryOrder) {
2974+
case sycl::memory_order::relaxed:
2975+
return atomic_fetch_add<T, addressSpace, sycl::memory_order::relaxed,
2976+
sycl::memory_scope::device>(addr, operand);
2977+
case sycl::memory_order::acq_rel:
2978+
return atomic_fetch_add<T, addressSpace, sycl::memory_order::acq_rel,
2979+
sycl::memory_scope::device>(addr, operand);
2980+
case sycl::memory_order::seq_cst:
2981+
return atomic_fetch_add<T, addressSpace, sycl::memory_order::seq_cst,
2982+
sycl::memory_scope::device>(addr, operand);
2983+
default:
2984+
assert(false && "Invalid memory_order for atomics. Valid memory_order for "
2985+
"atomics are: sycl::memory_order::relaxed, "
2986+
"sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!");
2987+
}
2988+
}
2989+
2990+
template <sycl::access::address_space addressSpace =
2991+
sycl::access::address_space::global_space,
2992+
typename T1, typename T2>
2993+
inline T1 atomic_fetch_add(T1 *addr, T2 operand,
2994+
sycl::memory_order memoryOrder) {
2995+
atomic_fetch_add<T1, addressSpace>(addr, operand, memoryOrder);
2996+
}
2997+
29472998
} // COPY from DPCT head files
29482999

29493000
#define GGML_COMMON_DECL_SYCL
@@ -3060,6 +3111,7 @@ void ggml_sycl_get_device_description(int device, char * description, size_t d
30603111
bool ggml_backend_is_sycl(ggml_backend_t backend);
30613112
int ggml_backend_sycl_get_device(ggml_backend_t backend);
30623113
int get_main_device();
3114+
static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer);
30633115
void print_ggml_tensor(const char*name, struct ggml_tensor *src);
30643116
void log_tensor_with_cnt(const char* name, struct ggml_tensor * src, int stop_cnt);
30653117

@@ -15459,22 +15511,86 @@ static void ggml_sycl_mul_mat_id_sycl(ggml_tensor * dst) {
1545915511
}
1546015512
#endif
1546115513

15514+
struct mmid_row_mapping {
15515+
int32_t i1;
15516+
int32_t i2;
15517+
};
15518+
15519+
__dpct_inline__ static void k_copy_src1_to_contiguous(
15520+
const char *__restrict__ src1_original, char *__restrict__ src1_contiguous,
15521+
int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping,
15522+
const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
15523+
int64_t ne11, int64_t ne10, size_t nb11, size_t nb12,
15524+
const sycl::nd_item<3> &item_ct1, int &src1_row) {
15525+
int32_t iid1 = item_ct1.get_group(2);
15526+
int32_t id = item_ct1.get_group(1);
15527+
15528+
const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
15529+
15530+
if (row_id_i != i02) {
15531+
return;
15532+
}
15533+
15534+
const int64_t i11 = id % ne11;
15535+
const int64_t i12 = iid1;
15536+
15537+
if (item_ct1.get_local_id(2) == 0) {
15538+
src1_row =
15539+
dpct::atomic_fetch_add<sycl::access::address_space::generic_space>(
15540+
cur_src1_row, 1);
15541+
row_mapping[src1_row] = {id, iid1};
15542+
}
15543+
/*
15544+
DPCT1065:194: Consider replacing sycl::nd_item::barrier() with
15545+
sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better
15546+
performance if there is no access to global memory.
15547+
*/
15548+
item_ct1.barrier();
15549+
15550+
const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
15551+
float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
15552+
15553+
#pragma unroll
15554+
for (int i = item_ct1.get_local_id(2); i < ne10;
15555+
i += item_ct1.get_local_range(2)) {
15556+
src1_row_contiguous[i] = src1_row_original[i];
15557+
}
15558+
}
15559+
15560+
__dpct_inline__ static void k_copy_dst_from_contiguous(
15561+
char *__restrict__ dst_original, const char *__restrict__ dst_contiguous,
15562+
const mmid_row_mapping *__restrict__ row_mapping, int64_t ne0, size_t nb1,
15563+
size_t nb2, const sycl::nd_item<3> &item_ct1) {
15564+
int32_t i = item_ct1.get_group(2);
15565+
15566+
const int32_t i1 = row_mapping[i].i1;
15567+
const int32_t i2 = row_mapping[i].i2;
15568+
15569+
const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
15570+
float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
15571+
15572+
#pragma unroll
15573+
for (int j = item_ct1.get_local_id(2); j < ne0;
15574+
j += item_ct1.get_local_range(2)) {
15575+
dst_row_original[j] = dst_row_contiguous[j];
15576+
}
15577+
}
15578+
1546215579
static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
1546315580
const ggml_tensor *src1,
1546415581
ggml_tensor *dst) try {
15465-
GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT &&
15466-
"mul_mat_id does not support split buffers");
15582+
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer) && "mul_mat_id does not support split buffers");
15583+
1546715584
const ggml_tensor *ids = dst->src[2];
15468-
const dpct::queue_ptr stream = g_syclStreams[g_main_device][0];
15585+
GGML_TENSOR_BINARY_OP_LOCALS
1546915586

15470-
const size_t nb11 = src1->nb[1];
15471-
const size_t nb1 = dst->nb[1];
15587+
const dpct::queue_ptr stream = g_syclStreams[g_main_device][0];
1547215588

15473-
const int32_t id = ((int32_t *)dst->op_params)[0];
15474-
const int32_t n_as = src0->ne[2];
15589+
const int64_t n_as = ne02;
15590+
const int64_t n_ids = ids->ne[0];
1547515591

1547615592
std::vector<char> ids_host(ggml_nbytes(ids));
15477-
const char *ids_dev = (const char *)ids->data;
15593+
const char * ids_dev = (const char *) ids->data;
1547815594

1547915595
SYCL_CHECK(CHECK_TRY_ERROR(
1548015596
stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
@@ -15514,24 +15630,40 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
1551415630

1551515631
src0_row.ne[2] = 1;
1551615632
src0_row.ne[3] = 1;
15517-
src0_row.nb[3] = src0->nb[2];
15518-
15519-
if (src1->ne[1] == 1) {
15520-
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
15521-
const int32_t row_id =
15522-
*(const int32_t *)(ids_host.data() + i01 * ids->nb[1] +
15523-
id * ids->nb[0]);
15524-
15525-
GGML_ASSERT(row_id >= 0 && row_id < n_as);
15633+
src0_row.nb[3] = nb02;
15634+
15635+
src1_row.ne[1] = 1;
15636+
src1_row.ne[2] = 1;
15637+
src1_row.ne[3] = 1;
15638+
src1_row.nb[2] = nb11;
15639+
src1_row.nb[3] = nb11;
15640+
15641+
dst_row.ne[1] = 1;
15642+
dst_row.ne[2] = 1;
15643+
dst_row.ne[3] = 1;
15644+
dst_row.nb[2] = nb1;
15645+
dst_row.nb[3] = nb1;
15646+
if (ne12 == 1) {
15647+
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
15648+
for (int64_t id = 0; id < n_ids; id++) {
15649+
const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
15650+
GGML_ASSERT(i02 >= 0 && i02 < n_as);
15651+
15652+
const int64_t i11 = id % ne11;
15653+
const int64_t i12 = iid1;
15654+
15655+
const int64_t i1 = id;
15656+
const int64_t i2 = i12;
1552615657

1552715658
src0_row_extra.data_device[g_main_device] =
15528-
src0_original + row_id * src0->nb[2];
15659+
src0_original + i02*nb02;
1552915660
src1_row_extra.data_device[g_main_device] =
15530-
src1_original + i01 * src1->nb[1];
15661+
src1_original + + i11*nb11 + i12*nb12;
1553115662
dst_row_extra.data_device[g_main_device] =
15532-
dst_original + i01 * dst->nb[1];
15663+
dst_original + i1*nb1 + i2*nb2;
1553315664

1553415665
ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row);
15666+
}
1553515667
}
1553615668
} else {
1553715669
sycl_pool_alloc<char> src1_contiguous(sizeof(float)*ggml_nelements(src1));
@@ -15540,64 +15672,98 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
1554015672
src1_row_extra.data_device[g_main_device] = src1_contiguous.get();
1554115673
dst_row_extra.data_device[g_main_device] = dst_contiguous.get();
1554215674

15543-
for (int32_t row_id = 0; row_id < n_as; ++row_id) {
15675+
for (int64_t i02 = 0; i02 < n_as; i02++) {
1554415676
int64_t num_src1_rows = 0;
15545-
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
15546-
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
15677+
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
15678+
for (int64_t id = 0; id < n_ids; id++) {
15679+
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
1554715680

15548-
if (row_id_i != row_id) {
15549-
continue;
15550-
}
15681+
GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
1555115682

15552-
GGML_ASSERT(row_id >= 0 && row_id < n_as);
15683+
if (row_id_i != i02) {
15684+
continue;
15685+
}
1555315686

15554-
SYCL_CHECK(CHECK_TRY_ERROR(
15555-
stream->memcpy(src1_contiguous.get() + num_src1_rows * nb11,
15556-
src1_original + i01 * nb11, nb11)));
15557-
num_src1_rows++;
15687+
num_src1_rows++;
15688+
}
1555815689
}
1555915690

1556015691
if (num_src1_rows == 0) {
1556115692
continue;
1556215693
}
1556315694

15564-
src0_row_extra.data_device[g_main_device] =
15565-
src0_original + row_id * src0->nb[2];
1556615695

15696+
sycl_pool_alloc<int> dev_cur_src1_row(1);
15697+
sycl_pool_alloc<mmid_row_mapping> dev_row_mapping(num_src1_rows);
15698+
SYCL_CHECK(CHECK_TRY_ERROR(
15699+
stream->memset(dev_cur_src1_row.get(), 0, sizeof(int))));
15700+
15701+
{
15702+
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u));
15703+
sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
15704+
stream->submit([&](sycl::handler &cgh) {
15705+
sycl::local_accessor<int, 0> src1_row_acc(cgh);
15706+
15707+
char *__restrict src1_contiguous_get =
15708+
src1_contiguous.get();
15709+
int *__restrict dev_cur_src1_row_get =
15710+
dev_cur_src1_row.get();
15711+
mmid_row_mapping *__restrict dev_row_mapping_get =
15712+
dev_row_mapping.get();
15713+
size_t ids_nb_ct6 = ids->nb[1];
15714+
size_t ids_nb_ct7 = ids->nb[0];
15715+
15716+
cgh.parallel_for(
15717+
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
15718+
[=](sycl::nd_item<3> item_ct1) {
15719+
k_copy_src1_to_contiguous(
15720+
src1_original, src1_contiguous_get,
15721+
dev_cur_src1_row_get,
15722+
dev_row_mapping_get, ids_dev, i02,
15723+
ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12,
15724+
item_ct1, src1_row_acc);
15725+
});
15726+
});
15727+
}
15728+
15729+
src0_row_extra.data_device[g_main_device] = src0_original + i02*nb02;
15730+
15731+
GGML_ASSERT(nb11 == sizeof(float)*ne10);
15732+
GGML_ASSERT(nb1 == sizeof(float)*ne0);
1556715733
src1_row.ne[1] = num_src1_rows;
15568-
dst_row.ne[1] = num_src1_rows;
1556915734

1557015735
src1_row.nb[1] = nb11;
1557115736
src1_row.nb[2] = num_src1_rows*nb11;
1557215737
src1_row.nb[3] = num_src1_rows*nb11;
1557315738

15739+
dst_row.ne[1] = num_src1_rows;
1557415740
dst_row.nb[1] = nb1;
1557515741
dst_row.nb[2] = num_src1_rows*nb1;
1557615742
dst_row.nb[3] = num_src1_rows*nb1;
1557715743

1557815744
ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row);
1557915745

15580-
num_src1_rows = 0;
15581-
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
15582-
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
15583-
15584-
if (row_id_i != row_id) {
15585-
continue;
15586-
}
15587-
15588-
GGML_ASSERT(row_id >= 0 && row_id < n_as);
15589-
15590-
SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(
15591-
dst_original + i01 * nb1,
15592-
dst_contiguous.get() + num_src1_rows * nb1, nb1)));
15593-
num_src1_rows++;
15746+
{
15747+
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u));
15748+
sycl::range<3> grid_dims(1, 1, num_src1_rows);
15749+
stream->submit([&](sycl::handler &cgh) {
15750+
const char *__restrict dst_contiguous_get =
15751+
dst_contiguous.get();
15752+
const mmid_row_mapping *__restrict dev_row_mapping_get =
15753+
dev_row_mapping.get();
15754+
15755+
cgh.parallel_for(
15756+
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
15757+
[=](sycl::nd_item<3> item_ct1) {
15758+
k_copy_dst_from_contiguous(dst_original,
15759+
dst_contiguous_get,
15760+
dev_row_mapping_get,
15761+
ne0, nb1, nb2, item_ct1);
15762+
});
15763+
});
1559415764
}
1559515765
}
1559615766
}
15597-
15598-
if (dst->backend == GGML_BACKEND_TYPE_CPU) {
15599-
SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
15600-
}
1560115767
}
1560215768
catch (sycl::exception const &exc) {
1560315769
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -16580,10 +16746,9 @@ GGML_CALL static const char * ggml_backend_sycl_split_buffer_get_name(ggml_backe
1658016746
UNUSED(buffer);
1658116747
}
1658216748

16583-
// unused at the moment
16584-
//static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) {
16585-
// return buffer->iface.get_name == ggml_backend_sycl_split_buffer_get_name;
16586-
//}
16749+
static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) {
16750+
return buffer->iface.get_name == ggml_backend_sycl_split_buffer_get_name;
16751+
}
1658716752

1658816753
GGML_CALL static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1658916754
ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;

0 commit comments

Comments
 (0)