Skip to content

[SYCL] Use batched mul_mat pathway #5591

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 44 additions & 63 deletions ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12726,6 +12726,7 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0,

GGML_ASSERT(dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
GGML_ASSERT(src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));

GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);

Expand Down Expand Up @@ -13269,31 +13270,23 @@ static void k_compute_batched_ptrs(const sycl::half *src0_as_f16,
int64_t i03 = i13 / r3;
int64_t i02 = i12 / r2;

ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
}

static void ggml_sycl_mul_mat_mat_batched_sycl(const ggml_tensor *src0,
const ggml_tensor *src1,
ggml_tensor *dst) try {
static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
const ggml_tensor *src1,
ggml_tensor *dst) try {
GGML_ASSERT(!ggml_is_transposed(src0));
GGML_ASSERT(!ggml_is_transposed(src1));

GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);

GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);

GGML_TENSOR_LOCALS(int64_t, nb0, src0, nb);

GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);

GGML_TENSOR_LOCALS(int64_t, nb1, src1, nb);
GGML_TENSOR_BINARY_OP_LOCALS

const int64_t ne1 = ggml_nelements(src1);
const int64_t ne = ggml_nelements(dst);
const int64_t ne_dst = ggml_nelements(dst);

SYCL_CHECK(ggml_sycl_set_device(g_main_device));
dpct::queue_ptr main_stream = g_syclStreams[g_main_device_index][0];
Expand All @@ -13312,11 +13305,16 @@ static void ggml_sycl_mul_mat_mat_batched_sycl(const ggml_tensor *src0,
float * dst_ddf = (float *) dst_extra->data_device[g_main_device_index];

// convert src1 to fp16
const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
GGML_ASSERT(to_fp16_sycl != nullptr);

sycl_pool_alloc<sycl::half> src1_as_f16(ne1);
to_fp16_sycl(src1_ddf, src1_as_f16.get(), ne1, main_stream);
sycl_pool_alloc<sycl::half> src1_f16_alloc;
if (src1->type != GGML_TYPE_F16) {
const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
const int64_t ne_src1 = ggml_nelements(src1);
src1_f16_alloc.alloc(ne_src1);
GGML_ASSERT(to_fp16_sycl != nullptr);
to_fp16_sycl(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream);
}
sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
: src1_f16_alloc.get();

sycl_pool_alloc<sycl::half> dst_f16;
char * dst_t;
Expand All @@ -13337,20 +13335,12 @@ static void ggml_sycl_mul_mat_mat_batched_sycl(const ggml_tensor *src0,
const void * alpha = &alpha_f16;
const void * beta = &beta_f16;

if (dst->op_params[0] == GGML_PREC_DEFAULT) {
dst_t = (char *) dst_f16.alloc(ne);

nbd2 /= sizeof(float) / sizeof(sycl::half);
nbd3 /= sizeof(float) / sizeof(sycl::half);
} else {
dst_t = (char *) dst_ddf;

cu_compute_type = dpct::library_data_t::real_float;
cu_data_type = dpct::library_data_t::real_float;
// TODO: Renable (dst->op_params[0] =! GGML_PREC_DEFAULT) pathway
// once oneMKL open source supports half, half, float, float: datatypes
dst_t = (char *) dst_f16.alloc(ne_dst);

alpha = &alpha_f32;
beta = &beta_f32;
}
nbd2 /= sizeof(float) / sizeof(sycl::half);
nbd3 /= sizeof(float) / sizeof(sycl::half);

GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);
Expand Down Expand Up @@ -13386,10 +13376,10 @@ static void ggml_sycl_mul_mat_mat_batched_sycl(const ggml_tensor *src0,
*g_sycl_handles[g_main_device_index], oneapi::mkl::transpose::trans,
oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
(const char *)src0_as_f16, dpct::library_data_t::real_half,
nb01 / sizeof(sycl::half), src0->nb[2] / sizeof(sycl::half),
(const char *)src1_as_f16.get(), dpct::library_data_t::real_half,
nb11 / sizeof(float), src1->nb[2] / sizeof(float), beta,
(char *)dst_t, cu_data_type, ne01, dst->nb[2] / sizeof(float),
nb01 / nb00, nb02 / nb00,
(const char *)src1_f16, dpct::library_data_t::real_half,
nb11 / nb10, nb12 / nb10, beta,
(char *)dst_t, cu_data_type, ne01, nb2 / nb0,
ne12 * ne13, cu_compute_type)));
} else {
// use syclGemmBatchedEx
Expand All @@ -13409,44 +13399,35 @@ static void ggml_sycl_mul_mat_mat_batched_sycl(const ggml_tensor *src0,
{sycl::aspect::fp16});

main_stream->submit([&](sycl::handler &cgh) {
const sycl::half *src1_as_f16_get_ct1 = src1_as_f16.get();
const void **ptrs_src_get_ct3 = ptrs_src.get();
void **ptrs_dst_get_ct4 = ptrs_dst.get();

const void **ptrs_src_get = ptrs_src.get();
void **ptrs_dst_get = ptrs_dst.get();
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : nb12 / 2;
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : nb13 / 2;
cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
k_compute_batched_ptrs(
src0_as_f16, src1_as_f16_get_ct1,
dst_t, ptrs_src_get_ct3,
ptrs_dst_get_ct4, ne12, ne13, ne23,
nb02, nb03, nb12, nb13, nbd2, nbd3, r2,
r3, item_ct1);
src0_as_f16, src1_f16,
dst_t, ptrs_src_get,
ptrs_dst_get, ne12, ne13, ne23,
nb02, nb03, nb12_scaled, nb13_scaled,
nbd2, nbd3, r2, r3, item_ct1);
});
});
}
/*
DPCT1010:95: SYCL uses exceptions to report errors and does not use the
error codes. The call was replaced with 0. You need to rewrite this
code.
*/
SYCL_CHECK(0);

SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
*g_sycl_handles[g_main_device_index], oneapi::mkl::transpose::trans,
oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
(const void **)(ptrs_src.get() + 0 * ne23),
dpct::library_data_t::real_half, nb01 / sizeof(sycl::half),
dpct::library_data_t::real_half, nb01 / nb00,
(const void **)(ptrs_src.get() + 1 * ne23),
dpct::library_data_t::real_half, nb11 / sizeof(float), beta,
dpct::library_data_t::real_half, nb11 / nb10, beta,
(void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
cu_compute_type)));
}
#endif

if (dst->op_params[0] == GGML_PREC_DEFAULT) {
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
to_fp32_sycl(dst_f16.get(), dst_ddf, ne, main_stream);
}
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
to_fp32_sycl(dst_f16.get(), dst_ddf, ne_dst, main_stream);
}
catch (sycl::exception const &exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
Expand Down Expand Up @@ -13491,10 +13472,10 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
// KQV single-batch
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_nc\n");
ggml_sycl_mul_mat_vec_nc(src0, src1, dst);
} else if (!split && all_on_device && use_xmx && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
} else if (!split && all_on_device && use_xmx && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
// KQ + KQV multi-batch
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat_mat_batched_sycl\n");
ggml_sycl_mul_mat_mat_batched_sycl(src0, src1, dst);
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat_batched_sycl\n");
ggml_sycl_mul_mat_batched_sycl(src0, src1, dst);
} else if (src0->type == GGML_TYPE_F32) {
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat\n");
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
Expand Down