Skip to content

Commit b1592ea

Browse files
committed
train : fix context size calculations
1 parent dc22db7 commit b1592ea

File tree

2 files changed

+6
-10
lines changed

2 files changed

+6
-10
lines changed

examples/finetune/finetune.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1742,11 +1742,9 @@ int main(int argc, char ** argv) {
17421742
ggml_allocr_free(alloc);
17431743

17441744
// context for compute tensors without their data
1745-
size_t estimated_compute_size_wo_data = (
1746-
ggml_tensor_overhead()*LLAMA_TRAIN_MAX_NODES*2
1747-
+ (GGML_OBJECT_SIZE+ggml_graph_overhead())*(
1748-
params.common.use_checkpointing ? 3 : 2
1749-
)
1745+
const size_t estimated_compute_size_wo_data = (
1746+
2*LLAMA_TRAIN_MAX_NODES*ggml_tensor_overhead() +
1747+
(params.common.use_checkpointing ? 3 : 2)*(GGML_OBJECT_SIZE+ggml_graph_overhead_custom(LLAMA_TRAIN_MAX_NODES, true))
17501748
);
17511749
struct ggml_init_params ctx_compute_params = {
17521750
estimated_compute_size_wo_data, // mem_size

examples/train-text-from-scratch/train-text-from-scratch.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,11 +1109,9 @@ int main(int argc, char ** argv) {
11091109
ggml_allocr_free(alloc);
11101110

11111111
// context for compute tensors without their data
1112-
size_t estimated_compute_size_wo_data = (
1113-
ggml_tensor_overhead()*LLAMA_TRAIN_MAX_NODES*2
1114-
+ (GGML_OBJECT_SIZE+ggml_graph_overhead())*(
1115-
params.common.use_checkpointing ? 3 : 2
1116-
)
1112+
const size_t estimated_compute_size_wo_data = (
1113+
2*LLAMA_TRAIN_MAX_NODES*ggml_tensor_overhead() +
1114+
(params.common.use_checkpointing ? 3 : 2)*(GGML_OBJECT_SIZE+ggml_graph_overhead_custom(LLAMA_TRAIN_MAX_NODES, true))
11171115
);
11181116
struct ggml_init_params ctx_compute_params = {
11191117
estimated_compute_size_wo_data, // mem_size

0 commit comments

Comments
 (0)