11#include " softmax.cuh"
22
3- template <bool vals_smem, int ncols_template, int block_size_template>
4- static __global__ void soft_max_f32 (const float * x, const half * mask, const half * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
3+ template <typename T>
4+ static __device__ __forceinline__ float t2f32 (T val) {
5+ return (float ) val;
6+ }
7+
8+ template <>
9+ __device__ float __forceinline__ t2f32<half>(half val) {
10+ return __half2float (val);
11+ }
12+
13+ template <bool vals_smem, int ncols_template, int block_size_template, typename T>
14+ static __global__ void soft_max_f32 (const float * x, const T * mask, const T * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
515 const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
616
717 const int tid = threadIdx .x ;
@@ -43,7 +53,7 @@ static __global__ void soft_max_f32(const float * x, const half * mask, const ha
4353 const int ix = rowx*ncols + col;
4454 const int iy = rowy*ncols + col;
4555
46- const float val = x[ix]*scale + (mask ? __half2float (mask[iy]) : 0 .0f ) + (pos ? slope*__half2float (pos[col]) : 0 .0f );
56+ const float val = x[ix]*scale + (mask ? t2f32 (mask[iy]) : 0 .0f ) + (pos ? slope*t2f32 (pos[col]) : 0 .0f );
4757
4858 vals[col] = val;
4959 max_val = max (max_val, val);
@@ -114,7 +124,8 @@ static __global__ void soft_max_f32(const float * x, const half * mask, const ha
114124 }
115125}
116126
117- static void soft_max_f32_cuda (const float * x, const half * mask, const half * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
127+ template <typename T>
128+ static void soft_max_f32_cuda (const float * x, const T * mask, const T * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
118129 int nth = WARP_SIZE;
119130 while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2 ;
120131 const dim3 block_dims (nth, 1 , 1 );
@@ -167,15 +178,19 @@ static void soft_max_f32_cuda(const float * x, const half * mask, const half * p
167178void ggml_cuda_op_soft_max (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
168179 const ggml_tensor * src0 = dst->src [0 ];
169180 const ggml_tensor * src1 = dst->src [1 ];
181+ const ggml_tensor * src2 = dst->src [2 ];
182+
170183 const float * src0_d = (const float *)src0->data ;
171- const half * src1_d = src1 ? (const half *)src1->data : nullptr ;
184+ const void * src1_d = src1 ? (const void *)src1->data : nullptr ;
185+
172186 float * dst_d = (float *)dst->data ;
173187 cudaStream_t stream = ctx.stream ();
174188
175189 GGML_ASSERT (src0->type == GGML_TYPE_F32);
176190 GGML_ASSERT ( dst->type == GGML_TYPE_F32);
177191
178- GGML_ASSERT (!src1 || src1->type == GGML_TYPE_F16); // src1 contains mask and it is optional
192+ GGML_ASSERT (!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
193+ GGML_ASSERT (!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
179194
180195 const int64_t ne00 = src0->ne [0 ];
181196 const int64_t nrows_x = ggml_nrows (src0);
@@ -188,14 +203,25 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
188203 memcpy (&max_bias, (float *) dst->op_params + 1 , sizeof (float ));
189204
190205 // positions tensor
191- half * src2_dd = nullptr ;
206+ void * src2_d = nullptr ;
192207
193- ggml_tensor * src2 = dst->src [2 ];
194208 const bool use_src2 = src2 != nullptr ;
195209
196210 if (use_src2) {
197- src2_dd = (half *)src2->data ;
211+ src2_d = (void *)src2->data ;
198212 }
199213
200- soft_max_f32_cuda (src0_d, src1_d, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
214+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
215+
216+ if (use_f16) {
217+ const half * src1_dd = (const half *)src1_d;
218+ const half * src2_dd = (const half *)src2_d;
219+
220+ soft_max_f32_cuda (src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
221+ } else {
222+ const float * src1_dd = (const float *)src1_d;
223+ const float * src2_dd = (const float *)src2_d;
224+
225+ soft_max_f32_cuda (src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
226+ }
201227}
0 commit comments