Skip to content

Commit e729965

Browse files
authored
falcon : add CUDA offloading (#2739)
1 parent 854ae5d commit e729965

File tree

1 file changed

+101
-11
lines changed

1 file changed

+101
-11
lines changed

llama.cpp

Lines changed: 101 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1860,31 +1860,54 @@ static void llm_load_tensors(
18601860

18611861
// output
18621862
{
1863-
model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, GGML_BACKEND_CPU);
1864-
model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, GGML_BACKEND_CPU);
1865-
model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
1863+
ggml_backend backend_norm;
1864+
ggml_backend backend_output;
1865+
1866+
if (n_gpu_layers > int(n_layer)) {
1867+
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
1868+
// on Windows however this is detrimental unless everything is on the GPU
1869+
#ifndef _WIN32
1870+
backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
1871+
#else
1872+
backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
1873+
#endif // _WIN32
1874+
1875+
backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
1876+
} else {
1877+
backend_norm = GGML_BACKEND_CPU;
1878+
backend_output = GGML_BACKEND_CPU;
1879+
}
1880+
1881+
model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm);
1882+
model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm);
1883+
model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output);
18661884
}
18671885

18681886
const uint32_t n_ff = hparams.n_ff;
18691887

1888+
const int i_gpu_start = n_layer - n_gpu_layers;
1889+
18701890
model.layers.resize(n_layer);
18711891

18721892
for (uint32_t i = 0; i < n_layer; ++i) {
1893+
const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
1894+
const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
1895+
18731896
auto & layer = model.layers[i];
18741897

1875-
layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, GGML_BACKEND_CPU);
1876-
layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, GGML_BACKEND_CPU);
1898+
layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
1899+
layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend);
18771900

18781901
if (gguf_find_tensor(ml.ctx_gguf, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i).c_str()) >= 0) {
1879-
layer.attn_norm_2 = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, GGML_BACKEND_CPU);
1880-
layer.attn_norm_2_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, GGML_BACKEND_CPU);
1902+
layer.attn_norm_2 = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, backend);
1903+
layer.attn_norm_2_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, backend);
18811904
}
18821905

1883-
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, GGML_BACKEND_CPU);
1884-
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, GGML_BACKEND_CPU);
1906+
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
1907+
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
18851908

1886-
layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, GGML_BACKEND_CPU);
1887-
layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, GGML_BACKEND_CPU);
1909+
layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
1910+
layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
18881911
}
18891912
} break;
18901913
default:
@@ -2390,6 +2413,8 @@ static struct ggml_cgraph * llm_build_falcon(
23902413
const float freq_scale = hparams.rope_freq_scale;
23912414
const float norm_eps = hparams.f_norm_eps;
23922415

2416+
const int n_gpu_layers = model.n_gpu_layers;
2417+
23932418
auto & buf_compute = lctx.buf_compute;
23942419

23952420
struct ggml_init_params params = {
@@ -2430,6 +2455,30 @@ static struct ggml_cgraph * llm_build_falcon(
24302455
}
24312456
}
24322457

2458+
const int i_gpu_start = n_layer - n_gpu_layers;
2459+
(void) i_gpu_start;
2460+
2461+
// offload functions set the tensor output backend to GPU
2462+
// tensors are GPU-accelerated if any input or the output has been offloaded
2463+
//
2464+
// with the low VRAM option VRAM scratch is disabled in llama_load_model_internal
2465+
// in that case ggml_cuda_assign_buffers has no effect
2466+
offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
2467+
offload_func_t offload_func_kq = llama_nop;
2468+
offload_func_t offload_func_v = llama_nop;
2469+
2470+
#ifdef GGML_USE_CUBLAS
2471+
if (n_gpu_layers > n_layer) {
2472+
offload_func_nr = ggml_cuda_assign_buffers_no_alloc;
2473+
}
2474+
if (n_gpu_layers > n_layer + 1) {
2475+
offload_func_v = ggml_cuda_assign_buffers_no_alloc;
2476+
}
2477+
if (n_gpu_layers > n_layer + 2) {
2478+
offload_func_kq = ggml_cuda_assign_buffers_no_alloc;
2479+
}
2480+
#endif // GGML_USE_CUBLAS
2481+
24332482
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
24342483
ggml_allocr_alloc(lctx.alloc, KQ_scale);
24352484
if (!ggml_allocr_is_measure(lctx.alloc)) {
@@ -2440,28 +2489,43 @@ static struct ggml_cgraph * llm_build_falcon(
24402489
for (int il = 0; il < n_layer; ++il) {
24412490
struct ggml_tensor * attn_norm;
24422491

2492+
offload_func_t offload_func = llama_nop;
2493+
2494+
#ifdef GGML_USE_CUBLAS
2495+
if (il >= i_gpu_start) {
2496+
offload_func = ggml_cuda_assign_buffers_no_alloc;
2497+
}
2498+
#endif // GGML_USE_CUBLAS
2499+
24432500
// self-attention
24442501
// TODO: refactor into common function (shared with LLaMA)
24452502
{
24462503
attn_norm = ggml_norm(ctx0, inpL, norm_eps);
2504+
offload_func(attn_norm);
24472505

24482506
attn_norm = ggml_add(ctx0,
24492507
ggml_mul(ctx0, attn_norm, model.layers[il].attn_norm),
24502508
model.layers[il].attn_norm_b);
2509+
offload_func(attn_norm->src[0]);
2510+
offload_func(attn_norm);
24512511

24522512
if (model.layers[il].attn_norm_2) { // Falcon-40B
24532513
cur = ggml_norm(ctx0, inpL, norm_eps);
2514+
offload_func(cur);
24542515

24552516
cur = ggml_add(ctx0,
24562517
ggml_mul(ctx0, cur, model.layers[il].attn_norm_2),
24572518
model.layers[il].attn_norm_2_b);
2519+
offload_func(cur->src[0]);
2520+
offload_func(cur);
24582521
} else { // Falcon 7B
24592522
cur = attn_norm;
24602523
}
24612524

24622525
// compute QKV
24632526

24642527
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
2528+
offload_func_kq(cur);
24652529

24662530
// Note that the strides for Kcur, Vcur are set up so that the
24672531
// resulting views are misaligned with the tensor's storage
@@ -2479,39 +2543,49 @@ static struct ggml_cgraph * llm_build_falcon(
24792543
wsize * n_embd_head,
24802544
wsize * n_embd_head * (n_head + 2 * n_head_kv),
24812545
0);
2546+
offload_func_kq(tmpq);
24822547

24832548
struct ggml_tensor * tmpk = ggml_view_3d(
24842549
ctx0, cur, n_embd_head, n_head_kv, N,
24852550
wsize * n_embd_head,
24862551
wsize * n_embd_head * (n_head + 2 * n_head_kv),
24872552
wsize * n_embd_head * n_head);
2553+
offload_func_kq(tmpk);
24882554

24892555
struct ggml_tensor * tmpv = ggml_view_3d(
24902556
ctx0, cur, n_embd_head, n_head_kv, N,
24912557
wsize * n_embd_head,
24922558
wsize * n_embd_head * (n_head + 2 * n_head_kv),
24932559
wsize * n_embd_head * (n_head + n_head_kv));
2560+
offload_func_v(tmpv);
24942561

24952562
// using mode = 2 for neox mode
24962563
struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, tmpq, n_past, n_embd_head, 2, 0, freq_base, freq_scale);
2564+
offload_func_kq(Qcur);
24972565
struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, tmpk, n_past, n_embd_head, 2, 0, freq_base, freq_scale);
2566+
offload_func_kq(Kcur);
24982567

24992568
{
25002569
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, N));
2570+
offload_func_v(Vcur);
2571+
offload_func_v(Vcur->src[0]->src[0]);
25012572
ggml_set_name(Vcur, "Vcur");
25022573

25032574
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past));
2575+
offload_func_kq(k);
25042576
ggml_set_name(k, "k");
25052577

25062578
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa,
25072579
( n_ctx)*ggml_element_size(kv_self.v),
25082580
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v));
2581+
offload_func_v(v);
25092582

25102583
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
25112584
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
25122585
}
25132586

25142587
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
2588+
offload_func_kq(Q);
25152589
ggml_set_name(Q, "Q");
25162590

25172591
struct ggml_tensor * K =
@@ -2520,18 +2594,23 @@ static struct ggml_cgraph * llm_build_falcon(
25202594
ggml_element_size(kv_self.k)*n_embd_gqa,
25212595
ggml_element_size(kv_self.k)*n_embd_head,
25222596
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
2597+
offload_func_kq(K);
25232598
ggml_set_name(K, "K");
25242599

25252600
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
2601+
offload_func_kq(KQ);
25262602
ggml_set_name(KQ, "KQ");
25272603

25282604
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
2605+
offload_func_kq(KQ_scaled);
25292606
ggml_set_name(KQ_scaled, "KQ_scaled");
25302607

25312608
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
2609+
offload_func_kq(KQ_masked);
25322610
ggml_set_name(KQ_masked, "KQ_masked");
25332611

25342612
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
2613+
offload_func_v(KQ_soft_max);
25352614
ggml_set_name(KQ_soft_max, "KQ_soft_max");
25362615

25372616
struct ggml_tensor * V =
@@ -2540,18 +2619,23 @@ static struct ggml_cgraph * llm_build_falcon(
25402619
ggml_element_size(kv_self.v)*n_ctx,
25412620
ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
25422621
ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
2622+
offload_func_v(V);
25432623
ggml_set_name(V, "V");
25442624

25452625
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
2626+
offload_func_v(KQV);
25462627
ggml_set_name(KQV, "KQV");
25472628

25482629
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2630+
offload_func_v(KQV_merged);
25492631
ggml_set_name(KQV_merged, "KQV_merged");
25502632

25512633
cur = ggml_cpy(ctx0, KQV_merged, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
2634+
offload_func_v(cur);
25522635
ggml_set_name(cur, "KQV_merged_contiguous");
25532636

25542637
cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
2638+
offload_func(cur);
25552639
ggml_set_name(cur, "result_wo");
25562640
}
25572641

@@ -2567,13 +2651,18 @@ static struct ggml_cgraph * llm_build_falcon(
25672651
// adding this, because there seems to be a bug in the Metal concurrency optimization
25682652
// without this line, the results are non-deterministic and wrong
25692653
cur->src[2] = attn_out;
2654+
offload_func(cur);
25702655

25712656
cur = ggml_gelu(ctx0, cur);
2657+
offload_func(cur);
25722658
cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur);
2659+
offload_func(cur);
25732660
}
25742661

25752662
cur = ggml_add(ctx0, cur, attn_out);
2663+
offload_func(cur);
25762664
cur = ggml_add(ctx0, cur, inpL);
2665+
offload_func(cur);
25772666

25782667
// input for next layer
25792668
inpL = cur;
@@ -2584,6 +2673,7 @@ static struct ggml_cgraph * llm_build_falcon(
25842673
// norm
25852674
{
25862675
cur = ggml_norm(ctx0, cur, norm_eps);
2676+
offload_func_nr(cur);
25872677

25882678
cur = ggml_add(ctx0,
25892679
ggml_mul(ctx0, cur, model.output_norm),

0 commit comments

Comments
 (0)