12
12
13
13
#include " ggml.h"
14
14
#ifdef GGML_USE_CUBLAS
15
+ #include < cuda_runtime.h>
15
16
#include " ggml-cuda.h"
16
17
#elif defined(GGML_USE_CLBLAST)
17
18
#include " ggml-opencl.h"
@@ -1010,6 +1011,33 @@ static const char *falcon_model_type_name(e_model type) {
1010
1011
}
1011
1012
}
1012
1013
1014
+ // dynamically gets all tensors from a layer
1015
+ std::vector<ggml_tensor*> get_tensors_from_layer (falcon_layer& layer) {
1016
+ std::vector<ggml_tensor*> tensors;
1017
+ ggml_tensor** tensor_ptr = reinterpret_cast <ggml_tensor**>(&layer); // Cast to the pointer to ggml_tensor pointer
1018
+
1019
+ // Iterate through the members and store their addresses in the vector
1020
+ for (std::size_t i = 0 ; i < sizeof (falcon_layer) / sizeof (ggml_tensor*); ++i) {
1021
+ tensors.push_back (tensor_ptr[i]);
1022
+ }
1023
+
1024
+ return tensors;
1025
+ }
1026
+ // get vram size of all tensors in a layer (todo: split handling)
1027
+ size_t calculate_layer_vram_bytes (const falcon_layer& layer) {
1028
+ size_t size = 0 ;
1029
+ auto tensors = get_tensors_from_layer (const_cast <falcon_layer&>(layer));
1030
+
1031
+ // Add the size of each member with GPU backend
1032
+ for (const auto & tensor : tensors) {
1033
+ if (tensor != nullptr && tensor->backend != GGML_BACKEND_CPU) {
1034
+ size += ggml_nbytes (tensor);
1035
+ }
1036
+ }
1037
+
1038
+ return size;
1039
+ }
1040
+
1013
1041
static void falcon_model_load_internal (
1014
1042
const std::string & fname,
1015
1043
falcon_context & lctx,
@@ -1033,6 +1061,7 @@ static void falcon_model_load_internal(
1033
1061
auto & model = lctx.model ;
1034
1062
model.hparams = ml->file_loaders .at (0 )->hparams ;
1035
1063
model.n_gpu_layers = n_gpu_layers;
1064
+
1036
1065
llama_file_version file_version = ml->file_loaders .at (0 )->file_version ;
1037
1066
auto & hparams = model.hparams ;
1038
1067
@@ -1123,6 +1152,7 @@ static void falcon_model_load_internal(
1123
1152
1124
1153
(void ) main_gpu;
1125
1154
#if defined(GGML_USE_CUBLAS)
1155
+ if (n_gpu_layers > 0 )
1126
1156
fprintf (stderr, " %s: using CUDA for GPU acceleration\n " , __func__);
1127
1157
ggml_cuda_set_main_device (main_gpu);
1128
1158
#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_GPU
@@ -1136,9 +1166,31 @@ static void falcon_model_load_internal(
1136
1166
#define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_CPU
1137
1167
#endif
1138
1168
1169
+ size_t vram_total=0 ;
1170
+ size_t vram_free=0 ;
1171
+ size_t vram_reserved=1024 *1024 *512 ; // will be adapted by model
1172
+ #if defined(GGML_USE_CUBLAS)
1173
+ cudaMemGetInfo (&vram_free, &vram_total); // this should go in ggml-cuda.cu but I don't want to make Johannes life harder by modifying that yet
1174
+ fprintf (stderr, " %s: VRAM free: %7.2f MB of %7.2f MB (already used: %7.2f MB)\n " , __func__, vram_free/MB*1.0 , vram_total/MB*1.0 , (vram_total-vram_free)/MB*1.0 );
1175
+ #endif
1176
+
1139
1177
// prepare memory for the weights
1140
1178
size_t vram_weights = 0 ;
1141
1179
size_t vram_scratch = 0 ;
1180
+ size_t vram_overhead = 0 ;
1181
+ (void ) vram_scratch;
1182
+ (void ) n_batch;
1183
+ // calculate scratch buffer size and allocate it
1184
+ #ifdef GGML_USE_CUBLAS
1185
+ vram_scratch = n_batch * MB;
1186
+ ggml_cuda_set_scratch_size (vram_scratch);
1187
+ if (n_gpu_layers > 0 ) {
1188
+
1189
+ fprintf (stderr, " %s: allocating batch_size x 1 MB = %ld MB VRAM for the scratch buffer\n " ,
1190
+ __func__, vram_scratch / MB);
1191
+ }
1192
+ #endif // GGML_USE_CUBLAS
1193
+
1142
1194
{
1143
1195
const uint32_t n_embd = hparams.n_embd ;
1144
1196
const uint32_t n_head = hparams.n_head ;
@@ -1152,11 +1204,25 @@ static void falcon_model_load_internal(
1152
1204
1153
1205
model.tok_embeddings = ml->get_tensor (" transformer.word_embeddings.weight" , {n_embd, n_vocab}, GGML_BACKEND_CPU);
1154
1206
1207
+ // I did not analyze the cause but that's the overhead that is dynamically added to the VRAM at first inference
1208
+ // same goes with reserved, most likely we can skip both for a proper size calculation.
1209
+ // If the below values are not correct GPU memory will fill up to 100%, resulting in a extreme slowdown of inference
1210
+ if (model.type == FALCON_40B)
1211
+ {
1212
+ vram_reserved=1900 *MB;
1213
+ vram_overhead+=2700 *MB;
1214
+ }
1215
+ else
1216
+ {
1217
+ vram_reserved=768 *MB;
1218
+ vram_overhead+=1200 *MB;
1219
+ }
1155
1220
1156
1221
1157
1222
ggml_backend backend_norm;
1158
1223
ggml_backend backend_output;
1159
- if (n_gpu_layers > int (n_layer)) { // NOLINT
1224
+ // disabled norm/output offloading until further tests, causes silent crash at the moment
1225
+ if (n_gpu_layers > int (n_layer) && false ) { // NOLINT
1160
1226
backend_norm = LLAMA_BACKEND_OFFLOAD;
1161
1227
backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
1162
1228
} else {
@@ -1172,12 +1238,26 @@ static void falcon_model_load_internal(
1172
1238
model.lm_head = ml->get_tensor (" lm_head.weight" , {n_embd, n_vocab}, backend_output);
1173
1239
}
1174
1240
1241
+ if (backend_norm != GGML_BACKEND_CPU)
1242
+ {
1243
+ vram_weights += ggml_nbytes (model.output_norm );
1244
+ vram_weights += ggml_nbytes (model.output_norm_b );
1245
+ vram_free -= ggml_nbytes (model.output_norm );
1246
+ vram_free -= ggml_nbytes (model.output_norm_b );
1247
+ }
1248
+ if (backend_output != GGML_BACKEND_CPU)
1249
+ {
1250
+ vram_weights += ggml_nbytes (model.lm_head );
1251
+ vram_free -= ggml_nbytes (model.lm_head );
1252
+ }
1253
+
1175
1254
const int i_gpu_start = n_layer - n_gpu_layers;
1255
+ int i_gpu_end = n_layer; // allows to terminate the offloading earlier. TODO: instead do a proper calculation run and determine the start before the loop
1176
1256
1177
1257
model.layers .resize (n_layer);
1178
1258
for (uint32_t i = 0 ; i < n_layer; ++i) {
1179
- const ggml_backend backend = int (i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
1180
- const ggml_backend backend_split = int (i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
1259
+ const ggml_backend backend = ( int (i) < i_gpu_start || int (i) >= i_gpu_end) ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
1260
+ const ggml_backend backend_split = ( int (i) < i_gpu_start || int (i) >= i_gpu_end) ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
1181
1261
1182
1262
auto & layer = model.layers [i];
1183
1263
@@ -1201,31 +1281,26 @@ static void falcon_model_load_internal(
1201
1281
1202
1282
layer.ffn_up = ml->get_tensor (" transformer.h." +str_i + " .mlp.dense_h_to_4h.weight" , {n_embd, n_ff}, backend_split); // before gelu
1203
1283
layer.ffn_down = ml->get_tensor (" transformer.h." +str_i + " .mlp.dense_4h_to_h.weight" , {n_ff, n_embd}, backend_split); // after gelu
1284
+
1285
+ if (backend != GGML_BACKEND_CPU)
1286
+ {
1287
+ size_t vram_layer = 0 ;
1288
+ vram_layer = calculate_layer_vram_bytes (layer);
1289
+ vram_weights += vram_layer;
1290
+ vram_free = (vram_layer > vram_free) ? 0 : vram_free - vram_layer; // simulate the layer being loaded in VRAM
1204
1291
1205
- if (backend == GGML_BACKEND_GPU) {
1206
- // llama:
1207
- // vram_weights +=
1208
- // ggml_nbytes(layer.attention_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) +
1209
- // ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) +
1210
- // ggml_nbytes(layer.w1) + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3);
1211
- // falcon:
1212
- if (model.type == FALCON_40B)
1213
- {
1214
- vram_weights +=
1215
- ggml_nbytes (layer.input_layernorm ) + ggml_nbytes (layer.input_layernorm_b ) +
1216
- ggml_nbytes (layer.attention_norm ) + ggml_nbytes (layer.attention_norm_b ) +
1217
- ggml_nbytes (layer.wo ) + ggml_nbytes (layer.wo ) +
1218
- ggml_nbytes (layer.ffn_down ) + ggml_nbytes (layer.ffn_down ) +
1219
- ggml_nbytes (layer.ffn_up ) + ggml_nbytes (layer.ffn_up );
1220
- } else // FALCON_7B
1292
+ if (vram_free <= (vram_overhead+vram_scratch+vram_reserved))
1221
1293
{
1222
- vram_weights +=
1223
- ggml_nbytes (layer. input_layernorm ) + ggml_nbytes ( layer. input_layernorm_b ) +
1224
- ggml_nbytes (layer. wo ) + ggml_nbytes (layer. wo ) +
1225
- ggml_nbytes (layer. ffn_down ) + ggml_nbytes (layer. ffn_down ) +
1226
- ggml_nbytes (layer. ffn_up ) + ggml_nbytes (layer. ffn_up ) ;
1294
+ // this needs some polishing (instead of fiddling with --ngl I'd like the option to auto-fill the vram with as many layers as possible as an alternative)
1295
+ fprintf (stderr, " WARNING: Not enough VRAM to load the model as configured - at layer %d of %d \n " , i, n_layer);
1296
+ n_gpu_layers = i+ 1 ;
1297
+ model. n_gpu_layers = n_gpu_layers;
1298
+ i_gpu_end = i ;
1227
1299
}
1228
1300
}
1301
+
1302
+
1303
+
1229
1304
}
1230
1305
}
1231
1306
@@ -1251,25 +1326,17 @@ static void falcon_model_load_internal(
1251
1326
fprintf (stderr, " %s: mem required = %7.2f MB (+ %7.2f MB per state)\n " , __func__,
1252
1327
mem_required / 1024.0 / 1024.0 , mem_required_state / 1024.0 / 1024.0 );
1253
1328
1254
- (void ) vram_scratch;
1255
- (void ) n_batch;
1256
- #ifdef GGML_USE_CUBLAS
1257
- vram_scratch = n_batch * MB;
1258
- ggml_cuda_set_scratch_size (vram_scratch);
1259
- if (n_gpu_layers > 0 ) {
1260
- fprintf (stderr, " %s: allocating batch_size x 1 MB = %ld MB VRAM for the scratch buffer\n " ,
1261
- __func__, vram_scratch / MB);
1262
- }
1263
- #endif // GGML_USE_CUBLAS
1329
+ // moved scratch allocation of vram to top
1264
1330
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
1265
1331
const int n_gpu = std::min (n_gpu_layers, int (hparams.n_layer ));
1266
1332
1267
- fprintf (stderr, " %s: offloading %d layers to GPU\n " , __func__, n_gpu);
1333
+ fprintf (stderr, " %s: offloading %d of %d layers to GPU, weights offloaded %7.2f MB\n " ,
1334
+ __func__, n_gpu, hparams.n_layer , vram_weights / 1024.0 / 1024.0 );
1268
1335
if (n_gpu_layers > (int ) hparams.n_layer ) {
1269
1336
fprintf (stderr, " %s: offloading output layer to GPU\n " , __func__);
1270
1337
}
1271
1338
fprintf (stderr, " %s: total VRAM used: %zu MB\n " ,
1272
- __func__, (vram_weights + vram_scratch + MB - 1 ) / MB); // round up
1339
+ __func__, (vram_weights + vram_scratch + vram_overhead + MB - 1 ) / MB); // round up
1273
1340
#else
1274
1341
(void ) n_gpu_layers;
1275
1342
#endif
@@ -1293,13 +1360,22 @@ static void falcon_model_load_internal(
1293
1360
progress_callback (1 .0f , progress_callback_user_data);
1294
1361
}
1295
1362
1363
+ #if defined(GGML_USE_CUBLAS)
1364
+ // size_t vram_free_simulated = vram_free;
1365
+ cudaMemGetInfo (&vram_free, &vram_total); // this should go in ggml-cuda.cu but I don't want to make Johannes life harder by modifying that yet
1366
+ fprintf (stderr, " %s: VRAM free: %7.2f MB of %7.2f MB (used: %7.2f MB)\n " , __func__, vram_free/MB*1.0 , vram_total/MB*1.0 , (vram_total-vram_free)/MB*1.0 );
1367
+
1368
+ #endif
1369
+
1370
+
1296
1371
model.mapping = std::move (ml->mapping );
1297
1372
1298
1373
// loading time will be recalculate after the first eval, so
1299
1374
// we take page faults deferred by mmap() into consideration
1300
1375
lctx.t_load_us = ggml_time_us () - lctx.t_start_us ;
1376
+
1301
1377
}
1302
-
1378
+ # include < windows.h >
1303
1379
static bool falcon_model_load (
1304
1380
const std::string & fname,
1305
1381
falcon_context & lctx,
@@ -2591,7 +2667,7 @@ struct falcon_context * falcon_init_from_file(
2591
2667
ggml_time_init ();
2592
2668
2593
2669
falcon_context * ctx = new falcon_context;
2594
-
2670
+
2595
2671
if (params.seed < 0 ) {
2596
2672
params.seed = time (NULL );
2597
2673
}
@@ -2625,6 +2701,7 @@ struct falcon_context * falcon_init_from_file(
2625
2701
llama_free (ctx);
2626
2702
return nullptr ;
2627
2703
}
2704
+ params.n_gpu_layers = ctx->model .n_gpu_layers ; // model_load_internal() may change this
2628
2705
2629
2706
// reserve memory for context buffers
2630
2707
if (!params.vocab_only ) {
0 commit comments