Skip to content

Commit 421da29

Browse files
author
Aidan
committed
Fix batched impl
1 parent b80cf3b commit 421da29

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

ggml-sycl.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15272,8 +15272,8 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
1527215272
sycl_pool_alloc<sycl::half> dst_f16;
1527315273
char * dst_t;
1527415274

15275-
dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float;
15276-
dpct::library_data_t cu_data_type = dpct::library_data_t::real_float;
15275+
dpct::library_data_t cu_compute_type = dpct::library_data_t::real_half;
15276+
dpct::library_data_t cu_data_type = dpct::library_data_t::real_half;
1527715277

1527815278
// dst strides
1527915279
size_t nbd2 = dst->nb[2];
@@ -15282,16 +15282,16 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
1528215282
const sycl::half alpha_f16 = 1.0f;
1528315283
const sycl::half beta_f16 = 0.0f;
1528415284

15285-
const float alpha_f32 = 1.0f;
15286-
const float beta_f32 = 0.0f;
15287-
15288-
const void * alpha = &alpha_f32;
15289-
const void * beta = &beta_f32;
15285+
const void * alpha = &alpha_f16;
15286+
const void * beta = &beta_f16;
1529015287

1529115288
// TODO: Renable (dst->op_params[0] =! GGML_PREC_DEFAULT) pathway
15292-
// oneMKL open source supports half, half, float, float: datatypes
15289+
// when oneMKL open source supports half, half, float, float: datatypes
15290+
15291+
dst_t = (char *) dst_f16.alloc(ne_dst);
1529315292

15294-
dst_t = (char *) dst_ddf;
15293+
nbd2 /= sizeof(float) / sizeof(sycl::half);
15294+
nbd3 /= sizeof(float) / sizeof(sycl::half);
1529515295

1529615296
GGML_ASSERT(ne12 % ne02 == 0);
1529715297
GGML_ASSERT(ne13 % ne03 == 0);
@@ -15377,6 +15377,8 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
1537715377
}
1537815378
#endif
1537915379

15380+
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
15381+
to_fp32_sycl(dst_f16.get(), dst_ddf, ne_dst, main_stream);
1538015382
}
1538115383
catch (sycl::exception const &exc) {
1538215384
std::cerr << exc.what() << "Exception caught at file:" << __FILE__

0 commit comments

Comments
 (0)