@@ -9414,7 +9414,7 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con
94149414
94159415
94169416template <bool vals_smem, int ncols_template, int block_size_template>
9417- static void soft_max_f32(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par,
9417+ static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par,
94189418 const int nrows_y, const float scale, const float max_bias, const float m0,
94199419 const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
94209420 const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
@@ -9428,7 +9428,7 @@ static void soft_max_f32(const float * x, const float * mask, const float *pos,
94289428 const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
94299429 const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
94309430
9431- float slope = 0 .0f;
9431+ float slope = 1 .0f;
94329432
94339433 // ALiBi
94349434 if (max_bias > 0.0f) {
@@ -9453,7 +9453,7 @@ static void soft_max_f32(const float * x, const float * mask, const float *pos,
94539453 const int ix = rowx*ncols + col;
94549454 const int iy = rowy*ncols + col;
94559455
9456- const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col ] : 0.0f);
9456+ const float val = x[ix]*scale + (mask ? slope*mask[iy ] : 0.0f);
94579457
94589458 vals[col] = val;
94599459 max_val = sycl::max(max_val, val);
@@ -13015,7 +13015,7 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst,
1301513015}
1301613016
1301713017template <bool vals_smem, int ncols_template, int block_size_template>
13018- static void soft_max_f32_submitter(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par,
13018+ static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par,
1301913019 const int nrows_y, const float scale, const float max_bias, const float m0,
1302013020 const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
1302113021 const size_t n_local_scratch, dpct::queue_ptr stream) {
@@ -13025,15 +13025,15 @@ static void soft_max_f32_submitter(const float * x, const float * mask, const fl
1302513025 cgh.parallel_for(
1302613026 sycl::nd_range<3>(block_nums * block_dims, block_dims),
1302713027 [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
13028- soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, pos, dst, ncols_par,
13028+ soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
1302913029 nrows_y, scale, max_bias, m0,
1303013030 m1, n_head_log2, item_ct1,
1303113031 local_buf_acc.get_pointer());
1303213032 });
1303313033 });
1303413034}
1303513035
13036- static void soft_max_f32_sycl(const float * x, const float * mask, const float * pos,
13036+ static void soft_max_f32_sycl(const float * x, const float * mask,
1303713037 float * dst, const int ncols_x, const int nrows_x,
1303813038 const int nrows_y, const float scale, const float max_bias,
1303913039 dpct::queue_ptr stream) {
@@ -13055,60 +13055,60 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float *
1305513055 const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
1305613056 if (n_local_scratch*sizeof(float) < local_mem_size) {
1305713057 if (ncols_x > max_block_size) {
13058- soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13058+ soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
1305913059 max_bias, m0, m1, n_head_log2, block_nums,
1306013060 block_dims, n_local_scratch, stream);
1306113061 return;
1306213062 }
1306313063 switch (ncols_x) {
1306413064 case 32:
13065- soft_max_f32_submitter<true, 32, 32>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13065+ soft_max_f32_submitter<true, 32, 32>(x, mask, dst, ncols_x, nrows_y, scale,
1306613066 max_bias, m0, m1, n_head_log2, block_nums,
1306713067 block_dims, n_local_scratch, stream);
1306813068 break;
1306913069 case 64:
13070- soft_max_f32_submitter<true, 64, 64>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13070+ soft_max_f32_submitter<true, 64, 64>(x, mask, dst, ncols_x, nrows_y, scale,
1307113071 max_bias, m0, m1, n_head_log2, block_nums,
1307213072 block_dims, n_local_scratch, stream);
1307313073 break;
1307413074 case 128:
13075- soft_max_f32_submitter<true, 128, 128>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13075+ soft_max_f32_submitter<true, 128, 128>(x, mask, dst, ncols_x, nrows_y, scale,
1307613076 max_bias, m0, m1, n_head_log2, block_nums,
1307713077 block_dims, n_local_scratch, stream);
1307813078 break;
1307913079 case 256:
13080- soft_max_f32_submitter<true, 256, 256>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13080+ soft_max_f32_submitter<true, 256, 256>(x, mask, dst, ncols_x, nrows_y, scale,
1308113081 max_bias, m0, m1, n_head_log2, block_nums,
1308213082 block_dims, n_local_scratch, stream);
1308313083 break;
1308413084 case 512:
13085- soft_max_f32_submitter<true, 512, 512>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13085+ soft_max_f32_submitter<true, 512, 512>(x, mask, dst, ncols_x, nrows_y, scale,
1308613086 max_bias, m0, m1, n_head_log2, block_nums,
1308713087 block_dims, n_local_scratch, stream);
1308813088 break;
1308913089 case 1024:
13090- soft_max_f32_submitter<true, 1024, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13090+ soft_max_f32_submitter<true, 1024, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
1309113091 max_bias, m0, m1, n_head_log2, block_nums,
1309213092 block_dims, n_local_scratch, stream);
1309313093 break;
1309413094 case 2048:
13095- soft_max_f32_submitter<true, 2048, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13095+ soft_max_f32_submitter<true, 2048, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
1309613096 max_bias, m0, m1, n_head_log2, block_nums,
1309713097 block_dims, n_local_scratch, stream);
1309813098 break;
1309913099 case 4096:
13100- soft_max_f32_submitter<true, 4096, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13100+ soft_max_f32_submitter<true, 4096, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
1310113101 max_bias, m0, m1, n_head_log2, block_nums,
1310213102 block_dims, n_local_scratch, stream);
1310313103 break;
1310413104 default:
13105- soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13105+ soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
1310613106 max_bias, m0, m1, n_head_log2, block_nums,
1310713107 block_dims, n_local_scratch, stream);
1310813108 break;
1310913109 }
1311013110 } else {
13111- soft_max_f32_submitter<false, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13111+ soft_max_f32_submitter<false, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
1311213112 max_bias, m0, m1, n_head_log2, block_nums,
1311313113 block_dims, WARP_SIZE, stream);
1311413114 }
@@ -14673,12 +14673,9 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
1467314673 GGML_ASSERT(src0->type == GGML_TYPE_F32);
1467414674 GGML_ASSERT( dst->type == GGML_TYPE_F32);
1467514675
14676- const ggml_tensor * src2 = dst->src[2];
14677-
14678- #pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 and src2 support")
14676+ #pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
1467914677#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
1468014678 GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
14681- GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
1468214679
1468314680 const int64_t ne00 = src0->ne[0];
1468414681 const int64_t nrows_x = ggml_nrows(src0);
@@ -14690,25 +14687,7 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
1469014687 memcpy(&scale, dst->op_params + 0, sizeof(float));
1469114688 memcpy(&max_bias, dst->op_params + 1, sizeof(float));
1469214689
14693- // positions tensor
14694- float * src2_dd = nullptr;
14695- sycl_pool_alloc<float> src2_f;
14696-
14697- const bool use_src2 = src2 != nullptr;
14698-
14699- if (use_src2) {
14700- const bool src2_on_device = src2->backend == GGML_BACKEND_TYPE_GPU;
14701-
14702- if (src2_on_device) {
14703- ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) src2->extra;
14704- src2_dd = (float *) src2_extra->data_device[g_main_device];
14705- } else {
14706- src2_dd = src2_f.alloc(ggml_nelements(src2));
14707- SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src2_dd, src2, 0, 0, 0, 1, main_stream));
14708- }
14709- }
14710-
14711- soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, src2_dd, dst_dd, ne00,
14690+ soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00,
1471214691 nrows_x, nrows_y, scale, max_bias, main_stream);
1471314692}
1471414693
0 commit comments