Skip to content

Commit 2b4ea35

Browse files
cuda : add batched cuBLAS GEMM for faster attention (#3749)
* cmake : add helper for faster CUDA builds * batched : add NGL arg * ggml : skip nops in compute_forward * cuda : minor indentation * cuda : batched cuBLAS GEMMs for src0 F16 and src1 F32 (attention ops) * Apply suggestions from code review These changes plus: ```c++ #define cublasGemmBatchedEx hipblasGemmBatchedEx ``` are needed to compile with ROCM. I haven't done performance testing, but it seems to work. I couldn't figure out how to propose a change for lines outside what the pull changed, also this is the first time trying to create a multi-part review so please forgive me if I mess something up. * cuda : add ROCm / hipBLAS cublasGemmBatchedEx define * cuda : add cublasGemmStridedBatchedEx for non-broadcasted cases * cuda : reduce mallocs in cublasGemmBatchedEx branch * cuda : add TODO for calling cublas from kernel + using mem pool --------- Co-authored-by: Kerfuffle <[email protected]>
1 parent daab3d7 commit 2b4ea35

File tree

4 files changed

+193
-13
lines changed

4 files changed

+193
-13
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ if (LLAMA_CUBLAS)
331331
set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics
332332
else()
333333
set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics
334+
#set(CMAKE_CUDA_ARCHITECTURES "") # use this to compile much faster, but only F16 models work
334335
endif()
335336
endif()
336337
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")

examples/batched/batched.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ int main(int argc, char ** argv) {
1111
gpt_params params;
1212

1313
if (argc == 1 || argv[1][0] == '-') {
14-
printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL] [LEN]\n" , argv[0]);
14+
printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL] [LEN] [NGL]\n" , argv[0]);
1515
return 1 ;
1616
}
1717

@@ -21,6 +21,9 @@ int main(int argc, char ** argv) {
2121
// total length of the sequences including the prompt
2222
int n_len = 32;
2323

24+
// number of layers to offload to the GPU
25+
int n_gpu_layers = 0;
26+
2427
if (argc >= 2) {
2528
params.model = argv[1];
2629
}
@@ -37,6 +40,10 @@ int main(int argc, char ** argv) {
3740
n_len = std::atoi(argv[4]);
3841
}
3942

43+
if (argc >= 6) {
44+
n_gpu_layers = std::atoi(argv[5]);
45+
}
46+
4047
if (params.prompt.empty()) {
4148
params.prompt = "Hello my name is";
4249
}
@@ -49,7 +56,7 @@ int main(int argc, char ** argv) {
4956

5057
llama_model_params model_params = llama_model_default_params();
5158

52-
// model_params.n_gpu_layers = 99; // offload all layers to the GPU
59+
model_params.n_gpu_layers = n_gpu_layers;
5360

5461
llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
5562

ggml-cuda.cu

Lines changed: 179 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
3030
#define cublasCreate hipblasCreate
3131
#define cublasGemmEx hipblasGemmEx
32+
#define cublasGemmBatchedEx hipblasGemmBatchedEx
33+
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
3234
#define cublasHandle_t hipblasHandle_t
3335
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
3436
#define cublasSetStream hipblasSetStream
@@ -4326,13 +4328,13 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
43264328

43274329
const half * x = (const half *) vx;
43284330

4329-
const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
4330-
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
4331+
const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
4332+
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
43314333
const int channel_x = channel / channel_x_divisor;
43324334

4333-
const int nrows_y = ncols_x;
4335+
const int nrows_y = ncols_x;
43344336
const int nrows_dst = nrows_x;
4335-
const int row_dst = row_x;
4337+
const int row_dst = row_x;
43364338

43374339
const int idst = channel*nrows_dst + row_dst;
43384340

@@ -4345,13 +4347,13 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
43454347
break;
43464348
}
43474349

4348-
const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
4349-
const float xi = __half2float(x[ix]);
4350-
43514350
const int row_y = col_x;
43524351

4352+
const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
43534353
const int iy = channel*nrows_y + row_y;
43544354

4355+
const float xi = __half2float(x[ix]);
4356+
43554357
tmp += xi * y[iy];
43564358
}
43574359

@@ -7013,7 +7015,8 @@ static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tens
70137015
}
70147016

70157017
static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
7016-
GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1));
7018+
GGML_ASSERT(!ggml_is_transposed(src0));
7019+
GGML_ASSERT(!ggml_is_transposed(src1));
70177020
GGML_ASSERT(!ggml_is_permuted(src0));
70187021
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
70197022
GGML_ASSERT(src0->type == GGML_TYPE_F16);
@@ -7023,11 +7026,11 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
70237026
const int64_t ne01 = src0->ne[1];
70247027
const int64_t ne02 = src0->ne[2];
70257028

7026-
const int64_t ne12 = src1->ne[2];
7027-
70287029
const int64_t nb01 = src0->nb[1];
70297030
const int64_t nb02 = src0->nb[2];
70307031

7032+
const int64_t ne12 = src1->ne[2];
7033+
70317034
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
70327035
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
70337036

@@ -7046,6 +7049,159 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
70467049
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
70477050
}
70487051

7052+
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
7053+
GGML_ASSERT(!ggml_is_transposed(src0));
7054+
GGML_ASSERT(!ggml_is_transposed(src1));
7055+
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
7056+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
7057+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
7058+
7059+
const int64_t ne00 = src0->ne[0]; GGML_UNUSED(ne00);
7060+
const int64_t ne01 = src0->ne[1];
7061+
const int64_t ne02 = src0->ne[2];
7062+
const int64_t ne03 = src0->ne[3];
7063+
7064+
const int64_t nb01 = src0->nb[1];
7065+
const int64_t nb02 = src0->nb[2]; GGML_UNUSED(nb02);
7066+
const int64_t nb03 = src0->nb[3]; GGML_UNUSED(nb03);
7067+
7068+
const int64_t ne10 = src1->ne[0];
7069+
const int64_t ne11 = src1->ne[1];
7070+
const int64_t ne12 = src1->ne[2];
7071+
const int64_t ne13 = src1->ne[3];
7072+
7073+
const int64_t nb11 = src1->nb[1];
7074+
const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12);
7075+
const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13);
7076+
7077+
const int64_t ne1 = ggml_nelements(src1);
7078+
const int64_t ne = ggml_nelements(dst);
7079+
7080+
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
7081+
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
7082+
7083+
int id;
7084+
CUDA_CHECK(cudaGetDevice(&id));
7085+
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream));
7086+
7087+
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
7088+
void * src0_ddq = src0_extra->data_device[g_main_device];
7089+
half * src0_as_f16 = (half *) src0_ddq;
7090+
7091+
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
7092+
float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
7093+
7094+
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
7095+
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
7096+
7097+
// convert src1 to fp16
7098+
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
7099+
GGML_ASSERT(to_fp16_cuda != nullptr);
7100+
7101+
size_t src1_as = 0;
7102+
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
7103+
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
7104+
7105+
size_t dst_as = 0;
7106+
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
7107+
7108+
GGML_ASSERT(ne12 % ne02 == 0);
7109+
GGML_ASSERT(ne13 % ne03 == 0);
7110+
7111+
// broadcast factors
7112+
const int64_t r2 = ne12/ne02;
7113+
const int64_t r3 = ne13/ne03;
7114+
7115+
const half alpha_f16 = 1.0f;
7116+
const half beta_f16 = 0.0f;
7117+
7118+
#if 0
7119+
// use cublasGemmEx
7120+
{
7121+
for (int i13 = 0; i13 < ne13; ++i13) {
7122+
for (int i12 = 0; i12 < ne12; ++i12) {
7123+
int i03 = i13 / r3;
7124+
int i02 = i12 / r2;
7125+
7126+
CUBLAS_CHECK(
7127+
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
7128+
ne01, ne11, ne10,
7129+
&alpha_f16, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
7130+
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
7131+
&beta_f16, ( char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01,
7132+
CUBLAS_COMPUTE_16F,
7133+
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
7134+
}
7135+
}
7136+
}
7137+
#else
7138+
if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
7139+
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
7140+
// use cublasGemmStridedBatchedEx
7141+
CUBLAS_CHECK(
7142+
cublasGemmStridedBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
7143+
ne01, ne11, ne10,
7144+
&alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
7145+
(const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
7146+
&beta_f16, ( char *) dst_f16, CUDA_R_16F, ne01, dst->nb[2]/sizeof(float), // strideC
7147+
ne12*ne13,
7148+
CUBLAS_COMPUTE_16F,
7149+
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
7150+
} else {
7151+
// use cublasGemmBatchedEx
7152+
// TODO: https://github.com/ggerganov/llama.cpp/pull/3749#discussion_r1369997000
7153+
const int ne23 = ne12*ne13;
7154+
7155+
// TODO: avoid this alloc
7156+
void ** ptrs = (void **) malloc(3*ne23*sizeof(void *));
7157+
7158+
for (int i13 = 0; i13 < ne13; ++i13) {
7159+
for (int i12 = 0; i12 < ne12; ++i12) {
7160+
int i03 = i13 / r3;
7161+
int i02 = i12 / r2;
7162+
7163+
ptrs[0*ne23 + i12 + i13*ne12] = (char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3];
7164+
ptrs[1*ne23 + i12 + i13*ne12] = (char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2;
7165+
ptrs[2*ne23 + i12 + i13*ne12] = (char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2;
7166+
}
7167+
}
7168+
7169+
// allocate device memory for pointers
7170+
void ** ptrs_as = nullptr;
7171+
CUDA_CHECK(cudaMalloc(&ptrs_as, 3*ne23*sizeof(void *)));
7172+
7173+
// TODO: this does not work for some reason -- not sure why?
7174+
//size_t ptrs_s = 0;
7175+
//ptrs_as = (void **) ggml_cuda_pool_malloc(3*ne23*sizeof(void *), &ptrs_s);
7176+
7177+
// copy pointers to device
7178+
CUDA_CHECK(cudaMemcpy(ptrs_as, ptrs, 3*ne23*sizeof(void *), cudaMemcpyHostToDevice));
7179+
7180+
free(ptrs);
7181+
7182+
CUBLAS_CHECK(
7183+
cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
7184+
ne01, ne11, ne10,
7185+
&alpha_f16, (const void **) (ptrs_as + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
7186+
(const void **) (ptrs_as + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
7187+
&beta_f16, ( void **) (ptrs_as + 2*ne23), CUDA_R_16F, ne01,
7188+
ne23,
7189+
CUBLAS_COMPUTE_16F,
7190+
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
7191+
7192+
// free device memory for pointers
7193+
CUDA_CHECK(cudaFree(ptrs_as));
7194+
//ggml_cuda_pool_free(ptrs_as, ptrs_s);
7195+
}
7196+
#endif
7197+
7198+
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
7199+
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
7200+
7201+
ggml_cuda_pool_free(src1_as_f16, src1_as);
7202+
ggml_cuda_pool_free(dst_f16, dst_as);
7203+
}
7204+
70497205
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
70507206
bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
70517207
src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU;
@@ -7058,10 +7214,22 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
70587214
}
70597215
}
70607216

7217+
// debug helpers
7218+
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
7219+
//printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
7220+
//printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);
7221+
//printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
7222+
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
7223+
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
7224+
70617225
if (all_on_device && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
7226+
// KQ
70627227
ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
7063-
} else if (all_on_device && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) {
7228+
} else if (all_on_device && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
7229+
// KQV
70647230
ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
7231+
} else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
7232+
ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
70657233
} else if (src0->type == GGML_TYPE_F32) {
70667234
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
70677235
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {

ggml.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16602,6 +16602,10 @@ static void ggml_compute_forward_cross_entropy_loss_back(
1660216602
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
1660316603
GGML_ASSERT(params);
1660416604

16605+
if (tensor->op == GGML_OP_NONE) {
16606+
return;
16607+
}
16608+
1660516609
#ifdef GGML_USE_CUBLAS
1660616610
bool skip_cpu = ggml_cuda_compute_forward(params, tensor);
1660716611
if (skip_cpu) {

0 commit comments

Comments
 (0)