@@ -9414,7 +9414,7 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con
9414
9414
9415
9415
9416
9416
template <bool vals_smem, int ncols_template, int block_size_template>
9417
- static void soft_max_f32(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par,
9417
+ static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par,
9418
9418
const int nrows_y, const float scale, const float max_bias, const float m0,
9419
9419
const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
9420
9420
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
@@ -9428,7 +9428,7 @@ static void soft_max_f32(const float * x, const float * mask, const float *pos,
9428
9428
const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
9429
9429
const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
9430
9430
9431
- float slope = 0 .0f;
9431
+ float slope = 1 .0f;
9432
9432
9433
9433
// ALiBi
9434
9434
if (max_bias > 0.0f) {
@@ -9453,7 +9453,7 @@ static void soft_max_f32(const float * x, const float * mask, const float *pos,
9453
9453
const int ix = rowx*ncols + col;
9454
9454
const int iy = rowy*ncols + col;
9455
9455
9456
- const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col ] : 0.0f);
9456
+ const float val = x[ix]*scale + (mask ? slope*mask[iy ] : 0.0f);
9457
9457
9458
9458
vals[col] = val;
9459
9459
max_val = sycl::max(max_val, val);
@@ -13015,7 +13015,7 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst,
13015
13015
}
13016
13016
13017
13017
template <bool vals_smem, int ncols_template, int block_size_template>
13018
- static void soft_max_f32_submitter(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par,
13018
+ static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par,
13019
13019
const int nrows_y, const float scale, const float max_bias, const float m0,
13020
13020
const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
13021
13021
const size_t n_local_scratch, dpct::queue_ptr stream) {
@@ -13025,15 +13025,15 @@ static void soft_max_f32_submitter(const float * x, const float * mask, const fl
13025
13025
cgh.parallel_for(
13026
13026
sycl::nd_range<3>(block_nums * block_dims, block_dims),
13027
13027
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
13028
- soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, pos, dst, ncols_par,
13028
+ soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
13029
13029
nrows_y, scale, max_bias, m0,
13030
13030
m1, n_head_log2, item_ct1,
13031
13031
local_buf_acc.get_pointer());
13032
13032
});
13033
13033
});
13034
13034
}
13035
13035
13036
- static void soft_max_f32_sycl(const float * x, const float * mask, const float * pos,
13036
+ static void soft_max_f32_sycl(const float * x, const float * mask,
13037
13037
float * dst, const int ncols_x, const int nrows_x,
13038
13038
const int nrows_y, const float scale, const float max_bias,
13039
13039
dpct::queue_ptr stream) {
@@ -13055,60 +13055,60 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float *
13055
13055
const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
13056
13056
if (n_local_scratch*sizeof(float) < local_mem_size) {
13057
13057
if (ncols_x > max_block_size) {
13058
- soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13058
+ soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
13059
13059
max_bias, m0, m1, n_head_log2, block_nums,
13060
13060
block_dims, n_local_scratch, stream);
13061
13061
return;
13062
13062
}
13063
13063
switch (ncols_x) {
13064
13064
case 32:
13065
- soft_max_f32_submitter<true, 32, 32>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13065
+ soft_max_f32_submitter<true, 32, 32>(x, mask, dst, ncols_x, nrows_y, scale,
13066
13066
max_bias, m0, m1, n_head_log2, block_nums,
13067
13067
block_dims, n_local_scratch, stream);
13068
13068
break;
13069
13069
case 64:
13070
- soft_max_f32_submitter<true, 64, 64>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13070
+ soft_max_f32_submitter<true, 64, 64>(x, mask, dst, ncols_x, nrows_y, scale,
13071
13071
max_bias, m0, m1, n_head_log2, block_nums,
13072
13072
block_dims, n_local_scratch, stream);
13073
13073
break;
13074
13074
case 128:
13075
- soft_max_f32_submitter<true, 128, 128>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13075
+ soft_max_f32_submitter<true, 128, 128>(x, mask, dst, ncols_x, nrows_y, scale,
13076
13076
max_bias, m0, m1, n_head_log2, block_nums,
13077
13077
block_dims, n_local_scratch, stream);
13078
13078
break;
13079
13079
case 256:
13080
- soft_max_f32_submitter<true, 256, 256>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13080
+ soft_max_f32_submitter<true, 256, 256>(x, mask, dst, ncols_x, nrows_y, scale,
13081
13081
max_bias, m0, m1, n_head_log2, block_nums,
13082
13082
block_dims, n_local_scratch, stream);
13083
13083
break;
13084
13084
case 512:
13085
- soft_max_f32_submitter<true, 512, 512>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13085
+ soft_max_f32_submitter<true, 512, 512>(x, mask, dst, ncols_x, nrows_y, scale,
13086
13086
max_bias, m0, m1, n_head_log2, block_nums,
13087
13087
block_dims, n_local_scratch, stream);
13088
13088
break;
13089
13089
case 1024:
13090
- soft_max_f32_submitter<true, 1024, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13090
+ soft_max_f32_submitter<true, 1024, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
13091
13091
max_bias, m0, m1, n_head_log2, block_nums,
13092
13092
block_dims, n_local_scratch, stream);
13093
13093
break;
13094
13094
case 2048:
13095
- soft_max_f32_submitter<true, 2048, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13095
+ soft_max_f32_submitter<true, 2048, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
13096
13096
max_bias, m0, m1, n_head_log2, block_nums,
13097
13097
block_dims, n_local_scratch, stream);
13098
13098
break;
13099
13099
case 4096:
13100
- soft_max_f32_submitter<true, 4096, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13100
+ soft_max_f32_submitter<true, 4096, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
13101
13101
max_bias, m0, m1, n_head_log2, block_nums,
13102
13102
block_dims, n_local_scratch, stream);
13103
13103
break;
13104
13104
default:
13105
- soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13105
+ soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
13106
13106
max_bias, m0, m1, n_head_log2, block_nums,
13107
13107
block_dims, n_local_scratch, stream);
13108
13108
break;
13109
13109
}
13110
13110
} else {
13111
- soft_max_f32_submitter<false, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13111
+ soft_max_f32_submitter<false, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
13112
13112
max_bias, m0, m1, n_head_log2, block_nums,
13113
13113
block_dims, WARP_SIZE, stream);
13114
13114
}
@@ -14673,12 +14673,9 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
14673
14673
GGML_ASSERT(src0->type == GGML_TYPE_F32);
14674
14674
GGML_ASSERT( dst->type == GGML_TYPE_F32);
14675
14675
14676
- const ggml_tensor * src2 = dst->src[2];
14677
-
14678
- #pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 and src2 support")
14676
+ #pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
14679
14677
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
14680
14678
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
14681
- GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
14682
14679
14683
14680
const int64_t ne00 = src0->ne[0];
14684
14681
const int64_t nrows_x = ggml_nrows(src0);
@@ -14690,25 +14687,7 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
14690
14687
memcpy(&scale, dst->op_params + 0, sizeof(float));
14691
14688
memcpy(&max_bias, dst->op_params + 1, sizeof(float));
14692
14689
14693
- // positions tensor
14694
- float * src2_dd = nullptr;
14695
- sycl_pool_alloc<float> src2_f;
14696
-
14697
- const bool use_src2 = src2 != nullptr;
14698
-
14699
- if (use_src2) {
14700
- const bool src2_on_device = src2->backend == GGML_BACKEND_TYPE_GPU;
14701
-
14702
- if (src2_on_device) {
14703
- ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) src2->extra;
14704
- src2_dd = (float *) src2_extra->data_device[g_main_device];
14705
- } else {
14706
- src2_dd = src2_f.alloc(ggml_nelements(src2));
14707
- SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src2_dd, src2, 0, 0, 0, 1, main_stream));
14708
- }
14709
- }
14710
-
14711
- soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, src2_dd, dst_dd, ne00,
14690
+ soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00,
14712
14691
nrows_x, nrows_y, scale, max_bias, main_stream);
14713
14692
}
14714
14693
0 commit comments