Skip to content

Commit 922a5b3

Browse files
committed
ggml : update ggml_soft_max_ext() CUDA, SYCL
1 parent c7a107b commit 922a5b3

File tree

4 files changed

+39
-75
lines changed

4 files changed

+39
-75
lines changed

ggml-cuda/softmax.cu

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ __device__ float __forceinline__ t2f32<half>(half val) {
1111
}
1212

1313
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
14-
static __global__ void soft_max_f32(const float * x, const T * mask, const T * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
14+
static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
1515
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
1616

1717
const int tid = threadIdx.x;
@@ -23,7 +23,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p
2323
const int warp_id = threadIdx.x / WARP_SIZE;
2424
const int lane_id = threadIdx.x % WARP_SIZE;
2525

26-
float slope = 0.0f;
26+
float slope = 1.0f;
2727

2828
// ALiBi
2929
if (max_bias > 0.0f) {
@@ -53,7 +53,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p
5353
const int64_t ix = (int64_t)rowx*ncols + col;
5454
const int64_t iy = (int64_t)rowy*ncols + col;
5555

56-
const float val = x[ix]*scale + (mask ? t2f32(mask[iy]) : 0.0f) + (pos ? slope*t2f32(pos[col]) : 0.0f);
56+
const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f);
5757

5858
vals[col] = val;
5959
max_val = max(max_val, val);
@@ -125,7 +125,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p
125125
}
126126

127127
template<typename T>
128-
static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
128+
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
129129
int nth = WARP_SIZE;
130130
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
131131
const dim3 block_dims(nth, 1, 1);
@@ -142,43 +142,42 @@ static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, fl
142142
if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
143143
switch (ncols_x) {
144144
case 32:
145-
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
145+
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
146146
break;
147147
case 64:
148-
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
148+
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
149149
break;
150150
case 128:
151-
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
151+
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
152152
break;
153153
case 256:
154-
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
154+
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
155155
break;
156156
case 512:
157-
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
157+
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
158158
break;
159159
case 1024:
160-
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
160+
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
161161
break;
162162
case 2048:
163-
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
163+
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
164164
break;
165165
case 4096:
166-
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
166+
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
167167
break;
168168
default:
169-
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
169+
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
170170
break;
171171
}
172172
} else {
173173
const size_t shmem_low = WARP_SIZE*sizeof(float);
174-
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
174+
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
175175
}
176176
}
177177

178178
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
179179
const ggml_tensor * src0 = dst->src[0];
180180
const ggml_tensor * src1 = dst->src[1];
181-
const ggml_tensor * src2 = dst->src[2];
182181

183182
const float * src0_d = (const float *)src0->data;
184183
const void * src1_d = src1 ? (const void *)src1->data : nullptr;
@@ -190,7 +189,6 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
190189
GGML_ASSERT( dst->type == GGML_TYPE_F32);
191190

192191
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
193-
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
194192

195193
const int64_t ne00 = src0->ne[0];
196194
const int64_t nrows_x = ggml_nrows(src0);
@@ -202,26 +200,15 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
202200
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
203201
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
204202

205-
// positions tensor
206-
void * src2_d = nullptr;
207-
208-
const bool use_src2 = src2 != nullptr;
209-
210-
if (use_src2) {
211-
src2_d = (void *)src2->data;
212-
}
213-
214-
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
203+
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
215204

216205
if (use_f16) {
217206
const half * src1_dd = (const half *)src1_d;
218-
const half * src2_dd = (const half *)src2_d;
219207

220-
soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
208+
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
221209
} else {
222210
const float * src1_dd = (const float *)src1_d;
223-
const float * src2_dd = (const float *)src2_d;
224211

225-
soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
212+
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
226213
}
227214
}

ggml-kompute.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1561,10 +1561,9 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
15611561
float scale;
15621562
memcpy(&scale, dst->op_params, sizeof(float));
15631563

1564-
#pragma message("TODO: add ggml_vk_soft_max() F16/F32 src1 and src2 support")
1564+
#pragma message("TODO: add ggml_vk_soft_max() F16 src1 support")
15651565
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
15661566
GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
1567-
GGML_ASSERT(src2 == nullptr);
15681567

15691568
ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
15701569
} break;

ggml-sycl.cpp

Lines changed: 19 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -9414,7 +9414,7 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con
94149414

94159415

94169416
template <bool vals_smem, int ncols_template, int block_size_template>
9417-
static void soft_max_f32(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par,
9417+
static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par,
94189418
const int nrows_y, const float scale, const float max_bias, const float m0,
94199419
const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
94209420
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
@@ -9428,7 +9428,7 @@ static void soft_max_f32(const float * x, const float * mask, const float *pos,
94289428
const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
94299429
const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
94309430

9431-
float slope = 0.0f;
9431+
float slope = 1.0f;
94329432

94339433
// ALiBi
94349434
if (max_bias > 0.0f) {
@@ -9453,7 +9453,7 @@ static void soft_max_f32(const float * x, const float * mask, const float *pos,
94539453
const int ix = rowx*ncols + col;
94549454
const int iy = rowy*ncols + col;
94559455

9456-
const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
9456+
const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f);
94579457

94589458
vals[col] = val;
94599459
max_val = sycl::max(max_val, val);
@@ -13015,7 +13015,7 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst,
1301513015
}
1301613016

1301713017
template <bool vals_smem, int ncols_template, int block_size_template>
13018-
static void soft_max_f32_submitter(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par,
13018+
static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par,
1301913019
const int nrows_y, const float scale, const float max_bias, const float m0,
1302013020
const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
1302113021
const size_t n_local_scratch, dpct::queue_ptr stream) {
@@ -13025,15 +13025,15 @@ static void soft_max_f32_submitter(const float * x, const float * mask, const fl
1302513025
cgh.parallel_for(
1302613026
sycl::nd_range<3>(block_nums * block_dims, block_dims),
1302713027
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
13028-
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, pos, dst, ncols_par,
13028+
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
1302913029
nrows_y, scale, max_bias, m0,
1303013030
m1, n_head_log2, item_ct1,
1303113031
local_buf_acc.get_pointer());
1303213032
});
1303313033
});
1303413034
}
1303513035

13036-
static void soft_max_f32_sycl(const float * x, const float * mask, const float * pos,
13036+
static void soft_max_f32_sycl(const float * x, const float * mask,
1303713037
float * dst, const int ncols_x, const int nrows_x,
1303813038
const int nrows_y, const float scale, const float max_bias,
1303913039
dpct::queue_ptr stream) {
@@ -13055,60 +13055,60 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float *
1305513055
const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
1305613056
if (n_local_scratch*sizeof(float) < local_mem_size) {
1305713057
if (ncols_x > max_block_size) {
13058-
soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13058+
soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
1305913059
max_bias, m0, m1, n_head_log2, block_nums,
1306013060
block_dims, n_local_scratch, stream);
1306113061
return;
1306213062
}
1306313063
switch (ncols_x) {
1306413064
case 32:
13065-
soft_max_f32_submitter<true, 32, 32>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13065+
soft_max_f32_submitter<true, 32, 32>(x, mask, dst, ncols_x, nrows_y, scale,
1306613066
max_bias, m0, m1, n_head_log2, block_nums,
1306713067
block_dims, n_local_scratch, stream);
1306813068
break;
1306913069
case 64:
13070-
soft_max_f32_submitter<true, 64, 64>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13070+
soft_max_f32_submitter<true, 64, 64>(x, mask, dst, ncols_x, nrows_y, scale,
1307113071
max_bias, m0, m1, n_head_log2, block_nums,
1307213072
block_dims, n_local_scratch, stream);
1307313073
break;
1307413074
case 128:
13075-
soft_max_f32_submitter<true, 128, 128>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13075+
soft_max_f32_submitter<true, 128, 128>(x, mask, dst, ncols_x, nrows_y, scale,
1307613076
max_bias, m0, m1, n_head_log2, block_nums,
1307713077
block_dims, n_local_scratch, stream);
1307813078
break;
1307913079
case 256:
13080-
soft_max_f32_submitter<true, 256, 256>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13080+
soft_max_f32_submitter<true, 256, 256>(x, mask, dst, ncols_x, nrows_y, scale,
1308113081
max_bias, m0, m1, n_head_log2, block_nums,
1308213082
block_dims, n_local_scratch, stream);
1308313083
break;
1308413084
case 512:
13085-
soft_max_f32_submitter<true, 512, 512>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13085+
soft_max_f32_submitter<true, 512, 512>(x, mask, dst, ncols_x, nrows_y, scale,
1308613086
max_bias, m0, m1, n_head_log2, block_nums,
1308713087
block_dims, n_local_scratch, stream);
1308813088
break;
1308913089
case 1024:
13090-
soft_max_f32_submitter<true, 1024, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13090+
soft_max_f32_submitter<true, 1024, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
1309113091
max_bias, m0, m1, n_head_log2, block_nums,
1309213092
block_dims, n_local_scratch, stream);
1309313093
break;
1309413094
case 2048:
13095-
soft_max_f32_submitter<true, 2048, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13095+
soft_max_f32_submitter<true, 2048, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
1309613096
max_bias, m0, m1, n_head_log2, block_nums,
1309713097
block_dims, n_local_scratch, stream);
1309813098
break;
1309913099
case 4096:
13100-
soft_max_f32_submitter<true, 4096, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13100+
soft_max_f32_submitter<true, 4096, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
1310113101
max_bias, m0, m1, n_head_log2, block_nums,
1310213102
block_dims, n_local_scratch, stream);
1310313103
break;
1310413104
default:
13105-
soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13105+
soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
1310613106
max_bias, m0, m1, n_head_log2, block_nums,
1310713107
block_dims, n_local_scratch, stream);
1310813108
break;
1310913109
}
1311013110
} else {
13111-
soft_max_f32_submitter<false, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13111+
soft_max_f32_submitter<false, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
1311213112
max_bias, m0, m1, n_head_log2, block_nums,
1311313113
block_dims, WARP_SIZE, stream);
1311413114
}
@@ -14673,12 +14673,9 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
1467314673
GGML_ASSERT(src0->type == GGML_TYPE_F32);
1467414674
GGML_ASSERT( dst->type == GGML_TYPE_F32);
1467514675

14676-
const ggml_tensor * src2 = dst->src[2];
14677-
14678-
#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 and src2 support")
14676+
#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
1467914677
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
1468014678
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
14681-
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
1468214679

1468314680
const int64_t ne00 = src0->ne[0];
1468414681
const int64_t nrows_x = ggml_nrows(src0);
@@ -14690,25 +14687,7 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
1469014687
memcpy(&scale, dst->op_params + 0, sizeof(float));
1469114688
memcpy(&max_bias, dst->op_params + 1, sizeof(float));
1469214689

14693-
// positions tensor
14694-
float * src2_dd = nullptr;
14695-
sycl_pool_alloc<float> src2_f;
14696-
14697-
const bool use_src2 = src2 != nullptr;
14698-
14699-
if (use_src2) {
14700-
const bool src2_on_device = src2->backend == GGML_BACKEND_TYPE_GPU;
14701-
14702-
if (src2_on_device) {
14703-
ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) src2->extra;
14704-
src2_dd = (float *) src2_extra->data_device[g_main_device];
14705-
} else {
14706-
src2_dd = src2_f.alloc(ggml_nelements(src2));
14707-
SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src2_dd, src2, 0, 0, 0, 1, main_stream));
14708-
}
14709-
}
14710-
14711-
soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, src2_dd, dst_dd, ne00,
14690+
soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00,
1471214691
nrows_x, nrows_y, scale, max_bias, main_stream);
1471314692
}
1471414693

ggml-vulkan.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3178,12 +3178,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
31783178
}
31793179
return nullptr;
31803180
case GGML_OP_SOFT_MAX:
3181-
#pragma message("TODO: add ggml_vk_soft_max() F16 src1 and src2 support")
3181+
#pragma message("TODO: add ggml_vk_soft_max() F16 src1")
31823182
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
31833183
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32);
3184-
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
31853184

3186-
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
3185+
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
31873186
return ctx->device->pipeline_soft_max_f32;
31883187
}
31893188
return nullptr;

0 commit comments

Comments
 (0)