@@ -9416,7 +9416,7 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con
9416
9416
9417
9417
9418
9418
template <bool vals_smem, int ncols_template, int block_size_template>
9419
- static void soft_max_f32(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par,
9419
+ static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par,
9420
9420
const int nrows_y, const float scale, const float max_bias, const float m0,
9421
9421
const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
9422
9422
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
@@ -9430,7 +9430,7 @@ static void soft_max_f32(const float * x, const float * mask, const float *pos,
9430
9430
const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
9431
9431
const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
9432
9432
9433
- float slope = 0 .0f;
9433
+ float slope = 1 .0f;
9434
9434
9435
9435
// ALiBi
9436
9436
if (max_bias > 0.0f) {
@@ -9455,7 +9455,7 @@ static void soft_max_f32(const float * x, const float * mask, const float *pos,
9455
9455
const int ix = rowx*ncols + col;
9456
9456
const int iy = rowy*ncols + col;
9457
9457
9458
- const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col ] : 0.0f);
9458
+ const float val = x[ix]*scale + (mask ? slope*mask[iy ] : 0.0f);
9459
9459
9460
9460
vals[col] = val;
9461
9461
max_val = sycl::max(max_val, val);
@@ -13017,7 +13017,7 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst,
13017
13017
}
13018
13018
13019
13019
template <bool vals_smem, int ncols_template, int block_size_template>
13020
- static void soft_max_f32_submitter(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par,
13020
+ static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par,
13021
13021
const int nrows_y, const float scale, const float max_bias, const float m0,
13022
13022
const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
13023
13023
const size_t n_local_scratch, dpct::queue_ptr stream) {
@@ -13027,15 +13027,15 @@ static void soft_max_f32_submitter(const float * x, const float * mask, const fl
13027
13027
cgh.parallel_for(
13028
13028
sycl::nd_range<3>(block_nums * block_dims, block_dims),
13029
13029
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
13030
- soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, pos, dst, ncols_par,
13030
+ soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
13031
13031
nrows_y, scale, max_bias, m0,
13032
13032
m1, n_head_log2, item_ct1,
13033
13033
local_buf_acc.get_pointer());
13034
13034
});
13035
13035
});
13036
13036
}
13037
13037
13038
- static void soft_max_f32_sycl(const float * x, const float * mask, const float * pos,
13038
+ static void soft_max_f32_sycl(const float * x, const float * mask,
13039
13039
float * dst, const int ncols_x, const int nrows_x,
13040
13040
const int nrows_y, const float scale, const float max_bias,
13041
13041
dpct::queue_ptr stream) {
@@ -13057,60 +13057,60 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float *
13057
13057
const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
13058
13058
if (n_local_scratch*sizeof(float) < local_mem_size) {
13059
13059
if (ncols_x > max_block_size) {
13060
- soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13060
+ soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
13061
13061
max_bias, m0, m1, n_head_log2, block_nums,
13062
13062
block_dims, n_local_scratch, stream);
13063
13063
return;
13064
13064
}
13065
13065
switch (ncols_x) {
13066
13066
case 32:
13067
- soft_max_f32_submitter<true, 32, 32>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13067
+ soft_max_f32_submitter<true, 32, 32>(x, mask, dst, ncols_x, nrows_y, scale,
13068
13068
max_bias, m0, m1, n_head_log2, block_nums,
13069
13069
block_dims, n_local_scratch, stream);
13070
13070
break;
13071
13071
case 64:
13072
- soft_max_f32_submitter<true, 64, 64>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13072
+ soft_max_f32_submitter<true, 64, 64>(x, mask, dst, ncols_x, nrows_y, scale,
13073
13073
max_bias, m0, m1, n_head_log2, block_nums,
13074
13074
block_dims, n_local_scratch, stream);
13075
13075
break;
13076
13076
case 128:
13077
- soft_max_f32_submitter<true, 128, 128>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13077
+ soft_max_f32_submitter<true, 128, 128>(x, mask, dst, ncols_x, nrows_y, scale,
13078
13078
max_bias, m0, m1, n_head_log2, block_nums,
13079
13079
block_dims, n_local_scratch, stream);
13080
13080
break;
13081
13081
case 256:
13082
- soft_max_f32_submitter<true, 256, 256>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13082
+ soft_max_f32_submitter<true, 256, 256>(x, mask, dst, ncols_x, nrows_y, scale,
13083
13083
max_bias, m0, m1, n_head_log2, block_nums,
13084
13084
block_dims, n_local_scratch, stream);
13085
13085
break;
13086
13086
case 512:
13087
- soft_max_f32_submitter<true, 512, 512>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13087
+ soft_max_f32_submitter<true, 512, 512>(x, mask, dst, ncols_x, nrows_y, scale,
13088
13088
max_bias, m0, m1, n_head_log2, block_nums,
13089
13089
block_dims, n_local_scratch, stream);
13090
13090
break;
13091
13091
case 1024:
13092
- soft_max_f32_submitter<true, 1024, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13092
+ soft_max_f32_submitter<true, 1024, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
13093
13093
max_bias, m0, m1, n_head_log2, block_nums,
13094
13094
block_dims, n_local_scratch, stream);
13095
13095
break;
13096
13096
case 2048:
13097
- soft_max_f32_submitter<true, 2048, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13097
+ soft_max_f32_submitter<true, 2048, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
13098
13098
max_bias, m0, m1, n_head_log2, block_nums,
13099
13099
block_dims, n_local_scratch, stream);
13100
13100
break;
13101
13101
case 4096:
13102
- soft_max_f32_submitter<true, 4096, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13102
+ soft_max_f32_submitter<true, 4096, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
13103
13103
max_bias, m0, m1, n_head_log2, block_nums,
13104
13104
block_dims, n_local_scratch, stream);
13105
13105
break;
13106
13106
default:
13107
- soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13107
+ soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
13108
13108
max_bias, m0, m1, n_head_log2, block_nums,
13109
13109
block_dims, n_local_scratch, stream);
13110
13110
break;
13111
13111
}
13112
13112
} else {
13113
- soft_max_f32_submitter<false, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13113
+ soft_max_f32_submitter<false, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
13114
13114
max_bias, m0, m1, n_head_log2, block_nums,
13115
13115
block_dims, WARP_SIZE, stream);
13116
13116
}
@@ -14675,12 +14675,9 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
14675
14675
GGML_ASSERT(src0->type == GGML_TYPE_F32);
14676
14676
GGML_ASSERT( dst->type == GGML_TYPE_F32);
14677
14677
14678
- const ggml_tensor * src2 = dst->src[2];
14679
-
14680
- #pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 and src2 support")
14678
+ #pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
14681
14679
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
14682
14680
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
14683
- GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
14684
14681
14685
14682
const int64_t ne00 = src0->ne[0];
14686
14683
const int64_t nrows_x = ggml_nrows(src0);
@@ -14692,25 +14689,7 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
14692
14689
memcpy(&scale, dst->op_params + 0, sizeof(float));
14693
14690
memcpy(&max_bias, dst->op_params + 1, sizeof(float));
14694
14691
14695
- // positions tensor
14696
- float * src2_dd = nullptr;
14697
- sycl_pool_alloc<float> src2_f;
14698
-
14699
- const bool use_src2 = src2 != nullptr;
14700
-
14701
- if (use_src2) {
14702
- const bool src2_on_device = src2->backend == GGML_BACKEND_TYPE_GPU;
14703
-
14704
- if (src2_on_device) {
14705
- ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) src2->extra;
14706
- src2_dd = (float *) src2_extra->data_device[g_main_device];
14707
- } else {
14708
- src2_dd = src2_f.alloc(ggml_nelements(src2));
14709
- SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src2_dd, src2, 0, 0, 0, 1, main_stream));
14710
- }
14711
- }
14712
-
14713
- soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, src2_dd, dst_dd, ne00,
14692
+ soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00,
14714
14693
nrows_x, nrows_y, scale, max_bias, main_stream);
14715
14694
}
14716
14695
0 commit comments