Skip to content

Commit cd9807c

Browse files
slarenolexiyb
authored andcommitted
ggml-cuda : compute ptrs for cublasGemmBatchedEx in a kernel (ggml-org#3891)
* ggml-cuda : compute ptrs for cublasGemmBatchedEx in a kernel * fix warnings
1 parent e8a2ed7 commit cd9807c

File tree

1 file changed

+46
-34
lines changed

1 file changed

+46
-34
lines changed

ggml-cuda.cu

Lines changed: 46 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6696,8 +6696,10 @@ inline void ggml_cuda_op_clamp(
66966696
GGML_ASSERT(src0->type == GGML_TYPE_F32);
66976697
GGML_ASSERT( dst->type == GGML_TYPE_F32);
66986698

6699-
const float min = ((float *) dst->op_params)[0];
6700-
const float max = ((float *) dst->op_params)[1];
6699+
float min;
6700+
float max;
6701+
memcpy(&min, dst->op_params, sizeof(float));
6702+
memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
67016703

67026704
clamp_f32_cuda(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream);
67036705
CUDA_CHECK(cudaGetLastError());
@@ -7221,6 +7223,30 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
72217223
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
72227224
}
72237225

7226+
__global__ void k_compute_batched_ptrs(
7227+
const half * src0_as_f16, const half * src1_as_f16, half * dst_f16,
7228+
void ** ptrs,
7229+
int ne12, int ne13,
7230+
int ne23,
7231+
int nb02, int nb03,
7232+
int nb12, int nb13,
7233+
int nb2, int nb3,
7234+
int r2, int r3) {
7235+
int i13 = blockIdx.x * blockDim.x + threadIdx.x;
7236+
int i12 = blockIdx.y * blockDim.y + threadIdx.y;
7237+
7238+
if (i13 >= ne13 || i12 >= ne12) {
7239+
return;
7240+
}
7241+
7242+
int i03 = i13 / r3;
7243+
int i02 = i12 / r2;
7244+
7245+
ptrs[0*ne23 + i12 + i13*ne12] = (char *) src0_as_f16 + i02*nb02 + i03*nb03;
7246+
ptrs[1*ne23 + i12 + i13*ne12] = (char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
7247+
ptrs[2*ne23 + i12 + i13*ne12] = (char *) dst_f16 + i12* nb2/2 + i13* nb3/2;
7248+
}
7249+
72247250
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
72257251
GGML_ASSERT(!ggml_is_transposed(src0));
72267252
GGML_ASSERT(!ggml_is_transposed(src1));
@@ -7322,49 +7348,35 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
73227348
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
73237349
} else {
73247350
// use cublasGemmBatchedEx
7325-
// TODO: https://github.com/ggerganov/llama.cpp/pull/3749#discussion_r1369997000
73267351
const int ne23 = ne12*ne13;
73277352

7328-
// TODO: avoid this alloc
7329-
void ** ptrs = (void **) malloc(3*ne23*sizeof(void *));
7330-
7331-
for (int i13 = 0; i13 < ne13; ++i13) {
7332-
for (int i12 = 0; i12 < ne12; ++i12) {
7333-
int i03 = i13 / r3;
7334-
int i02 = i12 / r2;
7335-
7336-
ptrs[0*ne23 + i12 + i13*ne12] = (char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3];
7337-
ptrs[1*ne23 + i12 + i13*ne12] = (char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2;
7338-
ptrs[2*ne23 + i12 + i13*ne12] = (char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2;
7339-
}
7340-
}
7341-
7342-
// allocate device memory for pointers
73437353
void ** ptrs_as = nullptr;
7344-
CUDA_CHECK(cudaMalloc(&ptrs_as, 3*ne23*sizeof(void *)));
7345-
7346-
// TODO: this does not work for some reason -- not sure why?
7347-
//size_t ptrs_s = 0;
7348-
//ptrs_as = (void **) ggml_cuda_pool_malloc(3*ne23*sizeof(void *), &ptrs_s);
7349-
7350-
// copy pointers to device
7351-
CUDA_CHECK(cudaMemcpy(ptrs_as, ptrs, 3*ne23*sizeof(void *), cudaMemcpyHostToDevice));
7352-
7353-
free(ptrs);
7354+
size_t ptrs_s = 0;
7355+
ptrs_as = (void **) ggml_cuda_pool_malloc(3*ne23*sizeof(void *), &ptrs_s);
7356+
7357+
dim3 block_dims(ne13, ne12);
7358+
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
7359+
src0_as_f16, src1_as_f16, dst_f16,
7360+
ptrs_as,
7361+
ne12, ne13,
7362+
ne23,
7363+
nb02, nb03,
7364+
nb12, nb13,
7365+
dst->nb[2], dst->nb[3],
7366+
r2, r3);
7367+
CUDA_CHECK(cudaGetLastError());
73547368

73557369
CUBLAS_CHECK(
73567370
cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
73577371
ne01, ne11, ne10,
7358-
&alpha_f16, (const void **) (ptrs_as + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
7359-
(const void **) (ptrs_as + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
7360-
&beta_f16, ( void **) (ptrs_as + 2*ne23), CUDA_R_16F, ne01,
7372+
&alpha_f16, (const void * const *) (ptrs_as + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
7373+
(const void * const *) (ptrs_as + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
7374+
&beta_f16, ( void ** ) (ptrs_as + 2*ne23), CUDA_R_16F, ne01,
73617375
ne23,
73627376
CUBLAS_COMPUTE_16F,
73637377
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
73647378

7365-
// free device memory for pointers
7366-
CUDA_CHECK(cudaFree(ptrs_as));
7367-
//ggml_cuda_pool_free(ptrs_as, ptrs_s);
7379+
ggml_cuda_pool_free(ptrs_as, ptrs_s);
73687380
}
73697381
#endif
73707382

0 commit comments

Comments
 (0)