diff --git a/examples/common.cpp b/examples/common.cpp index 23d69e7d55a80..e8107399db734 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -271,6 +271,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.use_color = true; } else if (arg == "--mlock") { params.use_mlock = true; + } else if (arg == "--gpu_layers") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.gpu_layers = std::stoi(argv[i]); } else if (arg == "--no-mmap") { params.use_mmap = false; } else if (arg == "--mtest") { @@ -406,6 +412,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { if (llama_mmap_supported()) { fprintf(stderr, " --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n"); } + fprintf(stderr, " --gpu_layers number of layers to store in VRAM\n"); fprintf(stderr, " --mtest compute maximum memory usage\n"); fprintf(stderr, " --verbose-prompt print prompt before generation\n"); fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); @@ -454,6 +461,7 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) { lparams.f16_kv = params.memory_f16; lparams.use_mmap = params.use_mmap; lparams.use_mlock = params.use_mlock; + lparams.gpu_layers = params.gpu_layers; lparams.logits_all = params.perplexity; lparams.embedding = params.embedding; diff --git a/examples/common.h b/examples/common.h index 43f1cc9ef09d5..5c40c4b54d744 100644 --- a/examples/common.h +++ b/examples/common.h @@ -68,6 +68,7 @@ struct gpt_params { bool perplexity = false; // compute perplexity over the prompt bool use_mmap = true; // use mmap for faster loads bool use_mlock = false; // use mlock to keep model in memory + int gpu_layers = 0; // number of layers to store in VRAM bool mem_test = false; // compute maximum memory usage bool verbose_prompt = false; // print prompt tokens before generation }; diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 127b352a0f2c9..599281cdce2d6 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -225,6 +225,45 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) { } } +template static __global__ void dequantize_mul_mat_q4_0(const void * vx, const float * y, float * dst, const int ncols) { + const block_q4_0 * x = (const block_q4_0 *) vx; + + const int row = blockIdx.x; + const int tid = threadIdx.x; + + float partial_sum = 0; // separate sum for each thread + + for (int i = 0; i < ncols/block_size; i += 2) { + const int col = i*block_size + 2*tid; + + // dequantize + const float d = x[(row*ncols + col)/QK4_0].d; + + const uint8_t * pp = x[(row*ncols + col)/QK4_0].qs; + + const uint8_t vui = pp[((row*ncols + col)%QK4_0)/2]; + + const int8_t vi0 = vui & 0xF; + const int8_t vi1 = vui >> 4; + + const float v0 = (vi0 - 8)*d; + const float v1 = (vi1 - 8)*d; + + // matrix multiplication + partial_sum += v0 * y[col + 0]; + partial_sum += v1 * y[col + 1]; + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask=16; mask > 0; mask >>= 1) { + partial_sum += __shfl_xor_sync(0xffffffff, partial_sum, mask, 32); + } + if (tid == 0) { + dst[row] = partial_sum; + } +} + static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK4_0; dequantize_block_q4_0<<>>(vx, y); @@ -255,6 +294,23 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre dequantize_block_q8_0<<>>(vx, y); } +static void dequantize_mul_mat_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + // static int block_size = -1; + // if (block_size == -1) { + // int min_grid_size, max_block_size = 1; + // CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &max_block_size, dequantize_mul_mat_q4_0<256>, 0, 0)); + // max_block_size = min(max_block_size, GGML_CUDA_MAX_BLOCK_SIZE); + // block_size = 1; + // while (block_size*2 <= max_block_size && block_size*2 % ncols == 0) { + // block_size *= 2; + // } + // } + // dequantize_mul_mat_q4_0<<>>(vx, y, dst, ncols); + const int block_size = 32; + GGML_ASSERT(ncols % block_size == 0); + dequantize_mul_mat_q4_0<<>>(vx, y, dst, ncols); +} + // TODO: optimize static __global__ void convert_fp16_to_fp32(const void * vx, float * y) { const half * x = (const half *) vx; @@ -290,7 +346,7 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { } // buffer pool for cuda -#define MAX_CUDA_BUFFERS 16 +#define MAX_CUDA_BUFFERS 256 struct scoped_spin_lock { std::atomic_flag& lock; @@ -597,7 +653,10 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type); size_t x_size, y_size, d_size, q_size; - float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size); + float * d_X; + if (ne11 > 1) { + d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size); + } float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size); float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size); char * d_Q = (char *) ggml_cuda_pool_malloc(n_mm * q_sz, &q_size); @@ -612,31 +671,54 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor cudaStream_t cudaStream2 = g_cudaStreams2[i % GGML_CUDA_MAX_STREAMS]; cudaEvent_t cudaEvent = g_cudaEvents[i % GGML_CUDA_MAX_EVENTS]; - float * c_X = d_X + i * x_ne; float * c_Y = d_Y + i * y_ne; float * c_D = d_D + i * d_ne; char * c_Q = d_Q + i * q_sz; - // copy src0 and convert to fp32 on device - CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2)); - to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2); - CUDA_CHECK(cudaGetLastError()); - CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2)); + // copy src0 to device if necessary + if (src0->backend == GGML_BACKEND_CPU) { + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2)); + } else if (src0->backend == GGML_BACKEND_CUDA) { + c_Q = ((char *) src0->data) + i * q_sz; + } else { + GGML_ASSERT(false); + } + if (ne11 == 1) { + CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2)); - // copy src1 to device - CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream)); + // copy src1 to device + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream)); - // wait for conversion - CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0)); + // wait for data + CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0)); - // compute - CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream)); - CUBLAS_CHECK( - cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, - ne01, ne11, ne10, - &alpha, c_X, ne00, - c_Y, ne10, - &beta, c_D, ne01)); + // compute + dequantize_mul_mat_q4_0_cuda(c_Q, c_Y, c_D, ne00, ne01, cudaStream); + CUDA_CHECK(cudaGetLastError()); + + } else { + float * c_X = d_X + i * x_ne; + + // convert src0 to fp32 on device + to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2); + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2)); + + // copy src1 to device + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream)); + + // wait for conversion + CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0)); + + // compute + CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream)); + CUBLAS_CHECK( + cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha, c_X, ne00, + c_Y, ne10, + &beta, c_D, ne01)); + } // copy dst to host float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); @@ -645,7 +727,9 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor } CUDA_CHECK(cudaDeviceSynchronize()); - ggml_cuda_pool_free(d_X, x_size); + if (ne11 > 1) { + ggml_cuda_pool_free(d_X, x_size); + } ggml_cuda_pool_free(d_Y, y_size); ggml_cuda_pool_free(d_D, d_size); ggml_cuda_pool_free(d_Q, q_size); @@ -661,8 +745,7 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && - (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) { - + ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_CUDA)) { return true; } @@ -714,3 +797,25 @@ size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct return 0; } } + +void ggml_cuda_transform_tensor(ggml_tensor * tensor) { + const int64_t ne0 = tensor->ne[0]; + const int64_t ne1 = tensor->ne[1]; + const int64_t ne2 = tensor->ne[2]; + const int64_t ne3 = tensor->ne[3]; + + const ggml_type type = tensor->type; + const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type); + + size_t q_size; + char * d_Q = (char *) ggml_cuda_pool_malloc(q_sz, &q_size); + + cudaStream_t cudaStream2 = g_cudaStreams2[0]; + + // copy tensor to device + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, tensor, 0, 0, cudaStream2)); + CUDA_CHECK(cudaDeviceSynchronize()); + + tensor->data = d_Q; + tensor->backend = GGML_BACKEND_CUDA; +} diff --git a/ggml-cuda.h b/ggml-cuda.h index f7d6a8bc1842a..4e2c24283ccf4 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -14,6 +14,8 @@ void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens void * ggml_cuda_host_malloc(size_t size); void ggml_cuda_host_free(void * ptr); +void ggml_cuda_transform_tensor(struct ggml_tensor * tensor); + #ifdef __cplusplus } #endif diff --git a/ggml.c b/ggml.c index 1b89bdd894489..2f771f39a532c 100644 --- a/ggml.c +++ b/ggml.c @@ -4711,6 +4711,7 @@ struct ggml_tensor * ggml_new_tensor_impl( *result = (struct ggml_tensor) { /*.type =*/ type, + /*.backend =*/ GGML_BACKEND_CPU, /*.n_dims =*/ n_dims, /*.ne =*/ { 1, 1, 1, 1 }, /*.nb =*/ { 0, 0, 0, 0 }, diff --git a/ggml.h b/ggml.h index 508dd69b41713..01f43c3bfef27 100644 --- a/ggml.h +++ b/ggml.h @@ -243,6 +243,11 @@ extern "C" { GGML_TYPE_COUNT, }; + enum ggml_backend { + GGML_BACKEND_CPU = 0, + GGML_BACKEND_CUDA = 1, + }; + // model file types enum ggml_ftype { GGML_FTYPE_UNKNOWN = -1, @@ -323,6 +328,7 @@ extern "C" { // n-dimensional tensor struct ggml_tensor { enum ggml_type type; + enum ggml_backend backend; int n_dims; int64_t ne[GGML_MAX_DIMS]; // number of elements @@ -353,7 +359,7 @@ extern "C" { char name[32]; - char padding[8]; // TODO: remove and add padding to name? + char padding[9]; // TODO: remove and add padding to name? }; // computation graph diff --git a/llama.cpp b/llama.cpp index 4bba93a111ae4..14784b4cf2974 100644 --- a/llama.cpp +++ b/llama.cpp @@ -9,6 +9,9 @@ #include "llama.h" #include "ggml.h" +#ifdef GGML_USE_CUBLAS +#include "ggml-cuda.h" +#endif #include #include @@ -815,6 +818,7 @@ struct llama_context_params llama_context_default_params() { /*.vocab_only =*/ false, /*.use_mmap =*/ true, /*.use_mlock =*/ false, + /*.gpu_layers =*/ 0, /*.embedding =*/ false, /*.progress_callback =*/ nullptr, /*.progress_callback_user_data =*/ nullptr, @@ -877,6 +881,7 @@ static void llama_model_load_internal( ggml_type memory_type, bool use_mmap, bool use_mlock, + int gpu_layers, bool vocab_only, llama_progress_callback progress_callback, void * progress_callback_user_data) { @@ -1011,6 +1016,18 @@ static void llama_model_load_internal( ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : NULL); model.mapping = std::move(ml->mapping); +#ifdef GGML_USE_CUBLAS + for (int i = 0; i < std::min(gpu_layers, int(hparams.n_layer)); ++i) { + auto & layer = model.layers[i]; + ggml_cuda_transform_tensor(layer.wq); + ggml_cuda_transform_tensor(layer.wk); + ggml_cuda_transform_tensor(layer.wv); + ggml_cuda_transform_tensor(layer.wo); + ggml_cuda_transform_tensor(layer.w1); + ggml_cuda_transform_tensor(layer.w2); + ggml_cuda_transform_tensor(layer.w3); + } +#endif // loading time will be recalculate after the first eval, so // we take page faults deferred by mmap() into consideration @@ -1024,11 +1041,12 @@ static bool llama_model_load( ggml_type memory_type, bool use_mmap, bool use_mlock, + int gpu_layers, bool vocab_only, llama_progress_callback progress_callback, void *progress_callback_user_data) { try { - llama_model_load_internal(fname, lctx, n_ctx, memory_type, use_mmap, use_mlock, + llama_model_load_internal(fname, lctx, n_ctx, memory_type, use_mmap, use_mlock, gpu_layers, vocab_only, progress_callback, progress_callback_user_data); return true; } catch (const std::string & err) { @@ -2088,7 +2106,7 @@ struct llama_context * llama_init_from_file( ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; if (!llama_model_load(path_model, *ctx, params.n_ctx, memory_type, - params.use_mmap, params.use_mlock, params.vocab_only, + params.use_mmap, params.use_mlock, params.gpu_layers, params.vocab_only, params.progress_callback, params.progress_callback_user_data)) { fprintf(stderr, "%s: failed to load model\n", __func__); llama_free(ctx); diff --git a/llama.h b/llama.h index 58c6e0699a999..db3c62da30d61 100644 --- a/llama.h +++ b/llama.h @@ -63,6 +63,7 @@ extern "C" { bool vocab_only; // only load the vocabulary, no weights bool use_mmap; // use mmap if possible bool use_mlock; // force system to keep model in RAM + int gpu_layers; // number of layers to store in VRAM bool embedding; // embedding mode only // called with a progress value between 0 and 1, pass NULL to disable